|
@@ -1,15 +1,14 @@
|
|
|
package server
|
|
|
|
|
|
import (
|
|
|
+ "bytes"
|
|
|
"cmp"
|
|
|
"context"
|
|
|
"encoding/json"
|
|
|
"errors"
|
|
|
"fmt"
|
|
|
"io"
|
|
|
- "io/fs"
|
|
|
"log/slog"
|
|
|
- "math"
|
|
|
"net"
|
|
|
"net/http"
|
|
|
"net/netip"
|
|
@@ -17,7 +16,6 @@ import (
|
|
|
"os/signal"
|
|
|
"path/filepath"
|
|
|
"slices"
|
|
|
- "strconv"
|
|
|
"strings"
|
|
|
"syscall"
|
|
|
"time"
|
|
@@ -31,6 +29,7 @@ import (
|
|
|
"github.com/ollama/ollama/llm"
|
|
|
"github.com/ollama/ollama/openai"
|
|
|
"github.com/ollama/ollama/parser"
|
|
|
+ "github.com/ollama/ollama/template"
|
|
|
"github.com/ollama/ollama/types/errtypes"
|
|
|
"github.com/ollama/ollama/types/model"
|
|
|
"github.com/ollama/ollama/version"
|
|
@@ -55,7 +54,7 @@ func init() {
|
|
|
gin.SetMode(mode)
|
|
|
}
|
|
|
|
|
|
-var defaultSessionDuration = 5 * time.Minute
|
|
|
+var errRequired = errors.New("is required")
|
|
|
|
|
|
func modelOptions(model *Model, requestOpts map[string]interface{}) (api.Options, error) {
|
|
|
opts := api.DefaultOptions()
|
|
@@ -70,164 +69,140 @@ func modelOptions(model *Model, requestOpts map[string]interface{}) (api.Options
|
|
|
return opts, nil
|
|
|
}
|
|
|
|
|
|
-func isSupportedImageType(image []byte) bool {
|
|
|
- contentType := http.DetectContentType(image)
|
|
|
- allowedTypes := []string{"image/jpeg", "image/jpg", "image/png"}
|
|
|
- return slices.Contains(allowedTypes, contentType)
|
|
|
-}
|
|
|
-
|
|
|
-func (s *Server) GenerateHandler(c *gin.Context) {
|
|
|
- checkpointStart := time.Now()
|
|
|
- var req api.GenerateRequest
|
|
|
- err := c.ShouldBindJSON(&req)
|
|
|
-
|
|
|
- switch {
|
|
|
- case errors.Is(err, io.EOF):
|
|
|
- c.AbortWithStatusJSON(http.StatusBadRequest, gin.H{"error": "missing request body"})
|
|
|
- return
|
|
|
- case err != nil:
|
|
|
- c.AbortWithStatusJSON(http.StatusBadRequest, gin.H{"error": err.Error()})
|
|
|
- return
|
|
|
+// scheduleRunner schedules a runner after validating inputs such as capabilities and model options.
|
|
|
+// It returns the allocated runner, model instance, and consolidated options if successful and error otherwise.
|
|
|
+func (s *Server) scheduleRunner(ctx context.Context, name string, caps []Capability, requestOpts map[string]any, keepAlive *api.Duration) (llm.LlamaServer, *Model, *api.Options, error) {
|
|
|
+ if name == "" {
|
|
|
+ return nil, nil, nil, fmt.Errorf("model %w", errRequired)
|
|
|
}
|
|
|
|
|
|
- // validate the request
|
|
|
- switch {
|
|
|
- case req.Model == "":
|
|
|
- c.AbortWithStatusJSON(http.StatusBadRequest, gin.H{"error": "model is required"})
|
|
|
- return
|
|
|
- case len(req.Format) > 0 && req.Format != "json":
|
|
|
- c.AbortWithStatusJSON(http.StatusBadRequest, gin.H{"error": "format must be json"})
|
|
|
- return
|
|
|
- case req.Raw && (req.Template != "" || req.System != "" || len(req.Context) > 0):
|
|
|
- c.AbortWithStatusJSON(http.StatusBadRequest, gin.H{"error": "raw mode does not support template, system, or context"})
|
|
|
- return
|
|
|
+ model, err := GetModel(name)
|
|
|
+ if err != nil {
|
|
|
+ return nil, nil, nil, err
|
|
|
}
|
|
|
|
|
|
- for _, img := range req.Images {
|
|
|
- if !isSupportedImageType(img) {
|
|
|
- c.AbortWithStatusJSON(http.StatusBadRequest, gin.H{"error": "unsupported image format"})
|
|
|
- return
|
|
|
- }
|
|
|
+ if err := model.CheckCapabilities(caps...); err != nil {
|
|
|
+ return nil, nil, nil, fmt.Errorf("%s %w", name, err)
|
|
|
}
|
|
|
|
|
|
- model, err := GetModel(req.Model)
|
|
|
+ opts, err := modelOptions(model, requestOpts)
|
|
|
if err != nil {
|
|
|
- var pErr *fs.PathError
|
|
|
- if errors.As(err, &pErr) {
|
|
|
- c.JSON(http.StatusNotFound, gin.H{"error": fmt.Sprintf("model '%s' not found, try pulling it first", req.Model)})
|
|
|
- return
|
|
|
- }
|
|
|
- c.JSON(http.StatusInternalServerError, gin.H{"error": err.Error()})
|
|
|
- return
|
|
|
+ return nil, nil, nil, err
|
|
|
}
|
|
|
|
|
|
- if model.IsEmbedding() {
|
|
|
- c.AbortWithStatusJSON(http.StatusBadRequest, gin.H{"error": "embedding models do not support generate"})
|
|
|
- return
|
|
|
+ runnerCh, errCh := s.sched.GetRunner(ctx, model, opts, keepAlive)
|
|
|
+ var runner *runnerRef
|
|
|
+ select {
|
|
|
+ case runner = <-runnerCh:
|
|
|
+ case err = <-errCh:
|
|
|
+ return nil, nil, nil, err
|
|
|
}
|
|
|
|
|
|
- opts, err := modelOptions(model, req.Options)
|
|
|
- if err != nil {
|
|
|
- c.JSON(http.StatusInternalServerError, gin.H{"error": err.Error()})
|
|
|
+ return runner.llama, model, &opts, nil
|
|
|
+}
|
|
|
+
|
|
|
+func (s *Server) GenerateHandler(c *gin.Context) {
|
|
|
+ var req api.GenerateRequest
|
|
|
+ if err := c.ShouldBindJSON(&req); errors.Is(err, io.EOF) {
|
|
|
+ c.AbortWithStatusJSON(http.StatusBadRequest, gin.H{"error": "missing request body"})
|
|
|
+ return
|
|
|
+ } else if err != nil {
|
|
|
+ c.AbortWithStatusJSON(http.StatusBadRequest, gin.H{"error": err.Error()})
|
|
|
return
|
|
|
}
|
|
|
|
|
|
- var sessionDuration time.Duration
|
|
|
- if req.KeepAlive == nil {
|
|
|
- sessionDuration = getDefaultSessionDuration()
|
|
|
- } else {
|
|
|
- sessionDuration = req.KeepAlive.Duration
|
|
|
+ if req.Format != "" && req.Format != "json" {
|
|
|
+ c.AbortWithStatusJSON(http.StatusBadRequest, gin.H{"error": "format must be empty or \"json\""})
|
|
|
+ return
|
|
|
+ } else if req.Raw && (req.Template != "" || req.System != "" || len(req.Context) > 0) {
|
|
|
+ c.AbortWithStatusJSON(http.StatusBadRequest, gin.H{"error": "raw mode does not support template, system, or context"})
|
|
|
+ return
|
|
|
}
|
|
|
|
|
|
- rCh, eCh := s.sched.GetRunner(c.Request.Context(), model, opts, sessionDuration)
|
|
|
- var runner *runnerRef
|
|
|
- select {
|
|
|
- case runner = <-rCh:
|
|
|
- case err = <-eCh:
|
|
|
- handleErrorResponse(c, err)
|
|
|
+ caps := []Capability{CapabilityCompletion}
|
|
|
+ r, m, opts, err := s.scheduleRunner(c.Request.Context(), req.Model, caps, req.Options, req.KeepAlive)
|
|
|
+ if errors.Is(err, errCapabilityCompletion) {
|
|
|
+ c.JSON(http.StatusBadRequest, gin.H{"error": fmt.Sprintf("%q does not support generate", req.Model)})
|
|
|
+ return
|
|
|
+ } else if err != nil {
|
|
|
+ handleScheduleError(c, req.Model, err)
|
|
|
return
|
|
|
}
|
|
|
|
|
|
- // an empty request loads the model
|
|
|
- // note: for a short while template was used in lieu
|
|
|
- // of `raw` mode so we need to check for it too
|
|
|
- if req.Prompt == "" && req.Template == "" && req.System == "" {
|
|
|
+ if req.Prompt == "" {
|
|
|
c.JSON(http.StatusOK, api.GenerateResponse{
|
|
|
- CreatedAt: time.Now().UTC(),
|
|
|
Model: req.Model,
|
|
|
+ CreatedAt: time.Now().UTC(),
|
|
|
Done: true,
|
|
|
DoneReason: "load",
|
|
|
})
|
|
|
return
|
|
|
}
|
|
|
|
|
|
- checkpointLoaded := time.Now()
|
|
|
-
|
|
|
- var prompt string
|
|
|
- switch {
|
|
|
- case req.Raw:
|
|
|
- prompt = req.Prompt
|
|
|
- case req.Prompt != "":
|
|
|
- if req.Template == "" {
|
|
|
- req.Template = model.Template
|
|
|
- }
|
|
|
+ images := make([]llm.ImageData, len(req.Images))
|
|
|
+ for i := range req.Images {
|
|
|
+ images[i] = llm.ImageData{ID: i, Data: req.Images[i]}
|
|
|
+ }
|
|
|
|
|
|
- if req.System == "" {
|
|
|
- req.System = model.System
|
|
|
+ prompt := req.Prompt
|
|
|
+ if !req.Raw {
|
|
|
+ var msgs []api.Message
|
|
|
+ if req.System != "" {
|
|
|
+ msgs = append(msgs, api.Message{Role: "system", Content: req.System})
|
|
|
+ } else if m.System != "" {
|
|
|
+ msgs = append(msgs, api.Message{Role: "system", Content: m.System})
|
|
|
}
|
|
|
|
|
|
- slog.Debug("generate handler", "prompt", req.Prompt)
|
|
|
- slog.Debug("generate handler", "template", req.Template)
|
|
|
- slog.Debug("generate handler", "system", req.System)
|
|
|
-
|
|
|
- var sb strings.Builder
|
|
|
- for i := range req.Images {
|
|
|
- fmt.Fprintf(&sb, "[img-%d] ", i)
|
|
|
+ for _, i := range images {
|
|
|
+ msgs = append(msgs, api.Message{Role: "user", Content: fmt.Sprintf("[img-%d]", i.ID)})
|
|
|
}
|
|
|
|
|
|
- sb.WriteString(req.Prompt)
|
|
|
+ msgs = append(msgs, api.Message{Role: "user", Content: req.Prompt})
|
|
|
|
|
|
- p, err := Prompt(req.Template, req.System, sb.String(), "", true)
|
|
|
- if err != nil {
|
|
|
- c.JSON(http.StatusInternalServerError, gin.H{"error": err.Error()})
|
|
|
- return
|
|
|
+ tmpl := m.Template
|
|
|
+ if req.Template != "" {
|
|
|
+ tmpl, err = template.Parse(req.Template)
|
|
|
+ if err != nil {
|
|
|
+ c.JSON(http.StatusInternalServerError, gin.H{"error": err.Error()})
|
|
|
+ return
|
|
|
+ }
|
|
|
}
|
|
|
|
|
|
- sb.Reset()
|
|
|
+ var b bytes.Buffer
|
|
|
if req.Context != nil {
|
|
|
- prev, err := runner.llama.Detokenize(c.Request.Context(), req.Context)
|
|
|
+ s, err := r.Detokenize(c.Request.Context(), req.Context)
|
|
|
if err != nil {
|
|
|
c.JSON(http.StatusInternalServerError, gin.H{"error": err.Error()})
|
|
|
return
|
|
|
}
|
|
|
|
|
|
- sb.WriteString(prev)
|
|
|
+ b.WriteString(s)
|
|
|
}
|
|
|
|
|
|
- sb.WriteString(p)
|
|
|
+ if err := tmpl.Execute(&b, template.Values{Messages: msgs}); err != nil {
|
|
|
+ c.JSON(http.StatusInternalServerError, gin.H{"error": err.Error()})
|
|
|
+ return
|
|
|
+ }
|
|
|
|
|
|
- prompt = sb.String()
|
|
|
+ prompt = b.String()
|
|
|
}
|
|
|
|
|
|
- slog.Debug("generate handler", "prompt", prompt)
|
|
|
+ slog.Debug("generate request", "prompt", prompt, "images", images)
|
|
|
|
|
|
ch := make(chan any)
|
|
|
- var generated strings.Builder
|
|
|
go func() {
|
|
|
defer close(ch)
|
|
|
-
|
|
|
- fn := func(r llm.CompletionResponse) {
|
|
|
- // Build up the full response
|
|
|
- if _, err := generated.WriteString(r.Content); err != nil {
|
|
|
- ch <- gin.H{"error": err.Error()}
|
|
|
- return
|
|
|
- }
|
|
|
-
|
|
|
- resp := api.GenerateResponse{
|
|
|
+ if err := r.Completion(c.Request.Context(), llm.CompletionRequest{
|
|
|
+ Prompt: prompt,
|
|
|
+ Images: images,
|
|
|
+ Format: req.Format,
|
|
|
+ Options: opts,
|
|
|
+ }, func(r llm.CompletionResponse) {
|
|
|
+ ch <- api.GenerateResponse{
|
|
|
Model: req.Model,
|
|
|
CreatedAt: time.Now().UTC(),
|
|
|
- Done: r.Done,
|
|
|
Response: r.Content,
|
|
|
+ Done: r.Done,
|
|
|
DoneReason: r.DoneReason,
|
|
|
Metrics: api.Metrics{
|
|
|
PromptEvalCount: r.PromptEvalCount,
|
|
@@ -236,156 +211,54 @@ func (s *Server) GenerateHandler(c *gin.Context) {
|
|
|
EvalDuration: r.EvalDuration,
|
|
|
},
|
|
|
}
|
|
|
-
|
|
|
- if r.Done {
|
|
|
- resp.TotalDuration = time.Since(checkpointStart)
|
|
|
- resp.LoadDuration = checkpointLoaded.Sub(checkpointStart)
|
|
|
-
|
|
|
- if !req.Raw {
|
|
|
- p, err := Prompt(req.Template, req.System, req.Prompt, generated.String(), false)
|
|
|
- if err != nil {
|
|
|
- c.JSON(http.StatusInternalServerError, gin.H{"error": err.Error()})
|
|
|
- return
|
|
|
- }
|
|
|
-
|
|
|
- // TODO (jmorganca): encode() should not strip special tokens
|
|
|
- tokens, err := runner.llama.Tokenize(c.Request.Context(), p)
|
|
|
- if err != nil {
|
|
|
- ch <- gin.H{"error": err.Error()}
|
|
|
- return
|
|
|
- }
|
|
|
-
|
|
|
- resp.Context = append(req.Context, tokens...)
|
|
|
- }
|
|
|
- }
|
|
|
-
|
|
|
- ch <- resp
|
|
|
- }
|
|
|
-
|
|
|
- var images []llm.ImageData
|
|
|
- for i := range req.Images {
|
|
|
- images = append(images, llm.ImageData{
|
|
|
- ID: i,
|
|
|
- Data: req.Images[i],
|
|
|
- })
|
|
|
- }
|
|
|
-
|
|
|
- // Start prediction
|
|
|
- req := llm.CompletionRequest{
|
|
|
- Prompt: prompt,
|
|
|
- Format: req.Format,
|
|
|
- Images: images,
|
|
|
- Options: opts,
|
|
|
- }
|
|
|
- if err := runner.llama.Completion(c.Request.Context(), req, fn); err != nil {
|
|
|
+ }); err != nil {
|
|
|
ch <- gin.H{"error": err.Error()}
|
|
|
}
|
|
|
}()
|
|
|
|
|
|
if req.Stream != nil && !*req.Stream {
|
|
|
- // Accumulate responses into the final response
|
|
|
- var final api.GenerateResponse
|
|
|
+ var r api.GenerateResponse
|
|
|
var sb strings.Builder
|
|
|
- for resp := range ch {
|
|
|
- switch r := resp.(type) {
|
|
|
+ for rr := range ch {
|
|
|
+ switch t := rr.(type) {
|
|
|
case api.GenerateResponse:
|
|
|
- sb.WriteString(r.Response)
|
|
|
- final = r
|
|
|
+ sb.WriteString(t.Response)
|
|
|
+ r = t
|
|
|
case gin.H:
|
|
|
- if errorMsg, ok := r["error"].(string); ok {
|
|
|
- c.JSON(http.StatusInternalServerError, gin.H{"error": errorMsg})
|
|
|
- return
|
|
|
- } else {
|
|
|
- c.JSON(http.StatusInternalServerError, gin.H{"error": "unexpected error format in response"})
|
|
|
- return
|
|
|
+ msg, ok := t["error"].(string)
|
|
|
+ if !ok {
|
|
|
+ msg = "unexpected error format in response"
|
|
|
}
|
|
|
+
|
|
|
+ c.JSON(http.StatusInternalServerError, gin.H{"error": msg})
|
|
|
+ return
|
|
|
default:
|
|
|
- c.JSON(http.StatusInternalServerError, gin.H{"error": "unexpected error"})
|
|
|
+ c.JSON(http.StatusInternalServerError, gin.H{"error": "unexpected response"})
|
|
|
return
|
|
|
}
|
|
|
}
|
|
|
|
|
|
- final.Response = sb.String()
|
|
|
- c.JSON(http.StatusOK, final)
|
|
|
+ r.Response = sb.String()
|
|
|
+ c.JSON(http.StatusOK, r)
|
|
|
return
|
|
|
}
|
|
|
|
|
|
streamResponse(c, ch)
|
|
|
}
|
|
|
|
|
|
-func getDefaultSessionDuration() time.Duration {
|
|
|
- if envconfig.KeepAlive != "" {
|
|
|
- v, err := strconv.Atoi(envconfig.KeepAlive)
|
|
|
- if err != nil {
|
|
|
- d, err := time.ParseDuration(envconfig.KeepAlive)
|
|
|
- if err != nil {
|
|
|
- return defaultSessionDuration
|
|
|
- }
|
|
|
-
|
|
|
- if d < 0 {
|
|
|
- return time.Duration(math.MaxInt64)
|
|
|
- }
|
|
|
-
|
|
|
- return d
|
|
|
- }
|
|
|
-
|
|
|
- d := time.Duration(v) * time.Second
|
|
|
- if d < 0 {
|
|
|
- return time.Duration(math.MaxInt64)
|
|
|
- }
|
|
|
- return d
|
|
|
- }
|
|
|
-
|
|
|
- return defaultSessionDuration
|
|
|
-}
|
|
|
-
|
|
|
func (s *Server) EmbeddingsHandler(c *gin.Context) {
|
|
|
var req api.EmbeddingRequest
|
|
|
- err := c.ShouldBindJSON(&req)
|
|
|
- switch {
|
|
|
- case errors.Is(err, io.EOF):
|
|
|
+ if err := c.ShouldBindJSON(&req); errors.Is(err, io.EOF) {
|
|
|
c.AbortWithStatusJSON(http.StatusBadRequest, gin.H{"error": "missing request body"})
|
|
|
return
|
|
|
- case err != nil:
|
|
|
+ } else if err != nil {
|
|
|
c.AbortWithStatusJSON(http.StatusBadRequest, gin.H{"error": err.Error()})
|
|
|
return
|
|
|
}
|
|
|
|
|
|
- if req.Model == "" {
|
|
|
- c.AbortWithStatusJSON(http.StatusBadRequest, gin.H{"error": "model is required"})
|
|
|
- return
|
|
|
- }
|
|
|
-
|
|
|
- model, err := GetModel(req.Model)
|
|
|
+ r, _, _, err := s.scheduleRunner(c.Request.Context(), req.Model, []Capability{}, req.Options, req.KeepAlive)
|
|
|
if err != nil {
|
|
|
- var pErr *fs.PathError
|
|
|
- if errors.As(err, &pErr) {
|
|
|
- c.JSON(http.StatusNotFound, gin.H{"error": fmt.Sprintf("model '%s' not found, try pulling it first", req.Model)})
|
|
|
- return
|
|
|
- }
|
|
|
- c.JSON(http.StatusInternalServerError, gin.H{"error": err.Error()})
|
|
|
- return
|
|
|
- }
|
|
|
-
|
|
|
- opts, err := modelOptions(model, req.Options)
|
|
|
- if err != nil {
|
|
|
- c.JSON(http.StatusInternalServerError, gin.H{"error": err.Error()})
|
|
|
- return
|
|
|
- }
|
|
|
-
|
|
|
- var sessionDuration time.Duration
|
|
|
- if req.KeepAlive == nil {
|
|
|
- sessionDuration = getDefaultSessionDuration()
|
|
|
- } else {
|
|
|
- sessionDuration = req.KeepAlive.Duration
|
|
|
- }
|
|
|
-
|
|
|
- rCh, eCh := s.sched.GetRunner(c.Request.Context(), model, opts, sessionDuration)
|
|
|
- var runner *runnerRef
|
|
|
- select {
|
|
|
- case runner = <-rCh:
|
|
|
- case err = <-eCh:
|
|
|
- handleErrorResponse(c, err)
|
|
|
+ handleScheduleError(c, req.Model, err)
|
|
|
return
|
|
|
}
|
|
|
|
|
@@ -395,17 +268,14 @@ func (s *Server) EmbeddingsHandler(c *gin.Context) {
|
|
|
return
|
|
|
}
|
|
|
|
|
|
- embedding, err := runner.llama.Embedding(c.Request.Context(), req.Prompt)
|
|
|
+ embedding, err := r.Embedding(c.Request.Context(), req.Prompt)
|
|
|
if err != nil {
|
|
|
slog.Info(fmt.Sprintf("embedding generation failed: %v", err))
|
|
|
c.JSON(http.StatusInternalServerError, gin.H{"error": "failed to generate embedding"})
|
|
|
return
|
|
|
}
|
|
|
|
|
|
- resp := api.EmbeddingResponse{
|
|
|
- Embedding: embedding,
|
|
|
- }
|
|
|
- c.JSON(http.StatusOK, resp)
|
|
|
+ c.JSON(http.StatusOK, api.EmbeddingResponse{Embedding: embedding})
|
|
|
}
|
|
|
|
|
|
func (s *Server) PullModelHandler(c *gin.Context) {
|
|
@@ -680,12 +550,15 @@ func GetModelInfo(req api.ShowRequest) (*api.ShowResponse, error) {
|
|
|
}
|
|
|
|
|
|
if req.Template != "" {
|
|
|
- m.Template = req.Template
|
|
|
+ m.Template, err = template.Parse(req.Template)
|
|
|
+ if err != nil {
|
|
|
+ return nil, err
|
|
|
+ }
|
|
|
}
|
|
|
|
|
|
- msgs := make([]api.Message, 0)
|
|
|
- for _, msg := range m.Messages {
|
|
|
- msgs = append(msgs, api.Message{Role: msg.Role, Content: msg.Content})
|
|
|
+ msgs := make([]api.Message, len(m.Messages))
|
|
|
+ for i, msg := range m.Messages {
|
|
|
+ msgs[i] = api.Message{Role: msg.Role, Content: msg.Content}
|
|
|
}
|
|
|
|
|
|
n := model.ParseName(req.Model)
|
|
@@ -701,7 +574,7 @@ func GetModelInfo(req api.ShowRequest) (*api.ShowResponse, error) {
|
|
|
resp := &api.ShowResponse{
|
|
|
License: strings.Join(m.License, "\n"),
|
|
|
System: m.System,
|
|
|
- Template: m.Template,
|
|
|
+ Template: m.Template.String(),
|
|
|
Details: modelDetails,
|
|
|
Messages: msgs,
|
|
|
ModifiedAt: manifest.fi.ModTime(),
|
|
@@ -754,7 +627,11 @@ func GetModelInfo(req api.ShowRequest) (*api.ShowResponse, error) {
|
|
|
}
|
|
|
|
|
|
func getKVData(digest string, verbose bool) (llm.KV, error) {
|
|
|
- kvData, err := llm.LoadModel(digest)
|
|
|
+ maxArraySize := 0
|
|
|
+ if verbose {
|
|
|
+ maxArraySize = -1
|
|
|
+ }
|
|
|
+ kvData, err := llm.LoadModel(digest, maxArraySize)
|
|
|
if err != nil {
|
|
|
return nil, err
|
|
|
}
|
|
@@ -1035,7 +912,10 @@ func (s *Server) GenerateRoutes() http.Handler {
|
|
|
r.GET("/api/ps", s.ProcessHandler)
|
|
|
|
|
|
// Compatibility endpoints
|
|
|
- r.POST("/v1/chat/completions", openai.Middleware(), s.ChatHandler)
|
|
|
+ r.POST("/v1/chat/completions", openai.ChatMiddleware(), s.ChatHandler)
|
|
|
+ r.POST("/v1/completions", openai.CompletionsMiddleware(), s.GenerateHandler)
|
|
|
+ r.GET("/v1/models", openai.ListMiddleware(), s.ListModelsHandler)
|
|
|
+ r.GET("/v1/models/:model", openai.RetrieveMiddleware(), s.ShowModelHandler)
|
|
|
|
|
|
for _, method := range []string{http.MethodGet, http.MethodHead} {
|
|
|
r.Handle(method, "/", func(c *gin.Context) {
|
|
@@ -1101,11 +981,20 @@ func Serve(ln net.Listener) error {
|
|
|
schedCtx, schedDone := context.WithCancel(ctx)
|
|
|
sched := InitScheduler(schedCtx)
|
|
|
s := &Server{addr: ln.Addr(), sched: sched}
|
|
|
- r := s.GenerateRoutes()
|
|
|
+
|
|
|
+ http.Handle("/", s.GenerateRoutes())
|
|
|
|
|
|
slog.Info(fmt.Sprintf("Listening on %s (version %s)", ln.Addr(), version.Version))
|
|
|
srvr := &http.Server{
|
|
|
- Handler: r,
|
|
|
+ // Use http.DefaultServeMux so we get net/http/pprof for
|
|
|
+ // free.
|
|
|
+ //
|
|
|
+ // TODO(bmizerany): Decide if we want to make this
|
|
|
+ // configurable so it is not exposed by default, or allow
|
|
|
+ // users to bind it to a different port. This was a quick
|
|
|
+ // and easy way to get pprof, but it may not be the best
|
|
|
+ // way.
|
|
|
+ Handler: nil,
|
|
|
}
|
|
|
|
|
|
// listen for a ctrl+c and stop any loaded llm
|
|
@@ -1224,142 +1113,63 @@ func (s *Server) ProcessHandler(c *gin.Context) {
|
|
|
models = append(models, mr)
|
|
|
}
|
|
|
|
|
|
- c.JSON(http.StatusOK, api.ProcessResponse{Models: models})
|
|
|
-}
|
|
|
-
|
|
|
-// ChatPrompt builds up a prompt from a series of messages for the currently `loaded` model
|
|
|
-func chatPrompt(ctx context.Context, runner *runnerRef, template string, messages []api.Message, numCtx int) (string, error) {
|
|
|
- encode := func(s string) ([]int, error) {
|
|
|
- return runner.llama.Tokenize(ctx, s)
|
|
|
- }
|
|
|
-
|
|
|
- prompt, err := ChatPrompt(template, messages, numCtx, encode)
|
|
|
- if err != nil {
|
|
|
- return "", err
|
|
|
- }
|
|
|
+ slices.SortStableFunc(models, func(i, j api.ProcessModelResponse) int {
|
|
|
+ // longest duration remaining listed first
|
|
|
+ return cmp.Compare(j.ExpiresAt.Unix(), i.ExpiresAt.Unix())
|
|
|
+ })
|
|
|
|
|
|
- return prompt, nil
|
|
|
+ c.JSON(http.StatusOK, api.ProcessResponse{Models: models})
|
|
|
}
|
|
|
|
|
|
func (s *Server) ChatHandler(c *gin.Context) {
|
|
|
- checkpointStart := time.Now()
|
|
|
-
|
|
|
var req api.ChatRequest
|
|
|
- err := c.ShouldBindJSON(&req)
|
|
|
- switch {
|
|
|
- case errors.Is(err, io.EOF):
|
|
|
+ if err := c.ShouldBindJSON(&req); errors.Is(err, io.EOF) {
|
|
|
c.AbortWithStatusJSON(http.StatusBadRequest, gin.H{"error": "missing request body"})
|
|
|
return
|
|
|
- case err != nil:
|
|
|
+ } else if err != nil {
|
|
|
c.AbortWithStatusJSON(http.StatusBadRequest, gin.H{"error": err.Error()})
|
|
|
return
|
|
|
}
|
|
|
|
|
|
- // validate the request
|
|
|
- switch {
|
|
|
- case req.Model == "":
|
|
|
- c.AbortWithStatusJSON(http.StatusBadRequest, gin.H{"error": "model is required"})
|
|
|
- return
|
|
|
- case len(req.Format) > 0 && req.Format != "json":
|
|
|
- c.AbortWithStatusJSON(http.StatusBadRequest, gin.H{"error": "format must be json"})
|
|
|
- return
|
|
|
- }
|
|
|
-
|
|
|
- model, err := GetModel(req.Model)
|
|
|
- if err != nil {
|
|
|
- var pErr *fs.PathError
|
|
|
- if errors.As(err, &pErr) {
|
|
|
- c.JSON(http.StatusNotFound, gin.H{"error": fmt.Sprintf("model '%s' not found, try pulling it first", req.Model)})
|
|
|
- return
|
|
|
- }
|
|
|
- c.JSON(http.StatusInternalServerError, gin.H{"error": err.Error()})
|
|
|
- return
|
|
|
- }
|
|
|
-
|
|
|
- if model.IsEmbedding() {
|
|
|
- c.AbortWithStatusJSON(http.StatusBadRequest, gin.H{"error": "embedding models do not support chat"})
|
|
|
- return
|
|
|
- }
|
|
|
-
|
|
|
- opts, err := modelOptions(model, req.Options)
|
|
|
- if err != nil {
|
|
|
- c.JSON(http.StatusInternalServerError, gin.H{"error": err.Error()})
|
|
|
+ caps := []Capability{CapabilityCompletion}
|
|
|
+ r, m, opts, err := s.scheduleRunner(c.Request.Context(), req.Model, caps, req.Options, req.KeepAlive)
|
|
|
+ if errors.Is(err, errCapabilityCompletion) {
|
|
|
+ c.JSON(http.StatusBadRequest, gin.H{"error": fmt.Sprintf("%q does not support chat", req.Model)})
|
|
|
return
|
|
|
- }
|
|
|
-
|
|
|
- var sessionDuration time.Duration
|
|
|
- if req.KeepAlive == nil {
|
|
|
- sessionDuration = getDefaultSessionDuration()
|
|
|
- } else {
|
|
|
- sessionDuration = req.KeepAlive.Duration
|
|
|
- }
|
|
|
-
|
|
|
- rCh, eCh := s.sched.GetRunner(c.Request.Context(), model, opts, sessionDuration)
|
|
|
- var runner *runnerRef
|
|
|
- select {
|
|
|
- case runner = <-rCh:
|
|
|
- case err = <-eCh:
|
|
|
- handleErrorResponse(c, err)
|
|
|
- return
|
|
|
- }
|
|
|
-
|
|
|
- checkpointLoaded := time.Now()
|
|
|
-
|
|
|
- // if the first message is not a system message, then add the model's default system message
|
|
|
- if len(req.Messages) > 0 && req.Messages[0].Role != "system" {
|
|
|
- req.Messages = append([]api.Message{
|
|
|
- {
|
|
|
- Role: "system",
|
|
|
- Content: model.System,
|
|
|
- },
|
|
|
- }, req.Messages...)
|
|
|
- }
|
|
|
-
|
|
|
- prompt, err := chatPrompt(c.Request.Context(), runner, model.Template, req.Messages, opts.NumCtx)
|
|
|
- if err != nil {
|
|
|
- c.JSON(http.StatusBadRequest, gin.H{"error": err.Error()})
|
|
|
+ } else if err != nil {
|
|
|
+ handleScheduleError(c, req.Model, err)
|
|
|
return
|
|
|
}
|
|
|
|
|
|
- // an empty request loads the model
|
|
|
- if len(req.Messages) == 0 || prompt == "" {
|
|
|
- resp := api.ChatResponse{
|
|
|
- CreatedAt: time.Now().UTC(),
|
|
|
+ if len(req.Messages) == 0 {
|
|
|
+ c.JSON(http.StatusOK, api.ChatResponse{
|
|
|
Model: req.Model,
|
|
|
+ CreatedAt: time.Now().UTC(),
|
|
|
+ Message: api.Message{Role: "assistant"},
|
|
|
Done: true,
|
|
|
DoneReason: "load",
|
|
|
- Message: api.Message{Role: "assistant"},
|
|
|
- }
|
|
|
- c.JSON(http.StatusOK, resp)
|
|
|
+ })
|
|
|
return
|
|
|
}
|
|
|
|
|
|
- // only send images that are in the prompt
|
|
|
- var i int
|
|
|
- var images []llm.ImageData
|
|
|
- for _, m := range req.Messages {
|
|
|
- for _, img := range m.Images {
|
|
|
- if !isSupportedImageType(img) {
|
|
|
- c.AbortWithStatusJSON(http.StatusBadRequest, gin.H{"error": "unsupported image format"})
|
|
|
- return
|
|
|
- }
|
|
|
-
|
|
|
- if strings.Contains(prompt, fmt.Sprintf("[img-%d]", i)) {
|
|
|
- images = append(images, llm.ImageData{Data: img, ID: i})
|
|
|
- }
|
|
|
- i += 1
|
|
|
- }
|
|
|
+ prompt, images, err := chatPrompt(c.Request.Context(), m, r.Tokenize, opts, req.Messages)
|
|
|
+ if err != nil {
|
|
|
+ c.JSON(http.StatusInternalServerError, gin.H{"error": err.Error()})
|
|
|
+ return
|
|
|
}
|
|
|
|
|
|
- slog.Debug("chat handler", "prompt", prompt, "images", len(images))
|
|
|
+ slog.Debug("chat request", "images", len(images), "prompt", prompt)
|
|
|
|
|
|
ch := make(chan any)
|
|
|
-
|
|
|
go func() {
|
|
|
defer close(ch)
|
|
|
-
|
|
|
- fn := func(r llm.CompletionResponse) {
|
|
|
- resp := api.ChatResponse{
|
|
|
+ if err := r.Completion(c.Request.Context(), llm.CompletionRequest{
|
|
|
+ Prompt: prompt,
|
|
|
+ Images: images,
|
|
|
+ Format: req.Format,
|
|
|
+ Options: opts,
|
|
|
+ }, func(r llm.CompletionResponse) {
|
|
|
+ ch <- api.ChatResponse{
|
|
|
Model: req.Model,
|
|
|
CreatedAt: time.Now().UTC(),
|
|
|
Message: api.Message{Role: "assistant", Content: r.Content},
|
|
@@ -1372,64 +1182,52 @@ func (s *Server) ChatHandler(c *gin.Context) {
|
|
|
EvalDuration: r.EvalDuration,
|
|
|
},
|
|
|
}
|
|
|
-
|
|
|
- if r.Done {
|
|
|
- resp.TotalDuration = time.Since(checkpointStart)
|
|
|
- resp.LoadDuration = checkpointLoaded.Sub(checkpointStart)
|
|
|
- }
|
|
|
-
|
|
|
- ch <- resp
|
|
|
- }
|
|
|
-
|
|
|
- if err := runner.llama.Completion(c.Request.Context(), llm.CompletionRequest{
|
|
|
- Prompt: prompt,
|
|
|
- Format: req.Format,
|
|
|
- Images: images,
|
|
|
- Options: opts,
|
|
|
- }, fn); err != nil {
|
|
|
+ }); err != nil {
|
|
|
ch <- gin.H{"error": err.Error()}
|
|
|
}
|
|
|
}()
|
|
|
|
|
|
if req.Stream != nil && !*req.Stream {
|
|
|
- // Accumulate responses into the final response
|
|
|
- var final api.ChatResponse
|
|
|
+ var r api.ChatResponse
|
|
|
var sb strings.Builder
|
|
|
- for resp := range ch {
|
|
|
- switch r := resp.(type) {
|
|
|
+ for rr := range ch {
|
|
|
+ switch t := rr.(type) {
|
|
|
case api.ChatResponse:
|
|
|
- sb.WriteString(r.Message.Content)
|
|
|
- final = r
|
|
|
+ sb.WriteString(t.Message.Content)
|
|
|
+ r = t
|
|
|
case gin.H:
|
|
|
- if errorMsg, ok := r["error"].(string); ok {
|
|
|
- c.JSON(http.StatusInternalServerError, gin.H{"error": errorMsg})
|
|
|
- return
|
|
|
- } else {
|
|
|
- c.JSON(http.StatusInternalServerError, gin.H{"error": "unexpected error format in response"})
|
|
|
- return
|
|
|
+ msg, ok := t["error"].(string)
|
|
|
+ if !ok {
|
|
|
+ msg = "unexpected error format in response"
|
|
|
}
|
|
|
+
|
|
|
+ c.JSON(http.StatusInternalServerError, gin.H{"error": msg})
|
|
|
+ return
|
|
|
default:
|
|
|
- c.JSON(http.StatusInternalServerError, gin.H{"error": "unexpected error"})
|
|
|
+ c.JSON(http.StatusInternalServerError, gin.H{"error": "unexpected response"})
|
|
|
return
|
|
|
}
|
|
|
}
|
|
|
|
|
|
- final.Message = api.Message{Role: "assistant", Content: sb.String()}
|
|
|
- c.JSON(http.StatusOK, final)
|
|
|
+ r.Message.Content = sb.String()
|
|
|
+ c.JSON(http.StatusOK, r)
|
|
|
return
|
|
|
}
|
|
|
|
|
|
streamResponse(c, ch)
|
|
|
}
|
|
|
|
|
|
-func handleErrorResponse(c *gin.Context, err error) {
|
|
|
- if errors.Is(err, context.Canceled) {
|
|
|
+func handleScheduleError(c *gin.Context, name string, err error) {
|
|
|
+ switch {
|
|
|
+ case errors.Is(err, errRequired):
|
|
|
+ c.JSON(http.StatusBadRequest, gin.H{"error": err.Error()})
|
|
|
+ case errors.Is(err, context.Canceled):
|
|
|
c.JSON(499, gin.H{"error": "request canceled"})
|
|
|
- return
|
|
|
- }
|
|
|
- if errors.Is(err, ErrMaxQueue) {
|
|
|
+ case errors.Is(err, ErrMaxQueue):
|
|
|
c.JSON(http.StatusServiceUnavailable, gin.H{"error": err.Error()})
|
|
|
- return
|
|
|
+ case errors.Is(err, os.ErrNotExist):
|
|
|
+ c.JSON(http.StatusNotFound, gin.H{"error": fmt.Sprintf("model %q not found, try pulling it first", name)})
|
|
|
+ default:
|
|
|
+ c.JSON(http.StatusInternalServerError, gin.H{"error": err.Error()})
|
|
|
}
|
|
|
- c.JSON(http.StatusInternalServerError, gin.H{"error": err.Error()})
|
|
|
}
|