|
@@ -60,17 +60,26 @@ var loaded struct {
|
|
var defaultSessionDuration = 5 * time.Minute
|
|
var defaultSessionDuration = 5 * time.Minute
|
|
|
|
|
|
// load a model into memory if it is not already loaded, it is up to the caller to lock loaded.mu before calling this function
|
|
// load a model into memory if it is not already loaded, it is up to the caller to lock loaded.mu before calling this function
|
|
-func load(ctx context.Context, workDir string, model *Model, reqOpts map[string]interface{}, sessionDuration time.Duration) error {
|
|
|
|
|
|
+func load(c *gin.Context, modelName string, reqOpts map[string]interface{}, sessionDuration time.Duration) (*Model, error) {
|
|
|
|
+ model, err := GetModel(modelName)
|
|
|
|
+ if err != nil {
|
|
|
|
+ return nil, err
|
|
|
|
+ }
|
|
|
|
+
|
|
|
|
+ workDir := c.GetString("workDir")
|
|
|
|
+
|
|
opts := api.DefaultOptions()
|
|
opts := api.DefaultOptions()
|
|
if err := opts.FromMap(model.Options); err != nil {
|
|
if err := opts.FromMap(model.Options); err != nil {
|
|
log.Printf("could not load model options: %v", err)
|
|
log.Printf("could not load model options: %v", err)
|
|
- return err
|
|
|
|
|
|
+ return nil, err
|
|
}
|
|
}
|
|
|
|
|
|
if err := opts.FromMap(reqOpts); err != nil {
|
|
if err := opts.FromMap(reqOpts); err != nil {
|
|
- return err
|
|
|
|
|
|
+ return nil, err
|
|
}
|
|
}
|
|
|
|
|
|
|
|
+ ctx := c.Request.Context()
|
|
|
|
+
|
|
// check if the loaded model is still running in a subprocess, in case something unexpected happened
|
|
// check if the loaded model is still running in a subprocess, in case something unexpected happened
|
|
if loaded.runner != nil {
|
|
if loaded.runner != nil {
|
|
if err := loaded.runner.Ping(ctx); err != nil {
|
|
if err := loaded.runner.Ping(ctx); err != nil {
|
|
@@ -106,7 +115,7 @@ func load(ctx context.Context, workDir string, model *Model, reqOpts map[string]
|
|
err = fmt.Errorf("%v: this model may be incompatible with your version of Ollama. If you previously pulled this model, try updating it by running `ollama pull %s`", err, model.ShortName)
|
|
err = fmt.Errorf("%v: this model may be incompatible with your version of Ollama. If you previously pulled this model, try updating it by running `ollama pull %s`", err, model.ShortName)
|
|
}
|
|
}
|
|
|
|
|
|
- return err
|
|
|
|
|
|
+ return nil, err
|
|
}
|
|
}
|
|
|
|
|
|
loaded.Model = model
|
|
loaded.Model = model
|
|
@@ -140,7 +149,7 @@ func load(ctx context.Context, workDir string, model *Model, reqOpts map[string]
|
|
}
|
|
}
|
|
|
|
|
|
loaded.expireTimer.Reset(sessionDuration)
|
|
loaded.expireTimer.Reset(sessionDuration)
|
|
- return nil
|
|
|
|
|
|
+ return model, nil
|
|
}
|
|
}
|
|
|
|
|
|
func GenerateHandler(c *gin.Context) {
|
|
func GenerateHandler(c *gin.Context) {
|
|
@@ -173,88 +182,135 @@ func GenerateHandler(c *gin.Context) {
|
|
return
|
|
return
|
|
}
|
|
}
|
|
|
|
|
|
- model, err := GetModel(req.Model)
|
|
|
|
|
|
+ sessionDuration := defaultSessionDuration
|
|
|
|
+ model, err := load(c, req.Model, req.Options, sessionDuration)
|
|
if err != nil {
|
|
if err != nil {
|
|
var pErr *fs.PathError
|
|
var pErr *fs.PathError
|
|
- if errors.As(err, &pErr) {
|
|
|
|
|
|
+ switch {
|
|
|
|
+ case errors.As(err, &pErr):
|
|
c.JSON(http.StatusNotFound, gin.H{"error": fmt.Sprintf("model '%s' not found, try pulling it first", req.Model)})
|
|
c.JSON(http.StatusNotFound, gin.H{"error": fmt.Sprintf("model '%s' not found, try pulling it first", req.Model)})
|
|
- return
|
|
|
|
|
|
+ case errors.Is(err, api.ErrInvalidOpts):
|
|
|
|
+ c.JSON(http.StatusBadRequest, gin.H{"error": err.Error()})
|
|
|
|
+ default:
|
|
|
|
+ c.JSON(http.StatusInternalServerError, gin.H{"error": err.Error()})
|
|
}
|
|
}
|
|
- c.JSON(http.StatusBadRequest, gin.H{"error": err.Error()})
|
|
|
|
return
|
|
return
|
|
}
|
|
}
|
|
|
|
|
|
- workDir := c.GetString("workDir")
|
|
|
|
-
|
|
|
|
- // TODO: set this duration from the request if specified
|
|
|
|
- sessionDuration := defaultSessionDuration
|
|
|
|
- if err := load(c.Request.Context(), workDir, model, req.Options, sessionDuration); err != nil {
|
|
|
|
- if errors.Is(err, api.ErrInvalidOpts) {
|
|
|
|
- c.JSON(http.StatusBadRequest, gin.H{"error": err.Error()})
|
|
|
|
- return
|
|
|
|
- }
|
|
|
|
- c.JSON(http.StatusInternalServerError, gin.H{"error": err.Error()})
|
|
|
|
|
|
+ // an empty request loads the model
|
|
|
|
+ if req.Prompt == "" && req.Template == "" && req.System == "" {
|
|
|
|
+ c.JSON(http.StatusOK, api.GenerateResponse{CreatedAt: time.Now().UTC(), Model: req.Model, Done: true})
|
|
return
|
|
return
|
|
}
|
|
}
|
|
|
|
|
|
checkpointLoaded := time.Now()
|
|
checkpointLoaded := time.Now()
|
|
|
|
|
|
- prompt := req.Prompt
|
|
|
|
- if !req.Raw {
|
|
|
|
- prompt, err = model.Prompt(req)
|
|
|
|
|
|
+ var prompt string
|
|
|
|
+ switch {
|
|
|
|
+ case req.Raw:
|
|
|
|
+ prompt = req.Prompt
|
|
|
|
+ case req.Prompt != "":
|
|
|
|
+ if req.Template != "" {
|
|
|
|
+ // override the default model template
|
|
|
|
+ model.Template = req.Template
|
|
|
|
+ }
|
|
|
|
+
|
|
|
|
+ var rebuild strings.Builder
|
|
|
|
+ if req.Context != nil {
|
|
|
|
+ // TODO: context is deprecated, at some point the context logic within this conditional should be removed
|
|
|
|
+ prevCtx, err := loaded.runner.Decode(c.Request.Context(), req.Context)
|
|
|
|
+ if err != nil {
|
|
|
|
+ c.JSON(http.StatusInternalServerError, gin.H{"error": err.Error()})
|
|
|
|
+ return
|
|
|
|
+ }
|
|
|
|
+
|
|
|
|
+ // Remove leading spaces from prevCtx if present
|
|
|
|
+ prevCtx = strings.TrimPrefix(prevCtx, " ")
|
|
|
|
+ rebuild.WriteString(prevCtx)
|
|
|
|
+ }
|
|
|
|
+ p, err := model.Prompt(PromptVars{
|
|
|
|
+ System: req.System,
|
|
|
|
+ Prompt: req.Prompt,
|
|
|
|
+ First: len(req.Context) == 0,
|
|
|
|
+ })
|
|
if err != nil {
|
|
if err != nil {
|
|
c.JSON(http.StatusInternalServerError, gin.H{"error": err.Error()})
|
|
c.JSON(http.StatusInternalServerError, gin.H{"error": err.Error()})
|
|
return
|
|
return
|
|
}
|
|
}
|
|
|
|
+ rebuild.WriteString(p)
|
|
|
|
+ prompt = rebuild.String()
|
|
}
|
|
}
|
|
|
|
|
|
ch := make(chan any)
|
|
ch := make(chan any)
|
|
|
|
+ var generated strings.Builder
|
|
go func() {
|
|
go func() {
|
|
defer close(ch)
|
|
defer close(ch)
|
|
- // an empty request loads the model
|
|
|
|
- if req.Prompt == "" && req.Template == "" && req.System == "" {
|
|
|
|
- ch <- api.GenerateResponse{CreatedAt: time.Now().UTC(), Model: req.Model, Done: true}
|
|
|
|
- return
|
|
|
|
- }
|
|
|
|
|
|
|
|
- fn := func(r api.GenerateResponse) {
|
|
|
|
|
|
+ fn := func(r llm.PredictResult) {
|
|
|
|
+ // Update model expiration
|
|
loaded.expireAt = time.Now().Add(sessionDuration)
|
|
loaded.expireAt = time.Now().Add(sessionDuration)
|
|
loaded.expireTimer.Reset(sessionDuration)
|
|
loaded.expireTimer.Reset(sessionDuration)
|
|
|
|
|
|
- r.Model = req.Model
|
|
|
|
- r.CreatedAt = time.Now().UTC()
|
|
|
|
- if r.Done {
|
|
|
|
- r.TotalDuration = time.Since(checkpointStart)
|
|
|
|
- r.LoadDuration = checkpointLoaded.Sub(checkpointStart)
|
|
|
|
|
|
+ // Build up the full response
|
|
|
|
+ if _, err := generated.WriteString(r.Content); err != nil {
|
|
|
|
+ ch <- gin.H{"error": err.Error()}
|
|
|
|
+ return
|
|
}
|
|
}
|
|
|
|
|
|
- if req.Raw {
|
|
|
|
- // in raw mode the client must manage history on their own
|
|
|
|
- r.Context = nil
|
|
|
|
|
|
+ resp := api.GenerateResponse{
|
|
|
|
+ Model: r.Model,
|
|
|
|
+ CreatedAt: r.CreatedAt,
|
|
|
|
+ Done: r.Done,
|
|
|
|
+ Response: r.Content,
|
|
|
|
+ Metrics: api.Metrics{
|
|
|
|
+ TotalDuration: r.TotalDuration,
|
|
|
|
+ LoadDuration: r.LoadDuration,
|
|
|
|
+ PromptEvalCount: r.PromptEvalCount,
|
|
|
|
+ PromptEvalDuration: r.PromptEvalDuration,
|
|
|
|
+ EvalCount: r.EvalCount,
|
|
|
|
+ EvalDuration: r.EvalDuration,
|
|
|
|
+ },
|
|
}
|
|
}
|
|
|
|
|
|
- ch <- r
|
|
|
|
|
|
+ if r.Done && !req.Raw {
|
|
|
|
+ embd, err := loaded.runner.Encode(c.Request.Context(), req.Prompt+generated.String())
|
|
|
|
+ if err != nil {
|
|
|
|
+ ch <- gin.H{"error": err.Error()}
|
|
|
|
+ return
|
|
|
|
+ }
|
|
|
|
+ resp.Context = embd
|
|
|
|
+ }
|
|
|
|
+
|
|
|
|
+ ch <- resp
|
|
}
|
|
}
|
|
|
|
|
|
- if err := loaded.runner.Predict(c.Request.Context(), req.Context, prompt, req.Format, fn); err != nil {
|
|
|
|
|
|
+ // Start prediction
|
|
|
|
+ predictReq := llm.PredictOpts{
|
|
|
|
+ Model: model.Name,
|
|
|
|
+ Prompt: prompt,
|
|
|
|
+ Format: req.Format,
|
|
|
|
+ CheckpointStart: checkpointStart,
|
|
|
|
+ CheckpointLoaded: checkpointLoaded,
|
|
|
|
+ }
|
|
|
|
+ if err := loaded.runner.Predict(c.Request.Context(), predictReq, fn); err != nil {
|
|
ch <- gin.H{"error": err.Error()}
|
|
ch <- gin.H{"error": err.Error()}
|
|
}
|
|
}
|
|
}()
|
|
}()
|
|
|
|
|
|
if req.Stream != nil && !*req.Stream {
|
|
if req.Stream != nil && !*req.Stream {
|
|
- var response api.GenerateResponse
|
|
|
|
- generated := ""
|
|
|
|
|
|
+ // Wait for the channel to close
|
|
|
|
+ var r api.GenerateResponse
|
|
|
|
+ var sb strings.Builder
|
|
for resp := range ch {
|
|
for resp := range ch {
|
|
- if r, ok := resp.(api.GenerateResponse); ok {
|
|
|
|
- generated += r.Response
|
|
|
|
- response = r
|
|
|
|
- } else {
|
|
|
|
|
|
+ var ok bool
|
|
|
|
+ if r, ok = resp.(api.GenerateResponse); !ok {
|
|
c.JSON(http.StatusInternalServerError, gin.H{"error": err.Error()})
|
|
c.JSON(http.StatusInternalServerError, gin.H{"error": err.Error()})
|
|
return
|
|
return
|
|
}
|
|
}
|
|
|
|
+ sb.WriteString(r.Response)
|
|
}
|
|
}
|
|
- response.Response = generated
|
|
|
|
- c.JSON(http.StatusOK, response)
|
|
|
|
|
|
+ r.Response = sb.String()
|
|
|
|
+ c.JSON(http.StatusOK, r)
|
|
return
|
|
return
|
|
}
|
|
}
|
|
|
|
|
|
@@ -281,15 +337,18 @@ func EmbeddingHandler(c *gin.Context) {
|
|
return
|
|
return
|
|
}
|
|
}
|
|
|
|
|
|
- model, err := GetModel(req.Model)
|
|
|
|
|
|
+ sessionDuration := defaultSessionDuration
|
|
|
|
+ _, err = load(c, req.Model, req.Options, sessionDuration)
|
|
if err != nil {
|
|
if err != nil {
|
|
- c.JSON(http.StatusBadRequest, gin.H{"error": err.Error()})
|
|
|
|
- return
|
|
|
|
- }
|
|
|
|
-
|
|
|
|
- workDir := c.GetString("workDir")
|
|
|
|
- if err := load(c.Request.Context(), workDir, model, req.Options, 5*time.Minute); err != nil {
|
|
|
|
- c.JSON(http.StatusBadRequest, gin.H{"error": err.Error()})
|
|
|
|
|
|
+ var pErr *fs.PathError
|
|
|
|
+ switch {
|
|
|
|
+ case errors.As(err, &pErr):
|
|
|
|
+ c.JSON(http.StatusNotFound, gin.H{"error": fmt.Sprintf("model '%s' not found, try pulling it first", req.Model)})
|
|
|
|
+ case errors.Is(err, api.ErrInvalidOpts):
|
|
|
|
+ c.JSON(http.StatusBadRequest, gin.H{"error": err.Error()})
|
|
|
|
+ default:
|
|
|
|
+ c.JSON(http.StatusInternalServerError, gin.H{"error": err.Error()})
|
|
|
|
+ }
|
|
return
|
|
return
|
|
}
|
|
}
|
|
|
|
|
|
@@ -767,6 +826,7 @@ func Serve(ln net.Listener, allowOrigins []string) error {
|
|
|
|
|
|
r.POST("/api/pull", PullModelHandler)
|
|
r.POST("/api/pull", PullModelHandler)
|
|
r.POST("/api/generate", GenerateHandler)
|
|
r.POST("/api/generate", GenerateHandler)
|
|
|
|
+ r.POST("/api/chat", ChatHandler)
|
|
r.POST("/api/embeddings", EmbeddingHandler)
|
|
r.POST("/api/embeddings", EmbeddingHandler)
|
|
r.POST("/api/create", CreateModelHandler)
|
|
r.POST("/api/create", CreateModelHandler)
|
|
r.POST("/api/push", PushModelHandler)
|
|
r.POST("/api/push", PushModelHandler)
|
|
@@ -860,3 +920,125 @@ func streamResponse(c *gin.Context, ch chan any) {
|
|
return true
|
|
return true
|
|
})
|
|
})
|
|
}
|
|
}
|
|
|
|
+
|
|
|
|
+func ChatHandler(c *gin.Context) {
|
|
|
|
+ loaded.mu.Lock()
|
|
|
|
+ defer loaded.mu.Unlock()
|
|
|
|
+
|
|
|
|
+ checkpointStart := time.Now()
|
|
|
|
+
|
|
|
|
+ var req api.ChatRequest
|
|
|
|
+ err := c.ShouldBindJSON(&req)
|
|
|
|
+ switch {
|
|
|
|
+ case errors.Is(err, io.EOF):
|
|
|
|
+ c.AbortWithStatusJSON(http.StatusBadRequest, gin.H{"error": "missing request body"})
|
|
|
|
+ return
|
|
|
|
+ case err != nil:
|
|
|
|
+ c.AbortWithStatusJSON(http.StatusBadRequest, gin.H{"error": err.Error()})
|
|
|
|
+ return
|
|
|
|
+ }
|
|
|
|
+
|
|
|
|
+ // validate the request
|
|
|
|
+ switch {
|
|
|
|
+ case req.Model == "":
|
|
|
|
+ c.AbortWithStatusJSON(http.StatusBadRequest, gin.H{"error": "model is required"})
|
|
|
|
+ return
|
|
|
|
+ case len(req.Format) > 0 && req.Format != "json":
|
|
|
|
+ c.AbortWithStatusJSON(http.StatusBadRequest, gin.H{"error": "format must be json"})
|
|
|
|
+ return
|
|
|
|
+ }
|
|
|
|
+
|
|
|
|
+ sessionDuration := defaultSessionDuration
|
|
|
|
+ model, err := load(c, req.Model, req.Options, sessionDuration)
|
|
|
|
+ if err != nil {
|
|
|
|
+ var pErr *fs.PathError
|
|
|
|
+ switch {
|
|
|
|
+ case errors.As(err, &pErr):
|
|
|
|
+ c.JSON(http.StatusNotFound, gin.H{"error": fmt.Sprintf("model '%s' not found, try pulling it first", req.Model)})
|
|
|
|
+ case errors.Is(err, api.ErrInvalidOpts):
|
|
|
|
+ c.JSON(http.StatusBadRequest, gin.H{"error": err.Error()})
|
|
|
|
+ default:
|
|
|
|
+ c.JSON(http.StatusInternalServerError, gin.H{"error": err.Error()})
|
|
|
|
+ }
|
|
|
|
+ return
|
|
|
|
+ }
|
|
|
|
+
|
|
|
|
+ // an empty request loads the model
|
|
|
|
+ if len(req.Messages) == 0 {
|
|
|
|
+ c.JSON(http.StatusOK, api.ChatResponse{CreatedAt: time.Now().UTC(), Model: req.Model, Done: true})
|
|
|
|
+ return
|
|
|
|
+ }
|
|
|
|
+
|
|
|
|
+ checkpointLoaded := time.Now()
|
|
|
|
+
|
|
|
|
+ prompt, err := model.ChatPrompt(req.Messages)
|
|
|
|
+ if err != nil {
|
|
|
|
+ c.JSON(http.StatusBadRequest, gin.H{"error": err.Error()})
|
|
|
|
+ return
|
|
|
|
+ }
|
|
|
|
+
|
|
|
|
+ ch := make(chan any)
|
|
|
|
+
|
|
|
|
+ go func() {
|
|
|
|
+ defer close(ch)
|
|
|
|
+
|
|
|
|
+ fn := func(r llm.PredictResult) {
|
|
|
|
+ // Update model expiration
|
|
|
|
+ loaded.expireAt = time.Now().Add(sessionDuration)
|
|
|
|
+ loaded.expireTimer.Reset(sessionDuration)
|
|
|
|
+
|
|
|
|
+ resp := api.ChatResponse{
|
|
|
|
+ Model: r.Model,
|
|
|
|
+ CreatedAt: r.CreatedAt,
|
|
|
|
+ Done: r.Done,
|
|
|
|
+ Metrics: api.Metrics{
|
|
|
|
+ TotalDuration: r.TotalDuration,
|
|
|
|
+ LoadDuration: r.LoadDuration,
|
|
|
|
+ PromptEvalCount: r.PromptEvalCount,
|
|
|
|
+ PromptEvalDuration: r.PromptEvalDuration,
|
|
|
|
+ EvalCount: r.EvalCount,
|
|
|
|
+ EvalDuration: r.EvalDuration,
|
|
|
|
+ },
|
|
|
|
+ }
|
|
|
|
+
|
|
|
|
+ if !r.Done {
|
|
|
|
+ resp.Message = &api.Message{Role: "assistant", Content: r.Content}
|
|
|
|
+ }
|
|
|
|
+
|
|
|
|
+ ch <- resp
|
|
|
|
+ }
|
|
|
|
+
|
|
|
|
+ // Start prediction
|
|
|
|
+ predictReq := llm.PredictOpts{
|
|
|
|
+ Model: model.Name,
|
|
|
|
+ Prompt: prompt,
|
|
|
|
+ Format: req.Format,
|
|
|
|
+ CheckpointStart: checkpointStart,
|
|
|
|
+ CheckpointLoaded: checkpointLoaded,
|
|
|
|
+ }
|
|
|
|
+ if err := loaded.runner.Predict(c.Request.Context(), predictReq, fn); err != nil {
|
|
|
|
+ ch <- gin.H{"error": err.Error()}
|
|
|
|
+ }
|
|
|
|
+ }()
|
|
|
|
+
|
|
|
|
+ if req.Stream != nil && !*req.Stream {
|
|
|
|
+ // Wait for the channel to close
|
|
|
|
+ var r api.ChatResponse
|
|
|
|
+ var sb strings.Builder
|
|
|
|
+ for resp := range ch {
|
|
|
|
+ var ok bool
|
|
|
|
+ if r, ok = resp.(api.ChatResponse); !ok {
|
|
|
|
+ c.JSON(http.StatusInternalServerError, gin.H{"error": err.Error()})
|
|
|
|
+ return
|
|
|
|
+ }
|
|
|
|
+ if r.Message != nil {
|
|
|
|
+ sb.WriteString(r.Message.Content)
|
|
|
|
+ }
|
|
|
|
+ }
|
|
|
|
+ r.Message = &api.Message{Role: "assistant", Content: sb.String()}
|
|
|
|
+ c.JSON(http.StatusOK, r)
|
|
|
|
+ return
|
|
|
|
+ }
|
|
|
|
+
|
|
|
|
+ streamResponse(c, ch)
|
|
|
|
+}
|