|
@@ -10,6 +10,7 @@ import (
|
|
|
"os"
|
|
|
"strings"
|
|
|
"testing"
|
|
|
+ "time"
|
|
|
|
|
|
"github.com/google/go-cmp/cmp"
|
|
|
"github.com/spf13/cobra"
|
|
@@ -490,6 +491,96 @@ func TestPushHandler(t *testing.T) {
|
|
|
}
|
|
|
}
|
|
|
|
|
|
+func TestListHandler(t *testing.T) {
|
|
|
+ tests := []struct {
|
|
|
+ name string
|
|
|
+ args []string
|
|
|
+ serverResponse []api.ListModelResponse
|
|
|
+ expectedError string
|
|
|
+ expectedOutput string
|
|
|
+ }{
|
|
|
+ {
|
|
|
+ name: "list all models",
|
|
|
+ args: []string{},
|
|
|
+ serverResponse: []api.ListModelResponse{
|
|
|
+ {Name: "model1", Digest: "sha256:abc123", Size: 1024, ModifiedAt: time.Now().Add(-24 * time.Hour)},
|
|
|
+ {Name: "model2", Digest: "sha256:def456", Size: 2048, ModifiedAt: time.Now().Add(-48 * time.Hour)},
|
|
|
+ },
|
|
|
+ expectedOutput: "NAME ID SIZE MODIFIED \n" +
|
|
|
+ "model1 sha256:abc12 1.0 KB 24 hours ago \n" +
|
|
|
+ "model2 sha256:def45 2.0 KB 2 days ago \n",
|
|
|
+ },
|
|
|
+ {
|
|
|
+ name: "filter models by prefix",
|
|
|
+ args: []string{"model1"},
|
|
|
+ serverResponse: []api.ListModelResponse{
|
|
|
+ {Name: "model1", Digest: "sha256:abc123", Size: 1024, ModifiedAt: time.Now().Add(-24 * time.Hour)},
|
|
|
+ {Name: "model2", Digest: "sha256:def456", Size: 2048, ModifiedAt: time.Now().Add(-24 * time.Hour)},
|
|
|
+ },
|
|
|
+ expectedOutput: "NAME ID SIZE MODIFIED \n" +
|
|
|
+ "model1 sha256:abc12 1.0 KB 24 hours ago \n",
|
|
|
+ },
|
|
|
+ {
|
|
|
+ name: "server error",
|
|
|
+ args: []string{},
|
|
|
+ expectedError: "server error",
|
|
|
+ },
|
|
|
+ }
|
|
|
+
|
|
|
+ for _, tt := range tests {
|
|
|
+ t.Run(tt.name, func(t *testing.T) {
|
|
|
+ mockServer := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
|
|
+ if r.URL.Path != "/api/tags" || r.Method != http.MethodGet {
|
|
|
+ t.Errorf("unexpected request to %s %s", r.Method, r.URL.Path)
|
|
|
+ http.Error(w, "not found", http.StatusNotFound)
|
|
|
+ return
|
|
|
+ }
|
|
|
+
|
|
|
+ if tt.expectedError != "" {
|
|
|
+ http.Error(w, tt.expectedError, http.StatusInternalServerError)
|
|
|
+ return
|
|
|
+ }
|
|
|
+
|
|
|
+ response := api.ListResponse{Models: tt.serverResponse}
|
|
|
+ if err := json.NewEncoder(w).Encode(response); err != nil {
|
|
|
+ t.Fatal(err)
|
|
|
+ }
|
|
|
+ }))
|
|
|
+ defer mockServer.Close()
|
|
|
+
|
|
|
+ t.Setenv("OLLAMA_HOST", mockServer.URL)
|
|
|
+
|
|
|
+ cmd := &cobra.Command{}
|
|
|
+ cmd.SetContext(context.TODO())
|
|
|
+
|
|
|
+ // Capture stdout
|
|
|
+ oldStdout := os.Stdout
|
|
|
+ r, w, _ := os.Pipe()
|
|
|
+ os.Stdout = w
|
|
|
+
|
|
|
+ err := ListHandler(cmd, tt.args)
|
|
|
+
|
|
|
+ // Restore stdout and get output
|
|
|
+ w.Close()
|
|
|
+ os.Stdout = oldStdout
|
|
|
+ output, _ := io.ReadAll(r)
|
|
|
+
|
|
|
+ if tt.expectedError == "" {
|
|
|
+ if err != nil {
|
|
|
+ t.Errorf("expected no error, got %v", err)
|
|
|
+ }
|
|
|
+ if got := string(output); got != tt.expectedOutput {
|
|
|
+ t.Errorf("expected output:\n%s\ngot:\n%s", tt.expectedOutput, got)
|
|
|
+ }
|
|
|
+ } else {
|
|
|
+ if err == nil || !strings.Contains(err.Error(), tt.expectedError) {
|
|
|
+ t.Errorf("expected error containing %q, got %v", tt.expectedError, err)
|
|
|
+ }
|
|
|
+ }
|
|
|
+ })
|
|
|
+ }
|
|
|
+}
|
|
|
+
|
|
|
func TestCreateHandler(t *testing.T) {
|
|
|
tests := []struct {
|
|
|
name string
|