|
@@ -1,13 +1,13 @@
|
|
|
package server
|
|
|
|
|
|
import (
|
|
|
+ "bytes"
|
|
|
"cmp"
|
|
|
"context"
|
|
|
"encoding/json"
|
|
|
"errors"
|
|
|
"fmt"
|
|
|
"io"
|
|
|
- "io/fs"
|
|
|
"log/slog"
|
|
|
"net"
|
|
|
"net/http"
|
|
@@ -67,163 +67,140 @@ func modelOptions(model *Model, requestOpts map[string]interface{}) (api.Options
|
|
|
return opts, nil
|
|
|
}
|
|
|
|
|
|
-func isSupportedImageType(image []byte) bool {
|
|
|
- contentType := http.DetectContentType(image)
|
|
|
- allowedTypes := []string{"image/jpeg", "image/jpg", "image/png"}
|
|
|
- return slices.Contains(allowedTypes, contentType)
|
|
|
-}
|
|
|
-
|
|
|
-func (s *Server) GenerateHandler(c *gin.Context) {
|
|
|
- checkpointStart := time.Now()
|
|
|
- var req api.GenerateRequest
|
|
|
- err := c.ShouldBindJSON(&req)
|
|
|
-
|
|
|
- switch {
|
|
|
- case errors.Is(err, io.EOF):
|
|
|
- c.AbortWithStatusJSON(http.StatusBadRequest, gin.H{"error": "missing request body"})
|
|
|
- return
|
|
|
- case err != nil:
|
|
|
- c.AbortWithStatusJSON(http.StatusBadRequest, gin.H{"error": err.Error()})
|
|
|
- return
|
|
|
+func (s *Server) scheduleRunner(ctx context.Context, name string, caps []Capability, requestOpts map[string]any, keepAlive *api.Duration) (*runnerRef, error) {
|
|
|
+ if name == "" {
|
|
|
+ return nil, errors.New("model is required")
|
|
|
}
|
|
|
|
|
|
- // validate the request
|
|
|
- switch {
|
|
|
- case req.Model == "":
|
|
|
- c.AbortWithStatusJSON(http.StatusBadRequest, gin.H{"error": "model is required"})
|
|
|
- return
|
|
|
- case len(req.Format) > 0 && req.Format != "json":
|
|
|
- c.AbortWithStatusJSON(http.StatusBadRequest, gin.H{"error": "format must be json"})
|
|
|
- return
|
|
|
- case req.Raw && (req.Template != "" || req.System != "" || len(req.Context) > 0):
|
|
|
- c.AbortWithStatusJSON(http.StatusBadRequest, gin.H{"error": "raw mode does not support template, system, or context"})
|
|
|
- return
|
|
|
- }
|
|
|
-
|
|
|
- for _, img := range req.Images {
|
|
|
- if !isSupportedImageType(img) {
|
|
|
- c.AbortWithStatusJSON(http.StatusBadRequest, gin.H{"error": "unsupported image format"})
|
|
|
- return
|
|
|
- }
|
|
|
- }
|
|
|
-
|
|
|
- model, err := GetModel(req.Model)
|
|
|
+ model, err := GetModel(name)
|
|
|
if err != nil {
|
|
|
- var pErr *fs.PathError
|
|
|
- if errors.As(err, &pErr) {
|
|
|
- c.JSON(http.StatusNotFound, gin.H{"error": fmt.Sprintf("model '%s' not found, try pulling it first", req.Model)})
|
|
|
- return
|
|
|
- }
|
|
|
- c.JSON(http.StatusInternalServerError, gin.H{"error": err.Error()})
|
|
|
- return
|
|
|
+ return nil, err
|
|
|
}
|
|
|
|
|
|
- if !model.Has(CapabilityCompletion) {
|
|
|
- c.JSON(http.StatusBadRequest, gin.H{"error": fmt.Sprintf("%s does not support generate", req.Model)})
|
|
|
- return
|
|
|
+ if err := model.CheckCapabilities(caps...); err != nil {
|
|
|
+ return nil, fmt.Errorf("%s %w", name, err)
|
|
|
}
|
|
|
|
|
|
- opts, err := modelOptions(model, req.Options)
|
|
|
+ opts, err := modelOptions(model, requestOpts)
|
|
|
if err != nil {
|
|
|
- c.JSON(http.StatusInternalServerError, gin.H{"error": err.Error()})
|
|
|
- return
|
|
|
+ return nil, err
|
|
|
}
|
|
|
|
|
|
- rCh, eCh := s.sched.GetRunner(c.Request.Context(), model, opts, req.KeepAlive)
|
|
|
+ runnerCh, errCh := s.sched.GetRunner(ctx, model, opts, keepAlive)
|
|
|
var runner *runnerRef
|
|
|
select {
|
|
|
- case runner = <-rCh:
|
|
|
- case err = <-eCh:
|
|
|
- handleErrorResponse(c, err)
|
|
|
- return
|
|
|
+ case runner = <-runnerCh:
|
|
|
+ case err = <-errCh:
|
|
|
+ return nil, err
|
|
|
}
|
|
|
|
|
|
- // an empty request loads the model
|
|
|
- // note: for a short while template was used in lieu
|
|
|
- // of `raw` mode so we need to check for it too
|
|
|
- if req.Prompt == "" && req.Template == "" && req.System == "" {
|
|
|
- c.JSON(http.StatusOK, api.GenerateResponse{
|
|
|
- CreatedAt: time.Now().UTC(),
|
|
|
- Model: req.Model,
|
|
|
- Done: true,
|
|
|
- DoneReason: "load",
|
|
|
- })
|
|
|
+ return runner, nil
|
|
|
+}
|
|
|
+
|
|
|
+func (s *Server) GenerateHandler(c *gin.Context) {
|
|
|
+ var req api.GenerateRequest
|
|
|
+ if err := c.ShouldBindJSON(&req); errors.Is(err, io.EOF) {
|
|
|
+ c.AbortWithStatusJSON(http.StatusBadRequest, gin.H{"error": "missing request body"})
|
|
|
+ return
|
|
|
+ } else if err != nil {
|
|
|
+ c.AbortWithStatusJSON(http.StatusBadRequest, gin.H{"error": err.Error()})
|
|
|
return
|
|
|
}
|
|
|
|
|
|
- tmpl, err := template.Parse(req.Template)
|
|
|
- if err != nil {
|
|
|
- c.JSON(http.StatusInternalServerError, gin.H{"error": err.Error()})
|
|
|
+ if req.Format != "" && req.Format != "json" {
|
|
|
+ c.AbortWithStatusJSON(http.StatusBadRequest, gin.H{"error": "format must be empty or \"json\""})
|
|
|
+ return
|
|
|
+ } else if req.Raw && (req.Template != "" || req.System != "" || len(req.Context) > 0) {
|
|
|
+ c.AbortWithStatusJSON(http.StatusBadRequest, gin.H{"error": "raw mode does not support template, system, or context"})
|
|
|
return
|
|
|
}
|
|
|
|
|
|
- checkpointLoaded := time.Now()
|
|
|
+ caps := []Capability{CapabilityCompletion}
|
|
|
+ r, err := s.scheduleRunner(c.Request.Context(), req.Model, caps, req.Options, req.KeepAlive)
|
|
|
+ if errors.Is(err, errCapabilityCompletion) {
|
|
|
+ c.JSON(http.StatusBadRequest, gin.H{"error": fmt.Sprintf("%q does not support generate", req.Model)})
|
|
|
+ return
|
|
|
+ } else if err != nil {
|
|
|
+ handleScheduleError(c, err)
|
|
|
+ return
|
|
|
+ }
|
|
|
|
|
|
- var prompt string
|
|
|
- switch {
|
|
|
- case req.Raw:
|
|
|
- prompt = req.Prompt
|
|
|
- case req.Prompt != "":
|
|
|
- if req.Template == "" {
|
|
|
- tmpl = model.Template
|
|
|
- }
|
|
|
+ images := make([]llm.ImageData, len(req.Images))
|
|
|
+ for i := range req.Images {
|
|
|
+ images[i] = llm.ImageData{ID: i, Data: req.Images[i]}
|
|
|
+ }
|
|
|
|
|
|
- if req.System == "" {
|
|
|
- req.System = model.System
|
|
|
+ prompt := req.Prompt
|
|
|
+ if !req.Raw {
|
|
|
+ var msgs []api.Message
|
|
|
+ if req.System != "" {
|
|
|
+ msgs = append(msgs, api.Message{Role: "system", Content: req.System})
|
|
|
+ } else if r.model.System != "" {
|
|
|
+ msgs = append(msgs, api.Message{Role: "system", Content: r.model.System})
|
|
|
}
|
|
|
|
|
|
- slog.Debug("generate handler", "prompt", req.Prompt)
|
|
|
- slog.Debug("generate handler", "template", req.Template)
|
|
|
- slog.Debug("generate handler", "system", req.System)
|
|
|
+ if req.Prompt != "" {
|
|
|
+ for _, i := range images {
|
|
|
+ msgs = append(msgs, api.Message{Role: "user", Content: fmt.Sprintf("[img-%d]", i.ID)})
|
|
|
+ }
|
|
|
|
|
|
- var sb strings.Builder
|
|
|
- for i := range req.Images {
|
|
|
- fmt.Fprintf(&sb, "[img-%d] ", i)
|
|
|
+ msgs = append(msgs, api.Message{Role: "user", Content: req.Prompt})
|
|
|
}
|
|
|
|
|
|
- sb.WriteString(req.Prompt)
|
|
|
-
|
|
|
- p, err := Prompt(tmpl, req.System, sb.String(), "", true)
|
|
|
- if err != nil {
|
|
|
- c.JSON(http.StatusInternalServerError, gin.H{"error": err.Error()})
|
|
|
+ if len(msgs) == 0 {
|
|
|
+ c.JSON(http.StatusOK, api.GenerateResponse{
|
|
|
+ Model: req.Model,
|
|
|
+ CreatedAt: time.Now().UTC(),
|
|
|
+ Done: true,
|
|
|
+ DoneReason: "load",
|
|
|
+ })
|
|
|
return
|
|
|
}
|
|
|
|
|
|
- sb.Reset()
|
|
|
+ tmpl := r.model.Template
|
|
|
+ if req.Template != "" {
|
|
|
+ tmpl, err = template.Parse(req.Template)
|
|
|
+ if err != nil {
|
|
|
+ c.JSON(http.StatusInternalServerError, gin.H{"error": err.Error()})
|
|
|
+ return
|
|
|
+ }
|
|
|
+ }
|
|
|
+
|
|
|
+ var b bytes.Buffer
|
|
|
if req.Context != nil {
|
|
|
- prev, err := runner.llama.Detokenize(c.Request.Context(), req.Context)
|
|
|
+ s, err := r.llama.Detokenize(c.Request.Context(), req.Context)
|
|
|
if err != nil {
|
|
|
c.JSON(http.StatusInternalServerError, gin.H{"error": err.Error()})
|
|
|
return
|
|
|
}
|
|
|
|
|
|
- sb.WriteString(prev)
|
|
|
+ b.WriteString(s)
|
|
|
}
|
|
|
|
|
|
- sb.WriteString(p)
|
|
|
+ if err := tmpl.Execute(&b, template.Values{Messages: msgs}); err != nil {
|
|
|
+ c.JSON(http.StatusInternalServerError, gin.H{"error": err.Error()})
|
|
|
+ return
|
|
|
+ }
|
|
|
|
|
|
- prompt = sb.String()
|
|
|
+ prompt = b.String()
|
|
|
}
|
|
|
|
|
|
- slog.Debug("generate handler", "prompt", prompt)
|
|
|
+ slog.Debug("generate request", "prompt", prompt, "images", images)
|
|
|
|
|
|
ch := make(chan any)
|
|
|
- var generated strings.Builder
|
|
|
go func() {
|
|
|
defer close(ch)
|
|
|
-
|
|
|
- fn := func(r llm.CompletionResponse) {
|
|
|
- // Build up the full response
|
|
|
- if _, err := generated.WriteString(r.Content); err != nil {
|
|
|
- ch <- gin.H{"error": err.Error()}
|
|
|
- return
|
|
|
- }
|
|
|
-
|
|
|
- resp := api.GenerateResponse{
|
|
|
+ if err := r.llama.Completion(c.Request.Context(), llm.CompletionRequest{
|
|
|
+ Prompt: prompt,
|
|
|
+ Images: images,
|
|
|
+ Format: req.Format,
|
|
|
+ Options: *r.Options,
|
|
|
+ }, func(r llm.CompletionResponse) {
|
|
|
+ ch <- api.GenerateResponse{
|
|
|
Model: req.Model,
|
|
|
CreatedAt: time.Now().UTC(),
|
|
|
- Done: r.Done,
|
|
|
Response: r.Content,
|
|
|
+ Done: r.Done,
|
|
|
DoneReason: r.DoneReason,
|
|
|
Metrics: api.Metrics{
|
|
|
PromptEvalCount: r.PromptEvalCount,
|
|
@@ -232,77 +209,35 @@ func (s *Server) GenerateHandler(c *gin.Context) {
|
|
|
EvalDuration: r.EvalDuration,
|
|
|
},
|
|
|
}
|
|
|
-
|
|
|
- if r.Done {
|
|
|
- resp.TotalDuration = time.Since(checkpointStart)
|
|
|
- resp.LoadDuration = checkpointLoaded.Sub(checkpointStart)
|
|
|
-
|
|
|
- if !req.Raw {
|
|
|
- p, err := Prompt(tmpl, req.System, req.Prompt, generated.String(), false)
|
|
|
- if err != nil {
|
|
|
- c.JSON(http.StatusInternalServerError, gin.H{"error": err.Error()})
|
|
|
- return
|
|
|
- }
|
|
|
-
|
|
|
- // TODO (jmorganca): encode() should not strip special tokens
|
|
|
- tokens, err := runner.llama.Tokenize(c.Request.Context(), p)
|
|
|
- if err != nil {
|
|
|
- ch <- gin.H{"error": err.Error()}
|
|
|
- return
|
|
|
- }
|
|
|
-
|
|
|
- resp.Context = append(req.Context, tokens...)
|
|
|
- }
|
|
|
- }
|
|
|
-
|
|
|
- ch <- resp
|
|
|
- }
|
|
|
-
|
|
|
- var images []llm.ImageData
|
|
|
- for i := range req.Images {
|
|
|
- images = append(images, llm.ImageData{
|
|
|
- ID: i,
|
|
|
- Data: req.Images[i],
|
|
|
- })
|
|
|
- }
|
|
|
-
|
|
|
- // Start prediction
|
|
|
- req := llm.CompletionRequest{
|
|
|
- Prompt: prompt,
|
|
|
- Format: req.Format,
|
|
|
- Images: images,
|
|
|
- Options: opts,
|
|
|
- }
|
|
|
- if err := runner.llama.Completion(c.Request.Context(), req, fn); err != nil {
|
|
|
+ }); err != nil {
|
|
|
ch <- gin.H{"error": err.Error()}
|
|
|
}
|
|
|
}()
|
|
|
|
|
|
if req.Stream != nil && !*req.Stream {
|
|
|
- // Accumulate responses into the final response
|
|
|
- var final api.GenerateResponse
|
|
|
+ var r api.GenerateResponse
|
|
|
var sb strings.Builder
|
|
|
- for resp := range ch {
|
|
|
- switch r := resp.(type) {
|
|
|
+ for rr := range ch {
|
|
|
+ switch t := rr.(type) {
|
|
|
case api.GenerateResponse:
|
|
|
- sb.WriteString(r.Response)
|
|
|
- final = r
|
|
|
+ sb.WriteString(t.Response)
|
|
|
+ r = t
|
|
|
case gin.H:
|
|
|
- if errorMsg, ok := r["error"].(string); ok {
|
|
|
- c.JSON(http.StatusInternalServerError, gin.H{"error": errorMsg})
|
|
|
- return
|
|
|
- } else {
|
|
|
- c.JSON(http.StatusInternalServerError, gin.H{"error": "unexpected error format in response"})
|
|
|
- return
|
|
|
+ msg, ok := t["error"].(string)
|
|
|
+ if !ok {
|
|
|
+ msg = "unexpected error format in response"
|
|
|
}
|
|
|
+
|
|
|
+ c.JSON(http.StatusInternalServerError, gin.H{"error": msg})
|
|
|
+ return
|
|
|
default:
|
|
|
- c.JSON(http.StatusInternalServerError, gin.H{"error": "unexpected error"})
|
|
|
+ c.JSON(http.StatusInternalServerError, gin.H{"error": "unexpected response"})
|
|
|
return
|
|
|
}
|
|
|
}
|
|
|
|
|
|
- final.Response = sb.String()
|
|
|
- c.JSON(http.StatusOK, final)
|
|
|
+ r.Response = sb.String()
|
|
|
+ c.JSON(http.StatusOK, r)
|
|
|
return
|
|
|
}
|
|
|
|
|
@@ -311,44 +246,17 @@ func (s *Server) GenerateHandler(c *gin.Context) {
|
|
|
|
|
|
func (s *Server) EmbeddingsHandler(c *gin.Context) {
|
|
|
var req api.EmbeddingRequest
|
|
|
- err := c.ShouldBindJSON(&req)
|
|
|
- switch {
|
|
|
- case errors.Is(err, io.EOF):
|
|
|
+ if err := c.ShouldBindJSON(&req); errors.Is(err, io.EOF) {
|
|
|
c.AbortWithStatusJSON(http.StatusBadRequest, gin.H{"error": "missing request body"})
|
|
|
return
|
|
|
- case err != nil:
|
|
|
+ } else if err != nil {
|
|
|
c.AbortWithStatusJSON(http.StatusBadRequest, gin.H{"error": err.Error()})
|
|
|
return
|
|
|
}
|
|
|
|
|
|
- if req.Model == "" {
|
|
|
- c.AbortWithStatusJSON(http.StatusBadRequest, gin.H{"error": "model is required"})
|
|
|
- return
|
|
|
- }
|
|
|
-
|
|
|
- model, err := GetModel(req.Model)
|
|
|
+ r, err := s.scheduleRunner(c.Request.Context(), req.Model, []Capability{}, req.Options, req.KeepAlive)
|
|
|
if err != nil {
|
|
|
- var pErr *fs.PathError
|
|
|
- if errors.As(err, &pErr) {
|
|
|
- c.JSON(http.StatusNotFound, gin.H{"error": fmt.Sprintf("model '%s' not found, try pulling it first", req.Model)})
|
|
|
- return
|
|
|
- }
|
|
|
- c.JSON(http.StatusInternalServerError, gin.H{"error": err.Error()})
|
|
|
- return
|
|
|
- }
|
|
|
-
|
|
|
- opts, err := modelOptions(model, req.Options)
|
|
|
- if err != nil {
|
|
|
- c.JSON(http.StatusInternalServerError, gin.H{"error": err.Error()})
|
|
|
- return
|
|
|
- }
|
|
|
-
|
|
|
- rCh, eCh := s.sched.GetRunner(c.Request.Context(), model, opts, req.KeepAlive)
|
|
|
- var runner *runnerRef
|
|
|
- select {
|
|
|
- case runner = <-rCh:
|
|
|
- case err = <-eCh:
|
|
|
- handleErrorResponse(c, err)
|
|
|
+ handleScheduleError(c, err)
|
|
|
return
|
|
|
}
|
|
|
|
|
@@ -358,17 +266,14 @@ func (s *Server) EmbeddingsHandler(c *gin.Context) {
|
|
|
return
|
|
|
}
|
|
|
|
|
|
- embedding, err := runner.llama.Embedding(c.Request.Context(), req.Prompt)
|
|
|
+ embedding, err := r.llama.Embedding(c.Request.Context(), req.Prompt)
|
|
|
if err != nil {
|
|
|
slog.Info(fmt.Sprintf("embedding generation failed: %v", err))
|
|
|
c.JSON(http.StatusInternalServerError, gin.H{"error": "failed to generate embedding"})
|
|
|
return
|
|
|
}
|
|
|
|
|
|
- resp := api.EmbeddingResponse{
|
|
|
- Embedding: embedding,
|
|
|
- }
|
|
|
- c.JSON(http.StatusOK, resp)
|
|
|
+ c.JSON(http.StatusOK, api.EmbeddingResponse{Embedding: embedding})
|
|
|
}
|
|
|
|
|
|
func (s *Server) PullModelHandler(c *gin.Context) {
|
|
@@ -649,9 +554,9 @@ func GetModelInfo(req api.ShowRequest) (*api.ShowResponse, error) {
|
|
|
}
|
|
|
}
|
|
|
|
|
|
- msgs := make([]api.Message, 0)
|
|
|
- for _, msg := range m.Messages {
|
|
|
- msgs = append(msgs, api.Message{Role: msg.Role, Content: msg.Content})
|
|
|
+ msgs := make([]api.Message, len(m.Messages))
|
|
|
+ for i, msg := range m.Messages {
|
|
|
+ msgs[i] = api.Message{Role: msg.Role, Content: msg.Content}
|
|
|
}
|
|
|
|
|
|
n := model.ParseName(req.Model)
|
|
@@ -1214,132 +1119,55 @@ func (s *Server) ProcessHandler(c *gin.Context) {
|
|
|
c.JSON(http.StatusOK, api.ProcessResponse{Models: models})
|
|
|
}
|
|
|
|
|
|
-// ChatPrompt builds up a prompt from a series of messages for the currently `loaded` model
|
|
|
-func chatPrompt(ctx context.Context, runner *runnerRef, template *template.Template, messages []api.Message, numCtx int) (string, error) {
|
|
|
- encode := func(s string) ([]int, error) {
|
|
|
- return runner.llama.Tokenize(ctx, s)
|
|
|
- }
|
|
|
-
|
|
|
- prompt, err := ChatPrompt(template, messages, numCtx, encode)
|
|
|
- if err != nil {
|
|
|
- return "", err
|
|
|
- }
|
|
|
-
|
|
|
- return prompt, nil
|
|
|
-}
|
|
|
-
|
|
|
func (s *Server) ChatHandler(c *gin.Context) {
|
|
|
- checkpointStart := time.Now()
|
|
|
-
|
|
|
var req api.ChatRequest
|
|
|
- err := c.ShouldBindJSON(&req)
|
|
|
- switch {
|
|
|
- case errors.Is(err, io.EOF):
|
|
|
+ if err := c.ShouldBindJSON(&req); errors.Is(err, io.EOF) {
|
|
|
c.AbortWithStatusJSON(http.StatusBadRequest, gin.H{"error": "missing request body"})
|
|
|
return
|
|
|
- case err != nil:
|
|
|
+ } else if err != nil {
|
|
|
c.AbortWithStatusJSON(http.StatusBadRequest, gin.H{"error": err.Error()})
|
|
|
return
|
|
|
}
|
|
|
|
|
|
- // validate the request
|
|
|
- switch {
|
|
|
- case req.Model == "":
|
|
|
- c.AbortWithStatusJSON(http.StatusBadRequest, gin.H{"error": "model is required"})
|
|
|
- return
|
|
|
- case len(req.Format) > 0 && req.Format != "json":
|
|
|
- c.AbortWithStatusJSON(http.StatusBadRequest, gin.H{"error": "format must be json"})
|
|
|
- return
|
|
|
- }
|
|
|
-
|
|
|
- model, err := GetModel(req.Model)
|
|
|
- if err != nil {
|
|
|
- var pErr *fs.PathError
|
|
|
- if errors.As(err, &pErr) {
|
|
|
- c.JSON(http.StatusNotFound, gin.H{"error": fmt.Sprintf("model '%s' not found, try pulling it first", req.Model)})
|
|
|
- return
|
|
|
- }
|
|
|
- c.JSON(http.StatusInternalServerError, gin.H{"error": err.Error()})
|
|
|
- return
|
|
|
- }
|
|
|
-
|
|
|
- if !model.Has(CapabilityCompletion) {
|
|
|
- c.JSON(http.StatusBadRequest, gin.H{"error": fmt.Sprintf("%s does not support chat", req.Model)})
|
|
|
- return
|
|
|
- }
|
|
|
-
|
|
|
- opts, err := modelOptions(model, req.Options)
|
|
|
- if err != nil {
|
|
|
- c.JSON(http.StatusInternalServerError, gin.H{"error": err.Error()})
|
|
|
- return
|
|
|
- }
|
|
|
-
|
|
|
- rCh, eCh := s.sched.GetRunner(c.Request.Context(), model, opts, req.KeepAlive)
|
|
|
- var runner *runnerRef
|
|
|
- select {
|
|
|
- case runner = <-rCh:
|
|
|
- case err = <-eCh:
|
|
|
- handleErrorResponse(c, err)
|
|
|
+ caps := []Capability{CapabilityCompletion}
|
|
|
+ r, err := s.scheduleRunner(c.Request.Context(), req.Model, caps, req.Options, req.KeepAlive)
|
|
|
+ if errors.Is(err, errCapabilityCompletion) {
|
|
|
+ c.JSON(http.StatusBadRequest, gin.H{"error": fmt.Sprintf("%q does not support chat", req.Model)})
|
|
|
return
|
|
|
- }
|
|
|
-
|
|
|
- checkpointLoaded := time.Now()
|
|
|
-
|
|
|
- // if the first message is not a system message, then add the model's default system message
|
|
|
- if len(req.Messages) > 0 && req.Messages[0].Role != "system" {
|
|
|
- req.Messages = append([]api.Message{
|
|
|
- {
|
|
|
- Role: "system",
|
|
|
- Content: model.System,
|
|
|
- },
|
|
|
- }, req.Messages...)
|
|
|
- }
|
|
|
-
|
|
|
- prompt, err := chatPrompt(c.Request.Context(), runner, model.Template, req.Messages, opts.NumCtx)
|
|
|
- if err != nil {
|
|
|
- c.JSON(http.StatusBadRequest, gin.H{"error": err.Error()})
|
|
|
+ } else if err != nil {
|
|
|
+ handleScheduleError(c, err)
|
|
|
return
|
|
|
}
|
|
|
|
|
|
- // an empty request loads the model
|
|
|
- if len(req.Messages) == 0 || prompt == "" {
|
|
|
- resp := api.ChatResponse{
|
|
|
- CreatedAt: time.Now().UTC(),
|
|
|
+ if len(req.Messages) == 0 {
|
|
|
+ c.JSON(http.StatusOK, api.ChatResponse{
|
|
|
Model: req.Model,
|
|
|
+ CreatedAt: time.Now().UTC(),
|
|
|
+ Message: api.Message{Role: "assistant"},
|
|
|
Done: true,
|
|
|
DoneReason: "load",
|
|
|
- Message: api.Message{Role: "assistant"},
|
|
|
- }
|
|
|
- c.JSON(http.StatusOK, resp)
|
|
|
+ })
|
|
|
return
|
|
|
}
|
|
|
|
|
|
- // only send images that are in the prompt
|
|
|
- var i int
|
|
|
- var images []llm.ImageData
|
|
|
- for _, m := range req.Messages {
|
|
|
- for _, img := range m.Images {
|
|
|
- if !isSupportedImageType(img) {
|
|
|
- c.AbortWithStatusJSON(http.StatusBadRequest, gin.H{"error": "unsupported image format"})
|
|
|
- return
|
|
|
- }
|
|
|
-
|
|
|
- if strings.Contains(prompt, fmt.Sprintf("[img-%d]", i)) {
|
|
|
- images = append(images, llm.ImageData{Data: img, ID: i})
|
|
|
- }
|
|
|
- i += 1
|
|
|
- }
|
|
|
+ prompt, images, err := chatPrompt(c.Request.Context(), r, req.Messages)
|
|
|
+ if err != nil {
|
|
|
+ c.JSON(http.StatusInternalServerError, gin.H{"error": err.Error()})
|
|
|
+ return
|
|
|
}
|
|
|
|
|
|
- slog.Debug("chat handler", "prompt", prompt, "images", len(images))
|
|
|
+ slog.Debug("chat request", "images", len(images), "prompt", prompt)
|
|
|
|
|
|
ch := make(chan any)
|
|
|
-
|
|
|
go func() {
|
|
|
defer close(ch)
|
|
|
-
|
|
|
- fn := func(r llm.CompletionResponse) {
|
|
|
- resp := api.ChatResponse{
|
|
|
+ if err := r.llama.Completion(c.Request.Context(), llm.CompletionRequest{
|
|
|
+ Prompt: prompt,
|
|
|
+ Images: images,
|
|
|
+ Format: req.Format,
|
|
|
+ Options: *r.Options,
|
|
|
+ }, func(r llm.CompletionResponse) {
|
|
|
+ ch <- api.ChatResponse{
|
|
|
Model: req.Model,
|
|
|
CreatedAt: time.Now().UTC(),
|
|
|
Message: api.Message{Role: "assistant", Content: r.Content},
|
|
@@ -1352,64 +1180,48 @@ func (s *Server) ChatHandler(c *gin.Context) {
|
|
|
EvalDuration: r.EvalDuration,
|
|
|
},
|
|
|
}
|
|
|
-
|
|
|
- if r.Done {
|
|
|
- resp.TotalDuration = time.Since(checkpointStart)
|
|
|
- resp.LoadDuration = checkpointLoaded.Sub(checkpointStart)
|
|
|
- }
|
|
|
-
|
|
|
- ch <- resp
|
|
|
- }
|
|
|
-
|
|
|
- if err := runner.llama.Completion(c.Request.Context(), llm.CompletionRequest{
|
|
|
- Prompt: prompt,
|
|
|
- Format: req.Format,
|
|
|
- Images: images,
|
|
|
- Options: opts,
|
|
|
- }, fn); err != nil {
|
|
|
+ }); err != nil {
|
|
|
ch <- gin.H{"error": err.Error()}
|
|
|
}
|
|
|
}()
|
|
|
|
|
|
if req.Stream != nil && !*req.Stream {
|
|
|
- // Accumulate responses into the final response
|
|
|
- var final api.ChatResponse
|
|
|
+ var r api.ChatResponse
|
|
|
var sb strings.Builder
|
|
|
- for resp := range ch {
|
|
|
- switch r := resp.(type) {
|
|
|
+ for rr := range ch {
|
|
|
+ switch t := rr.(type) {
|
|
|
case api.ChatResponse:
|
|
|
- sb.WriteString(r.Message.Content)
|
|
|
- final = r
|
|
|
+ sb.WriteString(t.Message.Content)
|
|
|
+ r = t
|
|
|
case gin.H:
|
|
|
- if errorMsg, ok := r["error"].(string); ok {
|
|
|
- c.JSON(http.StatusInternalServerError, gin.H{"error": errorMsg})
|
|
|
- return
|
|
|
- } else {
|
|
|
- c.JSON(http.StatusInternalServerError, gin.H{"error": "unexpected error format in response"})
|
|
|
- return
|
|
|
+ msg, ok := t["error"].(string)
|
|
|
+ if !ok {
|
|
|
+ msg = "unexpected error format in response"
|
|
|
}
|
|
|
+
|
|
|
+ c.JSON(http.StatusInternalServerError, gin.H{"error": msg})
|
|
|
+ return
|
|
|
default:
|
|
|
- c.JSON(http.StatusInternalServerError, gin.H{"error": "unexpected error"})
|
|
|
+ c.JSON(http.StatusInternalServerError, gin.H{"error": "unexpected response"})
|
|
|
return
|
|
|
}
|
|
|
}
|
|
|
|
|
|
- final.Message = api.Message{Role: "assistant", Content: sb.String()}
|
|
|
- c.JSON(http.StatusOK, final)
|
|
|
+ r.Message.Content = sb.String()
|
|
|
+ c.JSON(http.StatusOK, r)
|
|
|
return
|
|
|
}
|
|
|
|
|
|
streamResponse(c, ch)
|
|
|
}
|
|
|
|
|
|
-func handleErrorResponse(c *gin.Context, err error) {
|
|
|
- if errors.Is(err, context.Canceled) {
|
|
|
+func handleScheduleError(c *gin.Context, err error) {
|
|
|
+ switch {
|
|
|
+ case errors.Is(err, context.Canceled):
|
|
|
c.JSON(499, gin.H{"error": "request canceled"})
|
|
|
- return
|
|
|
- }
|
|
|
- if errors.Is(err, ErrMaxQueue) {
|
|
|
+ case errors.Is(err, ErrMaxQueue):
|
|
|
c.JSON(http.StatusServiceUnavailable, gin.H{"error": err.Error()})
|
|
|
- return
|
|
|
+ default:
|
|
|
+ c.JSON(http.StatusInternalServerError, gin.H{"error": err.Error()})
|
|
|
}
|
|
|
- c.JSON(http.StatusInternalServerError, gin.H{"error": err.Error()})
|
|
|
}
|