Michael Yang пре 1 година
родитељ
комит
a806b03f62
3 измењених фајлова са 22 додато и 39 уклоњено
  1. 0 1
      go.mod
  2. 0 2
      go.sum
  3. 22 36
      server/routes.go

+ 0 - 1
go.mod

@@ -39,7 +39,6 @@ require (
 	golang.org/x/arch v0.3.0 // indirect
 	golang.org/x/crypto v0.10.0 // indirect
 	golang.org/x/net v0.10.0 // indirect
-	golang.org/x/sync v0.3.0
 	golang.org/x/sys v0.10.0 // indirect
 	golang.org/x/term v0.10.0
 	golang.org/x/text v0.10.0 // indirect

+ 0 - 2
go.sum

@@ -99,8 +99,6 @@ golang.org/x/net v0.10.0/go.mod h1:0qNGK6F8kojg2nk9dLZ2mShWaEBan6FAoqfSigmmuDg=
 golang.org/x/sync v0.0.0-20190423024810-112230192c58/go.mod h1:RxMgew5VJxzue5/jJTE5uejpjVlOe/izrB70Jof72aM=
 golang.org/x/sync v0.0.0-20220722155255-886fb9371eb4/go.mod h1:RxMgew5VJxzue5/jJTE5uejpjVlOe/izrB70Jof72aM=
 golang.org/x/sync v0.1.0/go.mod h1:RxMgew5VJxzue5/jJTE5uejpjVlOe/izrB70Jof72aM=
-golang.org/x/sync v0.3.0 h1:ftCYgMx6zT/asHUrPw8BLLscYtGznsLAnjq5RH9P66E=
-golang.org/x/sync v0.3.0/go.mod h1:FU7BRWz2tNW+3quACPkgCx/L+uEAv1htQ0V83Z9Rj+Y=
 golang.org/x/sys v0.0.0-20190215142949-d0b11bdaac8a/go.mod h1:STP8DvDyc/dI5b8T5hshtkjS+E42TnysNCUPdjciGhY=
 golang.org/x/sys v0.0.0-20201119102817-f84b799fce68/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs=
 golang.org/x/sys v0.0.0-20210615035016-665e8c7367d1/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg=

+ 22 - 36
server/routes.go

@@ -16,7 +16,6 @@ import (
 
 	"github.com/gin-gonic/gin"
 	"github.com/lithammer/fuzzysearch/fuzzy"
-	"golang.org/x/sync/errgroup"
 
 	"github.com/jmorganca/ollama/api"
 	"github.com/jmorganca/ollama/llama"
@@ -56,12 +55,8 @@ func generate(c *gin.Context) {
 		req.Model = path.Join(cacheDir(), "models", req.Model+".bin")
 	}
 
-	llm, err := llama.New(req.Model, req.Options)
-	if err != nil {
-		c.JSON(http.StatusInternalServerError, gin.H{"error": err.Error()})
-		return
-	}
-	defer llm.Close()
+	ch := make(chan any)
+	go stream(c, ch)
 
 	templateNames := make([]string, 0, len(templates.Templates()))
 	for _, template := range templates.Templates() {
@@ -79,24 +74,22 @@ func generate(c *gin.Context) {
 		req.Prompt = sb.String()
 	}
 
-	ch := make(chan any)
-	g, _ := errgroup.WithContext(c.Request.Context())
-	g.Go(func() error {
-		defer close(ch)
-		return llm.Predict(req.Prompt, func(s string) {
-			ch <- api.GenerateResponse{Response: s}
-		})
-	})
+	llm, err := llama.New(req.Model, req.Options)
+	if err != nil {
+		c.JSON(http.StatusInternalServerError, gin.H{"error": err.Error()})
+		return
+	}
+	defer llm.Close()
 
-	g.Go(func() error {
-		stream(c, ch)
-		return nil
-	})
+	fn := func(s string) {
+		ch <- api.GenerateResponse{Response: s}
+	}
 
-	if err := g.Wait(); err != nil && !errors.Is(err, io.EOF) {
+	if err := llm.Predict(req.Prompt, fn); err != nil {
 		c.JSON(http.StatusInternalServerError, gin.H{"error": err.Error()})
 		return
 	}
+
 }
 
 func pull(c *gin.Context) {
@@ -113,24 +106,17 @@ func pull(c *gin.Context) {
 	}
 
 	ch := make(chan any)
-	g, _ := errgroup.WithContext(c.Request.Context())
-	g.Go(func() error {
-		defer close(ch)
-		return saveModel(remote, func(total, completed int64) {
-			ch <- api.PullProgress{
-				Total:     total,
-				Completed: completed,
-				Percent:   float64(total) / float64(completed) * 100,
-			}
-		})
-	})
+	go stream(c, ch)
 
-	g.Go(func() error {
-		stream(c, ch)
-		return nil
-	})
+	fn := func(total, completed int64) {
+		ch <- api.PullProgress{
+			Total:     total,
+			Completed: completed,
+			Percent:   float64(total) / float64(completed) * 100,
+		}
+	}
 
-	if err := g.Wait(); err != nil && !errors.Is(err, io.EOF) {
+	if err := saveModel(remote, fn); err != nil {
 		c.JSON(http.StatusInternalServerError, gin.H{"error": err.Error()})
 		return
 	}