فهرست منبع

Merge pull request #81 from jmorganca/fix-race-2

fix race
Michael Yang 1 سال پیش
والد
کامیت
567e74e7d7
1فایلهای تغییر یافته به همراه26 افزوده شده و 31 حذف شده
  1. 26 31
      server/routes.go

+ 26 - 31
server/routes.go

@@ -58,9 +58,6 @@ func generate(c *gin.Context) {
 		req.Model = path.Join(cacheDir(), "models", req.Model+".bin")
 	}
 
-	ch := make(chan any)
-	go stream(c, ch)
-
 	templateNames := make([]string, 0, len(templates.Templates()))
 	for _, template := range templates.Templates() {
 		templateNames = append(templateNames, template.Name())
@@ -84,21 +81,21 @@ func generate(c *gin.Context) {
 	}
 	defer llm.Close()
 
-	fn := func(r api.GenerateResponse) {
-		r.Model = req.Model
-		r.CreatedAt = time.Now().UTC()
-		if r.Done {
-			r.TotalDuration = time.Since(start)
-		}
-
-		ch <- r
-	}
-
-	if err := llm.Predict(req.Context, req.Prompt, fn); err != nil {
-		c.JSON(http.StatusInternalServerError, gin.H{"error": err.Error()})
-		return
-	}
+	ch := make(chan any)
+	go func() {
+		defer close(ch)
+		llm.Predict(req.Context, req.Prompt, func(r api.GenerateResponse) {
+			r.Model = req.Model
+			r.CreatedAt = time.Now().UTC()
+			if r.Done {
+				r.TotalDuration = time.Since(start)
+			}
+
+			ch <- r
+		})
+	}()
 
+	streamResponse(c, ch)
 }
 
 func pull(c *gin.Context) {
@@ -133,20 +130,18 @@ func pull(c *gin.Context) {
 	}
 
 	ch := make(chan any)
-	go stream(c, ch)
-
-	fn := func(total, completed int64) {
-		ch <- api.PullProgress{
-			Total:     total,
-			Completed: completed,
-			Percent:   float64(completed) / float64(total) * 100,
-		}
-	}
+	go func() {
+		defer close(ch)
+		saveModel(remote, func(total, completed int64) {
+			ch <- api.PullProgress{
+				Total:     total,
+				Completed: completed,
+				Percent:   float64(completed) / float64(total) * 100,
+			}
+		})
+	}()
 
-	if err := saveModel(remote, fn); err != nil {
-		c.JSON(http.StatusInternalServerError, gin.H{"error": err.Error()})
-		return
-	}
+	streamResponse(c, ch)
 }
 
 func Serve(ln net.Listener) error {
@@ -179,7 +174,7 @@ func matchRankOne(source string, targets []string) (bestMatch string, bestRank i
 	return
 }
 
-func stream(c *gin.Context, ch chan any) {
+func streamResponse(c *gin.Context, ch chan any) {
 	c.Stream(func(w io.Writer) bool {
 		val, ok := <-ch
 		if !ok {