Browse Source

pull models

Bruce MacDonald 1 year ago
parent
commit
a6494f8211
5 changed files with 202 additions and 9 deletions
  1. 12 1
      api/client.go
  2. 8 0
      api/types.go
  3. 17 1
      cmd/cmd.go
  4. 140 0
      server/models.go
  5. 25 7
      server/routes.go

+ 12 - 1
api/client.go

@@ -35,7 +35,7 @@ func checkError(resp *http.Response, body []byte) error {
 	return apiError
 }
 
-func (c *Client) stream(ctx context.Context, method string, path string, reqData any, callback func (data []byte)) error {
+func (c *Client) stream(ctx context.Context, method string, path string, reqData any, callback func(data []byte)) error {
 	var reqBody io.Reader
 	var data []byte
 	var err error
@@ -140,3 +140,14 @@ func (c *Client) Generate(ctx context.Context, req *GenerateRequest, callback fu
 
 	return &res, nil
 }
+
+func (c *Client) Pull(ctx context.Context, req *PullRequest, callback func(token string)) (*PullResponse, error) {
+	var res PullResponse
+	if err := c.stream(ctx, http.MethodPost, "/api/pull", req, func(token []byte) {
+		callback(string(token))
+	}); err != nil {
+		return nil, err
+	}
+
+	return &res, nil
+}

+ 8 - 0
api/types.go

@@ -18,6 +18,14 @@ func (e Error) Error() string {
 	return e.Message
 }
 
+type PullRequest struct {
+	Model string `json:"model"`
+}
+
+type PullResponse struct {
+	Response string `json:"response"`
+}
+
 type GenerateRequest struct {
 	Model  string `json:"model"`
 	Prompt string `json:"prompt"`

+ 17 - 1
cmd/cmd.go

@@ -2,6 +2,7 @@ package cmd
 
 import (
 	"context"
+	"fmt"
 	"log"
 	"net"
 	"net/http"
@@ -23,6 +24,21 @@ func cacheDir() string {
 	return path.Join(home, ".ollama")
 }
 
+func run(model string) error {
+	client, err := NewAPIClient()
+	if err != nil {
+		return err
+	}
+	pr := api.PullRequest{
+		Model: model,
+	}
+	callback := func(progress string) {
+		fmt.Println(progress)
+	}
+	_, err = client.Pull(context.Background(), &pr, callback)
+	return err
+}
+
 func serve() error {
 	sp := path.Join(cacheDir(), "ollama.sock")
 
@@ -94,7 +110,7 @@ func NewCLI() *cobra.Command {
 		Short: "Run a model",
 		Args:  cobra.ExactArgs(1),
 		RunE: func(cmd *cobra.Command, args []string) error {
-			return nil
+			return run(args[0])
 		},
 	}
 

+ 140 - 0
server/models.go

@@ -0,0 +1,140 @@
+package server
+
+import (
+	"encoding/json"
+	"fmt"
+	"io"
+	"io/ioutil"
+	"net/http"
+	"os"
+	"path"
+	"strconv"
+)
+
+// const directoryURL = "https://ollama.ai/api/models"
+const directoryURL = "https://raw.githubusercontent.com/jmorganca/ollama/go/models.json"
+
+type directoryCtxKey string
+
+var dirCtx directoryCtxKey = "directory"
+
+type Model struct {
+	Name             string `json:"name"`
+	DisplayName      string `json:"display_name"`
+	Parameters       string `json:"parameters"`
+	URL              string `json:"url"`
+	ShortDescription string `json:"short_description"`
+	Description      string `json:"description"`
+	PublishedBy      string `json:"published_by"`
+	OriginalAuthor   string `json:"original_author"`
+	OriginalURL      string `json:"original_url"`
+	License          string `json:"license"`
+}
+
+func pull(model string, progressCh chan<- string) error {
+	remote, err := getRemote(model)
+	if err != nil {
+		return fmt.Errorf("failed to pull model: %w", err)
+	}
+
+	return saveModel(remote, progressCh)
+}
+
+func getRemote(model string) (*Model, error) {
+	// resolve the model download from our directory
+	resp, err := http.Get(directoryURL)
+	if err != nil {
+		return nil, fmt.Errorf("failed to get directory: %w", err)
+	}
+	defer resp.Body.Close()
+	body, err := ioutil.ReadAll(resp.Body)
+	if err != nil {
+		return nil, fmt.Errorf("failed to read directory: %w", err)
+	}
+	var models []Model
+	err = json.Unmarshal(body, &models)
+	if err != nil {
+		return nil, fmt.Errorf("failed to parse directory: %w", err)
+	}
+	for _, m := range models {
+		if m.Name == model {
+			return &m, nil
+		}
+	}
+	return nil, fmt.Errorf("model not found in directory: %s", model)
+}
+
+func saveModel(model *Model, progressCh chan<- string) error {
+	// this models cache directory is created by the server on startup
+	home, err := os.UserHomeDir()
+	if err != nil {
+		return fmt.Errorf("failed to get home directory: %w", err)
+	}
+	modelsCache := path.Join(home, ".ollama", "models")
+
+	fileName := path.Join(modelsCache, model.Name+".bin")
+
+	client := &http.Client{}
+	req, err := http.NewRequest("GET", model.URL, nil)
+	if err != nil {
+		panic(err)
+	}
+	// check for resume
+	fileInfo, err := os.Stat(fileName)
+	if err != nil {
+		if !os.IsNotExist(err) {
+			return fmt.Errorf("failed to check resume model file: %w", err)
+		}
+		// file doesn't exist, create it now
+	} else {
+		req.Header.Add("Range", "bytes="+strconv.FormatInt(fileInfo.Size(), 10)+"-")
+	}
+
+	resp, err := client.Do(req)
+	if err != nil {
+		return fmt.Errorf("failed to download model: %w", err)
+	}
+
+	defer resp.Body.Close()
+
+	if resp.StatusCode != http.StatusOK {
+		return fmt.Errorf("failed to download model: %s", resp.Status)
+	}
+
+	out, err := os.OpenFile(fileName, os.O_CREATE|os.O_APPEND|os.O_WRONLY, 0o644)
+	if err != nil {
+		panic(err)
+	}
+	defer out.Close()
+
+	totalSize, _ := strconv.Atoi(resp.Header.Get("Content-Length"))
+
+	buf := make([]byte, 1024)
+	totalBytes := 0
+
+	for {
+		n, err := resp.Body.Read(buf)
+
+		if err != nil && err != io.EOF {
+			return err
+		}
+
+		if n == 0 {
+			break
+		}
+
+		if _, err := out.Write(buf[:n]); err != nil {
+			return err
+		}
+
+		totalBytes += n
+
+		// send progress updates
+		progressCh <- fmt.Sprintf("Downloaded %d out of %d bytes (%.2f%%)", totalBytes, totalSize, float64(totalBytes)/float64(totalSize)*100)
+	}
+
+	// send completion message
+	progressCh <- "Download complete!"
+
+	return nil
+}

+ 25 - 7
server/routes.go

@@ -14,12 +14,6 @@ import (
 	"github.com/jmorganca/ollama/api"
 )
 
-func pull(c *gin.Context) {
-	// TODO
-
-	c.JSON(http.StatusOK, gin.H{"message": "ok"})
-}
-
 func generate(c *gin.Context) {
 	// TODO: these should be request parameters
 	gpulayers := 1
@@ -65,7 +59,31 @@ func generate(c *gin.Context) {
 func Serve(ln net.Listener) error {
 	r := gin.Default()
 
-	r.POST("api/pull", pull)
+	r.POST("api/pull", func(c *gin.Context) {
+		var req api.PullRequest
+		if err := c.ShouldBindJSON(&req); err != nil {
+			c.JSON(http.StatusBadRequest, gin.H{"message": err.Error()})
+			return
+		}
+
+		progressCh := make(chan string)
+		go func() {
+			defer close(progressCh)
+			if err := pull(req.Model, progressCh); err != nil {
+				c.JSON(http.StatusBadRequest, gin.H{"message": err.Error()})
+				return
+			}
+		}()
+
+		c.Stream(func(w io.Writer) bool {
+			progress, ok := <-progressCh
+			if !ok {
+				return false
+			}
+			c.SSEvent("progress", progress)
+			return true
+		})
+	})
 
 	r.POST("/api/generate", generate)