瀏覽代碼

make the modelfile path relative for `ollama create` (#8380)

Patrick Devine 3 月之前
父節點
當前提交
32bd37adf8
共有 5 個文件被更改,包括 155 次插入25 次删除
  1. 3 4
      cmd/cmd.go
  2. 129 2
      cmd/cmd_test.go
  3. 13 11
      parser/expandpath_test.go
  4. 8 6
      parser/parser.go
  5. 2 2
      parser/parser_test.go

+ 3 - 4
cmd/cmd.go

@@ -46,9 +46,8 @@ import (
 var errModelfileNotFound = errors.New("specified Modelfile wasn't found")
 
 func getModelfileName(cmd *cobra.Command) (string, error) {
-	fn, _ := cmd.Flags().GetString("file")
+	filename, _ := cmd.Flags().GetString("file")
 
-	filename := fn
 	if filename == "" {
 		filename = "Modelfile"
 	}
@@ -60,7 +59,7 @@ func getModelfileName(cmd *cobra.Command) (string, error) {
 
 	_, err = os.Stat(absName)
 	if err != nil {
-		return fn, err
+		return filename, err
 	}
 
 	return absName, nil
@@ -100,7 +99,7 @@ func CreateHandler(cmd *cobra.Command, args []string) error {
 	spinner := progress.NewSpinner(status)
 	p.Add(status, spinner)
 
-	req, err := modelfile.CreateRequest()
+	req, err := modelfile.CreateRequest(filepath.Dir(filename))
 	if err != nil {
 		return err
 	}

+ 129 - 2
cmd/cmd_test.go

@@ -279,7 +279,7 @@ func TestGetModelfileName(t *testing.T) {
 			name:          "no modelfile specified, no modelfile exists",
 			modelfileName: "",
 			fileExists:    false,
-			expectedName:  "",
+			expectedName:  "Modelfile",
 			expectedErr:   os.ErrNotExist,
 		},
 		{
@@ -338,8 +338,8 @@ func TestGetModelfileName(t *testing.T) {
 					t.Fatalf("couldn't set file flag: %v", err)
 				}
 			} else {
+				expectedFilename = tt.expectedName
 				if tt.modelfileName != "" {
-					expectedFilename = tt.modelfileName
 					err := cmd.Flags().Set("file", tt.modelfileName)
 					if err != nil {
 						t.Fatalf("couldn't set file flag: %v", err)
@@ -489,3 +489,130 @@ func TestPushHandler(t *testing.T) {
 		})
 	}
 }
+
+func TestCreateHandler(t *testing.T) {
+	tests := []struct {
+		name           string
+		modelName      string
+		modelFile      string
+		serverResponse map[string]func(w http.ResponseWriter, r *http.Request)
+		expectedError  string
+		expectedOutput string
+	}{
+		{
+			name:      "successful create",
+			modelName: "test-model",
+			modelFile: "FROM foo",
+			serverResponse: map[string]func(w http.ResponseWriter, r *http.Request){
+				"/api/create": func(w http.ResponseWriter, r *http.Request) {
+					if r.Method != http.MethodPost {
+						t.Errorf("expected POST request, got %s", r.Method)
+					}
+
+					req := api.CreateRequest{}
+					if err := json.NewDecoder(r.Body).Decode(&req); err != nil {
+						http.Error(w, err.Error(), http.StatusBadRequest)
+						return
+					}
+
+					if req.Name != "test-model" {
+						t.Errorf("expected model name 'test-model', got %s", req.Name)
+					}
+
+					if req.From != "foo" {
+						t.Errorf("expected from 'foo', got %s", req.From)
+					}
+
+					responses := []api.ProgressResponse{
+						{Status: "using existing layer sha256:56bb8bd477a519ffa694fc449c2413c6f0e1d3b1c88fa7e3c9d88d3ae49d4dcb"},
+						{Status: "writing manifest"},
+						{Status: "success"},
+					}
+
+					for _, resp := range responses {
+						if err := json.NewEncoder(w).Encode(resp); err != nil {
+							http.Error(w, err.Error(), http.StatusInternalServerError)
+							return
+						}
+						w.(http.Flusher).Flush()
+					}
+				},
+			},
+			expectedOutput: "",
+		},
+	}
+
+	for _, tt := range tests {
+		t.Run(tt.name, func(t *testing.T) {
+			mockServer := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
+				handler, ok := tt.serverResponse[r.URL.Path]
+				if !ok {
+					t.Errorf("unexpected request to %s", r.URL.Path)
+					http.Error(w, "not found", http.StatusNotFound)
+					return
+				}
+				handler(w, r)
+			}))
+			t.Setenv("OLLAMA_HOST", mockServer.URL)
+			t.Cleanup(mockServer.Close)
+			tempFile, err := os.CreateTemp("", "modelfile")
+			if err != nil {
+				t.Fatal(err)
+			}
+			defer os.Remove(tempFile.Name())
+
+			if _, err := tempFile.WriteString(tt.modelFile); err != nil {
+				t.Fatal(err)
+			}
+			if err := tempFile.Close(); err != nil {
+				t.Fatal(err)
+			}
+
+			cmd := &cobra.Command{}
+			cmd.Flags().String("file", "", "")
+			if err := cmd.Flags().Set("file", tempFile.Name()); err != nil {
+				t.Fatal(err)
+			}
+
+			cmd.Flags().Bool("insecure", false, "")
+			cmd.SetContext(context.TODO())
+
+			// Redirect stderr to capture progress output
+			oldStderr := os.Stderr
+			r, w, _ := os.Pipe()
+			os.Stderr = w
+
+			// Capture stdout for the "Model pushed" message
+			oldStdout := os.Stdout
+			outR, outW, _ := os.Pipe()
+			os.Stdout = outW
+
+			err = CreateHandler(cmd, []string{tt.modelName})
+
+			// Restore stderr
+			w.Close()
+			os.Stderr = oldStderr
+			// drain the pipe
+			if _, err := io.ReadAll(r); err != nil {
+				t.Fatal(err)
+			}
+
+			// Restore stdout and get output
+			outW.Close()
+			os.Stdout = oldStdout
+			stdout, _ := io.ReadAll(outR)
+
+			if tt.expectedError == "" {
+				if err != nil {
+					t.Errorf("expected no error, got %v", err)
+				}
+
+				if tt.expectedOutput != "" {
+					if got := string(stdout); got != tt.expectedOutput {
+						t.Errorf("expected output %q, got %q", tt.expectedOutput, got)
+					}
+				}
+			}
+		})
+	}
+}

+ 13 - 11
parser/expandpath_test.go

@@ -31,27 +31,29 @@ func TestExpandPath(t *testing.T) {
 	}
 
 	tests := []struct {
-		input           string
+		path            string
+		relativeDir     string
 		expected        string
 		windowsExpected string
 		shouldErr       bool
 	}{
-		{"~", "/home/testuser", "D:\\home\\testuser", false},
-		{"~/myfolder/myfile.txt", "/home/testuser/myfolder/myfile.txt", "D:\\home\\testuser\\myfolder\\myfile.txt", false},
-		{"~anotheruser/docs/file.txt", "/home/anotheruser/docs/file.txt", "D:\\home\\anotheruser\\docs\\file.txt", false},
-		{"~nonexistentuser/file.txt", "", "", true},
-		{"relative/path/to/file", filepath.Join(os.Getenv("PWD"), "relative/path/to/file"), "relative\\path\\to\\file", false},
-		{"/absolute/path/to/file", "/absolute/path/to/file", "D:\\absolute\\path\\to\\file", false},
-		{".", os.Getenv("PWD"), os.Getenv("PWD"), false},
+		{"~", "", "/home/testuser", "D:\\home\\testuser", false},
+		{"~/myfolder/myfile.txt", "", "/home/testuser/myfolder/myfile.txt", "D:\\home\\testuser\\myfolder\\myfile.txt", false},
+		{"~anotheruser/docs/file.txt", "", "/home/anotheruser/docs/file.txt", "D:\\home\\anotheruser\\docs\\file.txt", false},
+		{"~nonexistentuser/file.txt", "", "", "", true},
+		{"relative/path/to/file", "", filepath.Join(os.Getenv("PWD"), "relative/path/to/file"), "relative\\path\\to\\file", false},
+		{"/absolute/path/to/file", "", "/absolute/path/to/file", "D:\\absolute\\path\\to\\file", false},
+		{".", os.Getenv("PWD"), "", os.Getenv("PWD"), false},
+		{"somefile", "somedir", filepath.Join(os.Getenv("PWD"), "somedir", "somefile"), "somedir\\somefile", false},
 	}
 
 	for _, test := range tests {
-		result, err := expandPathImpl(test.input, mockCurrentUser, mockLookupUser)
+		result, err := expandPathImpl(test.path, test.relativeDir, mockCurrentUser, mockLookupUser)
 		if (err != nil) != test.shouldErr {
-			t.Errorf("expandPathImpl(%q) returned error: %v, expected error: %v", test.input, err != nil, test.shouldErr)
+			t.Errorf("expandPathImpl(%q) returned error: %v, expected error: %v", test.path, err != nil, test.shouldErr)
 		}
 		if result != test.expected && result != test.windowsExpected && !test.shouldErr {
-			t.Errorf("expandPathImpl(%q) = %q, want %q", test.input, result, test.expected)
+			t.Errorf("expandPathImpl(%q) = %q, want %q", test.path, result, test.expected)
 		}
 	}
 }

+ 8 - 6
parser/parser.go

@@ -39,7 +39,7 @@ func (f Modelfile) String() string {
 var deprecatedParameters = []string{"penalize_newline"}
 
 // CreateRequest creates a new *api.CreateRequest from an existing Modelfile
-func (f Modelfile) CreateRequest() (*api.CreateRequest, error) {
+func (f Modelfile) CreateRequest(relativeDir string) (*api.CreateRequest, error) {
 	req := &api.CreateRequest{}
 
 	var messages []api.Message
@@ -49,7 +49,7 @@ func (f Modelfile) CreateRequest() (*api.CreateRequest, error) {
 	for _, c := range f.Commands {
 		switch c.Name {
 		case "model":
-			path, err := expandPath(c.Args)
+			path, err := expandPath(c.Args, relativeDir)
 			if err != nil {
 				return nil, err
 			}
@@ -64,7 +64,7 @@ func (f Modelfile) CreateRequest() (*api.CreateRequest, error) {
 
 			req.Files = digestMap
 		case "adapter":
-			path, err := expandPath(c.Args)
+			path, err := expandPath(c.Args, relativeDir)
 			if err != nil {
 				return nil, err
 			}
@@ -563,7 +563,7 @@ func isValidCommand(cmd string) bool {
 	}
 }
 
-func expandPathImpl(path string, currentUserFunc func() (*user.User, error), lookupUserFunc func(string) (*user.User, error)) (string, error) {
+func expandPathImpl(path, relativeDir string, currentUserFunc func() (*user.User, error), lookupUserFunc func(string) (*user.User, error)) (string, error) {
 	if strings.HasPrefix(path, "~") {
 		var homeDir string
 
@@ -591,11 +591,13 @@ func expandPathImpl(path string, currentUserFunc func() (*user.User, error), loo
 		}
 
 		path = filepath.Join(homeDir, path)
+	} else {
+		path = filepath.Join(relativeDir, path)
 	}
 
 	return filepath.Abs(path)
 }
 
-func expandPath(path string) (string, error) {
-	return expandPathImpl(path, user.Current, user.Lookup)
+func expandPath(path, relativeDir string) (string, error) {
+	return expandPathImpl(path, relativeDir, user.Current, user.Lookup)
 }

+ 2 - 2
parser/parser_test.go

@@ -747,7 +747,7 @@ MESSAGE assistant Hi! How are you?
 			t.Error(err)
 		}
 
-		actual, err := p.CreateRequest()
+		actual, err := p.CreateRequest("")
 		if err != nil {
 			t.Error(err)
 		}
@@ -816,7 +816,7 @@ func TestCreateRequestFiles(t *testing.T) {
 			t.Error(err)
 		}
 
-		actual, err := p.CreateRequest()
+		actual, err := p.CreateRequest("")
 		if err != nil {
 			t.Error(err)
 		}