Browse Source

handle client proxy

Michael Yang 1 year ago
parent
commit
2cfffea02e
2 changed files with 46 additions and 44 deletions
  1. 36 34
      api/client.go
  2. 10 10
      cmd/cmd.go

+ 36 - 34
api/client.go

@@ -7,6 +7,7 @@ import (
 	"encoding/json"
 	"fmt"
 	"io"
+	"net"
 	"net/http"
 	"net/url"
 	"os"
@@ -16,14 +17,9 @@ import (
 	"github.com/jmorganca/ollama/version"
 )
 
-const DefaultHost = "127.0.0.1:11434"
-
-var envHost = os.Getenv("OLLAMA_HOST")
-
 type Client struct {
-	Base    url.URL
-	HTTP    http.Client
-	Headers http.Header
+	base *url.URL
+	http http.Client
 }
 
 func checkError(resp *http.Response, body []byte) error {
@@ -42,34 +38,44 @@ func checkError(resp *http.Response, body []byte) error {
 	return apiError
 }
 
-// Host returns the default host to use for the client. It is determined in the following order:
-// 1. The OLLAMA_HOST environment variable
-// 2. The default host (localhost:11434)
-func Host() string {
-	if envHost != "" {
-		return envHost
+func ClientFromEnvironment() (*Client, error) {
+	scheme, hostport, ok := strings.Cut(os.Getenv("OLLAMA_HOST"), "://")
+	if !ok {
+		scheme, hostport = "http", os.Getenv("OLLAMA_HOST")
+	}
+
+	host, port, err := net.SplitHostPort(hostport)
+	if err != nil {
+		host, port = "127.0.0.1", "11434"
+		if ip := net.ParseIP(strings.Trim(os.Getenv("OLLAMA_HOST"), "[]")); ip != nil {
+			host = ip.String()
+		}
 	}
-	return DefaultHost
-}
 
-// FromEnv creates a new client using Host() as the host. An error is returns
-// if the host is invalid.
-func FromEnv() (*Client, error) {
-	h := Host()
-	if !strings.HasPrefix(h, "http://") && !strings.HasPrefix(h, "https://") {
-		h = "http://" + h
+	client := Client{
+		base: &url.URL{
+			Scheme: scheme,
+			Host:   net.JoinHostPort(host, port),
+		},
 	}
 
-	u, err := url.Parse(h)
+	mockRequest, err := http.NewRequest("HEAD", client.base.String(), nil)
 	if err != nil {
-		return nil, fmt.Errorf("could not parse host: %w", err)
+		return nil, err
 	}
 
-	if u.Port() == "" {
-		u.Host += ":11434"
+	proxyURL, err := http.ProxyFromEnvironment(mockRequest)
+	if err != nil {
+		return nil, err
 	}
 
-	return &Client{Base: *u, HTTP: http.Client{}}, nil
+	client.http = http.Client{
+		Transport: &http.Transport{
+			Proxy: http.ProxyURL(proxyURL),
+		},
+	}
+
+	return &client, nil
 }
 
 func (c *Client) do(ctx context.Context, method, path string, reqData, respData any) error {
@@ -84,7 +90,7 @@ func (c *Client) do(ctx context.Context, method, path string, reqData, respData
 		reqBody = bytes.NewReader(data)
 	}
 
-	requestURL := c.Base.JoinPath(path)
+	requestURL := c.base.JoinPath(path)
 	request, err := http.NewRequestWithContext(ctx, method, requestURL.String(), reqBody)
 	if err != nil {
 		return err
@@ -94,11 +100,7 @@ func (c *Client) do(ctx context.Context, method, path string, reqData, respData
 	request.Header.Set("Accept", "application/json")
 	request.Header.Set("User-Agent", fmt.Sprintf("ollama/%s (%s %s) Go/%s", version.Version, runtime.GOARCH, runtime.GOOS, runtime.Version()))
 
-	for k, v := range c.Headers {
-		request.Header[k] = v
-	}
-
-	respObj, err := c.HTTP.Do(request)
+	respObj, err := c.http.Do(request)
 	if err != nil {
 		return err
 	}
@@ -134,7 +136,7 @@ func (c *Client) stream(ctx context.Context, method, path string, data any, fn f
 		buf = bytes.NewBuffer(bts)
 	}
 
-	requestURL := c.Base.JoinPath(path)
+	requestURL := c.base.JoinPath(path)
 	request, err := http.NewRequestWithContext(ctx, method, requestURL.String(), buf)
 	if err != nil {
 		return err
@@ -144,7 +146,7 @@ func (c *Client) stream(ctx context.Context, method, path string, data any, fn f
 	request.Header.Set("Accept", "application/json")
 	request.Header.Set("User-Agent", fmt.Sprintf("ollama/%s (%s %s) Go/%s", version.Version, runtime.GOARCH, runtime.GOOS, runtime.Version()))
 
-	response, err := http.DefaultClient.Do(request)
+	response, err := c.http.Do(request)
 	if err != nil {
 		return err
 	}

+ 10 - 10
cmd/cmd.go

@@ -61,7 +61,7 @@ func CreateHandler(cmd *cobra.Command, args []string) error {
 		return err
 	}
 
-	client, err := api.FromEnv()
+	client, err := api.ClientFromEnvironment()
 	if err != nil {
 		return err
 	}
@@ -119,7 +119,7 @@ func CreateHandler(cmd *cobra.Command, args []string) error {
 }
 
 func RunHandler(cmd *cobra.Command, args []string) error {
-	client, err := api.FromEnv()
+	client, err := api.ClientFromEnvironment()
 	if err != nil {
 		return err
 	}
@@ -144,7 +144,7 @@ func RunHandler(cmd *cobra.Command, args []string) error {
 }
 
 func PushHandler(cmd *cobra.Command, args []string) error {
-	client, err := api.FromEnv()
+	client, err := api.ClientFromEnvironment()
 	if err != nil {
 		return err
 	}
@@ -188,7 +188,7 @@ func PushHandler(cmd *cobra.Command, args []string) error {
 }
 
 func ListHandler(cmd *cobra.Command, args []string) error {
-	client, err := api.FromEnv()
+	client, err := api.ClientFromEnvironment()
 	if err != nil {
 		return err
 	}
@@ -221,7 +221,7 @@ func ListHandler(cmd *cobra.Command, args []string) error {
 }
 
 func DeleteHandler(cmd *cobra.Command, args []string) error {
-	client, err := api.FromEnv()
+	client, err := api.ClientFromEnvironment()
 	if err != nil {
 		return err
 	}
@@ -237,7 +237,7 @@ func DeleteHandler(cmd *cobra.Command, args []string) error {
 }
 
 func ShowHandler(cmd *cobra.Command, args []string) error {
-	client, err := api.FromEnv()
+	client, err := api.ClientFromEnvironment()
 	if err != nil {
 		return err
 	}
@@ -315,7 +315,7 @@ func ShowHandler(cmd *cobra.Command, args []string) error {
 }
 
 func CopyHandler(cmd *cobra.Command, args []string) error {
-	client, err := api.FromEnv()
+	client, err := api.ClientFromEnvironment()
 	if err != nil {
 		return err
 	}
@@ -338,7 +338,7 @@ func PullHandler(cmd *cobra.Command, args []string) error {
 }
 
 func pull(model string, insecure bool) error {
-	client, err := api.FromEnv()
+	client, err := api.ClientFromEnvironment()
 	if err != nil {
 		return err
 	}
@@ -406,7 +406,7 @@ func RunGenerate(cmd *cobra.Command, args []string) error {
 type generateContextKey string
 
 func generate(cmd *cobra.Command, model, prompt string, wordWrap bool) error {
-	client, err := api.FromEnv()
+	client, err := api.ClientFromEnvironment()
 	if err != nil {
 		return err
 	}
@@ -906,7 +906,7 @@ func startMacApp(client *api.Client) error {
 }
 
 func checkServerHeartbeat(_ *cobra.Command, _ []string) error {
-	client, err := api.FromEnv()
+	client, err := api.ClientFromEnvironment()
 	if err != nil {
 		return err
 	}