Browse Source

Merge pull request #70 from jmorganca/offline-fixes

offline fixes
Michael Yang 1 year ago
parent
commit
7226980fb6
6 changed files with 143 additions and 144 deletions
  1. 23 5
      api/client.go
  2. 14 5
      cmd/cmd.go
  3. 0 1
      go.mod
  4. 0 2
      go.sum
  5. 45 56
      server/models.go
  6. 61 75
      server/routes.go

+ 23 - 5
api/client.go

@@ -10,6 +10,20 @@ import (
 	"net/url"
 )
 
+type StatusError struct {
+	StatusCode int
+	Status     string
+	Message    string
+}
+
+func (e StatusError) Error() string {
+	if e.Message != "" {
+		return fmt.Sprintf("%s: %s", e.Status, e.Message)
+	}
+
+	return e.Status
+}
+
 type Client struct {
 	base url.URL
 }
@@ -25,7 +39,7 @@ func NewClient(hosts ...string) *Client {
 	}
 }
 
-func (c *Client) stream(ctx context.Context, method, path string, data any, callback func([]byte) error) error {
+func (c *Client) stream(ctx context.Context, method, path string, data any, fn func([]byte) error) error {
 	var buf *bytes.Buffer
 	if data != nil {
 		bts, err := json.Marshal(data)
@@ -53,7 +67,7 @@ func (c *Client) stream(ctx context.Context, method, path string, data any, call
 	scanner := bufio.NewScanner(response.Body)
 	for scanner.Scan() {
 		var errorResponse struct {
-			Error string `json:"error"`
+			Error string `json:"error,omitempty"`
 		}
 
 		bts := scanner.Bytes()
@@ -61,11 +75,15 @@ func (c *Client) stream(ctx context.Context, method, path string, data any, call
 			return fmt.Errorf("unmarshal: %w", err)
 		}
 
-		if len(errorResponse.Error) > 0 {
-			return fmt.Errorf("stream: %s", errorResponse.Error)
+		if response.StatusCode >= 400 {
+			return StatusError{
+				StatusCode: response.StatusCode,
+				Status:     response.Status,
+				Message:    errorResponse.Error,
+			}
 		}
 
-		if err := callback(bts); err != nil {
+		if err := fn(bts); err != nil {
 			return err
 		}
 	}

+ 14 - 5
cmd/cmd.go

@@ -7,6 +7,7 @@ import (
 	"fmt"
 	"log"
 	"net"
+	"net/http"
 	"os"
 	"path"
 	"strings"
@@ -34,7 +35,14 @@ func RunRun(cmd *cobra.Command, args []string) error {
 	switch {
 	case errors.Is(err, os.ErrNotExist):
 		if err := pull(args[0]); err != nil {
-			return err
+			var apiStatusError api.StatusError
+			if !errors.As(err, &apiStatusError) {
+				return err
+			}
+
+			if apiStatusError.StatusCode != http.StatusBadGateway {
+				return err
+			}
 		}
 	case err != nil:
 		return err
@@ -50,11 +58,12 @@ func pull(model string) error {
 		context.Background(),
 		&api.PullRequest{Model: model},
 		func(progress api.PullProgress) error {
-			if bar == nil && progress.Percent == 100 {
-				// already downloaded
-				return nil
-			}
 			if bar == nil {
+				if progress.Percent == 100 {
+					// already downloaded
+					return nil
+				}
+
 				bar = progressbar.DefaultBytes(progress.Total)
 			}
 

+ 0 - 1
go.mod

@@ -39,7 +39,6 @@ require (
 	golang.org/x/arch v0.3.0 // indirect
 	golang.org/x/crypto v0.10.0 // indirect
 	golang.org/x/net v0.10.0 // indirect
-	golang.org/x/sync v0.3.0
 	golang.org/x/sys v0.10.0 // indirect
 	golang.org/x/term v0.10.0
 	golang.org/x/text v0.10.0 // indirect

+ 0 - 2
go.sum

@@ -99,8 +99,6 @@ golang.org/x/net v0.10.0/go.mod h1:0qNGK6F8kojg2nk9dLZ2mShWaEBan6FAoqfSigmmuDg=
 golang.org/x/sync v0.0.0-20190423024810-112230192c58/go.mod h1:RxMgew5VJxzue5/jJTE5uejpjVlOe/izrB70Jof72aM=
 golang.org/x/sync v0.0.0-20220722155255-886fb9371eb4/go.mod h1:RxMgew5VJxzue5/jJTE5uejpjVlOe/izrB70Jof72aM=
 golang.org/x/sync v0.1.0/go.mod h1:RxMgew5VJxzue5/jJTE5uejpjVlOe/izrB70Jof72aM=
-golang.org/x/sync v0.3.0 h1:ftCYgMx6zT/asHUrPw8BLLscYtGznsLAnjq5RH9P66E=
-golang.org/x/sync v0.3.0/go.mod h1:FU7BRWz2tNW+3quACPkgCx/L+uEAv1htQ0V83Z9Rj+Y=
 golang.org/x/sys v0.0.0-20190215142949-d0b11bdaac8a/go.mod h1:STP8DvDyc/dI5b8T5hshtkjS+E42TnysNCUPdjciGhY=
 golang.org/x/sys v0.0.0-20201119102817-f84b799fce68/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs=
 golang.org/x/sys v0.0.0-20210615035016-665e8c7367d1/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg=

+ 45 - 56
server/models.go

@@ -2,14 +2,13 @@ package server
 
 import (
 	"encoding/json"
+	"errors"
 	"fmt"
 	"io"
 	"net/http"
 	"os"
 	"path"
 	"strconv"
-
-	"github.com/jmorganca/ollama/api"
 )
 
 const directoryURL = "https://ollama.ai/api/models"
@@ -36,12 +35,12 @@ func (m *Model) FullName() string {
 	return path.Join(home, ".ollama", "models", m.Name+".bin")
 }
 
-func pull(model string, progressCh chan<- api.PullProgress) error {
-	remote, err := getRemote(model)
-	if err != nil {
-		return fmt.Errorf("failed to pull model: %w", err)
-	}
-	return saveModel(remote, progressCh)
+func (m *Model) TempFile() string {
+	fullName := m.FullName()
+	return path.Join(
+		path.Dir(fullName),
+		fmt.Sprintf(".%s.part", path.Base(fullName)),
+	)
 }
 
 func getRemote(model string) (*Model, error) {
@@ -68,7 +67,7 @@ func getRemote(model string) (*Model, error) {
 	return nil, fmt.Errorf("model not found in directory: %s", model)
 }
 
-func saveModel(model *Model, progressCh chan<- api.PullProgress) error {
+func saveModel(model *Model, fn func(total, completed int64)) error {
 	// this models cache directory is created by the server on startup
 
 	client := &http.Client{}
@@ -76,41 +75,45 @@ func saveModel(model *Model, progressCh chan<- api.PullProgress) error {
 	if err != nil {
 		return fmt.Errorf("failed to download model: %w", err)
 	}
-	// check for resume
-	alreadyDownloaded := int64(0)
-	fileInfo, err := os.Stat(model.FullName())
-	if err != nil {
-		if !os.IsNotExist(err) {
-			return fmt.Errorf("failed to check resume model file: %w", err)
-		}
-		// file doesn't exist, create it now
-	} else {
-		alreadyDownloaded = fileInfo.Size()
-		req.Header.Add("Range", fmt.Sprintf("bytes=%d-", alreadyDownloaded))
+
+	// check if completed file exists
+	fi, err := os.Stat(model.FullName())
+	switch {
+	case errors.Is(err, os.ErrNotExist):
+		// noop, file doesn't exist so create it
+	case err != nil:
+		return fmt.Errorf("stat: %w", err)
+	default:
+		fn(fi.Size(), fi.Size())
+		return nil
+	}
+
+	var size int64
+
+	// completed file doesn't exist, check partial file
+	fi, err = os.Stat(model.TempFile())
+	switch {
+	case errors.Is(err, os.ErrNotExist):
+		// noop, file doesn't exist so create it
+	case err != nil:
+		return fmt.Errorf("stat: %w", err)
+	default:
+		size = fi.Size()
 	}
 
+	req.Header.Add("Range", fmt.Sprintf("bytes=%d-", size))
+
 	resp, err := client.Do(req)
 	if err != nil {
 		return fmt.Errorf("failed to download model: %w", err)
 	}
-
 	defer resp.Body.Close()
 
-	if resp.StatusCode == http.StatusRequestedRangeNotSatisfiable {
-		// already downloaded
-		progressCh <- api.PullProgress{
-			Total:     alreadyDownloaded,
-			Completed: alreadyDownloaded,
-			Percent:   100,
-		}
-		return nil
-	}
-
-	if resp.StatusCode != http.StatusOK && resp.StatusCode != http.StatusPartialContent {
+	if resp.StatusCode >= 400 {
 		return fmt.Errorf("failed to download model: %s", resp.Status)
 	}
 
-	out, err := os.OpenFile(model.FullName(), os.O_CREATE|os.O_APPEND|os.O_WRONLY, 0o644)
+	out, err := os.OpenFile(model.TempFile(), os.O_CREATE|os.O_APPEND|os.O_WRONLY, 0o644)
 	if err != nil {
 		panic(err)
 	}
@@ -118,37 +121,23 @@ func saveModel(model *Model, progressCh chan<- api.PullProgress) error {
 
 	totalSize, _ := strconv.ParseInt(resp.Header.Get("Content-Length"), 10, 64)
 
-	buf := make([]byte, 1024)
-	totalBytes := alreadyDownloaded
-	totalSize += alreadyDownloaded
+	totalBytes := size
+	totalSize += size
 
 	for {
-		n, err := resp.Body.Read(buf)
-		if err != nil && err != io.EOF {
+		n, err := io.CopyN(out, resp.Body, 8192)
+		if err != nil && !errors.Is(err, io.EOF) {
 			return err
 		}
+
 		if n == 0 {
 			break
 		}
-		if _, err := out.Write(buf[:n]); err != nil {
-			return err
-		}
-
-		totalBytes += int64(n)
-
-		// send progress updates
-		progressCh <- api.PullProgress{
-			Total:     totalSize,
-			Completed: totalBytes,
-			Percent:   float64(totalBytes) / float64(totalSize) * 100,
-		}
-	}
 
-	progressCh <- api.PullProgress{
-		Total:     totalSize,
-		Completed: totalSize,
-		Percent:   100,
+		totalBytes += n
+		fn(totalSize, totalBytes)
 	}
 
-	return nil
+	fn(totalSize, totalSize)
+	return os.Rename(model.TempFile(), model.FullName())
 }

+ 61 - 75
server/routes.go

@@ -16,7 +16,6 @@ import (
 
 	"github.com/gin-gonic/gin"
 	"github.com/lithammer/fuzzysearch/fuzzy"
-	"golang.org/x/sync/errgroup"
 
 	"github.com/jmorganca/ollama/api"
 	"github.com/jmorganca/ollama/llama"
@@ -56,12 +55,8 @@ func generate(c *gin.Context) {
 		req.Model = path.Join(cacheDir(), "models", req.Model+".bin")
 	}
 
-	llm, err := llama.New(req.Model, req.Options)
-	if err != nil {
-		c.JSON(http.StatusInternalServerError, gin.H{"error": err.Error()})
-		return
-	}
-	defer llm.Close()
+	ch := make(chan any)
+	go stream(c, ch)
 
 	templateNames := make([]string, 0, len(templates.Templates()))
 	for _, template := range templates.Templates() {
@@ -79,39 +74,49 @@ func generate(c *gin.Context) {
 		req.Prompt = sb.String()
 	}
 
-	ch := make(chan string)
-	g, _ := errgroup.WithContext(c.Request.Context())
-	g.Go(func() error {
-		defer close(ch)
-		return llm.Predict(req.Prompt, func(s string) {
-			ch <- s
-		})
-	})
+	llm, err := llama.New(req.Model, req.Options)
+	if err != nil {
+		c.JSON(http.StatusInternalServerError, gin.H{"error": err.Error()})
+		return
+	}
+	defer llm.Close()
 
-	g.Go(func() error {
-		c.Stream(func(w io.Writer) bool {
-			s, ok := <-ch
-			if !ok {
-				return false
-			}
+	fn := func(s string) {
+		ch <- api.GenerateResponse{Response: s}
+	}
 
-			bts, err := json.Marshal(api.GenerateResponse{Response: s})
-			if err != nil {
-				return false
-			}
+	if err := llm.Predict(req.Prompt, fn); err != nil {
+		c.JSON(http.StatusInternalServerError, gin.H{"error": err.Error()})
+		return
+	}
 
-			bts = append(bts, '\n')
-			if _, err := w.Write(bts); err != nil {
-				return false
-			}
+}
 
-			return true
-		})
+func pull(c *gin.Context) {
+	var req api.PullRequest
+	if err := c.ShouldBindJSON(&req); err != nil {
+		c.JSON(http.StatusBadRequest, gin.H{"error": err.Error()})
+		return
+	}
 
-		return nil
-	})
+	remote, err := getRemote(req.Model)
+	if err != nil {
+		c.JSON(http.StatusBadGateway, gin.H{"error": err.Error()})
+		return
+	}
 
-	if err := g.Wait(); err != nil && !errors.Is(err, io.EOF) {
+	ch := make(chan any)
+	go stream(c, ch)
+
+	fn := func(total, completed int64) {
+		ch <- api.PullProgress{
+			Total:     total,
+			Completed: completed,
+			Percent:   float64(total) / float64(completed) * 100,
+		}
+	}
+
+	if err := saveModel(remote, fn); err != nil {
 		c.JSON(http.StatusInternalServerError, gin.H{"error": err.Error()})
 		return
 	}
@@ -124,47 +129,7 @@ func Serve(ln net.Listener) error {
 		c.String(http.StatusOK, "Ollama is running")
 	})
 
-	r.POST("api/pull", func(c *gin.Context) {
-		var req api.PullRequest
-		if err := c.ShouldBindJSON(&req); err != nil {
-			c.JSON(http.StatusBadRequest, gin.H{"error": err.Error()})
-			return
-		}
-
-		progressCh := make(chan api.PullProgress)
-		go func() {
-			defer close(progressCh)
-			if err := pull(req.Model, progressCh); err != nil {
-				var opError *net.OpError
-				if errors.As(err, &opError) {
-					c.JSON(http.StatusBadGateway, gin.H{"error": err.Error()})
-					return
-				}
-				c.JSON(http.StatusBadRequest, gin.H{"error": err.Error()})
-				return
-			}
-		}()
-
-		c.Stream(func(w io.Writer) bool {
-			progress, ok := <-progressCh
-			if !ok {
-				return false
-			}
-
-			bts, err := json.Marshal(progress)
-			if err != nil {
-				return false
-			}
-
-			bts = append(bts, '\n')
-			if _, err := w.Write(bts); err != nil {
-				return false
-			}
-
-			return true
-		})
-	})
-
+	r.POST("api/pull", pull)
 	r.POST("/api/generate", generate)
 
 	log.Printf("Listening on %s", ln.Addr())
@@ -186,3 +151,24 @@ func matchRankOne(source string, targets []string) (bestMatch string, bestRank i
 
 	return
 }
+
+func stream(c *gin.Context, ch chan any) {
+	c.Stream(func(w io.Writer) bool {
+		val, ok := <-ch
+		if !ok {
+			return false
+		}
+
+		bts, err := json.Marshal(val)
+		if err != nil {
+			return false
+		}
+
+		bts = append(bts, '\n')
+		if _, err := w.Write(bts); err != nil {
+			return false
+		}
+
+		return true
+	})
+}