소스 검색

fix race

block on write which only returns when the channel is closed. this is
contrary to the previous arrangement where the handler may return but
the stream hasn't finished writing. it can lead to the client receiving
unexpected responses (since the request has been handled) or worst case
a nil-pointer dereference as the stream tries to flush a nil writer
Michael Yang 1 년 전
부모
커밋
5ade3db040
1개의 변경된 파일26개의 추가작업 그리고 31개의 파일을 삭제
  1. 26 31
      server/routes.go

+ 26 - 31
server/routes.go

@@ -58,9 +58,6 @@ func generate(c *gin.Context) {
 		req.Model = path.Join(cacheDir(), "models", req.Model+".bin")
 	}
 
-	ch := make(chan any)
-	go stream(c, ch)
-
 	templateNames := make([]string, 0, len(templates.Templates()))
 	for _, template := range templates.Templates() {
 		templateNames = append(templateNames, template.Name())
@@ -84,21 +81,21 @@ func generate(c *gin.Context) {
 	}
 	defer llm.Close()
 
-	fn := func(r api.GenerateResponse) {
-		r.Model = req.Model
-		r.CreatedAt = time.Now().UTC()
-		if r.Done {
-			r.TotalDuration = time.Since(start)
-		}
-
-		ch <- r
-	}
-
-	if err := llm.Predict(req.Context, req.Prompt, fn); err != nil {
-		c.JSON(http.StatusInternalServerError, gin.H{"error": err.Error()})
-		return
-	}
+	ch := make(chan any)
+	go func() {
+		defer close(ch)
+		llm.Predict(req.Context, req.Prompt, func(r api.GenerateResponse) {
+			r.Model = req.Model
+			r.CreatedAt = time.Now().UTC()
+			if r.Done {
+				r.TotalDuration = time.Since(start)
+			}
+
+			ch <- r
+		})
+	}()
 
+	streamResponse(c, ch)
 }
 
 func pull(c *gin.Context) {
@@ -133,20 +130,18 @@ func pull(c *gin.Context) {
 	}
 
 	ch := make(chan any)
-	go stream(c, ch)
-
-	fn := func(total, completed int64) {
-		ch <- api.PullProgress{
-			Total:     total,
-			Completed: completed,
-			Percent:   float64(completed) / float64(total) * 100,
-		}
-	}
+	go func() {
+		defer close(ch)
+		saveModel(remote, func(total, completed int64) {
+			ch <- api.PullProgress{
+				Total:     total,
+				Completed: completed,
+				Percent:   float64(completed) / float64(total) * 100,
+			}
+		})
+	}()
 
-	if err := saveModel(remote, fn); err != nil {
-		c.JSON(http.StatusInternalServerError, gin.H{"error": err.Error()})
-		return
-	}
+	streamResponse(c, ch)
 }
 
 func Serve(ln net.Listener) error {
@@ -179,7 +174,7 @@ func matchRankOne(source string, targets []string) (bestMatch string, bestRank i
 	return
 }
 
-func stream(c *gin.Context, ch chan any) {
+func streamResponse(c *gin.Context, ch chan any) {
 	c.Stream(func(w io.Writer) bool {
 		val, ok := <-ch
 		if !ok {