瀏覽代碼

start tests

Josh Yan 10 月之前
父節點
當前提交
f7d64856d5
共有 3 個文件被更改,包括 121 次插入0 次删除
  1. 20 0
      api/client.go
  2. 79 0
      api/client_test.go
  3. 22 0
      cmd/cmd.go

+ 20 - 0
api/client.go

@@ -383,3 +383,23 @@ func (c *Client) Version(ctx context.Context) (string, error) {
 
 	return version.Version, nil
 }
+
+// IsLocal checks whether the client is connecting to a local server.
+func (c *Client) IsLocal() bool {
+	// Resolve the host to an IP address and check if the IP is local
+	// Currently, only checks if it is localhost or loopback
+	host, _, err := net.SplitHostPort(c.base.Host)
+	if err != nil {
+		host = c.base.Host
+	}
+
+	if host == "" || host == "localhost" {
+		return true
+	}
+
+	if ip := net.ParseIP(host); ip != nil {
+		return ip.IsLoopback()
+	}
+
+	return false
+}

+ 79 - 0
api/client_test.go

@@ -1,6 +1,8 @@
 package api
 
 import (
+	"net/http"
+	"net/url"
 	"testing"
 
 	"github.com/ollama/ollama/envconfig"
@@ -46,3 +48,80 @@ func TestClientFromEnvironment(t *testing.T) {
 		})
 	}
 }
+
+// Test function
+func TestIsLocal(t *testing.T) {
+	type test struct {
+		client *Client
+		want   bool
+		err    error
+	}
+
+	tests := map[string]test{
+		"localhost": {
+			client: func() *Client {
+				baseURL, _ := url.Parse("http://localhost:1234")
+				return &Client{base: baseURL, http: &http.Client{}}
+			}(),
+			want: true,
+			err:  nil,
+		},
+		"127.0.0.1": {
+			client: func() *Client {
+				baseURL, _ := url.Parse("http://127.0.0.1:1234")
+				return &Client{base: baseURL, http: &http.Client{}}
+			}(),
+			want: true,
+			err:  nil,
+		},
+		"example.com": {
+			client: func() *Client {
+				baseURL, _ := url.Parse("http://example.com:1111")
+				return &Client{base: baseURL, http: &http.Client{}}
+			}(),
+			want: false,
+			err:  nil,
+		},
+		"8.8.8.8": {
+			client: func() *Client {
+				baseURL, _ := url.Parse("http://8.8.8.8:1234")
+				return &Client{base: baseURL, http: &http.Client{}}
+			}(),
+			want: false,
+			err:  nil,
+		},
+		"empty host with port": {
+			client: func() *Client {
+				baseURL, _ := url.Parse("http://:1234")
+				return &Client{base: baseURL, http: &http.Client{}}
+			}(),
+			want: true,
+			err:  nil,
+		},
+		"empty host without port": {
+			client: func() *Client {
+				baseURL, _ := url.Parse("http://")
+				return &Client{base: baseURL, http: &http.Client{}}
+			}(),
+			want: true,
+			err:  nil,
+		},
+		"remote host without port": {
+			client: func() *Client {
+				baseURL, _ := url.Parse("http://example.com")
+				return &Client{base: baseURL, http: &http.Client{}}
+			}(),
+			want: false,
+			err:  nil,
+		},
+	}
+
+	for name, tc := range tests {
+		t.Run(name, func(t *testing.T) {
+			got := tc.client.IsLocal()
+			if got != tc.want {
+				t.Errorf("test %s failed: got %v, want %v", name, got, tc.want)
+			}
+		})
+	}
+}

+ 22 - 0
cmd/cmd.go

@@ -277,12 +277,34 @@ func createBlob(cmd *cobra.Command, client *api.Client, path string) (string, er
 	}
 
 	digest := fmt.Sprintf("sha256:%x", hash.Sum(nil))
+
+	// Here, we want to check if the server is local
+	// If true, call, createBlobLocal
+	// This should find the model directory, copy blob over, and return the digest
+	// If this fails, just upload it
+	// If this is successful, return the digest
+
+	// Resolve server to IP
+	// Check if server is local
+	if client.IsLocal() {
+		err := createBlobLocal(cmd, client, digest)
+		if err == nil {
+			return digest, nil
+		}
+	}
+
 	if err = client.CreateBlob(cmd.Context(), digest, bin); err != nil {
 		return "", err
 	}
 	return digest, nil
 }
 
+func createBlobLocal(cmd *cobra.Command, client *api.Client, digest string) error {
+	// This function should be called if the server is local
+	// It should find the model directory, copy the blob over, and return the digest
+
+}
+
 func RunHandler(cmd *cobra.Command, args []string) error {
 	interactive := true