Browse Source

Merge pull request #49 from jmorganca/go-run

Go run
Michael Yang 1 year ago
parent
commit
962cc9ca49
7 changed files with 235 additions and 187 deletions
  1. 63 108
      api/client.go
  2. 2 2
      api/types.go
  3. 126 55
      cmd/cmd.go
  4. 8 2
      go.mod
  5. 14 4
      go.sum
  6. 18 16
      server/models.go
  7. 4 0
      server/routes.go

+ 63 - 108
api/client.go

@@ -5,153 +5,108 @@ import (
 	"bytes"
 	"context"
 	"encoding/json"
-	"fmt"
 	"io"
 	"net/http"
-	"sync"
+	"net/url"
 )
 
 type Client struct {
-	URL  string
-	HTTP http.Client
+	base url.URL
 }
 
-func checkError(resp *http.Response, body []byte) error {
-	if resp.StatusCode >= 200 && resp.StatusCode < 400 {
-		return nil
+func NewClient(hosts ...string) *Client {
+	host := "127.0.0.1:11434"
+	if len(hosts) > 0 {
+		host = hosts[0]
 	}
 
-	apiError := Error{Code: int32(resp.StatusCode)}
-
-	if err := json.Unmarshal(body, &apiError); err != nil {
-		// Use the full body as the message if we fail to decode a response.
-		apiError.Message = string(body)
+	return &Client{
+		base: url.URL{Scheme: "http", Host: host},
 	}
-
-	return apiError
 }
 
-func (c *Client) stream(ctx context.Context, method string, path string, reqData any, callback func(data []byte)) error {
-	var reqBody io.Reader
-	var data []byte
-	var err error
-	if reqData != nil {
-		data, err = json.Marshal(reqData)
-		if err != nil {
-			return err
-		}
-		reqBody = bytes.NewReader(data)
-	}
-
-	url := fmt.Sprintf("%s%s", c.URL, path)
+type options struct {
+	requestBody  io.Reader
+	responseFunc func(bts []byte) error
+}
 
-	req, err := http.NewRequestWithContext(ctx, method, url, reqBody)
+func OptionRequestBody(data any) func(*options) {
+	bts, err := json.Marshal(data)
 	if err != nil {
-		return err
+		panic(err)
 	}
 
-	req.Header.Set("Content-Type", "application/json")
-	req.Header.Set("Accept", "application/json")
-
-	res, err := c.HTTP.Do(req)
-	if err != nil {
-		return err
+	return func(opts *options) {
+		opts.requestBody = bytes.NewReader(bts)
 	}
-	defer res.Body.Close()
-
-	reader := bufio.NewReader(res.Body)
+}
 
-	for {
-		line, err := reader.ReadBytes('\n')
-		if err != nil {
-			if err == io.EOF {
-				break
-			} else {
-				return err // Handle other errors
-			}
-		}
-		if err := checkError(res, line); err != nil {
-			return err
-		}
-		callback(bytes.TrimSuffix(line, []byte("\n")))
+func OptionResponseFunc(fn func([]byte) error) func(*options) {
+	return func(opts *options) {
+		opts.responseFunc = fn
 	}
-
-	return nil
 }
 
-func (c *Client) do(ctx context.Context, method string, path string, reqData any, respData any) error {
-	var reqBody io.Reader
-	var data []byte
-	var err error
-	if reqData != nil {
-		data, err = json.Marshal(reqData)
-		if err != nil {
-			return err
-		}
-		reqBody = bytes.NewReader(data)
+func (c *Client) stream(ctx context.Context, method, path string, fns ...func(*options)) error {
+	var opts options
+	for _, fn := range fns {
+		fn(&opts)
 	}
 
-	url := fmt.Sprintf("%s%s", c.URL, path)
-
-	req, err := http.NewRequestWithContext(ctx, method, url, reqBody)
+	request, err := http.NewRequestWithContext(ctx, method, c.base.JoinPath(path).String(), opts.requestBody)
 	if err != nil {
 		return err
 	}
 
-	req.Header.Set("Content-Type", "application/json")
-	req.Header.Set("Accept", "application/json")
+	request.Header.Set("Content-Type", "application/json")
+	request.Header.Set("Accept", "application/json")
 
-	respObj, err := c.HTTP.Do(req)
+	response, err := http.DefaultClient.Do(request)
 	if err != nil {
 		return err
 	}
-	defer respObj.Body.Close()
+	defer response.Body.Close()
 
-	respBody, err := io.ReadAll(respObj.Body)
-	if err != nil {
-		return err
-	}
-
-	if err := checkError(respObj, respBody); err != nil {
-		return err
-	}
-
-	if len(respBody) > 0 && respData != nil {
-		if err := json.Unmarshal(respBody, respData); err != nil {
-			return err
+	if opts.responseFunc != nil {
+		scanner := bufio.NewScanner(response.Body)
+		for scanner.Scan() {
+			if err := opts.responseFunc(scanner.Bytes()); err != nil {
+				return err
+			}
 		}
 	}
+
 	return nil
 }
 
-func (c *Client) Generate(ctx context.Context, req *GenerateRequest, callback func(token string)) (*GenerateResponse, error) {
-	var res GenerateResponse
-	if err := c.stream(ctx, http.MethodPost, "/api/generate", req, func(token []byte) {
-		callback(string(token))
-	}); err != nil {
-		return nil, err
-	}
+type GenerateResponseFunc func(GenerateResponse) error
 
-	return &res, nil
+func (c *Client) Generate(ctx context.Context, req *GenerateRequest, fn GenerateResponseFunc) error {
+	return c.stream(ctx, http.MethodPost, "/api/generate",
+		OptionRequestBody(req),
+		OptionResponseFunc(func(bts []byte) error {
+			var resp GenerateResponse
+			if err := json.Unmarshal(bts, &resp); err != nil {
+				return err
+			}
+
+			return fn(resp)
+		}),
+	)
 }
 
-func (c *Client) Pull(ctx context.Context, req *PullRequest, callback func(progress PullProgress)) error {
-	var wg sync.WaitGroup
-	wg.Add(1)
-	if err := c.stream(ctx, http.MethodPost, "/api/pull", req, func(progressBytes []byte) {
-		var progress PullProgress
-		if err := json.Unmarshal(progressBytes, &progress); err != nil {
-			fmt.Println(err)
-			return
-		}
-		if progress.Completed >= progress.Total {
-			wg.Done()
-		}
-		callback(progress)
-	}); err != nil {
-		return err
-	}
+type PullProgressFunc func(PullProgress) error
 
-	wg.Wait()
-	return nil
+func (c *Client) Pull(ctx context.Context, req *PullRequest, fn PullProgressFunc) error {
+	return c.stream(ctx, http.MethodPost, "/api/pull",
+		OptionRequestBody(req),
+		OptionResponseFunc(func(bts []byte) error {
+			var resp PullProgress
+			if err := json.Unmarshal(bts, &resp); err != nil {
+				return err
+			}
+
+			return fn(resp)
+		}),
+	)
 }

+ 2 - 2
api/types.go

@@ -23,8 +23,8 @@ type PullRequest struct {
 }
 
 type PullProgress struct {
-	Total     int     `json:"total"`
-	Completed int     `json:"completed"`
+	Total     int64   `json:"total"`
+	Completed int64   `json:"completed"`
 	Percent   float64 `json:"percent"`
 }
 

+ 126 - 55
cmd/cmd.go

@@ -1,18 +1,22 @@
 package cmd
 
 import (
+	"bufio"
 	"context"
+	"errors"
 	"fmt"
 	"log"
 	"net"
 	"os"
 	"path"
-	"sync"
+	"time"
+
+	"github.com/schollz/progressbar/v3"
+	"github.com/spf13/cobra"
+	"golang.org/x/term"
 
-	"github.com/gosuri/uiprogress"
 	"github.com/jmorganca/ollama/api"
 	"github.com/jmorganca/ollama/server"
-	"github.com/spf13/cobra"
 )
 
 func cacheDir() string {
@@ -24,46 +28,126 @@ func cacheDir() string {
 	return path.Join(home, ".ollama")
 }
 
-func bytesToGB(bytes int) float64 {
-	return float64(bytes) / float64(1<<30)
-}
+func RunRun(cmd *cobra.Command, args []string) error {
+	_, err := os.Stat(args[0])
+	switch {
+	case errors.Is(err, os.ErrNotExist):
+		if err := pull(args[0]); err != nil {
+			return err
+		}
 
-func run(model string) error {
-	client, err := NewAPIClient()
-	if err != nil {
+		fmt.Println("Up to date.")
+	case err != nil:
 		return err
 	}
-	pr := api.PullRequest{
-		Model: model,
+
+	return RunGenerate(cmd, args)
+}
+
+func pull(model string) error {
+	client := api.NewClient()
+
+	var bar *progressbar.ProgressBar
+	return client.Pull(
+		context.Background(),
+		&api.PullRequest{Model: model},
+		func(progress api.PullProgress) error {
+			if bar == nil {
+				bar = progressbar.DefaultBytes(progress.Total)
+			}
+
+			return bar.Set64(progress.Completed)
+		},
+	)
+}
+
+func RunGenerate(_ *cobra.Command, args []string) error {
+	if len(args) > 1 {
+		return generateOneshot(args[0], args[1:]...)
+	}
+
+	if term.IsTerminal(int(os.Stdin.Fd())) {
+		return generateInteractive(args[0])
 	}
-	var bar *uiprogress.Bar
-	mutex := &sync.Mutex{}
-	var progressData api.PullProgress
-
-	pullCallback := func(progress api.PullProgress) {
-		mutex.Lock()
-		progressData = progress
-		if bar == nil {
-			uiprogress.Start()
-			bar = uiprogress.AddBar(int(progress.Total))
-			bar.PrependFunc(func(b *uiprogress.Bar) string {
-				return fmt.Sprintf("Downloading: %.2f GB / %.2f GB", bytesToGB(progressData.Completed), bytesToGB(progressData.Total))
-			})
-			bar.AppendFunc(func(b *uiprogress.Bar) string {
-				return fmt.Sprintf(" %d%%", int((float64(progressData.Completed)/float64(progressData.Total))*100))
-			})
+
+	return generateBatch(args[0])
+}
+
+func generate(model, prompt string) error {
+	client := api.NewClient()
+
+	spinner := progressbar.NewOptions(-1,
+		progressbar.OptionSetWriter(os.Stderr),
+		progressbar.OptionThrottle(60*time.Millisecond),
+		progressbar.OptionSpinnerType(14),
+		progressbar.OptionSetRenderBlankState(true),
+		progressbar.OptionSetElapsedTime(false),
+		progressbar.OptionClearOnFinish(),
+	)
+
+	go func() {
+		for range time.Tick(60 * time.Millisecond) {
+			if spinner.IsFinished() {
+				break
+			}
+
+			spinner.Add(1)
+		}
+	}()
+
+	client.Generate(context.Background(), &api.GenerateRequest{Model: model, Prompt: prompt}, func(resp api.GenerateResponse) error {
+		if !spinner.IsFinished() {
+			spinner.Finish()
+		}
+
+		fmt.Print(resp.Response)
+		return nil
+	})
+
+	fmt.Println()
+	fmt.Println()
+	return nil
+}
+
+func generateOneshot(model string, prompts ...string) error {
+	for _, prompt := range prompts {
+		fmt.Printf(">>> %s\n", prompt)
+		if err := generate(model, prompt); err != nil {
+			return err
 		}
-		bar.Set(int(progress.Completed))
-		mutex.Unlock()
 	}
-	if err := client.Pull(context.Background(), &pr, pullCallback); err != nil {
-		return err
+
+	return nil
+}
+
+func generateInteractive(model string) error {
+	fmt.Print(">>> ")
+	scanner := bufio.NewScanner(os.Stdin)
+	for scanner.Scan() {
+		if err := generate(model, scanner.Text()); err != nil {
+			return err
+		}
+
+		fmt.Print(">>> ")
 	}
-	fmt.Println("Up to date.")
+
 	return nil
 }
 
-func serve() error {
+func generateBatch(model string) error {
+	scanner := bufio.NewScanner(os.Stdin)
+	for scanner.Scan() {
+		prompt := scanner.Text()
+		fmt.Printf(">>> %s\n", prompt)
+		if err := generate(model, prompt); err != nil {
+			return err
+		}
+	}
+
+	return nil
+}
+
+func RunServer(_ *cobra.Command, _ []string) error {
 	ln, err := net.Listen("tcp", "127.0.0.1:11434")
 	if err != nil {
 		return err
@@ -72,49 +156,36 @@ func serve() error {
 	return server.Serve(ln)
 }
 
-func NewAPIClient() (*api.Client, error) {
-	return &api.Client{
-		URL: "http://localhost:11434",
-	}, nil
-}
-
 func NewCLI() *cobra.Command {
 	log.SetFlags(log.LstdFlags | log.Lshortfile)
 
 	rootCmd := &cobra.Command{
-		Use:   "ollama",
-		Short: "Large language model runner",
+		Use:          "ollama",
+		Short:        "Large language model runner",
+		SilenceUsage: true,
 		CompletionOptions: cobra.CompletionOptions{
 			DisableDefaultCmd: true,
 		},
-		PersistentPreRun: func(cmd *cobra.Command, args []string) {
-			// Disable usage printing on errors
-			cmd.SilenceUsage = true
+		PersistentPreRunE: func(_ *cobra.Command, args []string) error {
 			// create the models directory and it's parent
-			if err := os.MkdirAll(path.Join(cacheDir(), "models"), 0o700); err != nil {
-				panic(err)
-			}
+			return os.MkdirAll(path.Join(cacheDir(), "models"), 0o700)
 		},
 	}
 
 	cobra.EnableCommandSorting = false
 
 	runCmd := &cobra.Command{
-		Use:   "run MODEL",
+		Use:   "run MODEL [PROMPT]",
 		Short: "Run a model",
-		Args:  cobra.ExactArgs(1),
-		RunE: func(cmd *cobra.Command, args []string) error {
-			return run(args[0])
-		},
+		Args:  cobra.MinimumNArgs(1),
+		RunE:  RunRun,
 	}
 
 	serveCmd := &cobra.Command{
 		Use:     "serve",
 		Aliases: []string{"start"},
 		Short:   "Start ollama",
-		RunE: func(cmd *cobra.Command, args []string) error {
-			return serve()
-		},
+		RunE:    RunServer,
 	}
 
 	rootCmd.AddCommand(

+ 8 - 2
go.mod

@@ -4,10 +4,15 @@ go 1.20
 
 require (
 	github.com/gin-gonic/gin v1.9.1
-	github.com/gosuri/uiprogress v0.0.1
 	github.com/spf13/cobra v1.7.0
 )
 
+require (
+	github.com/mattn/go-runewidth v0.0.14 // indirect
+	github.com/mitchellh/colorstring v0.0.0-20190213212951-d06e56a500db // indirect
+	github.com/rivo/uniseg v0.2.0 // indirect
+)
+
 require (
 	github.com/bytedance/sonic v1.9.1 // indirect
 	github.com/chenzhuoyu/base64x v0.0.0-20221115062448-fe3a3abad311 // indirect
@@ -18,7 +23,6 @@ require (
 	github.com/go-playground/validator/v10 v10.14.0 // indirect
 	github.com/goccy/go-json v0.10.2 // indirect
 	github.com/google/go-cmp v0.5.9 // indirect
-	github.com/gosuri/uilive v0.0.4 // indirect
 	github.com/inconshreveable/mousetrap v1.1.0 // indirect
 	github.com/json-iterator/go v1.1.12 // indirect
 	github.com/klauspost/cpuid/v2 v2.2.4 // indirect
@@ -28,6 +32,7 @@ require (
 	github.com/modern-go/concurrent v0.0.0-20180306012644-bacd9c7ef1dd // indirect
 	github.com/modern-go/reflect2 v1.0.2 // indirect
 	github.com/pelletier/go-toml/v2 v2.0.8 // indirect
+	github.com/schollz/progressbar/v3 v3.13.1
 	github.com/spf13/pflag v1.0.5 // indirect
 	github.com/twitchyliquid64/golang-asm v0.15.1 // indirect
 	github.com/ugorji/go/codec v1.2.11 // indirect
@@ -35,6 +40,7 @@ require (
 	golang.org/x/crypto v0.10.0 // indirect
 	golang.org/x/net v0.10.0 // indirect
 	golang.org/x/sys v0.10.0 // indirect
+	golang.org/x/term v0.10.0
 	golang.org/x/text v0.10.0 // indirect
 	google.golang.org/protobuf v1.30.0 // indirect
 	gopkg.in/yaml.v3 v3.0.1 // indirect

+ 14 - 4
go.sum

@@ -28,14 +28,11 @@ github.com/google/go-cmp v0.5.5/go.mod h1:v8dTdLbMG2kIc/vJvl+f65V22dbkXbowE6jgT/
 github.com/google/go-cmp v0.5.9 h1:O2Tfq5qg4qc4AmwVlvv0oLiVAGB7enBSJ2x2DqQFi38=
 github.com/google/go-cmp v0.5.9/go.mod h1:17dUlkBOakJ0+DkrSSNjCkIjxS6bF9zb3elmeNGIjoY=
 github.com/google/gofuzz v1.0.0/go.mod h1:dBl0BpW6vV/+mYPU4Po3pmUjxk6FQPldtuIdl/M65Eg=
-github.com/gosuri/uilive v0.0.4 h1:hUEBpQDj8D8jXgtCdBu7sWsy5sbW/5GhuO8KBwJ2jyY=
-github.com/gosuri/uilive v0.0.4/go.mod h1:V/epo5LjjlDE5RJUcqx8dbw+zc93y5Ya3yg8tfZ74VI=
-github.com/gosuri/uiprogress v0.0.1 h1:0kpv/XY/qTmFWl/SkaJykZXrBBzwwadmW8fRb7RJSxw=
-github.com/gosuri/uiprogress v0.0.1/go.mod h1:C1RTYn4Sc7iEyf6j8ft5dyoZ4212h8G1ol9QQluh5+0=
 github.com/inconshreveable/mousetrap v1.1.0 h1:wN+x4NVGpMsO7ErUn/mUI3vEoE6Jt13X2s0bqwp9tc8=
 github.com/inconshreveable/mousetrap v1.1.0/go.mod h1:vpF70FUmC8bwa3OWnCshd2FqLfsEA9PFc4w1p2J65bw=
 github.com/json-iterator/go v1.1.12 h1:PV8peI4a0ysnczrg+LtxykD8LfKY9ML6u2jnxaEnrnM=
 github.com/json-iterator/go v1.1.12/go.mod h1:e30LSqwooZae/UwlEbR2852Gd8hjQvJoHmT4TnhNGBo=
+github.com/k0kubun/go-ansi v0.0.0-20180517002512-3bf9e2903213/go.mod h1:vNUNkEQ1e29fT/6vq2aBdFsgNPmy8qMdSay1npru+Sw=
 github.com/klauspost/cpuid/v2 v2.0.9/go.mod h1:FInQzS24/EEf25PyTYn52gqo7WaD8xa0213Md/qVLRg=
 github.com/klauspost/cpuid/v2 v2.2.4 h1:acbojRNwl3o09bUq+yDCtZFc1aiwaAAxtcn8YkZXnvk=
 github.com/klauspost/cpuid/v2 v2.2.4/go.mod h1:RVVoqg1df56z8g3pUjL/3lE5UfnlrJX8tyFgg4nqhuY=
@@ -43,8 +40,13 @@ github.com/leodido/go-urn v1.2.4 h1:XlAE/cm/ms7TE/VMVoduSpNBoyc2dOxHs5MZSwAN63Q=
 github.com/leodido/go-urn v1.2.4/go.mod h1:7ZrI8mTSeBSHl/UaRyKQW1qZeMgak41ANeCNaVckg+4=
 github.com/lithammer/fuzzysearch v1.1.8 h1:/HIuJnjHuXS8bKaiTMeeDlW2/AyIWk2brx1V8LFgLN4=
 github.com/lithammer/fuzzysearch v1.1.8/go.mod h1:IdqeyBClc3FFqSzYq/MXESsS4S0FsZ5ajtkr5xPLts4=
+github.com/mattn/go-isatty v0.0.17/go.mod h1:kYGgaQfpe5nmfYZH+SKPsOc2e4SrIfOl2e/yFXSvRLM=
 github.com/mattn/go-isatty v0.0.19 h1:JITubQf0MOLdlGRuRq+jtsDlekdYPia9ZFsB8h/APPA=
 github.com/mattn/go-isatty v0.0.19/go.mod h1:W+V8PltTTMOvKvAeJH7IuucS94S2C6jfK/D7dTCTo3Y=
+github.com/mattn/go-runewidth v0.0.14 h1:+xnbZSEeDbOIg5/mE6JF0w6n9duR1l3/WmbinWVwUuU=
+github.com/mattn/go-runewidth v0.0.14/go.mod h1:Jdepj2loyihRzMpdS35Xk/zdY8IAYHsh153qUoGf23w=
+github.com/mitchellh/colorstring v0.0.0-20190213212951-d06e56a500db h1:62I3jR2EmQ4l5rM/4FEfDWcRD+abF5XlKShorW5LRoQ=
+github.com/mitchellh/colorstring v0.0.0-20190213212951-d06e56a500db/go.mod h1:l0dey0ia/Uv7NcFFVbCLtqEBQbrT4OCwCSKTEv6enCw=
 github.com/modern-go/concurrent v0.0.0-20180228061459-e0a39a4cb421/go.mod h1:6dJC0mAP4ikYIbvyc7fijjWJddQyLn8Ig3JB5CqoB9Q=
 github.com/modern-go/concurrent v0.0.0-20180306012644-bacd9c7ef1dd h1:TRLaZ9cD/w8PVh93nsPXa1VrQ6jlwL5oN8l14QlcNfg=
 github.com/modern-go/concurrent v0.0.0-20180306012644-bacd9c7ef1dd/go.mod h1:6dJC0mAP4ikYIbvyc7fijjWJddQyLn8Ig3JB5CqoB9Q=
@@ -54,7 +56,11 @@ github.com/pelletier/go-toml/v2 v2.0.8 h1:0ctb6s9mE31h0/lhu+J6OPmVeDxJn+kYnJc2jZ
 github.com/pelletier/go-toml/v2 v2.0.8/go.mod h1:vuYfssBdrU2XDZ9bYydBu6t+6a6PYNcZljzZR9VXg+4=
 github.com/pmezard/go-difflib v1.0.0 h1:4DBwDE0NGyQoBHbLQYPwSUPoCMWR5BEzIk/f1lZbAQM=
 github.com/pmezard/go-difflib v1.0.0/go.mod h1:iKH77koFhYxTK1pcRnkKkqfTogsbg7gZNVY4sRDYZ/4=
+github.com/rivo/uniseg v0.2.0 h1:S1pD9weZBuJdFmowNwbpi7BJ8TNftyUImj/0WQi72jY=
+github.com/rivo/uniseg v0.2.0/go.mod h1:J6wj4VEh+S6ZtnVlnTBMWIodfgj8LQOQFoIToxlJtxc=
 github.com/russross/blackfriday/v2 v2.1.0/go.mod h1:+Rmxgy9KzJVeS9/2gXHxylqXiyQDYRxCVz55jmeOWTM=
+github.com/schollz/progressbar/v3 v3.13.1 h1:o8rySDYiQ59Mwzy2FELeHY5ZARXZTVJC7iHD6PEFUiE=
+github.com/schollz/progressbar/v3 v3.13.1/go.mod h1:xvrbki8kfT1fzWzBT/UZd9L6GA+jdL7HAgq2RFnO6fQ=
 github.com/spf13/cobra v1.7.0 h1:hyqWnYt1ZQShIddO5kBpj3vu05/++x6tJ6dg8EC572I=
 github.com/spf13/cobra v1.7.0/go.mod h1:uLxZILRyS/50WlhOIKD7W6V5bgeIt+4sICxh6uRMrb0=
 github.com/spf13/pflag v1.0.5 h1:iy+VFUOCP1a+8yFto/drg2CJ5u0yRoB7fZw3DKv/JXA=
@@ -99,6 +105,7 @@ golang.org/x/sys v0.0.0-20210615035016-665e8c7367d1/go.mod h1:oPkhp1MJrh7nUepCBc
 golang.org/x/sys v0.0.0-20220520151302-bc2c85ada10a/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg=
 golang.org/x/sys v0.0.0-20220704084225-05e143d24a9e/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg=
 golang.org/x/sys v0.0.0-20220722155257-8c9f86f7a55f/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg=
+golang.org/x/sys v0.0.0-20220811171246-fbc7d0a398ab/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg=
 golang.org/x/sys v0.5.0/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg=
 golang.org/x/sys v0.6.0/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg=
 golang.org/x/sys v0.10.0 h1:SqMFp9UcQJZa+pmYuAKjd9xq1f0j5rLcDIk0mj4qAsA=
@@ -106,6 +113,9 @@ golang.org/x/sys v0.10.0/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg=
 golang.org/x/term v0.0.0-20201126162022-7de9c90e9dd1/go.mod h1:bj7SfCRtBDWHUb9snDiAeCFNEtKQo2Wmx5Cou7ajbmo=
 golang.org/x/term v0.0.0-20210927222741-03fcf44c2211/go.mod h1:jbD1KX2456YbFQfuXm/mYQcufACuNUgVhRMnK/tPxf8=
 golang.org/x/term v0.5.0/go.mod h1:jMB1sMXY+tzblOD4FWmEbocvup2/aLOaQEp7JmGp78k=
+golang.org/x/term v0.6.0/go.mod h1:m6U89DPEgQRMq3DNkDClhWw02AUbt2daBVO4cn4Hv9U=
+golang.org/x/term v0.10.0 h1:3R7pNqamzBraeqj/Tj8qt1aQ2HpmlC+Cx/qL/7hn4/c=
+golang.org/x/term v0.10.0/go.mod h1:lpqdcUyK/oCiQxvxVrppt5ggO2KCZ5QblwqPnfZ6d5o=
 golang.org/x/text v0.3.0/go.mod h1:NqM8EUOU14njkJ3fqMW+pc6Ldnwhi/IjpwHt7yyuwOQ=
 golang.org/x/text v0.3.3/go.mod h1:5Zoc/QRtKVWzQhOtBMvqHzDpF6irO9z98xDceosuGiQ=
 golang.org/x/text v0.3.7/go.mod h1:u+2+/6zg+i71rQMx5EYifcz6MCKuco9NR6JIITiCfzQ=

+ 18 - 16
server/models.go

@@ -4,7 +4,6 @@ import (
 	"encoding/json"
 	"fmt"
 	"io"
-	"io/ioutil"
 	"net/http"
 	"os"
 	"path"
@@ -30,6 +29,15 @@ type Model struct {
 	License          string `json:"license"`
 }
 
+func (m *Model) FullName() string {
+	home, err := os.UserHomeDir()
+	if err != nil {
+		panic(err)
+	}
+
+	return path.Join(home, ".ollama", "models", m.Name+".bin")
+}
+
 func pull(model string, progressCh chan<- api.PullProgress) error {
 	remote, err := getRemote(model)
 	if err != nil {
@@ -45,7 +53,7 @@ func getRemote(model string) (*Model, error) {
 		return nil, fmt.Errorf("failed to get directory: %w", err)
 	}
 	defer resp.Body.Close()
-	body, err := ioutil.ReadAll(resp.Body)
+	body, err := io.ReadAll(resp.Body)
 	if err != nil {
 		return nil, fmt.Errorf("failed to read directory: %w", err)
 	}
@@ -64,13 +72,6 @@ func getRemote(model string) (*Model, error) {
 
 func saveModel(model *Model, progressCh chan<- api.PullProgress) error {
 	// this models cache directory is created by the server on startup
-	home, err := os.UserHomeDir()
-	if err != nil {
-		return fmt.Errorf("failed to get home directory: %w", err)
-	}
-	modelsCache := path.Join(home, ".ollama", "models")
-
-	fileName := path.Join(modelsCache, model.Name+".bin")
 
 	client := &http.Client{}
 	req, err := http.NewRequest("GET", model.URL, nil)
@@ -78,16 +79,16 @@ func saveModel(model *Model, progressCh chan<- api.PullProgress) error {
 		return fmt.Errorf("failed to download model: %w", err)
 	}
 	// check for resume
-	alreadyDownloaded := 0
-	fileInfo, err := os.Stat(fileName)
+	alreadyDownloaded := int64(0)
+	fileInfo, err := os.Stat(model.FullName())
 	if err != nil {
 		if !os.IsNotExist(err) {
 			return fmt.Errorf("failed to check resume model file: %w", err)
 		}
 		// file doesn't exist, create it now
 	} else {
-		alreadyDownloaded = int(fileInfo.Size())
-		req.Header.Add("Range", "bytes="+strconv.Itoa(alreadyDownloaded)+"-")
+		alreadyDownloaded = fileInfo.Size()
+		req.Header.Add("Range", fmt.Sprintf("bytes=%d-", alreadyDownloaded))
 	}
 
 	resp, err := client.Do(req)
@@ -111,13 +112,13 @@ func saveModel(model *Model, progressCh chan<- api.PullProgress) error {
 		return fmt.Errorf("failed to download model: %s", resp.Status)
 	}
 
-	out, err := os.OpenFile(fileName, os.O_CREATE|os.O_APPEND|os.O_WRONLY, 0o644)
+	out, err := os.OpenFile(model.FullName(), os.O_CREATE|os.O_APPEND|os.O_WRONLY, 0o644)
 	if err != nil {
 		panic(err)
 	}
 	defer out.Close()
 
-	totalSize, _ := strconv.Atoi(resp.Header.Get("Content-Length"))
+	totalSize, _ := strconv.ParseInt(resp.Header.Get("Content-Length"), 10, 64)
 
 	buf := make([]byte, 1024)
 	totalBytes := alreadyDownloaded
@@ -134,7 +135,8 @@ func saveModel(model *Model, progressCh chan<- api.PullProgress) error {
 		if _, err := out.Write(buf[:n]); err != nil {
 			return err
 		}
-		totalBytes += n
+
+		totalBytes += int64(n)
 
 		// send progress updates
 		progressCh <- api.PullProgress{

+ 4 - 0
server/routes.go

@@ -37,6 +37,10 @@ func generate(c *gin.Context) {
 		return
 	}
 
+	if remoteModel, _ := getRemote(req.Model); remoteModel != nil {
+		req.Model = remoteModel.FullName()
+	}
+
 	model, err := llama.New(req.Model, llama.EnableF16Memory, llama.SetContext(128), llama.EnableEmbeddings, llama.SetGPULayers(gpulayers))
 	if err != nil {
 		fmt.Println("Loading the model failed:", err.Error())