Browse Source

client: fix trailing slash

Michael Yang 1 năm trước cách đây
mục cha
commit
28c3f288e2
2 tập tin đã thay đổi với 55 bổ sung2 xóa
  1. 12 2
      api/client.go
  2. 43 0
      api/client_test.go

+ 12 - 2
api/client.go

@@ -44,14 +44,24 @@ func checkError(resp *http.Response, body []byte) error {
 }
 
 func ClientFromEnvironment() (*Client, error) {
+	defaultPort := "11434"
+
 	scheme, hostport, ok := strings.Cut(os.Getenv("OLLAMA_HOST"), "://")
-	if !ok {
+	switch {
+	case !ok:
 		scheme, hostport = "http", os.Getenv("OLLAMA_HOST")
+	case scheme == "http":
+		defaultPort = "80"
+	case scheme == "https":
+		defaultPort = "443"
 	}
 
+	// trim trailing slashes
+	hostport = strings.TrimRight(hostport, "/")
+
 	host, port, err := net.SplitHostPort(hostport)
 	if err != nil {
-		host, port = "127.0.0.1", "11434"
+		host, port = "127.0.0.1", defaultPort
 		if ip := net.ParseIP(strings.Trim(hostport, "[]")); ip != nil {
 			host = ip.String()
 		} else if hostport != "" {

+ 43 - 0
api/client_test.go

@@ -0,0 +1,43 @@
+package api
+
+import "testing"
+
+func TestClientFromEnvironment(t *testing.T) {
+	type testCase struct {
+		value  string
+		expect string
+		err    error
+	}
+
+	testCases := map[string]*testCase{
+		"empty":                      {value: "", expect: "http://127.0.0.1:11434"},
+		"only address":               {value: "1.2.3.4", expect: "http://1.2.3.4:11434"},
+		"only port":                  {value: ":1234", expect: "http://:1234"},
+		"address and port":           {value: "1.2.3.4:1234", expect: "http://1.2.3.4:1234"},
+		"scheme http and address":    {value: "http://1.2.3.4", expect: "http://1.2.3.4:80"},
+		"scheme https and address":   {value: "https://1.2.3.4", expect: "https://1.2.3.4:443"},
+		"scheme, address, and port":  {value: "https://1.2.3.4:1234", expect: "https://1.2.3.4:1234"},
+		"hostname":                   {value: "example.com", expect: "http://example.com:11434"},
+		"hostname and port":          {value: "example.com:1234", expect: "http://example.com:1234"},
+		"scheme http and hostname":   {value: "http://example.com", expect: "http://example.com:80"},
+		"scheme https and hostname":  {value: "https://example.com", expect: "https://example.com:443"},
+		"scheme, hostname, and port": {value: "https://example.com:1234", expect: "https://example.com:1234"},
+		"trailing slash":             {value: "example.com/", expect: "http://example.com:11434"},
+		"trailing slash port":        {value: "example.com:1234/", expect: "http://example.com:1234"},
+	}
+
+	for k, v := range testCases {
+		t.Run(k, func(t *testing.T) {
+			t.Setenv("OLLAMA_HOST", v.value)
+
+			client, err := ClientFromEnvironment()
+			if err != v.err {
+				t.Fatalf("expected %s, got %s", v.err, err)
+			}
+
+			if client.base.String() != v.expect {
+				t.Fatalf("expected %s, got %s", v.expect, client.base.String())
+			}
+		})
+	}
+}