Michael Yang 10 hónapja
szülő
commit
2c3fe1fd97
5 módosított fájl, 223 hozzáadás és 112 törlés
  1. 19 10
      server/prompt.go
  2. 12 22
      server/prompt_test.go
  3. 25 21
      server/routes.go
  4. 25 21
      template/template.go
  5. 142 38
      template/template_test.go

+ 19 - 10
server/prompt.go

@@ -11,8 +11,13 @@ import (
 	"github.com/ollama/ollama/template"
 )
 
-func chatPrompt(ctx context.Context, r *runnerRef, msgs []api.Message) (prompt string, images []llm.ImageData, _ error) {
-	// extract system messages which should always be included
+type tokenizeFunc func(context.Context, string) ([]int, error)
+
+// chatPrompt accepts a list of messages and returns the prompt and images that should be used for the next chat turn.
+// chatPrompt truncates any messages that exceed the context window of the model, making sure to always include 1) the
+// latest message and 2) system messages
+func chatPrompt(ctx context.Context, m *Model, tokenize tokenizeFunc, opts *api.Options, msgs []api.Message) (prompt string, images []llm.ImageData, _ error) {
+	// pull out any system messages which should always be included in the prompt
 	var system []api.Message
 	msgs = slices.DeleteFunc(msgs, func(m api.Message) bool {
 		if m.Role == "system" {
@@ -23,32 +28,35 @@ func chatPrompt(ctx context.Context, r *runnerRef, msgs []api.Message) (prompt s
 		return false
 	})
 
-	if len(system) == 0 && r.model.System != "" {
+	if len(system) == 0 && m.System != "" {
 		// add model system prompt since it wasn't provided
-		system = append(system, api.Message{Role: "system", Content: r.model.System})
+		system = append(system, api.Message{Role: "system", Content: m.System})
 	}
 
+	// always include the last message
 	n := len(msgs) - 1
+	// in reverse, find all messages that fit into context window
 	for i := n - 1; i >= 0; i-- {
 		var b bytes.Buffer
-		if err := r.model.Template.Execute(&b, template.Values{Messages: append(system, msgs[i:]...)}); err != nil {
+		if err := m.Template.Execute(&b, template.Values{Messages: append(system, msgs[i:]...)}); err != nil {
 			return "", nil, err
 		}
 
-		s, err := r.llama.Tokenize(ctx, b.String())
+		s, err := tokenize(ctx, b.String())
 		if err != nil {
 			return "", nil, err
 		}
 
 		c := len(s)
-		if r.model.ProjectorPaths != nil {
+		if m.ProjectorPaths != nil {
 			for _, m := range msgs[i:] {
-				// TODO: get image embedding length from project metadata
+				// images are represented as 768 sized embeddings
+				// TODO: get embedding length from project metadata
 				c += 768 * len(m.Images)
 			}
 		}
 
-		if c > r.NumCtx {
+		if c > opts.NumCtx {
 			slog.Debug("truncating input messages which exceed context length", "truncated", len(msgs[i:]))
 			break
 		} else {
@@ -56,8 +64,9 @@ func chatPrompt(ctx context.Context, r *runnerRef, msgs []api.Message) (prompt s
 		}
 	}
 
+	// truncate any messages that do not fit into the context window
 	var b bytes.Buffer
-	if err := r.model.Template.Execute(&b, template.Values{Messages: append(system, msgs[n:]...)}); err != nil {
+	if err := m.Template.Execute(&b, template.Values{Messages: append(system, msgs[n:]...)}); err != nil {
 		return "", nil, err
 	}
 

+ 12 - 22
server/prompt_test.go

@@ -7,15 +7,10 @@ import (
 	"testing"
 
 	"github.com/ollama/ollama/api"
-	"github.com/ollama/ollama/llm"
 	"github.com/ollama/ollama/template"
 )
 
-type mock struct {
-	llm.LlamaServer
-}
-
-func (m mock) Tokenize(_ context.Context, s string) (tokens []int, err error) {
+func tokenize(_ context.Context, s string) (tokens []int, err error) {
 	for range strings.Fields(s) {
 		tokens = append(tokens, len(tokens))
 	}
@@ -48,7 +43,7 @@ func TestChatPrompt(t *testing.T) {
 			},
 		},
 		{
-			name: "truncate messages",
+			name:  "truncate messages",
 			limit: 1,
 			msgs: []api.Message{
 				{Role: "user", Content: "You're a test, Harry!"},
@@ -60,7 +55,7 @@ func TestChatPrompt(t *testing.T) {
 			},
 		},
 		{
-			name: "truncate messages with image",
+			name:  "truncate messages with image",
 			limit: 64,
 			msgs: []api.Message{
 				{Role: "user", Content: "You're a test, Harry!"},
@@ -75,7 +70,7 @@ func TestChatPrompt(t *testing.T) {
 			},
 		},
 		{
-			name: "truncate messages with images",
+			name:  "truncate messages with images",
 			limit: 64,
 			msgs: []api.Message{
 				{Role: "user", Content: "You're a test, Harry!", Images: []api.ImageData{[]byte("something")}},
@@ -90,7 +85,7 @@ func TestChatPrompt(t *testing.T) {
 			},
 		},
 		{
-			name: "messages with images",
+			name:  "messages with images",
 			limit: 2048,
 			msgs: []api.Message{
 				{Role: "user", Content: "You're a test, Harry!", Images: []api.ImageData{[]byte("something")}},
@@ -106,7 +101,7 @@ func TestChatPrompt(t *testing.T) {
 			},
 		},
 		{
-			name: "message with image tag",
+			name:  "message with image tag",
 			limit: 2048,
 			msgs: []api.Message{
 				{Role: "user", Content: "You're a test, Harry! [img]", Images: []api.ImageData{[]byte("something")}},
@@ -122,7 +117,7 @@ func TestChatPrompt(t *testing.T) {
 			},
 		},
 		{
-			name: "messages with interleaved images",
+			name:  "messages with interleaved images",
 			limit: 2048,
 			msgs: []api.Message{
 				{Role: "user", Content: "You're a test, Harry!"},
@@ -140,7 +135,7 @@ func TestChatPrompt(t *testing.T) {
 			},
 		},
 		{
-			name: "truncate message with interleaved images",
+			name:  "truncate message with interleaved images",
 			limit: 1024,
 			msgs: []api.Message{
 				{Role: "user", Content: "You're a test, Harry!"},
@@ -157,7 +152,7 @@ func TestChatPrompt(t *testing.T) {
 			},
 		},
 		{
-			name: "message with system prompt",
+			name:  "message with system prompt",
 			limit: 2048,
 			msgs: []api.Message{
 				{Role: "system", Content: "You are the Test Who Lived."},
@@ -181,14 +176,9 @@ func TestChatPrompt(t *testing.T) {
 
 	for _, tt := range cases {
 		t.Run(tt.name, func(t *testing.T) {
-			r := runnerRef{
-				llama:   mock{},
-				model:   &Model{Template: tmpl, ProjectorPaths: []string{"vision"}},
-				Options: &api.Options{},
-			}
-
-			r.NumCtx = tt.limit
-			prompt, images, err := chatPrompt(context.TODO(), &r, tt.msgs)
+			model := Model{Template: tmpl, ProjectorPaths: []string{"vision"}}
+			opts := api.Options{Runner: api.Runner{NumCtx: tt.limit}}
+			prompt, images, err := chatPrompt(context.TODO(), &model, tokenize, &opts, tt.msgs)
 			if err != nil {
 				t.Fatal(err)
 			}

+ 25 - 21
server/routes.go

@@ -54,6 +54,8 @@ func init() {
 	gin.SetMode(mode)
 }
 
+var errRequired = errors.New("is required")
+
 func modelOptions(model *Model, requestOpts map[string]interface{}) (api.Options, error) {
 	opts := api.DefaultOptions()
 	if err := opts.FromMap(model.Options); err != nil {
@@ -69,7 +71,7 @@ func modelOptions(model *Model, requestOpts map[string]interface{}) (api.Options
 
 func (s *Server) scheduleRunner(ctx context.Context, name string, caps []Capability, requestOpts map[string]any, keepAlive *api.Duration) (*runnerRef, error) {
 	if name == "" {
-		return nil, errors.New("model is required")
+		return nil, fmt.Errorf("model %w", errRequired)
 	}
 
 	model, err := GetModel(name)
@@ -121,7 +123,17 @@ func (s *Server) GenerateHandler(c *gin.Context) {
 		c.JSON(http.StatusBadRequest, gin.H{"error": fmt.Sprintf("%q does not support generate", req.Model)})
 		return
 	} else if err != nil {
-		handleScheduleError(c, err)
+		handleScheduleError(c, req.Model, err)
+		return
+	}
+
+	if req.Prompt == "" {
+		c.JSON(http.StatusOK, api.GenerateResponse{
+			Model:      req.Model,
+			CreatedAt:  time.Now().UTC(),
+			Done:       true,
+			DoneReason: "load",
+		})
 		return
 	}
 
@@ -139,23 +151,11 @@ func (s *Server) GenerateHandler(c *gin.Context) {
 			msgs = append(msgs, api.Message{Role: "system", Content: r.model.System})
 		}
 
-		if req.Prompt != "" {
-			for _, i := range images {
-				msgs = append(msgs, api.Message{Role: "user", Content: fmt.Sprintf("[img-%d]", i.ID)})
-			}
-
-			msgs = append(msgs, api.Message{Role: "user", Content: req.Prompt})
+		for _, i := range images {
+			msgs = append(msgs, api.Message{Role: "user", Content: fmt.Sprintf("[img-%d]", i.ID)})
 		}
 
-		if len(msgs) == 0 {
-			c.JSON(http.StatusOK, api.GenerateResponse{
-				Model:      req.Model,
-				CreatedAt:  time.Now().UTC(),
-				Done:       true,
-				DoneReason: "load",
-			})
-			return
-		}
+		msgs = append(msgs, api.Message{Role: "user", Content: req.Prompt})
 
 		tmpl := r.model.Template
 		if req.Template != "" {
@@ -256,7 +256,7 @@ func (s *Server) EmbeddingsHandler(c *gin.Context) {
 
 	r, err := s.scheduleRunner(c.Request.Context(), req.Model, []Capability{}, req.Options, req.KeepAlive)
 	if err != nil {
-		handleScheduleError(c, err)
+		handleScheduleError(c, req.Model, err)
 		return
 	}
 
@@ -1135,7 +1135,7 @@ func (s *Server) ChatHandler(c *gin.Context) {
 		c.JSON(http.StatusBadRequest, gin.H{"error": fmt.Sprintf("%q does not support chat", req.Model)})
 		return
 	} else if err != nil {
-		handleScheduleError(c, err)
+		handleScheduleError(c, req.Model, err)
 		return
 	}
 
@@ -1150,7 +1150,7 @@ func (s *Server) ChatHandler(c *gin.Context) {
 		return
 	}
 
-	prompt, images, err := chatPrompt(c.Request.Context(), r, req.Messages)
+	prompt, images, err := chatPrompt(c.Request.Context(), r.model, r.llama.Tokenize, r.Options, req.Messages)
 	if err != nil {
 		c.JSON(http.StatusInternalServerError, gin.H{"error": err.Error()})
 		return
@@ -1215,12 +1215,16 @@ func (s *Server) ChatHandler(c *gin.Context) {
 	streamResponse(c, ch)
 }
 
-func handleScheduleError(c *gin.Context, err error) {
+func handleScheduleError(c *gin.Context, name string, err error) {
 	switch {
+	case errors.Is(err, errRequired):
+		c.JSON(http.StatusBadRequest, gin.H{"error": err.Error()})
 	case errors.Is(err, context.Canceled):
 		c.JSON(499, gin.H{"error": "request canceled"})
 	case errors.Is(err, ErrMaxQueue):
 		c.JSON(http.StatusServiceUnavailable, gin.H{"error": err.Error()})
+	case errors.Is(err, os.ErrNotExist):
+		c.JSON(http.StatusNotFound, gin.H{"error": fmt.Sprintf("model %q not found, try pulling it first", name)})
 	default:
 		c.JSON(http.StatusInternalServerError, gin.H{"error": err.Error()})
 	}

+ 25 - 21
template/template.go

@@ -83,6 +83,7 @@ type Template struct {
 	raw string
 }
 
+// response is a template node that can be added to templates that don't already have one
 var response = parse.ActionNode{
 	NodeType: parse.NodeAction,
 	Pipe: &parse.PipeNode{
@@ -101,28 +102,25 @@ var response = parse.ActionNode{
 	},
 }
 
-func Parse(s string) (*Template, error) {
-	tmpl := template.New("").Option("missingkey=zero").Funcs(template.FuncMap{
-		"toJson": func(v any) string {
-			b, err := json.Marshal(v)
-			if err != nil {
-				return ""
-			}
-
-			return string(b)
-		},
-		"isLastMessage": func(s []*api.Message, m *api.Message) bool {
-			for i := len(s) - 1; i >= 0; i-- {
-				if m.Role != s[i].Role {
-					continue
-				}
+var funcs = template.FuncMap{
+	"toJson": func(v any) string {
+		b, err := json.Marshal(v)
+		if err != nil {
+			return ""
+		}
 
-				return m == s[i]
-			}
+		return string(b)
+	},
+	"add": func(a, b int) int {
+		return a + b
+	},
+	"sub": func(a, b int) int {
+		return a - b
+	},
+}
 
-			return false
-		},
-	})
+func Parse(s string) (*Template, error) {
+	tmpl := template.New("").Option("missingkey=zero").Funcs(funcs)
 
 	tmpl, err := tmpl.Parse(s)
 	if err != nil {
@@ -218,7 +216,13 @@ func (t *Template) Execute(w io.Writer, v Values) error {
 	return err
 }
 
-func collate(msgs []api.Message) (system string, collated []*api.Message) {
+type messages []*api.Message
+
+// collate messages based on role. consecutive messages of the same role are merged
+// into a single message. collate also pulls out and merges messages with Role == "system"
+// which are templated separately. As a side effect, it mangles message content adding image
+// tags ([img-%d]) as needed
+func collate(msgs []api.Message) (system string, collated messages) {
 	var n int
 	for i := range msgs {
 		msg := msgs[i]

+ 142 - 38
template/template_test.go

@@ -8,6 +8,7 @@ import (
 	"os"
 	"path/filepath"
 	"slices"
+	"strconv"
 	"testing"
 	"text/template"
 
@@ -15,6 +16,98 @@ import (
 	"github.com/ollama/ollama/llm"
 )
 
+func TestFuncs(t *testing.T) {
+	t.Run("toJson", func(t *testing.T) {
+		cases := []struct {
+			input    any
+			expected string
+		}{
+			{nil, "null"},
+			{true, "true"},
+			{false, "false"},
+			{0, "0"},
+			{1, "1"},
+			{1.0, "1"},
+			{1.1, "1.1"},
+			{"", `""`},
+			{"hello", `"hello"`},
+			{[]int{1, 2, 3}, "[1,2,3]"},
+			{[]string{"a", "b", "c"}, `["a","b","c"]`},
+			{map[string]int{"a": 1, "b": 2}, `{"a":1,"b":2}`},
+			{map[string]string{"a": "b", "c": "d"}, `{"a":"b","c":"d"}`},
+		}
+
+		for _, tt := range cases {
+			t.Run(tt.expected, func(t *testing.T) {
+				toJson, ok := funcs["toJson"].(func(any) string)
+				if !ok {
+					t.Fatal("toJson is not a function")
+				}
+
+				if s := toJson(tt.input); s != tt.expected {
+					t.Errorf("expected %q, got %q", tt.expected, s)
+				}
+			})
+		}
+	})
+
+	t.Run("add", func(t *testing.T) {
+		cases := []struct {
+			a, b     int
+			expected int
+		}{
+			{0, 0, 0},
+			{0, 1, 1},
+			{1, 0, 1},
+			{1, 1, 2},
+			{1, -1, 0},
+			{-1, 1, 0},
+			{-1, -1, -2},
+		}
+
+		for _, tt := range cases {
+			t.Run(strconv.Itoa(tt.expected), func(t *testing.T) {
+				add, ok := funcs["add"].(func(int, int) int)
+				if !ok {
+					t.Fatal("add is not a function")
+				}
+
+				if n := add(tt.a, tt.b); n != tt.expected {
+					t.Errorf("expected %d, got %d", tt.expected, n)
+				}
+			})
+		}
+	})
+
+	t.Run("sub", func(t *testing.T) {
+		cases := []struct {
+			a, b     int
+			expected int
+		}{
+			{0, 0, 0},
+			{0, 1, -1},
+			{1, 0, 1},
+			{1, 1, 0},
+			{1, -1, 2},
+			{-1, 1, -2},
+			{-1, -1, 0},
+		}
+
+		for _, tt := range cases {
+			t.Run(strconv.Itoa(tt.expected), func(t *testing.T) {
+				sub, ok := funcs["sub"].(func(int, int) int)
+				if !ok {
+					t.Fatal("sub is not a function")
+				}
+
+				if n := sub(tt.a, tt.b); n != tt.expected {
+					t.Errorf("expected %d, got %d", tt.expected, n)
+				}
+			})
+		}
+	})
+}
+
 func TestNamed(t *testing.T) {
 	f, err := os.Open(filepath.Join("testdata", "templates.jsonl"))
 	if err != nil {
@@ -89,77 +182,86 @@ func TestParse(t *testing.T) {
 }
 
 func TestExecuteWithMessages(t *testing.T) {
+	type template struct {
+		name     string
+		template string
+	}
 	cases := []struct {
-		templates []string
+		name      string
+		templates []template
 		values    Values
 		expected  string
 	}{
 		{
-			[]string{
-				`[INST] {{ if .System }}{{ .System }}{{ print "\n\n" }}{{ end }}{{ .Prompt }}[/INST] `,
-				`[INST] {{ if .System }}{{ .System }}{{ print "\n\n" }}{{ end }}{{ .Prompt }}[/INST] {{ .Response }}`,
-				`{{- range .Messages }}
-{{- if eq .Role "user" }}[INST] {{ if and (isLastMessage $.Messages .) $.System }}{{ $.System }}{{ print "\n\n" }}
+			"mistral",
+			[]template{
+				{"no response", `[INST] {{ if .System }}{{ .System }}{{ "\n\n" }}{{ end }}{{ .Prompt }}[/INST] `},
+				{"response", `[INST] {{ if .System }}{{ .System }}{{ "\n\n" }}{{ end }}{{ .Prompt }}[/INST] {{ .Response }}`},
+				{"messages", `{{- range .Messages }}
+{{- if eq .Role "user" }}[INST] {{ if and (eq (index $.Messages (sub (len $.Messages) 1)) .) $.System }}{{ $.System }}{{ "\n\n" }}
 {{- end }}{{ .Content }}[/INST] {{ else if eq .Role "assistant" }}{{ .Content }}
 {{- end }}
-{{- end }}`,
+{{- end }}`},
 			},
 			Values{
 				Messages: []api.Message{
 					{Role: "user", Content: "Hello friend!"},
 					{Role: "assistant", Content: "Hello human!"},
-					{Role: "user", Content: "Yay!"},
+					{Role: "user", Content: "What is your name?"},
 				},
 			},
-			`[INST] Hello friend![/INST] Hello human![INST] Yay![/INST] `,
+			`[INST] Hello friend![/INST] Hello human![INST] What is your name?[/INST] `,
 		},
 		{
-			[]string{
-				`[INST] {{ if .System }}{{ .System }}{{ print "\n\n" }}{{ end }}{{ .Prompt }}[/INST] `,
-				`[INST] {{ if .System }}{{ .System }}{{ print "\n\n" }}{{ end }}{{ .Prompt }}[/INST] {{ .Response }}`,
-				`
+			"mistral system",
+			[]template{
+				{"no response", `[INST] {{ if .System }}{{ .System }}{{ "\n\n" }}{{ end }}{{ .Prompt }}[/INST] `},
+				{"response", `[INST] {{ if .System }}{{ .System }}{{ "\n\n" }}{{ end }}{{ .Prompt }}[/INST] {{ .Response }}`},
+				{"messages", `
 {{- range .Messages }}
-{{- if eq .Role "user" }}[INST] {{ if and (isLastMessage $.Messages .) $.System }}{{ $.System }}{{ print "\n\n" }}
+{{- if eq .Role "user" }}[INST] {{ if and (eq (index $.Messages (sub (len $.Messages) 1)) .) $.System }}{{ $.System }}{{ "\n\n" }}
 {{- end }}{{ .Content }}[/INST] {{ else if eq .Role "assistant" }}{{ .Content }}
 {{- end }}
-{{- end }}`,
+{{- end }}`},
 			},
 			Values{
 				Messages: []api.Message{
 					{Role: "system", Content: "You are a helpful assistant!"},
 					{Role: "user", Content: "Hello friend!"},
 					{Role: "assistant", Content: "Hello human!"},
-					{Role: "user", Content: "Yay!"},
+					{Role: "user", Content: "What is your name?"},
 				},
 			},
 			`[INST] Hello friend![/INST] Hello human![INST] You are a helpful assistant!
 
-Yay![/INST] `,
+What is your name?[/INST] `,
 		},
 		{
-			[]string{
-				`{{ if .System }}<|im_start|>system
+			"chatml",
+			[]template{
+				// this does not have a "no response" test because it's impossible to render the same output
+				{"response", `{{ if .System }}<|im_start|>system
 {{ .System }}<|im_end|>
 {{ end }}{{ if .Prompt }}<|im_start|>user
 {{ .Prompt }}<|im_end|>
 {{ end }}<|im_start|>assistant
 {{ .Response }}<|im_end|>
-`,
-				`
+`},
+				{"messages", `
 {{- range .Messages }}
-{{- if and (eq .Role "user") (isLastMessage $.Messages .) $.System }}<|im_start|>system
-{{ $.System }}<|im_end|>{{ print "\n" }}
+{{- if and (eq .Role "user") (eq (index $.Messages (sub (len $.Messages) 1)) .) $.System }}<|im_start|>system
+{{ $.System }}<|im_end|>{{ "\n" }}
 {{- end }}<|im_start|>{{ .Role }}
-{{ .Content }}<|im_end|>{{ print "\n" }}
+{{ .Content }}<|im_end|>{{ "\n" }}
 {{- end }}<|im_start|>assistant
-`,
+`},
 			},
 			Values{
 				Messages: []api.Message{
 					{Role: "system", Content: "You are a helpful assistant!"},
 					{Role: "user", Content: "Hello friend!"},
 					{Role: "assistant", Content: "Hello human!"},
-					{Role: "user", Content: "Yay!"},
+					{Role: "user", Content: "What is your name?"},
 				},
 			},
 			`<|im_start|>user
@@ -169,23 +271,25 @@ Hello human!<|im_end|>
 <|im_start|>system
 You are a helpful assistant!<|im_end|>
 <|im_start|>user
-Yay!<|im_end|>
+What is your name?<|im_end|>
 <|im_start|>assistant
 `,
 		},
 		{
-			[]string{
-				`{{ if .Prompt }}Question: {{ .Prompt }}
+			"moondream",
+			[]template{
+				// this does not have a "no response" test because it's impossible to render the same output
+				{"response", `{{ if .Prompt }}Question: {{ .Prompt }}
 
 {{ end }}Answer: {{ .Response }}
 
-`,
-				`
+`},
+				{"messages", `
 {{- range .Messages }}
-{{- if eq .Role "user" }}Question: {{ .Content }}{{ print "\n\n" }}
-{{- else if eq .Role "assistant" }}Answer: {{ .Content }}{{ print "\n\n" }}
+{{- if eq .Role "user" }}Question: {{ .Content }}{{ "\n\n" }}
+{{- else if eq .Role "assistant" }}Answer: {{ .Content }}{{ "\n\n" }}
 {{- end }}
-{{- end }}Answer: `,
+{{- end }}Answer: `},
 			},
 			Values{
 				Messages: []api.Message{
@@ -211,10 +315,10 @@ Answer: `,
 	}
 
 	for _, tt := range cases {
-		t.Run("", func(t *testing.T) {
-			for _, tmpl := range tt.templates {
-				t.Run("", func(t *testing.T) {
-					tmpl, err := Parse(tmpl)
+		t.Run(tt.name, func(t *testing.T) {
+			for _, ttt := range tt.templates {
+				t.Run(ttt.name, func(t *testing.T) {
+					tmpl, err := Parse(ttt.template)
 					if err != nil {
 						t.Fatal(err)
 					}