فهرست منبع

common stream producer

Michael Yang 1 سال پیش
والد
کامیت
2a66a1164a
2فایلهای تغییر یافته به همراه61 افزوده شده و 85 حذف شده
  1. 4 28
      server/models.go
  2. 57 57
      server/routes.go

+ 4 - 28
server/models.go

@@ -8,8 +8,6 @@ import (
 	"os"
 	"path"
 	"strconv"
-
-	"github.com/jmorganca/ollama/api"
 )
 
 const directoryURL = "https://ollama.ai/api/models"
@@ -36,14 +34,6 @@ func (m *Model) FullName() string {
 	return path.Join(home, ".ollama", "models", m.Name+".bin")
 }
 
-func pull(model string, progressCh chan<- api.PullProgress) error {
-	remote, err := getRemote(model)
-	if err != nil {
-		return fmt.Errorf("failed to pull model: %w", err)
-	}
-	return saveModel(remote, progressCh)
-}
-
 func getRemote(model string) (*Model, error) {
 	// resolve the model download from our directory
 	resp, err := http.Get(directoryURL)
@@ -68,7 +58,7 @@ func getRemote(model string) (*Model, error) {
 	return nil, fmt.Errorf("model not found in directory: %s", model)
 }
 
-func saveModel(model *Model, progressCh chan<- api.PullProgress) error {
+func saveModel(model *Model, fn func(total, completed int64)) error {
 	// this models cache directory is created by the server on startup
 
 	client := &http.Client{}
@@ -98,11 +88,7 @@ func saveModel(model *Model, progressCh chan<- api.PullProgress) error {
 
 	if resp.StatusCode == http.StatusRequestedRangeNotSatisfiable {
 		// already downloaded
-		progressCh <- api.PullProgress{
-			Total:     alreadyDownloaded,
-			Completed: alreadyDownloaded,
-			Percent:   100,
-		}
+		fn(alreadyDownloaded, alreadyDownloaded)
 		return nil
 	}
 
@@ -136,19 +122,9 @@ func saveModel(model *Model, progressCh chan<- api.PullProgress) error {
 
 		totalBytes += int64(n)
 
-		// send progress updates
-		progressCh <- api.PullProgress{
-			Total:     totalSize,
-			Completed: totalBytes,
-			Percent:   float64(totalBytes) / float64(totalSize) * 100,
-		}
-	}
-
-	progressCh <- api.PullProgress{
-		Total:     totalSize,
-		Completed: totalSize,
-		Percent:   100,
+		fn(totalSize, totalBytes)
 	}
 
+	fn(totalSize, totalSize)
 	return nil
 }

+ 57 - 57
server/routes.go

@@ -79,35 +79,54 @@ func generate(c *gin.Context) {
 		req.Prompt = sb.String()
 	}
 
-	ch := make(chan string)
+	ch := make(chan any)
 	g, _ := errgroup.WithContext(c.Request.Context())
 	g.Go(func() error {
 		defer close(ch)
 		return llm.Predict(req.Prompt, func(s string) {
-			ch <- s
+			ch <- api.GenerateResponse{Response: s}
 		})
 	})
 
 	g.Go(func() error {
-		c.Stream(func(w io.Writer) bool {
-			s, ok := <-ch
-			if !ok {
-				return false
-			}
+		stream(c, ch)
+		return nil
+	})
 
-			bts, err := json.Marshal(api.GenerateResponse{Response: s})
-			if err != nil {
-				return false
-			}
+	if err := g.Wait(); err != nil && !errors.Is(err, io.EOF) {
+		c.JSON(http.StatusInternalServerError, gin.H{"error": err.Error()})
+		return
+	}
+}
 
-			bts = append(bts, '\n')
-			if _, err := w.Write(bts); err != nil {
-				return false
-			}
+func pull(c *gin.Context) {
+	var req api.PullRequest
+	if err := c.ShouldBindJSON(&req); err != nil {
+		c.JSON(http.StatusBadRequest, gin.H{"error": err.Error()})
+		return
+	}
+
+	remote, err := getRemote(req.Model)
+	if err != nil {
+		c.JSON(http.StatusBadRequest, gin.H{"error": err.Error()})
+		return
+	}
 
-			return true
+	ch := make(chan any)
+	g, _ := errgroup.WithContext(c.Request.Context())
+	g.Go(func() error {
+		defer close(ch)
+		return saveModel(remote, func(total, completed int64) {
+			ch <- api.PullProgress{
+				Total:     total,
+				Completed: completed,
+				Percent:   float64(total) / float64(completed) * 100,
+			}
 		})
+	})
 
+	g.Go(func() error {
+		stream(c, ch)
 		return nil
 	})
 
@@ -124,47 +143,7 @@ func Serve(ln net.Listener) error {
 		c.String(http.StatusOK, "Ollama is running")
 	})
 
-	r.POST("api/pull", func(c *gin.Context) {
-		var req api.PullRequest
-		if err := c.ShouldBindJSON(&req); err != nil {
-			c.JSON(http.StatusBadRequest, gin.H{"error": err.Error()})
-			return
-		}
-
-		progressCh := make(chan api.PullProgress)
-		go func() {
-			defer close(progressCh)
-			if err := pull(req.Model, progressCh); err != nil {
-				var opError *net.OpError
-				if errors.As(err, &opError) {
-					c.JSON(http.StatusBadGateway, gin.H{"error": err.Error()})
-					return
-				}
-				c.JSON(http.StatusBadRequest, gin.H{"error": err.Error()})
-				return
-			}
-		}()
-
-		c.Stream(func(w io.Writer) bool {
-			progress, ok := <-progressCh
-			if !ok {
-				return false
-			}
-
-			bts, err := json.Marshal(progress)
-			if err != nil {
-				return false
-			}
-
-			bts = append(bts, '\n')
-			if _, err := w.Write(bts); err != nil {
-				return false
-			}
-
-			return true
-		})
-	})
-
+	r.POST("api/pull", pull)
 	r.POST("/api/generate", generate)
 
 	log.Printf("Listening on %s", ln.Addr())
@@ -186,3 +165,24 @@ func matchRankOne(source string, targets []string) (bestMatch string, bestRank i
 
 	return
 }
+
+func stream(c *gin.Context, ch chan any) {
+	c.Stream(func(w io.Writer) bool {
+		val, ok := <-ch
+		if !ok {
+			return false
+		}
+
+		bts, err := json.Marshal(val)
+		if err != nil {
+			return false
+		}
+
+		bts = append(bts, '\n')
+		if _, err := w.Write(bts); err != nil {
+			return false
+		}
+
+		return true
+	})
+}