浏览代码

default to "FROM ." if a Modelfile isn't present (#7250)

Patrick Devine 6 月之前
父节点
当前提交
d78fb62056
共有 2 个文件被更改,包括 146 次插入12 次删除
  1. 47 12
      cmd/cmd.go
  2. 99 0
      cmd/cmd_test.go

+ 47 - 12
cmd/cmd.go

@@ -46,28 +46,58 @@ import (
 	"github.com/ollama/ollama/version"
 )
 
-func CreateHandler(cmd *cobra.Command, args []string) error {
-	filename, _ := cmd.Flags().GetString("file")
-	filename, err := filepath.Abs(filename)
+var (
+	errModelNotFound     = errors.New("no Modelfile or safetensors files found")
+	errModelfileNotFound = errors.New("specified Modelfile wasn't found")
+)
+
+func getModelfileName(cmd *cobra.Command) (string, error) {
+	fn, _ := cmd.Flags().GetString("file")
+
+	filename := fn
+	if filename == "" {
+		filename = "Modelfile"
+	}
+
+	absName, err := filepath.Abs(filename)
 	if err != nil {
-		return err
+		return "", err
 	}
 
-	client, err := api.ClientFromEnvironment()
+	_, err = os.Stat(absName)
 	if err != nil {
-		return err
+		return fn, err
 	}
 
+	return absName, nil
+}
+
+func CreateHandler(cmd *cobra.Command, args []string) error {
 	p := progress.NewProgress(os.Stderr)
 	defer p.Stop()
 
-	f, err := os.Open(filename)
-	if err != nil {
+	var reader io.Reader
+
+	filename, err := getModelfileName(cmd)
+	if os.IsNotExist(err) {
+		if filename == "" {
+			reader = strings.NewReader("FROM .\n")
+		} else {
+			return errModelfileNotFound
+		}
+	} else if err != nil {
 		return err
+	} else {
+		f, err := os.Open(filename)
+		if err != nil {
+			return err
+		}
+
+		reader = f
+		defer f.Close()
 	}
-	defer f.Close()
 
-	modelfile, err := parser.ParseFile(f)
+	modelfile, err := parser.ParseFile(reader)
 	if err != nil {
 		return err
 	}
@@ -82,6 +112,11 @@ func CreateHandler(cmd *cobra.Command, args []string) error {
 	p.Add(status, spinner)
 	defer p.Stop()
 
+	client, err := api.ClientFromEnvironment()
+	if err != nil {
+		return err
+	}
+
 	for i := range modelfile.Commands {
 		switch modelfile.Commands[i].Name {
 		case "model", "adapter":
@@ -220,7 +255,7 @@ func tempZipFiles(path string) (string, error) {
 		// covers consolidated.x.pth, consolidated.pth
 		files = append(files, pt...)
 	} else {
-		return "", errors.New("no safetensors or torch files found")
+		return "", errModelNotFound
 	}
 
 	// add configuration files, json files are detected as text/plain
@@ -1315,7 +1350,7 @@ func NewCLI() *cobra.Command {
 		RunE:    CreateHandler,
 	}
 
-	createCmd.Flags().StringP("file", "f", "Modelfile", "Name of the Modelfile")
+	createCmd.Flags().StringP("file", "f", "", "Name of the Modelfile (default \"Modelfile\"")
 	createCmd.Flags().StringP("quantize", "q", "", "Quantize model to this level (e.g. q4_0)")
 
 	showCmd := &cobra.Command{

+ 99 - 0
cmd/cmd_test.go

@@ -270,3 +270,102 @@ func TestDeleteHandler(t *testing.T) {
 		t.Fatalf("DeleteHandler failed: expected error about stopping non-existent model, got %v", err)
 	}
 }
+
+func TestGetModelfileName(t *testing.T) {
+	tests := []struct {
+		name          string
+		modelfileName string
+		fileExists    bool
+		expectedName  string
+		expectedErr   error
+	}{
+		{
+			name:          "no modelfile specified, no modelfile exists",
+			modelfileName: "",
+			fileExists:    false,
+			expectedName:  "",
+			expectedErr:   os.ErrNotExist,
+		},
+		{
+			name:          "no modelfile specified, modelfile exists",
+			modelfileName: "",
+			fileExists:    true,
+			expectedName:  "Modelfile",
+			expectedErr:   nil,
+		},
+		{
+			name:          "modelfile specified, no modelfile exists",
+			modelfileName: "crazyfile",
+			fileExists:    false,
+			expectedName:  "crazyfile",
+			expectedErr:   os.ErrNotExist,
+		},
+		{
+			name:          "modelfile specified, modelfile exists",
+			modelfileName: "anotherfile",
+			fileExists:    true,
+			expectedName:  "anotherfile",
+			expectedErr:   nil,
+		},
+	}
+
+	for _, tt := range tests {
+		t.Run(tt.name, func(t *testing.T) {
+			cmd := &cobra.Command{
+				Use: "fakecmd",
+			}
+			cmd.Flags().String("file", "", "path to modelfile")
+
+			var expectedFilename string
+
+			if tt.fileExists {
+				tempDir, err := os.MkdirTemp("", "modelfiledir")
+				defer os.RemoveAll(tempDir)
+				if err != nil {
+					t.Fatalf("temp modelfile dir creation failed: %v", err)
+				}
+				var fn string
+				if tt.modelfileName != "" {
+					fn = tt.modelfileName
+				} else {
+					fn = "Modelfile"
+				}
+
+				tempFile, err := os.CreateTemp(tempDir, fn)
+				if err != nil {
+					t.Fatalf("temp modelfile creation failed: %v", err)
+				}
+
+				expectedFilename = tempFile.Name()
+				err = cmd.Flags().Set("file", expectedFilename)
+				if err != nil {
+					t.Fatalf("couldn't set file flag: %v", err)
+				}
+			} else {
+				if tt.modelfileName != "" {
+					expectedFilename = tt.modelfileName
+					err := cmd.Flags().Set("file", tt.modelfileName)
+					if err != nil {
+						t.Fatalf("couldn't set file flag: %v", err)
+					}
+				}
+			}
+
+			actualFilename, actualErr := getModelfileName(cmd)
+
+			if actualFilename != expectedFilename {
+				t.Errorf("expected filename: '%s' actual filename: '%s'", expectedFilename, actualFilename)
+			}
+
+			if tt.expectedErr != os.ErrNotExist {
+				if actualErr != tt.expectedErr {
+					t.Errorf("expected err: %v actual err: %v", tt.expectedErr, actualErr)
+				}
+			} else {
+				if !os.IsNotExist(actualErr) {
+					t.Errorf("expected err: %v actual err: %v", tt.expectedErr, actualErr)
+				}
+			}
+		})
+	}
+}