|
@@ -279,7 +279,7 @@ func TestGetModelfileName(t *testing.T) {
|
|
|
name: "no modelfile specified, no modelfile exists",
|
|
|
modelfileName: "",
|
|
|
fileExists: false,
|
|
|
- expectedName: "",
|
|
|
+ expectedName: "Modelfile",
|
|
|
expectedErr: os.ErrNotExist,
|
|
|
},
|
|
|
{
|
|
@@ -338,8 +338,8 @@ func TestGetModelfileName(t *testing.T) {
|
|
|
t.Fatalf("couldn't set file flag: %v", err)
|
|
|
}
|
|
|
} else {
|
|
|
+ expectedFilename = tt.expectedName
|
|
|
if tt.modelfileName != "" {
|
|
|
- expectedFilename = tt.modelfileName
|
|
|
err := cmd.Flags().Set("file", tt.modelfileName)
|
|
|
if err != nil {
|
|
|
t.Fatalf("couldn't set file flag: %v", err)
|
|
@@ -489,3 +489,130 @@ func TestPushHandler(t *testing.T) {
|
|
|
})
|
|
|
}
|
|
|
}
|
|
|
+
|
|
|
+func TestCreateHandler(t *testing.T) {
|
|
|
+ tests := []struct {
|
|
|
+ name string
|
|
|
+ modelName string
|
|
|
+ modelFile string
|
|
|
+ serverResponse map[string]func(w http.ResponseWriter, r *http.Request)
|
|
|
+ expectedError string
|
|
|
+ expectedOutput string
|
|
|
+ }{
|
|
|
+ {
|
|
|
+ name: "successful create",
|
|
|
+ modelName: "test-model",
|
|
|
+ modelFile: "FROM foo",
|
|
|
+ serverResponse: map[string]func(w http.ResponseWriter, r *http.Request){
|
|
|
+ "/api/create": func(w http.ResponseWriter, r *http.Request) {
|
|
|
+ if r.Method != http.MethodPost {
|
|
|
+ t.Errorf("expected POST request, got %s", r.Method)
|
|
|
+ }
|
|
|
+
|
|
|
+ req := api.CreateRequest{}
|
|
|
+ if err := json.NewDecoder(r.Body).Decode(&req); err != nil {
|
|
|
+ http.Error(w, err.Error(), http.StatusBadRequest)
|
|
|
+ return
|
|
|
+ }
|
|
|
+
|
|
|
+ if req.Name != "test-model" {
|
|
|
+ t.Errorf("expected model name 'test-model', got %s", req.Name)
|
|
|
+ }
|
|
|
+
|
|
|
+ if req.From != "foo" {
|
|
|
+ t.Errorf("expected from 'foo', got %s", req.From)
|
|
|
+ }
|
|
|
+
|
|
|
+ responses := []api.ProgressResponse{
|
|
|
+ {Status: "using existing layer sha256:56bb8bd477a519ffa694fc449c2413c6f0e1d3b1c88fa7e3c9d88d3ae49d4dcb"},
|
|
|
+ {Status: "writing manifest"},
|
|
|
+ {Status: "success"},
|
|
|
+ }
|
|
|
+
|
|
|
+ for _, resp := range responses {
|
|
|
+ if err := json.NewEncoder(w).Encode(resp); err != nil {
|
|
|
+ http.Error(w, err.Error(), http.StatusInternalServerError)
|
|
|
+ return
|
|
|
+ }
|
|
|
+ w.(http.Flusher).Flush()
|
|
|
+ }
|
|
|
+ },
|
|
|
+ },
|
|
|
+ expectedOutput: "",
|
|
|
+ },
|
|
|
+ }
|
|
|
+
|
|
|
+ for _, tt := range tests {
|
|
|
+ t.Run(tt.name, func(t *testing.T) {
|
|
|
+ mockServer := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
|
|
+ handler, ok := tt.serverResponse[r.URL.Path]
|
|
|
+ if !ok {
|
|
|
+ t.Errorf("unexpected request to %s", r.URL.Path)
|
|
|
+ http.Error(w, "not found", http.StatusNotFound)
|
|
|
+ return
|
|
|
+ }
|
|
|
+ handler(w, r)
|
|
|
+ }))
|
|
|
+ t.Setenv("OLLAMA_HOST", mockServer.URL)
|
|
|
+ t.Cleanup(mockServer.Close)
|
|
|
+ tempFile, err := os.CreateTemp("", "modelfile")
|
|
|
+ if err != nil {
|
|
|
+ t.Fatal(err)
|
|
|
+ }
|
|
|
+ defer os.Remove(tempFile.Name())
|
|
|
+
|
|
|
+ if _, err := tempFile.WriteString(tt.modelFile); err != nil {
|
|
|
+ t.Fatal(err)
|
|
|
+ }
|
|
|
+ if err := tempFile.Close(); err != nil {
|
|
|
+ t.Fatal(err)
|
|
|
+ }
|
|
|
+
|
|
|
+ cmd := &cobra.Command{}
|
|
|
+ cmd.Flags().String("file", "", "")
|
|
|
+ if err := cmd.Flags().Set("file", tempFile.Name()); err != nil {
|
|
|
+ t.Fatal(err)
|
|
|
+ }
|
|
|
+
|
|
|
+ cmd.Flags().Bool("insecure", false, "")
|
|
|
+ cmd.SetContext(context.TODO())
|
|
|
+
|
|
|
+ // Redirect stderr to capture progress output
|
|
|
+ oldStderr := os.Stderr
|
|
|
+ r, w, _ := os.Pipe()
|
|
|
+ os.Stderr = w
|
|
|
+
|
|
|
+ // Capture stdout for the "Model pushed" message
|
|
|
+ oldStdout := os.Stdout
|
|
|
+ outR, outW, _ := os.Pipe()
|
|
|
+ os.Stdout = outW
|
|
|
+
|
|
|
+ err = CreateHandler(cmd, []string{tt.modelName})
|
|
|
+
|
|
|
+ // Restore stderr
|
|
|
+ w.Close()
|
|
|
+ os.Stderr = oldStderr
|
|
|
+ // drain the pipe
|
|
|
+ if _, err := io.ReadAll(r); err != nil {
|
|
|
+ t.Fatal(err)
|
|
|
+ }
|
|
|
+
|
|
|
+ // Restore stdout and get output
|
|
|
+ outW.Close()
|
|
|
+ os.Stdout = oldStdout
|
|
|
+ stdout, _ := io.ReadAll(outR)
|
|
|
+
|
|
|
+ if tt.expectedError == "" {
|
|
|
+ if err != nil {
|
|
|
+ t.Errorf("expected no error, got %v", err)
|
|
|
+ }
|
|
|
+
|
|
|
+ if tt.expectedOutput != "" {
|
|
|
+ if got := string(stdout); got != tt.expectedOutput {
|
|
|
+ t.Errorf("expected output %q, got %q", tt.expectedOutput, got)
|
|
|
+ }
|
|
|
+ }
|
|
|
+ }
|
|
|
+ })
|
|
|
+ }
|
|
|
+}
|