浏览代码

test: add test cases for ListHandler (#9146)

yuiseki 2 月之前
父节点
当前提交
d721a02e7d
共有 1 个文件被更改,包括 91 次插入0 次删除
  1. 91 0
      cmd/cmd_test.go

+ 91 - 0
cmd/cmd_test.go

@@ -10,6 +10,7 @@ import (
 	"os"
 	"strings"
 	"testing"
+	"time"
 
 	"github.com/google/go-cmp/cmp"
 	"github.com/spf13/cobra"
@@ -490,6 +491,96 @@ func TestPushHandler(t *testing.T) {
 	}
 }
 
+func TestListHandler(t *testing.T) {
+	tests := []struct {
+		name           string
+		args           []string
+		serverResponse []api.ListModelResponse
+		expectedError  string
+		expectedOutput string
+	}{
+		{
+			name: "list all models",
+			args: []string{},
+			serverResponse: []api.ListModelResponse{
+				{Name: "model1", Digest: "sha256:abc123", Size: 1024, ModifiedAt: time.Now().Add(-24 * time.Hour)},
+				{Name: "model2", Digest: "sha256:def456", Size: 2048, ModifiedAt: time.Now().Add(-48 * time.Hour)},
+			},
+			expectedOutput: "NAME      ID              SIZE      MODIFIED     \n" +
+				"model1    sha256:abc12    1.0 KB    24 hours ago    \n" +
+				"model2    sha256:def45    2.0 KB    2 days ago      \n",
+		},
+		{
+			name: "filter models by prefix",
+			args: []string{"model1"},
+			serverResponse: []api.ListModelResponse{
+				{Name: "model1", Digest: "sha256:abc123", Size: 1024, ModifiedAt: time.Now().Add(-24 * time.Hour)},
+				{Name: "model2", Digest: "sha256:def456", Size: 2048, ModifiedAt: time.Now().Add(-24 * time.Hour)},
+			},
+			expectedOutput: "NAME      ID              SIZE      MODIFIED     \n" +
+				"model1    sha256:abc12    1.0 KB    24 hours ago    \n",
+		},
+		{
+			name:          "server error",
+			args:          []string{},
+			expectedError: "server error",
+		},
+	}
+
+	for _, tt := range tests {
+		t.Run(tt.name, func(t *testing.T) {
+			mockServer := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
+				if r.URL.Path != "/api/tags" || r.Method != http.MethodGet {
+					t.Errorf("unexpected request to %s %s", r.Method, r.URL.Path)
+					http.Error(w, "not found", http.StatusNotFound)
+					return
+				}
+
+				if tt.expectedError != "" {
+					http.Error(w, tt.expectedError, http.StatusInternalServerError)
+					return
+				}
+
+				response := api.ListResponse{Models: tt.serverResponse}
+				if err := json.NewEncoder(w).Encode(response); err != nil {
+					t.Fatal(err)
+				}
+			}))
+			defer mockServer.Close()
+
+			t.Setenv("OLLAMA_HOST", mockServer.URL)
+
+			cmd := &cobra.Command{}
+			cmd.SetContext(context.TODO())
+
+			// Capture stdout
+			oldStdout := os.Stdout
+			r, w, _ := os.Pipe()
+			os.Stdout = w
+
+			err := ListHandler(cmd, tt.args)
+
+			// Restore stdout and get output
+			w.Close()
+			os.Stdout = oldStdout
+			output, _ := io.ReadAll(r)
+
+			if tt.expectedError == "" {
+				if err != nil {
+					t.Errorf("expected no error, got %v", err)
+				}
+				if got := string(output); got != tt.expectedOutput {
+					t.Errorf("expected output:\n%s\ngot:\n%s", tt.expectedOutput, got)
+				}
+			} else {
+				if err == nil || !strings.Contains(err.Error(), tt.expectedError) {
+					t.Errorf("expected error containing %q, got %v", tt.expectedError, err)
+				}
+			}
+		})
+	}
+}
+
 func TestCreateHandler(t *testing.T) {
 	tests := []struct {
 		name           string