|
@@ -31,6 +31,7 @@ import (
|
|
|
"github.com/ollama/ollama/llm"
|
|
|
"github.com/ollama/ollama/openai"
|
|
|
"github.com/ollama/ollama/parser"
|
|
|
+ "github.com/ollama/ollama/template"
|
|
|
"github.com/ollama/ollama/types/errtypes"
|
|
|
"github.com/ollama/ollama/types/model"
|
|
|
"github.com/ollama/ollama/version"
|
|
@@ -161,6 +162,12 @@ func (s *Server) GenerateHandler(c *gin.Context) {
|
|
|
return
|
|
|
}
|
|
|
|
|
|
+ tmpl, err := template.Parse(req.Template)
|
|
|
+ if err != nil {
|
|
|
+ c.JSON(http.StatusInternalServerError, gin.H{"error": err.Error()})
|
|
|
+ return
|
|
|
+ }
|
|
|
+
|
|
|
checkpointLoaded := time.Now()
|
|
|
|
|
|
var prompt string
|
|
@@ -169,7 +176,11 @@ func (s *Server) GenerateHandler(c *gin.Context) {
|
|
|
prompt = req.Prompt
|
|
|
case req.Prompt != "":
|
|
|
if req.Template == "" {
|
|
|
- req.Template = model.Template
|
|
|
+ model.Template, err = template.Parse(req.Template)
|
|
|
+ if err != nil {
|
|
|
+ c.JSON(http.StatusInternalServerError, gin.H{"error": err.Error()})
|
|
|
+ return
|
|
|
+ }
|
|
|
}
|
|
|
|
|
|
if req.System == "" {
|
|
@@ -187,7 +198,7 @@ func (s *Server) GenerateHandler(c *gin.Context) {
|
|
|
|
|
|
sb.WriteString(req.Prompt)
|
|
|
|
|
|
- p, err := Prompt(req.Template, req.System, sb.String(), "", true)
|
|
|
+ p, err := Prompt(tmpl, req.System, sb.String(), "", true)
|
|
|
if err != nil {
|
|
|
c.JSON(http.StatusInternalServerError, gin.H{"error": err.Error()})
|
|
|
return
|
|
@@ -242,7 +253,7 @@ func (s *Server) GenerateHandler(c *gin.Context) {
|
|
|
resp.LoadDuration = checkpointLoaded.Sub(checkpointStart)
|
|
|
|
|
|
if !req.Raw {
|
|
|
- p, err := Prompt(req.Template, req.System, req.Prompt, generated.String(), false)
|
|
|
+ p, err := Prompt(tmpl, req.System, req.Prompt, generated.String(), false)
|
|
|
if err != nil {
|
|
|
c.JSON(http.StatusInternalServerError, gin.H{"error": err.Error()})
|
|
|
return
|
|
@@ -680,7 +691,10 @@ func GetModelInfo(req api.ShowRequest) (*api.ShowResponse, error) {
|
|
|
}
|
|
|
|
|
|
if req.Template != "" {
|
|
|
- m.Template = req.Template
|
|
|
+ m.Template, err = template.Parse(req.Template)
|
|
|
+ if err != nil {
|
|
|
+ return nil, err
|
|
|
+ }
|
|
|
}
|
|
|
|
|
|
msgs := make([]api.Message, 0)
|
|
@@ -701,7 +715,7 @@ func GetModelInfo(req api.ShowRequest) (*api.ShowResponse, error) {
|
|
|
resp := &api.ShowResponse{
|
|
|
License: strings.Join(m.License, "\n"),
|
|
|
System: m.System,
|
|
|
- Template: m.Template,
|
|
|
+ Template: m.Template.String(),
|
|
|
Details: modelDetails,
|
|
|
Messages: msgs,
|
|
|
ModifiedAt: manifest.fi.ModTime(),
|
|
@@ -1246,7 +1260,7 @@ func (s *Server) ProcessHandler(c *gin.Context) {
|
|
|
}
|
|
|
|
|
|
// ChatPrompt builds up a prompt from a series of messages for the currently `loaded` model
|
|
|
-func chatPrompt(ctx context.Context, runner *runnerRef, template string, messages []api.Message, numCtx int) (string, error) {
|
|
|
+func chatPrompt(ctx context.Context, runner *runnerRef, template *template.Template, messages []api.Message, numCtx int) (string, error) {
|
|
|
encode := func(s string) ([]int, error) {
|
|
|
return runner.llama.Tokenize(ctx, s)
|
|
|
}
|