Browse Source

api: enable tool streaming (#7836)

Parth Sareen 5 months ago
parent
commit
ce7455a8e1
4 changed files with 289 additions and 13 deletions
  1. 9 4
      openai/openai.go
  2. 1 0
      server/model_test.go
  3. 31 1
      server/routes.go
  4. 248 8
      server/routes_generate_test.go

+ 9 - 4
openai/openai.go

@@ -200,9 +200,9 @@ func toolCallId() string {
 	return "call_" + strings.ToLower(string(b))
 }
 
-func toChatCompletion(id string, r api.ChatResponse) ChatCompletion {
-	toolCalls := make([]ToolCall, len(r.Message.ToolCalls))
-	for i, tc := range r.Message.ToolCalls {
+func toToolCalls(tc []api.ToolCall) []ToolCall {
+	toolCalls := make([]ToolCall, len(tc))
+	for i, tc := range tc {
 		toolCalls[i].ID = toolCallId()
 		toolCalls[i].Type = "function"
 		toolCalls[i].Function.Name = tc.Function.Name
@@ -215,7 +215,11 @@ func toChatCompletion(id string, r api.ChatResponse) ChatCompletion {
 
 		toolCalls[i].Function.Arguments = string(args)
 	}
+	return toolCalls
+}
 
+func toChatCompletion(id string, r api.ChatResponse) ChatCompletion {
+	toolCalls := toToolCalls(r.Message.ToolCalls)
 	return ChatCompletion{
 		Id:                id,
 		Object:            "chat.completion",
@@ -244,6 +248,7 @@ func toChatCompletion(id string, r api.ChatResponse) ChatCompletion {
 }
 
 func toChunk(id string, r api.ChatResponse) ChatCompletionChunk {
+	toolCalls := toToolCalls(r.Message.ToolCalls)
 	return ChatCompletionChunk{
 		Id:                id,
 		Object:            "chat.completion.chunk",
@@ -252,7 +257,7 @@ func toChunk(id string, r api.ChatResponse) ChatCompletionChunk {
 		SystemFingerprint: "fp_ollama",
 		Choices: []ChunkChoice{{
 			Index: 0,
-			Delta: Message{Role: "assistant", Content: r.Message.Content},
+			Delta: Message{Role: "assistant", Content: r.Message.Content, ToolCalls: toolCalls},
 			FinishReason: func(reason string) *string {
 				if len(reason) > 0 {
 					return &reason

+ 1 - 0
server/model_test.go

@@ -39,6 +39,7 @@ func TestExecuteWithTools(t *testing.T) {
 		{"mistral", `[TOOL_CALLS]  [{"name": "get_current_weather", "arguments": {"format":"fahrenheit","location":"San Francisco, CA"}},{"name": "get_current_weather", "arguments": {"format":"celsius","location":"Toronto, Canada"}}]
 
 The temperature in San Francisco, CA is 70°F and in Toronto, Canada is 20°C.`, true},
+		{"mistral", `[TOOL_CALLS]  [{"name": "get_current_weather", "arguments": {"format":"fahrenheit","location":"San Francisco, CA"}},{"name": "get_current_weather", "arguments": {"format":"celsius","location":"Toronto, Canada"}},{"name": "get_current_weather", "arguments": {"format":"celsius","location":"To }]`, false},
 		{"mistral", `I'm not aware of that information. However, I can suggest searching for the weather using the "get_current_weather" function:
 
 		[{"name": "get_current_weather", "arguments": {"format":"fahrenheit","location":"San Francisco, CA"}},{"name": "get_current_weather", "arguments": {"format":"celsius","location":"Toronto, Canada"}}]`, true},

+ 31 - 1
server/routes.go

@@ -1458,6 +1458,7 @@ func (s *Server) ChatHandler(c *gin.Context) {
 
 	prompt, images, err := chatPrompt(c.Request.Context(), m, r.Tokenize, opts, msgs, req.Tools)
 	if err != nil {
+		slog.Error("chat prompt error", "error", err)
 		c.JSON(http.StatusInternalServerError, gin.H{"error": err.Error()})
 		return
 	}
@@ -1467,6 +1468,8 @@ func (s *Server) ChatHandler(c *gin.Context) {
 	ch := make(chan any)
 	go func() {
 		defer close(ch)
+		var sb strings.Builder
+		var hasToolCalls bool
 		if err := r.Completion(c.Request.Context(), llm.CompletionRequest{
 			Prompt:  prompt,
 			Images:  images,
@@ -1492,7 +1495,34 @@ func (s *Server) ChatHandler(c *gin.Context) {
 				res.LoadDuration = checkpointLoaded.Sub(checkpointStart)
 			}
 
-			ch <- res
+			// TODO: tool call checking and filtering should be moved outside of this callback once streaming
+			// however this was a simple change for now without reworking streaming logic of this (and other)
+			// handlers
+			if req.Stream != nil && !*req.Stream || len(req.Tools) == 0 {
+				ch <- res
+				return
+			}
+
+			// Streaming tool calls:
+			// If tools are recognized, use a flag to track the sending of a tool downstream
+			// This ensures that content is cleared from the message on the last chunk sent
+			sb.WriteString(r.Content)
+			if toolCalls, ok := m.parseToolCalls(sb.String()); ok {
+				res.Message.ToolCalls = toolCalls
+				res.Message.Content = ""
+				sb.Reset()
+				hasToolCalls = true
+				ch <- res
+				return
+			}
+
+			if r.Done {
+				// Send any remaining content if no tool calls were detected
+				if !hasToolCalls {
+					res.Message.Content = sb.String()
+				}
+				ch <- res
+			}
 		}); err != nil {
 			ch <- gin.H{"error": err.Error()}
 		}

+ 248 - 8
server/routes_generate_test.go

@@ -8,6 +8,7 @@ import (
 	"io"
 	"net/http"
 	"strings"
+	"sync"
 	"testing"
 	"time"
 
@@ -25,10 +26,14 @@ type mockRunner struct {
 	// CompletionRequest is only valid until the next call to Completion
 	llm.CompletionRequest
 	llm.CompletionResponse
+	CompletionFn func(context.Context, llm.CompletionRequest, func(llm.CompletionResponse)) error
 }
 
-func (m *mockRunner) Completion(_ context.Context, r llm.CompletionRequest, fn func(r llm.CompletionResponse)) error {
+func (m *mockRunner) Completion(ctx context.Context, r llm.CompletionRequest, fn func(r llm.CompletionResponse)) error {
 	m.CompletionRequest = r
+	if m.CompletionFn != nil {
+		return m.CompletionFn(ctx, r, fn)
+	}
 	fn(m.CompletionResponse)
 	return nil
 }
@@ -88,9 +93,14 @@ func TestGenerateChat(t *testing.T) {
 		Model: "test",
 		Modelfile: fmt.Sprintf(`FROM %s
 		TEMPLATE """
-{{- if .System }}System: {{ .System }} {{ end }}
-{{- if .Prompt }}User: {{ .Prompt }} {{ end }}
-{{- if .Response }}Assistant: {{ .Response }} {{ end }}"""
+{{- if .Tools }}
+{{ .Tools }}
+{{ end }}
+{{- range .Messages }}
+{{- .Role }}: {{ .Content }}
+{{- range .ToolCalls }}{"name": "{{ .Function.Name }}", "arguments": {{ .Function.Arguments }}}
+{{- end }}
+{{ end }}"""
 `, createBinFile(t, llm.KV{
 			"general.architecture":          "llama",
 			"llama.block_count":             uint32(1),
@@ -263,7 +273,7 @@ func TestGenerateChat(t *testing.T) {
 			t.Errorf("expected status 200, got %d", w.Code)
 		}
 
-		if diff := cmp.Diff(mock.CompletionRequest.Prompt, "User: Hello! "); diff != "" {
+		if diff := cmp.Diff(mock.CompletionRequest.Prompt, "user: Hello!\n"); diff != "" {
 			t.Errorf("mismatch (-got +want):\n%s", diff)
 		}
 
@@ -292,7 +302,7 @@ func TestGenerateChat(t *testing.T) {
 			t.Errorf("expected status 200, got %d", w.Code)
 		}
 
-		if diff := cmp.Diff(mock.CompletionRequest.Prompt, "System: You are a helpful assistant. User: Hello! "); diff != "" {
+		if diff := cmp.Diff(mock.CompletionRequest.Prompt, "system: You are a helpful assistant.\nuser: Hello!\n"); diff != "" {
 			t.Errorf("mismatch (-got +want):\n%s", diff)
 		}
 
@@ -314,7 +324,7 @@ func TestGenerateChat(t *testing.T) {
 			t.Errorf("expected status 200, got %d", w.Code)
 		}
 
-		if diff := cmp.Diff(mock.CompletionRequest.Prompt, "System: You can perform magic tricks. User: Hello! "); diff != "" {
+		if diff := cmp.Diff(mock.CompletionRequest.Prompt, "system: You can perform magic tricks.\nuser: Hello!\n"); diff != "" {
 			t.Errorf("mismatch (-got +want):\n%s", diff)
 		}
 
@@ -337,12 +347,242 @@ func TestGenerateChat(t *testing.T) {
 			t.Errorf("expected status 200, got %d", w.Code)
 		}
 
-		if diff := cmp.Diff(mock.CompletionRequest.Prompt, "System: You are a helpful assistant. User: Hello! Assistant: I can help you with that. System: You can perform magic tricks. User: Help me write tests. "); diff != "" {
+		if diff := cmp.Diff(mock.CompletionRequest.Prompt, "system: You are a helpful assistant.\nuser: Hello!\nassistant: I can help you with that.\nsystem: You can perform magic tricks.\nuser: Help me write tests.\n"); diff != "" {
 			t.Errorf("mismatch (-got +want):\n%s", diff)
 		}
 
 		checkChatResponse(t, w.Body, "test-system", "Abra kadabra!")
 	})
+
+	t.Run("messages with tools (non-streaming)", func(t *testing.T) {
+		if w.Code != http.StatusOK {
+			t.Fatalf("failed to create test-system model: %d", w.Code)
+		}
+
+		tools := []api.Tool{
+			{
+				Type: "function",
+				Function: api.ToolFunction{
+					Name:        "get_weather",
+					Description: "Get the current weather",
+					Parameters: struct {
+						Type       string   `json:"type"`
+						Required   []string `json:"required"`
+						Properties map[string]struct {
+							Type        string   `json:"type"`
+							Description string   `json:"description"`
+							Enum        []string `json:"enum,omitempty"`
+						} `json:"properties"`
+					}{
+						Type:     "object",
+						Required: []string{"location"},
+						Properties: map[string]struct {
+							Type        string   `json:"type"`
+							Description string   `json:"description"`
+							Enum        []string `json:"enum,omitempty"`
+						}{
+							"location": {
+								Type:        "string",
+								Description: "The city and state",
+							},
+							"unit": {
+								Type: "string",
+								Enum: []string{"celsius", "fahrenheit"},
+							},
+						},
+					},
+				},
+			},
+		}
+
+		mock.CompletionResponse = llm.CompletionResponse{
+			Content:            `{"name":"get_weather","arguments":{"location":"Seattle, WA","unit":"celsius"}}`,
+			Done:               true,
+			DoneReason:         "done",
+			PromptEvalCount:    1,
+			PromptEvalDuration: 1,
+			EvalCount:          1,
+			EvalDuration:       1,
+		}
+
+		streamRequest := true
+
+		w := createRequest(t, s.ChatHandler, api.ChatRequest{
+			Model: "test-system",
+			Messages: []api.Message{
+				{Role: "user", Content: "What's the weather in Seattle?"},
+			},
+			Tools:  tools,
+			Stream: &streamRequest,
+		})
+
+		if w.Code != http.StatusOK {
+			var errResp struct {
+				Error string `json:"error"`
+			}
+			if err := json.NewDecoder(w.Body).Decode(&errResp); err != nil {
+				t.Logf("Failed to decode error response: %v", err)
+			} else {
+				t.Logf("Error response: %s", errResp.Error)
+			}
+		}
+
+		if w.Code != http.StatusOK {
+			t.Errorf("expected status 200, got %d", w.Code)
+		}
+
+		var resp api.ChatResponse
+		if err := json.NewDecoder(w.Body).Decode(&resp); err != nil {
+			t.Fatal(err)
+		}
+
+		if resp.Message.ToolCalls == nil {
+			t.Error("expected tool calls, got nil")
+		}
+
+		expectedToolCall := api.ToolCall{
+			Function: api.ToolCallFunction{
+				Name: "get_weather",
+				Arguments: api.ToolCallFunctionArguments{
+					"location": "Seattle, WA",
+					"unit":     "celsius",
+				},
+			},
+		}
+
+		if diff := cmp.Diff(resp.Message.ToolCalls[0], expectedToolCall); diff != "" {
+			t.Errorf("tool call mismatch (-got +want):\n%s", diff)
+		}
+	})
+
+	t.Run("messages with tools (streaming)", func(t *testing.T) {
+		tools := []api.Tool{
+			{
+				Type: "function",
+				Function: api.ToolFunction{
+					Name:        "get_weather",
+					Description: "Get the current weather",
+					Parameters: struct {
+						Type       string   `json:"type"`
+						Required   []string `json:"required"`
+						Properties map[string]struct {
+							Type        string   `json:"type"`
+							Description string   `json:"description"`
+							Enum        []string `json:"enum,omitempty"`
+						} `json:"properties"`
+					}{
+						Type:     "object",
+						Required: []string{"location"},
+						Properties: map[string]struct {
+							Type        string   `json:"type"`
+							Description string   `json:"description"`
+							Enum        []string `json:"enum,omitempty"`
+						}{
+							"location": {
+								Type:        "string",
+								Description: "The city and state",
+							},
+							"unit": {
+								Type: "string",
+								Enum: []string{"celsius", "fahrenheit"},
+							},
+						},
+					},
+				},
+			},
+		}
+
+		// Simulate streaming response with multiple chunks
+		var wg sync.WaitGroup
+		wg.Add(1)
+
+		mock.CompletionFn = func(ctx context.Context, r llm.CompletionRequest, fn func(r llm.CompletionResponse)) error {
+			defer wg.Done()
+
+			// Send chunks with small delays to simulate streaming
+			responses := []llm.CompletionResponse{
+				{
+					Content:            `{"name":"get_`,
+					Done:               false,
+					PromptEvalCount:    1,
+					PromptEvalDuration: 1,
+				},
+				{
+					Content:            `weather","arguments":{"location":"Seattle`,
+					Done:               false,
+					PromptEvalCount:    2,
+					PromptEvalDuration: 1,
+				},
+				{
+					Content:            `, WA","unit":"celsius"}}`,
+					Done:               true,
+					DoneReason:         "tool_call",
+					PromptEvalCount:    3,
+					PromptEvalDuration: 1,
+				},
+			}
+
+			for _, resp := range responses {
+				select {
+				case <-ctx.Done():
+					return ctx.Err()
+				default:
+					fn(resp)
+					time.Sleep(10 * time.Millisecond) // Small delay between chunks
+				}
+			}
+			return nil
+		}
+
+		w := createRequest(t, s.ChatHandler, api.ChatRequest{
+			Model: "test-system",
+			Messages: []api.Message{
+				{Role: "user", Content: "What's the weather in Seattle?"},
+			},
+			Tools:  tools,
+			Stream: &stream,
+		})
+
+		wg.Wait()
+
+		if w.Code != http.StatusOK {
+			t.Errorf("expected status 200, got %d", w.Code)
+		}
+
+		// Read and validate the streamed responses
+		decoder := json.NewDecoder(w.Body)
+		var finalToolCall api.ToolCall
+
+		for {
+			var resp api.ChatResponse
+			if err := decoder.Decode(&resp); err == io.EOF {
+				break
+			} else if err != nil {
+				t.Fatal(err)
+			}
+
+			if resp.Done {
+				if len(resp.Message.ToolCalls) != 1 {
+					t.Errorf("expected 1 tool call in final response, got %d", len(resp.Message.ToolCalls))
+				}
+				finalToolCall = resp.Message.ToolCalls[0]
+			}
+		}
+
+		expectedToolCall := api.ToolCall{
+			Function: api.ToolCallFunction{
+				Name: "get_weather",
+				Arguments: api.ToolCallFunctionArguments{
+					"location": "Seattle, WA",
+					"unit":     "celsius",
+				},
+			},
+		}
+
+		if diff := cmp.Diff(finalToolCall, expectedToolCall); diff != "" {
+			t.Errorf("final tool call mismatch (-got +want):\n%s", diff)
+		}
+	})
 }
 
 func TestGenerate(t *testing.T) {