|
@@ -79,35 +79,54 @@ func generate(c *gin.Context) {
|
|
req.Prompt = sb.String()
|
|
req.Prompt = sb.String()
|
|
}
|
|
}
|
|
|
|
|
|
- ch := make(chan string)
|
|
|
|
|
|
+ ch := make(chan any)
|
|
g, _ := errgroup.WithContext(c.Request.Context())
|
|
g, _ := errgroup.WithContext(c.Request.Context())
|
|
g.Go(func() error {
|
|
g.Go(func() error {
|
|
defer close(ch)
|
|
defer close(ch)
|
|
return llm.Predict(req.Prompt, func(s string) {
|
|
return llm.Predict(req.Prompt, func(s string) {
|
|
- ch <- s
|
|
|
|
|
|
+ ch <- api.GenerateResponse{Response: s}
|
|
})
|
|
})
|
|
})
|
|
})
|
|
|
|
|
|
g.Go(func() error {
|
|
g.Go(func() error {
|
|
- c.Stream(func(w io.Writer) bool {
|
|
|
|
- s, ok := <-ch
|
|
|
|
- if !ok {
|
|
|
|
- return false
|
|
|
|
- }
|
|
|
|
|
|
+ stream(c, ch)
|
|
|
|
+ return nil
|
|
|
|
+ })
|
|
|
|
|
|
- bts, err := json.Marshal(api.GenerateResponse{Response: s})
|
|
|
|
- if err != nil {
|
|
|
|
- return false
|
|
|
|
- }
|
|
|
|
|
|
+ if err := g.Wait(); err != nil && !errors.Is(err, io.EOF) {
|
|
|
|
+ c.JSON(http.StatusInternalServerError, gin.H{"error": err.Error()})
|
|
|
|
+ return
|
|
|
|
+ }
|
|
|
|
+}
|
|
|
|
|
|
- bts = append(bts, '\n')
|
|
|
|
- if _, err := w.Write(bts); err != nil {
|
|
|
|
- return false
|
|
|
|
- }
|
|
|
|
|
|
+func pull(c *gin.Context) {
|
|
|
|
+ var req api.PullRequest
|
|
|
|
+ if err := c.ShouldBindJSON(&req); err != nil {
|
|
|
|
+ c.JSON(http.StatusBadRequest, gin.H{"error": err.Error()})
|
|
|
|
+ return
|
|
|
|
+ }
|
|
|
|
+
|
|
|
|
+ remote, err := getRemote(req.Model)
|
|
|
|
+ if err != nil {
|
|
|
|
+ c.JSON(http.StatusBadRequest, gin.H{"error": err.Error()})
|
|
|
|
+ return
|
|
|
|
+ }
|
|
|
|
|
|
- return true
|
|
|
|
|
|
+ ch := make(chan any)
|
|
|
|
+ g, _ := errgroup.WithContext(c.Request.Context())
|
|
|
|
+ g.Go(func() error {
|
|
|
|
+ defer close(ch)
|
|
|
|
+ return saveModel(remote, func(total, completed int64) {
|
|
|
|
+ ch <- api.PullProgress{
|
|
|
|
+ Total: total,
|
|
|
|
+ Completed: completed,
|
|
|
|
+ Percent: float64(total) / float64(completed) * 100,
|
|
|
|
+ }
|
|
})
|
|
})
|
|
|
|
+ })
|
|
|
|
|
|
|
|
+ g.Go(func() error {
|
|
|
|
+ stream(c, ch)
|
|
return nil
|
|
return nil
|
|
})
|
|
})
|
|
|
|
|
|
@@ -124,47 +143,7 @@ func Serve(ln net.Listener) error {
|
|
c.String(http.StatusOK, "Ollama is running")
|
|
c.String(http.StatusOK, "Ollama is running")
|
|
})
|
|
})
|
|
|
|
|
|
- r.POST("api/pull", func(c *gin.Context) {
|
|
|
|
- var req api.PullRequest
|
|
|
|
- if err := c.ShouldBindJSON(&req); err != nil {
|
|
|
|
- c.JSON(http.StatusBadRequest, gin.H{"error": err.Error()})
|
|
|
|
- return
|
|
|
|
- }
|
|
|
|
-
|
|
|
|
- progressCh := make(chan api.PullProgress)
|
|
|
|
- go func() {
|
|
|
|
- defer close(progressCh)
|
|
|
|
- if err := pull(req.Model, progressCh); err != nil {
|
|
|
|
- var opError *net.OpError
|
|
|
|
- if errors.As(err, &opError) {
|
|
|
|
- c.JSON(http.StatusBadGateway, gin.H{"error": err.Error()})
|
|
|
|
- return
|
|
|
|
- }
|
|
|
|
- c.JSON(http.StatusBadRequest, gin.H{"error": err.Error()})
|
|
|
|
- return
|
|
|
|
- }
|
|
|
|
- }()
|
|
|
|
-
|
|
|
|
- c.Stream(func(w io.Writer) bool {
|
|
|
|
- progress, ok := <-progressCh
|
|
|
|
- if !ok {
|
|
|
|
- return false
|
|
|
|
- }
|
|
|
|
-
|
|
|
|
- bts, err := json.Marshal(progress)
|
|
|
|
- if err != nil {
|
|
|
|
- return false
|
|
|
|
- }
|
|
|
|
-
|
|
|
|
- bts = append(bts, '\n')
|
|
|
|
- if _, err := w.Write(bts); err != nil {
|
|
|
|
- return false
|
|
|
|
- }
|
|
|
|
-
|
|
|
|
- return true
|
|
|
|
- })
|
|
|
|
- })
|
|
|
|
-
|
|
|
|
|
|
+ r.POST("api/pull", pull)
|
|
r.POST("/api/generate", generate)
|
|
r.POST("/api/generate", generate)
|
|
|
|
|
|
log.Printf("Listening on %s", ln.Addr())
|
|
log.Printf("Listening on %s", ln.Addr())
|
|
@@ -186,3 +165,24 @@ func matchRankOne(source string, targets []string) (bestMatch string, bestRank i
|
|
|
|
|
|
return
|
|
return
|
|
}
|
|
}
|
|
|
|
+
|
|
|
|
+func stream(c *gin.Context, ch chan any) {
|
|
|
|
+ c.Stream(func(w io.Writer) bool {
|
|
|
|
+ val, ok := <-ch
|
|
|
|
+ if !ok {
|
|
|
|
+ return false
|
|
|
|
+ }
|
|
|
|
+
|
|
|
|
+ bts, err := json.Marshal(val)
|
|
|
|
+ if err != nil {
|
|
|
|
+ return false
|
|
|
|
+ }
|
|
|
|
+
|
|
|
|
+ bts = append(bts, '\n')
|
|
|
|
+ if _, err := w.Write(bts); err != nil {
|
|
|
|
+ return false
|
|
|
|
+ }
|
|
|
|
+
|
|
|
|
+ return true
|
|
|
|
+ })
|
|
|
|
+}
|