Bladeren bron

cmd: support OLLAMA_CLIENT_HOST environment variable (#262)

* cmd: support OLLAMA_HOST environment variable

This commit adds support for the OLLAMA_HOST environment
variable. This variable can be used to specify the host to which
the client should connect. This is useful when the client is
running somewhere other than the host where the server is running.

The new api.FromEnv function is used to read configure clients from the
environment. Clients wishing to use the environment variable being
consistent with the Ollama CLI can use this new function.

* Update api/client.go

Co-authored-by: Jeffrey Morgan <jmorganca@gmail.com>

* Update api/client.go

Co-authored-by: Jeffrey Morgan <jmorganca@gmail.com>

---------

Co-authored-by: Jeffrey Morgan <jmorganca@gmail.com>
Blake Mizerany 1 jaar geleden
bovenliggende
commit
67e593e355
2 gewijzigde bestanden met toevoegingen van 64 en 13 verwijderingen
  1. 32 5
      api/client.go
  2. 32 8
      cmd/cmd.go

+ 32 - 5
api/client.go

@@ -9,10 +9,17 @@ import (
 	"io"
 	"io"
 	"net/http"
 	"net/http"
 	"net/url"
 	"net/url"
+	"os"
+)
+
+const DefaultHost = "localhost:11434"
+
+var (
+	envHost = os.Getenv("OLLAMA_HOST")
 )
 )
 
 
 type Client struct {
 type Client struct {
-	base    url.URL
+	Base    url.URL
 	HTTP    http.Client
 	HTTP    http.Client
 	Headers http.Header
 	Headers http.Header
 }
 }
@@ -33,14 +40,34 @@ func checkError(resp *http.Response, body []byte) error {
 	return apiError
 	return apiError
 }
 }
 
 
+// Host returns the default host to use for the client. It is determined in the following order:
+// 1. The OLLAMA_HOST environment variable
+// 2. The default host (localhost:11434)
+func Host() string {
+	if envHost != "" {
+		return envHost
+	}
+	return DefaultHost
+}
+
+// FromEnv creates a new client using Host() as the host. An error is returns
+// if the host is invalid.
+func FromEnv() (*Client, error) {
+	u, err := url.Parse(Host())
+	if err != nil {
+		return nil, err
+	}
+	return &Client{Base: *u}, nil
+}
+
 func NewClient(hosts ...string) *Client {
 func NewClient(hosts ...string) *Client {
-	host := "127.0.0.1:11434"
+	host := DefaultHost
 	if len(hosts) > 0 {
 	if len(hosts) > 0 {
 		host = hosts[0]
 		host = hosts[0]
 	}
 	}
 
 
 	return &Client{
 	return &Client{
-		base: url.URL{Scheme: "http", Host: host},
+		Base: url.URL{Scheme: "http", Host: host},
 		HTTP: http.Client{},
 		HTTP: http.Client{},
 	}
 	}
 }
 }
@@ -57,7 +84,7 @@ func (c *Client) do(ctx context.Context, method, path string, reqData, respData
 		reqBody = bytes.NewReader(data)
 		reqBody = bytes.NewReader(data)
 	}
 	}
 
 
-	url := c.base.JoinPath(path).String()
+	url := c.Base.JoinPath(path).String()
 
 
 	req, err := http.NewRequestWithContext(ctx, method, url, reqBody)
 	req, err := http.NewRequestWithContext(ctx, method, url, reqBody)
 	if err != nil {
 	if err != nil {
@@ -105,7 +132,7 @@ func (c *Client) stream(ctx context.Context, method, path string, data any, fn f
 		buf = bytes.NewBuffer(bts)
 		buf = bytes.NewBuffer(bts)
 	}
 	}
 
 
-	request, err := http.NewRequestWithContext(ctx, method, c.base.JoinPath(path).String(), buf)
+	request, err := http.NewRequestWithContext(ctx, method, c.Base.JoinPath(path).String(), buf)
 	if err != nil {
 	if err != nil {
 		return err
 		return err
 	}
 	}

+ 32 - 8
cmd/cmd.go

@@ -39,7 +39,10 @@ func CreateHandler(cmd *cobra.Command, args []string) error {
 		return err
 		return err
 	}
 	}
 
 
-	client := api.NewClient()
+	client, err := api.FromEnv()
+	if err != nil {
+		return err
+	}
 
 
 	var spinner *Spinner
 	var spinner *Spinner
 
 
@@ -117,7 +120,10 @@ func RunHandler(cmd *cobra.Command, args []string) error {
 }
 }
 
 
 func PushHandler(cmd *cobra.Command, args []string) error {
 func PushHandler(cmd *cobra.Command, args []string) error {
-	client := api.NewClient()
+	client, err := api.FromEnv()
+	if err != nil {
+		return err
+	}
 
 
 	insecure, err := cmd.Flags().GetBool("insecure")
 	insecure, err := cmd.Flags().GetBool("insecure")
 	if err != nil {
 	if err != nil {
@@ -153,7 +159,10 @@ func PushHandler(cmd *cobra.Command, args []string) error {
 }
 }
 
 
 func ListHandler(cmd *cobra.Command, args []string) error {
 func ListHandler(cmd *cobra.Command, args []string) error {
-	client := api.NewClient()
+	client, err := api.FromEnv()
+	if err != nil {
+		return err
+	}
 
 
 	models, err := client.List(context.Background())
 	models, err := client.List(context.Background())
 	if err != nil {
 	if err != nil {
@@ -183,7 +192,10 @@ func ListHandler(cmd *cobra.Command, args []string) error {
 }
 }
 
 
 func DeleteHandler(cmd *cobra.Command, args []string) error {
 func DeleteHandler(cmd *cobra.Command, args []string) error {
-	client := api.NewClient()
+	client, err := api.FromEnv()
+	if err != nil {
+		return err
+	}
 
 
 	req := api.DeleteRequest{Name: args[0]}
 	req := api.DeleteRequest{Name: args[0]}
 	if err := client.Delete(context.Background(), &req); err != nil {
 	if err := client.Delete(context.Background(), &req); err != nil {
@@ -194,7 +206,10 @@ func DeleteHandler(cmd *cobra.Command, args []string) error {
 }
 }
 
 
 func CopyHandler(cmd *cobra.Command, args []string) error {
 func CopyHandler(cmd *cobra.Command, args []string) error {
-	client := api.NewClient()
+	client, err := api.FromEnv()
+	if err != nil {
+		return err
+	}
 
 
 	req := api.CopyRequest{Source: args[0], Destination: args[1]}
 	req := api.CopyRequest{Source: args[0], Destination: args[1]}
 	if err := client.Copy(context.Background(), &req); err != nil {
 	if err := client.Copy(context.Background(), &req); err != nil {
@@ -214,7 +229,10 @@ func PullHandler(cmd *cobra.Command, args []string) error {
 }
 }
 
 
 func pull(model string, insecure bool) error {
 func pull(model string, insecure bool) error {
-	client := api.NewClient()
+	client, err := api.FromEnv()
+	if err != nil {
+		return err
+	}
 
 
 	var currentDigest string
 	var currentDigest string
 	var bar *progressbar.ProgressBar
 	var bar *progressbar.ProgressBar
@@ -261,7 +279,10 @@ type generateContextKey string
 
 
 func generate(cmd *cobra.Command, model, prompt string) error {
 func generate(cmd *cobra.Command, model, prompt string) error {
 	if len(strings.TrimSpace(prompt)) > 0 {
 	if len(strings.TrimSpace(prompt)) > 0 {
-		client := api.NewClient()
+		client, err := api.FromEnv()
+		if err != nil {
+			return err
+		}
 
 
 		spinner := NewSpinner("")
 		spinner := NewSpinner("")
 		go spinner.Spin(60 * time.Millisecond)
 		go spinner.Spin(60 * time.Millisecond)
@@ -644,7 +665,10 @@ func startMacApp(client *api.Client) error {
 }
 }
 
 
 func checkServerHeartbeat(_ *cobra.Command, _ []string) error {
 func checkServerHeartbeat(_ *cobra.Command, _ []string) error {
-	client := api.NewClient()
+	client, err := api.FromEnv()
+	if err != nil {
+		return err
+	}
 	if err := client.Heartbeat(context.Background()); err != nil {
 	if err := client.Heartbeat(context.Background()); err != nil {
 		if !strings.Contains(err.Error(), "connection refused") {
 		if !strings.Contains(err.Error(), "connection refused") {
 			return err
 			return err