|
@@ -16,7 +16,6 @@ import (
|
|
|
|
|
|
"github.com/gin-gonic/gin"
|
|
|
"github.com/lithammer/fuzzysearch/fuzzy"
|
|
|
- "golang.org/x/sync/errgroup"
|
|
|
|
|
|
"github.com/jmorganca/ollama/api"
|
|
|
"github.com/jmorganca/ollama/llama"
|
|
@@ -56,12 +55,8 @@ func generate(c *gin.Context) {
|
|
|
req.Model = path.Join(cacheDir(), "models", req.Model+".bin")
|
|
|
}
|
|
|
|
|
|
- llm, err := llama.New(req.Model, req.Options)
|
|
|
- if err != nil {
|
|
|
- c.JSON(http.StatusInternalServerError, gin.H{"error": err.Error()})
|
|
|
- return
|
|
|
- }
|
|
|
- defer llm.Close()
|
|
|
+ ch := make(chan any)
|
|
|
+ go stream(c, ch)
|
|
|
|
|
|
templateNames := make([]string, 0, len(templates.Templates()))
|
|
|
for _, template := range templates.Templates() {
|
|
@@ -79,24 +74,22 @@ func generate(c *gin.Context) {
|
|
|
req.Prompt = sb.String()
|
|
|
}
|
|
|
|
|
|
- ch := make(chan any)
|
|
|
- g, _ := errgroup.WithContext(c.Request.Context())
|
|
|
- g.Go(func() error {
|
|
|
- defer close(ch)
|
|
|
- return llm.Predict(req.Prompt, func(s string) {
|
|
|
- ch <- api.GenerateResponse{Response: s}
|
|
|
- })
|
|
|
- })
|
|
|
+ llm, err := llama.New(req.Model, req.Options)
|
|
|
+ if err != nil {
|
|
|
+ c.JSON(http.StatusInternalServerError, gin.H{"error": err.Error()})
|
|
|
+ return
|
|
|
+ }
|
|
|
+ defer llm.Close()
|
|
|
|
|
|
- g.Go(func() error {
|
|
|
- stream(c, ch)
|
|
|
- return nil
|
|
|
- })
|
|
|
+ fn := func(s string) {
|
|
|
+ ch <- api.GenerateResponse{Response: s}
|
|
|
+ }
|
|
|
|
|
|
- if err := g.Wait(); err != nil && !errors.Is(err, io.EOF) {
|
|
|
+ if err := llm.Predict(req.Prompt, fn); err != nil {
|
|
|
c.JSON(http.StatusInternalServerError, gin.H{"error": err.Error()})
|
|
|
return
|
|
|
}
|
|
|
+
|
|
|
}
|
|
|
|
|
|
func pull(c *gin.Context) {
|
|
@@ -113,24 +106,17 @@ func pull(c *gin.Context) {
|
|
|
}
|
|
|
|
|
|
ch := make(chan any)
|
|
|
- g, _ := errgroup.WithContext(c.Request.Context())
|
|
|
- g.Go(func() error {
|
|
|
- defer close(ch)
|
|
|
- return saveModel(remote, func(total, completed int64) {
|
|
|
- ch <- api.PullProgress{
|
|
|
- Total: total,
|
|
|
- Completed: completed,
|
|
|
- Percent: float64(total) / float64(completed) * 100,
|
|
|
- }
|
|
|
- })
|
|
|
- })
|
|
|
+ go stream(c, ch)
|
|
|
|
|
|
- g.Go(func() error {
|
|
|
- stream(c, ch)
|
|
|
- return nil
|
|
|
- })
|
|
|
+ fn := func(total, completed int64) {
|
|
|
+ ch <- api.PullProgress{
|
|
|
+ Total: total,
|
|
|
+ Completed: completed,
|
|
|
+ Percent: float64(total) / float64(completed) * 100,
|
|
|
+ }
|
|
|
+ }
|
|
|
|
|
|
- if err := g.Wait(); err != nil && !errors.Is(err, io.EOF) {
|
|
|
+ if err := saveModel(remote, fn); err != nil {
|
|
|
c.JSON(http.StatusInternalServerError, gin.H{"error": err.Error()})
|
|
|
return
|
|
|
}
|