Prechádzať zdrojové kódy

if directory cannot be resolved, do not fail

Bruce MacDonald 1 rok pred
rodič
commit
61dd87bd90
3 zmenil súbory, kde vykonal 43 pridanie a 2 odobranie
  1. 1 0
      api/types.go
  2. 5 0
      cmd/cmd.go
  3. 37 2
      server/routes.go

+ 1 - 0
api/types.go

@@ -26,6 +26,7 @@ type PullProgress struct {
 	Total     int64   `json:"total"`
 	Completed int64   `json:"completed"`
 	Percent   float64 `json:"percent"`
+	Error     Error   `json:"error"`
 }
 
 type GenerateRequest struct {

+ 5 - 0
cmd/cmd.go

@@ -7,6 +7,7 @@ import (
 	"fmt"
 	"log"
 	"net"
+	"net/http"
 	"os"
 	"path"
 	"strings"
@@ -50,6 +51,10 @@ func pull(model string) error {
 		context.Background(),
 		&api.PullRequest{Model: model},
 		func(progress api.PullProgress) error {
+			if progress.Error.Code == http.StatusBadGateway {
+				// couldn't pull the model from the directory, proceed in offline mode
+				return nil
+			}
 			if bar == nil && progress.Percent == 100 {
 				// already downloaded
 				return nil

+ 37 - 2
server/routes.go

@@ -3,12 +3,14 @@ package server
 import (
 	"embed"
 	"encoding/json"
+	"errors"
 	"fmt"
 	"io"
 	"log"
 	"math"
 	"net"
 	"net/http"
+	"os"
 	"path"
 	"runtime"
 	"strings"
@@ -25,6 +27,15 @@ import (
 var templatesFS embed.FS
 var templates = template.Must(template.ParseFS(templatesFS, "templates/*.prompt"))
 
+func cacheDir() string {
+	home, err := os.UserHomeDir()
+	if err != nil {
+		panic(err)
+	}
+
+	return path.Join(home, ".ollama")
+}
+
 func generate(c *gin.Context) {
 	var req api.GenerateRequest
 	req.ModelOptions = api.DefaultModelOptions
@@ -34,12 +45,25 @@ func generate(c *gin.Context) {
 		return
 	}
 
-	if remoteModel, _ := getRemote(req.Model); remoteModel != nil {
+	remoteModel, err := getRemote(req.Model)
+	if err != nil {
+		// couldn't check the directory, proceed in offline mode
+		_, err := os.Stat(req.Model)
+		if err != nil {
+			if !os.IsNotExist(err) {
+				c.JSON(http.StatusInternalServerError, gin.H{"message": err.Error()})
+				return
+			}
+			// couldn't find the model file, try setting the model to the cache directory
+			req.Model = path.Join(cacheDir(), "models", req.Model+".bin")
+		}
+	}
+	if remoteModel != nil {
 		req.Model = remoteModel.FullName()
 	}
 
 	modelOpts := getModelOpts(req)
-	modelOpts.NGPULayers = 1  // hard-code this for now
+	modelOpts.NGPULayers = 1 // hard-code this for now
 
 	model, err := llama.New(req.Model, modelOpts)
 	if err != nil {
@@ -118,6 +142,17 @@ func Serve(ln net.Listener) error {
 		go func() {
 			defer close(progressCh)
 			if err := pull(req.Model, progressCh); err != nil {
+				var opError *net.OpError
+				if errors.As(err, &opError) {
+					result := api.PullProgress{
+						Error: api.Error{
+							Code:    http.StatusBadGateway,
+							Message: "failed to get models from directory",
+						},
+					}
+					c.JSON(http.StatusBadGateway, result)
+					return
+				}
 				c.JSON(http.StatusBadRequest, gin.H{"message": err.Error()})
 				return
 			}