|
@@ -4,6 +4,7 @@ import (
|
|
|
"bytes"
|
|
|
"context"
|
|
|
"encoding/json"
|
|
|
+ "io"
|
|
|
"net/http"
|
|
|
"net/http/httptest"
|
|
|
"os"
|
|
@@ -369,3 +370,127 @@ func TestGetModelfileName(t *testing.T) {
|
|
|
})
|
|
|
}
|
|
|
}
|
|
|
+
|
|
|
+func TestPushHandler(t *testing.T) {
|
|
|
+ tests := []struct {
|
|
|
+ name string
|
|
|
+ modelName string
|
|
|
+ serverResponse map[string]func(w http.ResponseWriter, r *http.Request)
|
|
|
+ expectedError string
|
|
|
+ expectedOutput string
|
|
|
+ }{
|
|
|
+ {
|
|
|
+ name: "successful push",
|
|
|
+ modelName: "test-model",
|
|
|
+ serverResponse: map[string]func(w http.ResponseWriter, r *http.Request){
|
|
|
+ "/api/push": func(w http.ResponseWriter, r *http.Request) {
|
|
|
+ if r.Method != http.MethodPost {
|
|
|
+ t.Errorf("expected POST request, got %s", r.Method)
|
|
|
+ }
|
|
|
+
|
|
|
+ var req api.PushRequest
|
|
|
+ if err := json.NewDecoder(r.Body).Decode(&req); err != nil {
|
|
|
+ http.Error(w, err.Error(), http.StatusBadRequest)
|
|
|
+ return
|
|
|
+ }
|
|
|
+
|
|
|
+ if req.Name != "test-model" {
|
|
|
+ t.Errorf("expected model name 'test-model', got %s", req.Name)
|
|
|
+ }
|
|
|
+
|
|
|
+ // Simulate progress updates
|
|
|
+ responses := []api.ProgressResponse{
|
|
|
+ {Status: "preparing manifest"},
|
|
|
+ {Digest: "sha256:abc123456789", Total: 100, Completed: 50},
|
|
|
+ {Digest: "sha256:abc123456789", Total: 100, Completed: 100},
|
|
|
+ }
|
|
|
+
|
|
|
+ for _, resp := range responses {
|
|
|
+ if err := json.NewEncoder(w).Encode(resp); err != nil {
|
|
|
+ http.Error(w, err.Error(), http.StatusInternalServerError)
|
|
|
+ return
|
|
|
+ }
|
|
|
+ w.(http.Flusher).Flush()
|
|
|
+ }
|
|
|
+ },
|
|
|
+ },
|
|
|
+ expectedOutput: "\nYou can find your model at:\n\n\thttps://ollama.com/test-model\n",
|
|
|
+ },
|
|
|
+ {
|
|
|
+ name: "unauthorized push",
|
|
|
+ modelName: "unauthorized-model",
|
|
|
+ serverResponse: map[string]func(w http.ResponseWriter, r *http.Request){
|
|
|
+ "/api/push": func(w http.ResponseWriter, r *http.Request) {
|
|
|
+ w.Header().Set("Content-Type", "application/json")
|
|
|
+ w.WriteHeader(http.StatusUnauthorized)
|
|
|
+ err := json.NewEncoder(w).Encode(map[string]string{
|
|
|
+ "error": "access denied",
|
|
|
+ })
|
|
|
+ if err != nil {
|
|
|
+ t.Fatal(err)
|
|
|
+ }
|
|
|
+ },
|
|
|
+ },
|
|
|
+ expectedError: "you are not authorized to push to this namespace, create the model under a namespace you own",
|
|
|
+ },
|
|
|
+ }
|
|
|
+
|
|
|
+ for _, tt := range tests {
|
|
|
+ t.Run(tt.name, func(t *testing.T) {
|
|
|
+ mockServer := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
|
|
+ if handler, ok := tt.serverResponse[r.URL.Path]; ok {
|
|
|
+ handler(w, r)
|
|
|
+ return
|
|
|
+ }
|
|
|
+ http.Error(w, "not found", http.StatusNotFound)
|
|
|
+ }))
|
|
|
+ defer mockServer.Close()
|
|
|
+
|
|
|
+ t.Setenv("OLLAMA_HOST", mockServer.URL)
|
|
|
+
|
|
|
+ cmd := &cobra.Command{}
|
|
|
+ cmd.Flags().Bool("insecure", false, "")
|
|
|
+ cmd.SetContext(context.TODO())
|
|
|
+
|
|
|
+ // Redirect stderr to capture progress output
|
|
|
+ oldStderr := os.Stderr
|
|
|
+ r, w, _ := os.Pipe()
|
|
|
+ os.Stderr = w
|
|
|
+
|
|
|
+ // Capture stdout for the "Model pushed" message
|
|
|
+ oldStdout := os.Stdout
|
|
|
+ outR, outW, _ := os.Pipe()
|
|
|
+ os.Stdout = outW
|
|
|
+
|
|
|
+ err := PushHandler(cmd, []string{tt.modelName})
|
|
|
+
|
|
|
+ // Restore stderr
|
|
|
+ w.Close()
|
|
|
+ os.Stderr = oldStderr
|
|
|
+ // drain the pipe
|
|
|
+ if _, err := io.ReadAll(r); err != nil {
|
|
|
+ t.Fatal(err)
|
|
|
+ }
|
|
|
+
|
|
|
+ // Restore stdout and get output
|
|
|
+ outW.Close()
|
|
|
+ os.Stdout = oldStdout
|
|
|
+ stdout, _ := io.ReadAll(outR)
|
|
|
+
|
|
|
+ if tt.expectedError == "" {
|
|
|
+ if err != nil {
|
|
|
+ t.Errorf("expected no error, got %v", err)
|
|
|
+ }
|
|
|
+ if tt.expectedOutput != "" {
|
|
|
+ if got := string(stdout); got != tt.expectedOutput {
|
|
|
+ t.Errorf("expected output %q, got %q", tt.expectedOutput, got)
|
|
|
+ }
|
|
|
+ }
|
|
|
+ } else {
|
|
|
+ if err == nil || !strings.Contains(err.Error(), tt.expectedError) {
|
|
|
+ t.Errorf("expected error containing %q, got %v", tt.expectedError, err)
|
|
|
+ }
|
|
|
+ }
|
|
|
+ })
|
|
|
+ }
|
|
|
+}
|