Просмотр исходного кода

pass the template to the `/api/chat` endpoint

Patrick Devine 9 месяцев назад
Родитель
Сommit
3c0d043b79
4 измененных файлов с 34 добавлено и 8 удалено
  1. 3 0
      api/types.go
  2. 1 0
      cmd/cmd.go
  3. 19 4
      cmd/interactive.go
  4. 11 4
      server/routes.go

+ 3 - 0
api/types.go

@@ -84,6 +84,9 @@ type ChatRequest struct {
 	// Model is the model name, as in [GenerateRequest].
 	Model string `json:"model"`
 
+	// Template overrides the model's default prompt template.
+	Template string `json:"template"`
+
 	// Messages is the messages of the chat - can be used to keep a chat memory.
 	Messages []Message `json:"messages"`
 

+ 1 - 0
cmd/cmd.go

@@ -947,6 +947,7 @@ func chat(cmd *cobra.Command, opts runOptions) (*api.Message, error) {
 
 	req := &api.ChatRequest{
 		Model:    opts.Model,
+		Template: opts.Template,
 		Messages: opts.Messages,
 		Format:   opts.Format,
 		Options:  opts.Options,

+ 19 - 4
cmd/interactive.go

@@ -18,6 +18,7 @@ import (
 	"github.com/ollama/ollama/envconfig"
 	"github.com/ollama/ollama/progress"
 	"github.com/ollama/ollama/readline"
+	"github.com/ollama/ollama/template"
 	"github.com/ollama/ollama/types/errtypes"
 )
 
@@ -205,9 +206,17 @@ func generateInteractive(cmd *cobra.Command, opts runOptions) error {
 				fmt.Println("Set system message.")
 				sb.Reset()
 			case MultilineTemplate:
-				opts.Template = sb.String()
-				fmt.Println("Set prompt template.")
+				mTemplate := sb.String()
 				sb.Reset()
+				_, err := template.Parse(mTemplate)
+				if err != nil {
+					multiline = MultilineNone
+					scanner.Prompt.UseAlt = false
+					fmt.Println("The template is invalid.")
+					continue
+				}
+				opts.Template = mTemplate
+				fmt.Println("Set prompt template.")
 			}
 
 			multiline = MultilineNone
@@ -369,9 +378,15 @@ func generateInteractive(cmd *cobra.Command, opts runOptions) error {
 						fmt.Println("Set system message.")
 						sb.Reset()
 					} else if args[1] == "template" {
-						opts.Template = sb.String()
-						fmt.Println("Set prompt template.")
+						mTemplate := sb.String()
 						sb.Reset()
+						_, err := template.Parse(mTemplate)
+						if err != nil {
+							fmt.Println("The template is invalid.")
+							continue
+						}
+						opts.Template = mTemplate
+						fmt.Println("Set prompt template.")
 					}
 
 					sb.Reset()

+ 11 - 4
server/routes.go

@@ -71,7 +71,7 @@ func modelOptions(model *Model, requestOpts map[string]interface{}) (api.Options
 
 // scheduleRunner schedules a runner after validating inputs such as capabilities and model options.
 // It returns the allocated runner, model instance, and consolidated options if successful and error otherwise.
-func (s *Server) scheduleRunner(ctx context.Context, name string, caps []Capability, requestOpts map[string]any, keepAlive *api.Duration) (llm.LlamaServer, *Model, *api.Options, error) {
+func (s *Server) scheduleRunner(ctx context.Context, name string, mTemplate string, caps []Capability, requestOpts map[string]any, keepAlive *api.Duration) (llm.LlamaServer, *Model, *api.Options, error) {
 	if name == "" {
 		return nil, nil, nil, fmt.Errorf("model %w", errRequired)
 	}
@@ -81,6 +81,13 @@ func (s *Server) scheduleRunner(ctx context.Context, name string, caps []Capabil
 		return nil, nil, nil, err
 	}
 
+	if mTemplate != "" {
+		model.Template, err = template.Parse(mTemplate)
+		if err != nil {
+			return nil, nil, nil, err
+		}
+	}
+
 	if err := model.CheckCapabilities(caps...); err != nil {
 		return nil, nil, nil, fmt.Errorf("%s %w", name, err)
 	}
@@ -120,7 +127,7 @@ func (s *Server) GenerateHandler(c *gin.Context) {
 	}
 
 	caps := []Capability{CapabilityCompletion}
-	r, m, opts, err := s.scheduleRunner(c.Request.Context(), req.Model, caps, req.Options, req.KeepAlive)
+	r, m, opts, err := s.scheduleRunner(c.Request.Context(), req.Model, "", caps, req.Options, req.KeepAlive)
 	if errors.Is(err, errCapabilityCompletion) {
 		c.JSON(http.StatusBadRequest, gin.H{"error": fmt.Sprintf("%q does not support generate", req.Model)})
 		return
@@ -256,7 +263,7 @@ func (s *Server) EmbeddingsHandler(c *gin.Context) {
 		return
 	}
 
-	r, _, _, err := s.scheduleRunner(c.Request.Context(), req.Model, []Capability{}, req.Options, req.KeepAlive)
+	r, _, _, err := s.scheduleRunner(c.Request.Context(), req.Model, "", []Capability{}, req.Options, req.KeepAlive)
 	if err != nil {
 		handleScheduleError(c, req.Model, err)
 		return
@@ -1132,7 +1139,7 @@ func (s *Server) ChatHandler(c *gin.Context) {
 	}
 
 	caps := []Capability{CapabilityCompletion}
-	r, m, opts, err := s.scheduleRunner(c.Request.Context(), req.Model, caps, req.Options, req.KeepAlive)
+	r, m, opts, err := s.scheduleRunner(c.Request.Context(), req.Model, req.Template, caps, req.Options, req.KeepAlive)
 	if errors.Is(err, errCapabilityCompletion) {
 		c.JSON(http.StatusBadRequest, gin.H{"error": fmt.Sprintf("%q does not support chat", req.Model)})
 		return