Browse Source

Merge pull request #5207 from ollama/mxyng/suffix

add insert support to generate endpoint
Michael Yang 9 tháng trước cách đây
mục cha
commit
cd0853f2d5
6 tập tin đã thay đổi với 155 bổ sung27 xóa
  1. 3 0
      api/types.go
  2. 14 3
      server/images.go
  3. 25 15
      server/routes.go
  4. 69 8
      server/routes_generate_test.go
  5. 9 1
      template/template.go
  6. 35 0
      template/template_test.go

+ 3 - 0
api/types.go

@@ -47,6 +47,9 @@ type GenerateRequest struct {
 	// Prompt is the textual prompt to send to the model.
 	Prompt string `json:"prompt"`
 
+	// Suffix is the text that comes after the inserted text.
+	Suffix string `json:"suffix"`
+
 	// System overrides the model's default system message/prompt.
 	System string `json:"system"`
 

+ 14 - 3
server/images.go

@@ -34,13 +34,19 @@ import (
 	"github.com/ollama/ollama/version"
 )
 
-var errCapabilityCompletion = errors.New("completion")
+var (
+	errCapabilities         = errors.New("does not support")
+	errCapabilityCompletion = errors.New("completion")
+	errCapabilityTools      = errors.New("tools")
+	errCapabilityInsert     = errors.New("insert")
+)
 
 type Capability string
 
 const (
 	CapabilityCompletion = Capability("completion")
 	CapabilityTools      = Capability("tools")
+	CapabilityInsert     = Capability("insert")
 )
 
 type registryOptions struct {
@@ -93,7 +99,12 @@ func (m *Model) CheckCapabilities(caps ...Capability) error {
 			}
 		case CapabilityTools:
 			if !slices.Contains(m.Template.Vars(), "tools") {
-				errs = append(errs, errors.New("tools"))
+				errs = append(errs, errCapabilityTools)
+			}
+		case CapabilityInsert:
+			vars := m.Template.Vars()
+			if !slices.Contains(vars, "suffix") {
+				errs = append(errs, errCapabilityInsert)
 			}
 		default:
 			slog.Error("unknown capability", "capability", cap)
@@ -102,7 +113,7 @@ func (m *Model) CheckCapabilities(caps ...Capability) error {
 	}
 
 	if err := errors.Join(errs...); err != nil {
-		return fmt.Errorf("does not support %w", errors.Join(errs...))
+		return fmt.Errorf("%w %w", errCapabilities, errors.Join(errs...))
 	}
 
 	return nil

+ 25 - 15
server/routes.go

@@ -122,6 +122,10 @@ func (s *Server) GenerateHandler(c *gin.Context) {
 	}
 
 	caps := []Capability{CapabilityCompletion}
+	if req.Suffix != "" {
+		caps = append(caps, CapabilityInsert)
+	}
+
 	r, m, opts, err := s.scheduleRunner(c.Request.Context(), req.Model, caps, req.Options, req.KeepAlive)
 	if errors.Is(err, errCapabilityCompletion) {
 		c.JSON(http.StatusBadRequest, gin.H{"error": fmt.Sprintf("%q does not support generate", req.Model)})
@@ -150,19 +154,6 @@ func (s *Server) GenerateHandler(c *gin.Context) {
 
 	prompt := req.Prompt
 	if !req.Raw {
-		var msgs []api.Message
-		if req.System != "" {
-			msgs = append(msgs, api.Message{Role: "system", Content: req.System})
-		} else if m.System != "" {
-			msgs = append(msgs, api.Message{Role: "system", Content: m.System})
-		}
-
-		for _, i := range images {
-			msgs = append(msgs, api.Message{Role: "user", Content: fmt.Sprintf("[img-%d]", i.ID)})
-		}
-
-		msgs = append(msgs, api.Message{Role: "user", Content: req.Prompt})
-
 		tmpl := m.Template
 		if req.Template != "" {
 			tmpl, err = template.Parse(req.Template)
@@ -183,7 +174,26 @@ func (s *Server) GenerateHandler(c *gin.Context) {
 			b.WriteString(s)
 		}
 
-		if err := tmpl.Execute(&b, template.Values{Messages: msgs}); err != nil {
+		var values template.Values
+		if req.Suffix != "" {
+			values.Prompt = prompt
+			values.Suffix = req.Suffix
+		} else {
+			var msgs []api.Message
+			if req.System != "" {
+				msgs = append(msgs, api.Message{Role: "system", Content: req.System})
+			} else if m.System != "" {
+				msgs = append(msgs, api.Message{Role: "system", Content: m.System})
+			}
+
+			for _, i := range images {
+				msgs = append(msgs, api.Message{Role: "user", Content: fmt.Sprintf("[img-%d]", i.ID)})
+			}
+
+			values.Messages = append(msgs, api.Message{Role: "user", Content: req.Prompt})
+		}
+
+		if err := tmpl.Execute(&b, values); err != nil {
 			c.JSON(http.StatusInternalServerError, gin.H{"error": err.Error()})
 			return
 		}
@@ -1394,7 +1404,7 @@ func (s *Server) ChatHandler(c *gin.Context) {
 
 func handleScheduleError(c *gin.Context, name string, err error) {
 	switch {
-	case errors.Is(err, errRequired):
+	case errors.Is(err, errCapabilities), errors.Is(err, errRequired):
 		c.JSON(http.StatusBadRequest, gin.H{"error": err.Error()})
 	case errors.Is(err, context.Canceled):
 		c.JSON(499, gin.H{"error": "request canceled"})

+ 69 - 8
server/routes_generate_test.go

@@ -73,6 +73,8 @@ func TestGenerateChat(t *testing.T) {
 			getCpuFn:      gpu.GetCPUInfo,
 			reschedDelay:  250 * time.Millisecond,
 			loadFn: func(req *LlmRequest, ggml *llm.GGML, gpus gpu.GpuInfoList, numParallel int) {
+				// add 10ms delay to simulate loading
+				time.Sleep(10 * time.Millisecond)
 				req.successCh <- &runnerRef{
 					llama: &mock,
 				}
@@ -83,7 +85,7 @@ func TestGenerateChat(t *testing.T) {
 	go s.sched.Run(context.TODO())
 
 	w := createRequest(t, s.CreateModelHandler, api.CreateRequest{
-		Name: "test",
+		Model: "test",
 		Modelfile: fmt.Sprintf(`FROM %s
 		TEMPLATE """
 {{- if .System }}System: {{ .System }} {{ end }}
@@ -141,9 +143,9 @@ func TestGenerateChat(t *testing.T) {
 		}
 	})
 
-	t.Run("missing capabilities", func(t *testing.T) {
+	t.Run("missing capabilities chat", func(t *testing.T) {
 		w := createRequest(t, s.CreateModelHandler, api.CreateRequest{
-			Name: "bert",
+			Model: "bert",
 			Modelfile: fmt.Sprintf("FROM %s", createBinFile(t, llm.KV{
 				"general.architecture": "bert",
 				"bert.pooling_type":    uint32(0),
@@ -243,7 +245,7 @@ func TestGenerateChat(t *testing.T) {
 		}
 
 		if actual.TotalDuration == 0 {
-			t.Errorf("expected load duration > 0, got 0")
+			t.Errorf("expected total duration > 0, got 0")
 		}
 	}
 
@@ -379,7 +381,7 @@ func TestGenerate(t *testing.T) {
 	go s.sched.Run(context.TODO())
 
 	w := createRequest(t, s.CreateModelHandler, api.CreateRequest{
-		Name: "test",
+		Model: "test",
 		Modelfile: fmt.Sprintf(`FROM %s
 		TEMPLATE """
 {{- if .System }}System: {{ .System }} {{ end }}
@@ -437,9 +439,9 @@ func TestGenerate(t *testing.T) {
 		}
 	})
 
-	t.Run("missing capabilities", func(t *testing.T) {
+	t.Run("missing capabilities generate", func(t *testing.T) {
 		w := createRequest(t, s.CreateModelHandler, api.CreateRequest{
-			Name: "bert",
+			Model: "bert",
 			Modelfile: fmt.Sprintf("FROM %s", createBinFile(t, llm.KV{
 				"general.architecture": "bert",
 				"bert.pooling_type":    uint32(0),
@@ -464,6 +466,22 @@ func TestGenerate(t *testing.T) {
 		}
 	})
 
+	t.Run("missing capabilities suffix", func(t *testing.T) {
+		w := createRequest(t, s.GenerateHandler, api.GenerateRequest{
+			Model:  "test",
+			Prompt: "def add(",
+			Suffix: "    return c",
+		})
+
+		if w.Code != http.StatusBadRequest {
+			t.Errorf("expected status 400, got %d", w.Code)
+		}
+
+		if diff := cmp.Diff(w.Body.String(), `{"error":"test does not support insert"}`); diff != "" {
+			t.Errorf("mismatch (-got +want):\n%s", diff)
+		}
+	})
+
 	t.Run("load model", func(t *testing.T) {
 		w := createRequest(t, s.GenerateHandler, api.GenerateRequest{
 			Model: "test",
@@ -540,7 +558,7 @@ func TestGenerate(t *testing.T) {
 		}
 
 		if actual.TotalDuration == 0 {
-			t.Errorf("expected load duration > 0, got 0")
+			t.Errorf("expected total duration > 0, got 0")
 		}
 	}
 
@@ -632,6 +650,49 @@ func TestGenerate(t *testing.T) {
 		checkGenerateResponse(t, w.Body, "test-system", "Abra kadabra!")
 	})
 
+	w = createRequest(t, s.CreateModelHandler, api.CreateRequest{
+		Model: "test-suffix",
+		Modelfile: `FROM test
+TEMPLATE """{{- if .Suffix }}<PRE> {{ .Prompt }} <SUF>{{ .Suffix }} <MID>
+{{- else }}{{ .Prompt }}
+{{- end }}"""`,
+	})
+
+	if w.Code != http.StatusOK {
+		t.Fatalf("expected status 200, got %d", w.Code)
+	}
+
+	t.Run("prompt with suffix", func(t *testing.T) {
+		w := createRequest(t, s.GenerateHandler, api.GenerateRequest{
+			Model:  "test-suffix",
+			Prompt: "def add(",
+			Suffix: "    return c",
+		})
+
+		if w.Code != http.StatusOK {
+			t.Errorf("expected status 200, got %d", w.Code)
+		}
+
+		if diff := cmp.Diff(mock.CompletionRequest.Prompt, "<PRE> def add( <SUF>    return c <MID>"); diff != "" {
+			t.Errorf("mismatch (-got +want):\n%s", diff)
+		}
+	})
+
+	t.Run("prompt without suffix", func(t *testing.T) {
+		w := createRequest(t, s.GenerateHandler, api.GenerateRequest{
+			Model:  "test-suffix",
+			Prompt: "def add(",
+		})
+
+		if w.Code != http.StatusOK {
+			t.Errorf("expected status 200, got %d", w.Code)
+		}
+
+		if diff := cmp.Diff(mock.CompletionRequest.Prompt, "def add("); diff != "" {
+			t.Errorf("mismatch (-got +want):\n%s", diff)
+		}
+	})
+
 	t.Run("raw", func(t *testing.T) {
 		w := createRequest(t, s.GenerateHandler, api.GenerateRequest{
 			Model:  "test-system",

+ 9 - 1
template/template.go

@@ -151,6 +151,8 @@ func (t *Template) Vars() []string {
 type Values struct {
 	Messages []api.Message
 	Tools    []api.Tool
+	Prompt   string
+	Suffix   string
 
 	// forceLegacy is a flag used to test compatibility with legacy templates
 	forceLegacy bool
@@ -204,7 +206,13 @@ func (t *Template) Subtree(fn func(parse.Node) bool) *template.Template {
 
 func (t *Template) Execute(w io.Writer, v Values) error {
 	system, messages := collate(v.Messages)
-	if !v.forceLegacy && slices.Contains(t.Vars(), "messages") {
+	if v.Prompt != "" && v.Suffix != "" {
+		return t.Template.Execute(w, map[string]any{
+			"Prompt":   v.Prompt,
+			"Suffix":   v.Suffix,
+			"Response": "",
+		})
+	} else if !v.forceLegacy && slices.Contains(t.Vars(), "messages") {
 		return t.Template.Execute(w, map[string]any{
 			"System":   system,
 			"Messages": messages,

+ 35 - 0
template/template_test.go

@@ -359,3 +359,38 @@ Answer: `,
 		})
 	}
 }
+
+func TestExecuteWithSuffix(t *testing.T) {
+	tmpl, err := Parse(`{{- if .Suffix }}<PRE> {{ .Prompt }} <SUF>{{ .Suffix }} <MID>
+{{- else }}{{ .Prompt }}
+{{- end }}`)
+	if err != nil {
+		t.Fatal(err)
+	}
+
+	cases := []struct {
+		name   string
+		values Values
+		expect string
+	}{
+		{
+			"message", Values{Messages: []api.Message{{Role: "user", Content: "hello"}}}, "hello",
+		},
+		{
+			"prompt suffix", Values{Prompt: "def add(", Suffix: "return x"}, "<PRE> def add( <SUF>return x <MID>",
+		},
+	}
+
+	for _, tt := range cases {
+		t.Run(tt.name, func(t *testing.T) {
+			var b bytes.Buffer
+			if err := tmpl.Execute(&b, tt.values); err != nil {
+				t.Fatal(err)
+			}
+
+			if diff := cmp.Diff(b.String(), tt.expect); diff != "" {
+				t.Errorf("mismatch (-got +want):\n%s", diff)
+			}
+		})
+	}
+}