Browse Source

simplify api client

Michael Yang 1 year ago
parent
commit
b0e63bfb4c
2 changed files with 21 additions and 25 deletions
  1. 19 11
      api/client.go
  2. 2 14
      cmd/cmd.go

+ 19 - 11
api/client.go

@@ -5,14 +5,24 @@ import (
 	"bytes"
 	"context"
 	"encoding/json"
-	"fmt"
 	"io"
 	"net/http"
+	"net/url"
 )
 
 type Client struct {
-	URL  string
-	HTTP http.Client
+	base url.URL
+}
+
+func NewClient(hosts ...string) *Client {
+	host := "127.0.0.1:11434"
+	if len(hosts) > 0 {
+		host = hosts[0]
+	}
+
+	return &Client{
+		base: url.URL{Scheme: "http", Host: host},
+	}
 }
 
 func (c *Client) stream(ctx context.Context, method string, path string, reqData any, fn func(bts []byte) error) error {
@@ -27,23 +37,21 @@ func (c *Client) stream(ctx context.Context, method string, path string, reqData
 		reqBody = bytes.NewReader(data)
 	}
 
-	url := fmt.Sprintf("%s%s", c.URL, path)
-
-	req, err := http.NewRequestWithContext(ctx, method, url, reqBody)
+	request, err := http.NewRequestWithContext(ctx, method, c.base.JoinPath(path).String(), reqBody)
 	if err != nil {
 		return err
 	}
 
-	req.Header.Set("Content-Type", "application/json")
-	req.Header.Set("Accept", "application/json")
+	request.Header.Set("Content-Type", "application/json")
+	request.Header.Set("Accept", "application/json")
 
-	res, err := c.HTTP.Do(req)
+	response, err := http.DefaultClient.Do(request)
 	if err != nil {
 		return err
 	}
-	defer res.Body.Close()
+	defer response.Body.Close()
 
-	scanner := bufio.NewScanner(res.Body)
+	scanner := bufio.NewScanner(response.Body)
 	for scanner.Scan() {
 		if err := fn(scanner.Bytes()); err != nil {
 			return err

+ 2 - 14
cmd/cmd.go

@@ -36,10 +36,7 @@ func RunRun(cmd *cobra.Command, args []string) error {
 }
 
 func pull(model string) error {
-	client, err := NewAPIClient()
-	if err != nil {
-		return err
-	}
+	client := api.NewClient()
 
 	var bar *progressbar.ProgressBar
 	return client.Pull(
@@ -68,10 +65,7 @@ func RunGenerate(_ *cobra.Command, args []string) error {
 }
 
 func generate(model string, prompts ...string) error {
-	client, err := NewAPIClient()
-	if err != nil {
-		return err
-	}
+	client := api.NewClient()
 
 	for _, prompt := range prompts {
 		client.Generate(context.Background(), &api.GenerateRequest{Model: model, Prompt: prompt}, func(resp api.GenerateResponse) error {
@@ -121,12 +115,6 @@ func RunServer(_ *cobra.Command, _ []string) error {
 	return server.Serve(ln)
 }
 
-func NewAPIClient() (*api.Client, error) {
-	return &api.Client{
-		URL: "http://localhost:11434",
-	}, nil
-}
-
 func NewCLI() *cobra.Command {
 	log.SetFlags(log.LstdFlags | log.Lshortfile)