Browse Source

Merge pull request #5196 from ollama/mxyng/messages-2

include modelfile messages
Michael Yang 9 months ago
parent
commit
c4c84b7a0d
4 changed files with 36 additions and 40 deletions
  1. 16 1
      cmd/cmd.go
  2. 1 20
      cmd/interactive.go
  3. 1 6
      server/images.go
  4. 18 13
      server/routes.go

+ 16 - 1
cmd/cmd.go

@@ -362,9 +362,24 @@ func RunHandler(cmd *cobra.Command, args []string) error {
 
 	opts.MultiModal = slices.Contains(info.Details.Families, "clip")
 	opts.ParentModel = info.Details.ParentModel
-	opts.Messages = append(opts.Messages, info.Messages...)
 
 	if interactive {
+		if err := loadModel(cmd, &opts); err != nil {
+			return err
+		}
+
+		for _, msg := range info.Messages {
+			switch msg.Role {
+			case "user":
+				fmt.Printf(">>> %s\n", msg.Content)
+			case "assistant":
+				state := &displayResponseState{}
+				displayResponse(msg.Content, opts.WordWrap, state)
+				fmt.Println()
+				fmt.Println()
+			}
+		}
+
 		return generateInteractive(cmd, opts)
 	}
 	return generate(cmd, opts)

+ 1 - 20
cmd/interactive.go

@@ -48,29 +48,10 @@ func loadModel(cmd *cobra.Command, opts *runOptions) error {
 		KeepAlive: opts.KeepAlive,
 	}
 
-	return client.Chat(cmd.Context(), chatReq, func(resp api.ChatResponse) error {
-		p.StopAndClear()
-		for _, msg := range opts.Messages {
-			switch msg.Role {
-			case "user":
-				fmt.Printf(">>> %s\n", msg.Content)
-			case "assistant":
-				state := &displayResponseState{}
-				displayResponse(msg.Content, opts.WordWrap, state)
-				fmt.Println()
-				fmt.Println()
-			}
-		}
-		return nil
-	})
+	return client.Chat(cmd.Context(), chatReq, func(api.ChatResponse) error { return nil })
 }
 
 func generateInteractive(cmd *cobra.Command, opts runOptions) error {
-	err := loadModel(cmd, &opts)
-	if err != nil {
-		return err
-	}
-
 	usage := func() {
 		fmt.Fprintln(os.Stderr, "Available Commands:")
 		fmt.Fprintln(os.Stderr, "  /set            Set session variables")

+ 1 - 6
server/images.go

@@ -70,7 +70,7 @@ type Model struct {
 	License        []string
 	Digest         string
 	Options        map[string]interface{}
-	Messages       []Message
+	Messages       []api.Message
 
 	Template *template.Template
 }
@@ -191,11 +191,6 @@ func (m *Model) String() string {
 	return modelfile.String()
 }
 
-type Message struct {
-	Role    string `json:"role"`
-	Content string `json:"content"`
-}
-
 type ConfigV2 struct {
 	ModelFormat   string   `json:"model_format"`
 	ModelFamily   string   `json:"model_family"`

+ 18 - 13
server/routes.go

@@ -164,17 +164,6 @@ func (s *Server) GenerateHandler(c *gin.Context) {
 			}
 		}
 
-		var b bytes.Buffer
-		if req.Context != nil {
-			s, err := r.Detokenize(c.Request.Context(), req.Context)
-			if err != nil {
-				c.JSON(http.StatusInternalServerError, gin.H{"error": err.Error()})
-				return
-			}
-
-			b.WriteString(s)
-		}
-
 		var values template.Values
 		if req.Suffix != "" {
 			values.Prompt = prompt
@@ -187,6 +176,10 @@ func (s *Server) GenerateHandler(c *gin.Context) {
 				msgs = append(msgs, api.Message{Role: "system", Content: m.System})
 			}
 
+			if req.Context == nil {
+				msgs = append(msgs, m.Messages...)
+			}
+
 			for _, i := range images {
 				msgs = append(msgs, api.Message{Role: "user", Content: fmt.Sprintf("[img-%d]", i.ID)})
 			}
@@ -194,11 +187,22 @@ func (s *Server) GenerateHandler(c *gin.Context) {
 			values.Messages = append(msgs, api.Message{Role: "user", Content: req.Prompt})
 		}
 
+		var b bytes.Buffer
 		if err := tmpl.Execute(&b, values); err != nil {
 			c.JSON(http.StatusInternalServerError, gin.H{"error": err.Error()})
 			return
 		}
 
+		if req.Context != nil {
+			s, err := r.Detokenize(c.Request.Context(), req.Context)
+			if err != nil {
+				c.JSON(http.StatusInternalServerError, gin.H{"error": err.Error()})
+				return
+			}
+
+			b.WriteString(s)
+		}
+
 		prompt = b.String()
 	}
 
@@ -1329,11 +1333,12 @@ func (s *Server) ChatHandler(c *gin.Context) {
 		return
 	}
 
+	msgs := append(m.Messages, req.Messages...)
 	if req.Messages[0].Role != "system" && m.System != "" {
-		req.Messages = append([]api.Message{{Role: "system", Content: m.System}}, req.Messages...)
+		msgs = append([]api.Message{{Role: "system", Content: m.System}}, msgs...)
 	}
 
-	prompt, images, err := chatPrompt(c.Request.Context(), m, r.Tokenize, opts, req.Messages, req.Tools)
+	prompt, images, err := chatPrompt(c.Request.Context(), m, r.Tokenize, opts, msgs, req.Tools)
 	if err != nil {
 		c.JSON(http.StatusInternalServerError, gin.H{"error": err.Error()})
 		return