Browse Source

cmd: print location of model after pushing (#7695)

After a user pushes their model it is not clear what to do next. Add a link
to the output of `ollama push` that tells the user where their model can now
be found.
Bruce MacDonald 5 months ago
parent
commit
a210ec74d2
2 changed files with 137 additions and 0 deletions
  1. 12 0
      cmd/cmd.go
  2. 125 0
      cmd/cmd_test.go

+ 12 - 0
cmd/cmd.go

@@ -39,6 +39,7 @@ import (
 	"github.com/ollama/ollama/parser"
 	"github.com/ollama/ollama/progress"
 	"github.com/ollama/ollama/server"
+	"github.com/ollama/ollama/types/model"
 	"github.com/ollama/ollama/version"
 )
 
@@ -558,6 +559,8 @@ func PushHandler(cmd *cobra.Command, args []string) error {
 	}
 
 	request := api.PushRequest{Name: args[0], Insecure: insecure}
+
+	n := model.ParseName(args[0])
 	if err := client.Push(cmd.Context(), &request, fn); err != nil {
 		if spinner != nil {
 			spinner.Stop()
@@ -568,7 +571,16 @@ func PushHandler(cmd *cobra.Command, args []string) error {
 		return err
 	}
 
+	p.Stop()
 	spinner.Stop()
+
+	destination := n.String()
+	if strings.HasSuffix(n.Host, ".ollama.ai") || strings.HasSuffix(n.Host, ".ollama.com") {
+		destination = "https://ollama.com/" + strings.TrimSuffix(n.DisplayShortest(), ":latest")
+	}
+	fmt.Printf("\nYou can find your model at:\n\n")
+	fmt.Printf("\t%s\n", destination)
+
 	return nil
 }
 

+ 125 - 0
cmd/cmd_test.go

@@ -4,6 +4,7 @@ import (
 	"bytes"
 	"context"
 	"encoding/json"
+	"io"
 	"net/http"
 	"net/http/httptest"
 	"os"
@@ -369,3 +370,127 @@ func TestGetModelfileName(t *testing.T) {
 		})
 	}
 }
+
+func TestPushHandler(t *testing.T) {
+	tests := []struct {
+		name           string
+		modelName      string
+		serverResponse map[string]func(w http.ResponseWriter, r *http.Request)
+		expectedError  string
+		expectedOutput string
+	}{
+		{
+			name:      "successful push",
+			modelName: "test-model",
+			serverResponse: map[string]func(w http.ResponseWriter, r *http.Request){
+				"/api/push": func(w http.ResponseWriter, r *http.Request) {
+					if r.Method != http.MethodPost {
+						t.Errorf("expected POST request, got %s", r.Method)
+					}
+
+					var req api.PushRequest
+					if err := json.NewDecoder(r.Body).Decode(&req); err != nil {
+						http.Error(w, err.Error(), http.StatusBadRequest)
+						return
+					}
+
+					if req.Name != "test-model" {
+						t.Errorf("expected model name 'test-model', got %s", req.Name)
+					}
+
+					// Simulate progress updates
+					responses := []api.ProgressResponse{
+						{Status: "preparing manifest"},
+						{Digest: "sha256:abc123456789", Total: 100, Completed: 50},
+						{Digest: "sha256:abc123456789", Total: 100, Completed: 100},
+					}
+
+					for _, resp := range responses {
+						if err := json.NewEncoder(w).Encode(resp); err != nil {
+							http.Error(w, err.Error(), http.StatusInternalServerError)
+							return
+						}
+						w.(http.Flusher).Flush()
+					}
+				},
+			},
+			expectedOutput: "\nYou can find your model at:\n\n\thttps://ollama.com/test-model\n",
+		},
+		{
+			name:      "unauthorized push",
+			modelName: "unauthorized-model",
+			serverResponse: map[string]func(w http.ResponseWriter, r *http.Request){
+				"/api/push": func(w http.ResponseWriter, r *http.Request) {
+					w.Header().Set("Content-Type", "application/json")
+					w.WriteHeader(http.StatusUnauthorized)
+					err := json.NewEncoder(w).Encode(map[string]string{
+						"error": "access denied",
+					})
+					if err != nil {
+						t.Fatal(err)
+					}
+				},
+			},
+			expectedError: "you are not authorized to push to this namespace, create the model under a namespace you own",
+		},
+	}
+
+	for _, tt := range tests {
+		t.Run(tt.name, func(t *testing.T) {
+			mockServer := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
+				if handler, ok := tt.serverResponse[r.URL.Path]; ok {
+					handler(w, r)
+					return
+				}
+				http.Error(w, "not found", http.StatusNotFound)
+			}))
+			defer mockServer.Close()
+
+			t.Setenv("OLLAMA_HOST", mockServer.URL)
+
+			cmd := &cobra.Command{}
+			cmd.Flags().Bool("insecure", false, "")
+			cmd.SetContext(context.TODO())
+
+			// Redirect stderr to capture progress output
+			oldStderr := os.Stderr
+			r, w, _ := os.Pipe()
+			os.Stderr = w
+
+			// Capture stdout for the "Model pushed" message
+			oldStdout := os.Stdout
+			outR, outW, _ := os.Pipe()
+			os.Stdout = outW
+
+			err := PushHandler(cmd, []string{tt.modelName})
+
+			// Restore stderr
+			w.Close()
+			os.Stderr = oldStderr
+			// drain the pipe
+			if _, err := io.ReadAll(r); err != nil {
+				t.Fatal(err)
+			}
+
+			// Restore stdout and get output
+			outW.Close()
+			os.Stdout = oldStdout
+			stdout, _ := io.ReadAll(outR)
+
+			if tt.expectedError == "" {
+				if err != nil {
+					t.Errorf("expected no error, got %v", err)
+				}
+				if tt.expectedOutput != "" {
+					if got := string(stdout); got != tt.expectedOutput {
+						t.Errorf("expected output %q, got %q", tt.expectedOutput, got)
+					}
+				}
+			} else {
+				if err == nil || !strings.Contains(err.Error(), tt.expectedError) {
+					t.Errorf("expected error containing %q, got %v", tt.expectedError, err)
+				}
+			}
+		})
+	}
+}