Bläddra i källkod

if directory cannot be resolved, do not fail

Bruce MacDonald 1 år sedan
förälder
incheckning
0bee4a8c07
3 ändrade filer med 36 tillägg och 1 borttagningar
  1. 5 0
      api/client.go
  2. 1 0
      api/types.go
  3. 30 1
      server/routes.go

+ 5 - 0
api/client.go

@@ -106,6 +106,11 @@ func (c *Client) Pull(ctx context.Context, req *PullRequest, fn PullProgressFunc
 				return err
 			}
 
+			if resp.Error.Message != "" {
+				// couldn't pull the model from the directory, proceed anyway
+				return nil
+			}
+
 			return fn(resp)
 		}),
 	)

+ 1 - 0
api/types.go

@@ -26,6 +26,7 @@ type PullProgress struct {
 	Total     int64   `json:"total"`
 	Completed int64   `json:"completed"`
 	Percent   float64 `json:"percent"`
+	Error     Error   `json:"error"`
 }
 
 type GenerateRequest struct {

+ 30 - 1
server/routes.go

@@ -3,12 +3,14 @@ package server
 import (
 	"embed"
 	"encoding/json"
+	"errors"
 	"fmt"
 	"io"
 	"log"
 	"math"
 	"net"
 	"net/http"
+	"os"
 	"path"
 	"runtime"
 	"strings"
@@ -25,6 +27,15 @@ import (
 var templatesFS embed.FS
 var templates = template.Must(template.ParseFS(templatesFS, "templates/*.prompt"))
 
+func cacheDir() string {
+	home, err := os.UserHomeDir()
+	if err != nil {
+		panic(err)
+	}
+
+	return path.Join(home, ".ollama")
+}
+
 func generate(c *gin.Context) {
 	var req api.GenerateRequest
 	req.ModelOptions = api.DefaultModelOptions
@@ -37,9 +48,16 @@ func generate(c *gin.Context) {
 	if remoteModel, _ := getRemote(req.Model); remoteModel != nil {
 		req.Model = remoteModel.FullName()
 	}
+	if _, err := os.Stat(req.Model); err != nil {
+		if !errors.Is(err, os.ErrNotExist) {
+			c.JSON(http.StatusBadRequest, gin.H{"message": err.Error()})
+			return
+		}
+		req.Model = path.Join(cacheDir(), "models", req.Model+".bin")
+	}
 
 	modelOpts := getModelOpts(req)
-	modelOpts.NGPULayers = 1  // hard-code this for now
+	modelOpts.NGPULayers = 1 // hard-code this for now
 
 	model, err := llama.New(req.Model, modelOpts)
 	if err != nil {
@@ -118,6 +136,17 @@ func Serve(ln net.Listener) error {
 		go func() {
 			defer close(progressCh)
 			if err := pull(req.Model, progressCh); err != nil {
+				var opError *net.OpError
+				if errors.As(err, &opError) {
+					result := api.PullProgress{
+						Error: api.Error{
+							Code:    http.StatusBadGateway,
+							Message: "failed to get models from directory",
+						},
+					}
+					c.JSON(http.StatusBadGateway, result)
+					return
+				}
 				c.JSON(http.StatusBadRequest, gin.H{"message": err.Error()})
 				return
 			}