|
@@ -8,6 +8,7 @@ import (
|
|
|
"io"
|
|
|
"net/http"
|
|
|
"strings"
|
|
|
+ "sync"
|
|
|
"testing"
|
|
|
"time"
|
|
|
|
|
@@ -25,10 +26,14 @@ type mockRunner struct {
|
|
|
// CompletionRequest is only valid until the next call to Completion
|
|
|
llm.CompletionRequest
|
|
|
llm.CompletionResponse
|
|
|
+ CompletionFn func(context.Context, llm.CompletionRequest, func(llm.CompletionResponse)) error
|
|
|
}
|
|
|
|
|
|
-func (m *mockRunner) Completion(_ context.Context, r llm.CompletionRequest, fn func(r llm.CompletionResponse)) error {
|
|
|
+func (m *mockRunner) Completion(ctx context.Context, r llm.CompletionRequest, fn func(r llm.CompletionResponse)) error {
|
|
|
m.CompletionRequest = r
|
|
|
+ if m.CompletionFn != nil {
|
|
|
+ return m.CompletionFn(ctx, r, fn)
|
|
|
+ }
|
|
|
fn(m.CompletionResponse)
|
|
|
return nil
|
|
|
}
|
|
@@ -88,9 +93,14 @@ func TestGenerateChat(t *testing.T) {
|
|
|
Model: "test",
|
|
|
Modelfile: fmt.Sprintf(`FROM %s
|
|
|
TEMPLATE """
|
|
|
-{{- if .System }}System: {{ .System }} {{ end }}
|
|
|
-{{- if .Prompt }}User: {{ .Prompt }} {{ end }}
|
|
|
-{{- if .Response }}Assistant: {{ .Response }} {{ end }}"""
|
|
|
+{{- if .Tools }}
|
|
|
+{{ .Tools }}
|
|
|
+{{ end }}
|
|
|
+{{- range .Messages }}
|
|
|
+{{- .Role }}: {{ .Content }}
|
|
|
+{{- range .ToolCalls }}{"name": "{{ .Function.Name }}", "arguments": {{ .Function.Arguments }}}
|
|
|
+{{- end }}
|
|
|
+{{ end }}"""
|
|
|
`, createBinFile(t, llm.KV{
|
|
|
"general.architecture": "llama",
|
|
|
"llama.block_count": uint32(1),
|
|
@@ -263,7 +273,7 @@ func TestGenerateChat(t *testing.T) {
|
|
|
t.Errorf("expected status 200, got %d", w.Code)
|
|
|
}
|
|
|
|
|
|
- if diff := cmp.Diff(mock.CompletionRequest.Prompt, "User: Hello! "); diff != "" {
|
|
|
+ if diff := cmp.Diff(mock.CompletionRequest.Prompt, "user: Hello!\n"); diff != "" {
|
|
|
t.Errorf("mismatch (-got +want):\n%s", diff)
|
|
|
}
|
|
|
|
|
@@ -292,7 +302,7 @@ func TestGenerateChat(t *testing.T) {
|
|
|
t.Errorf("expected status 200, got %d", w.Code)
|
|
|
}
|
|
|
|
|
|
- if diff := cmp.Diff(mock.CompletionRequest.Prompt, "System: You are a helpful assistant. User: Hello! "); diff != "" {
|
|
|
+ if diff := cmp.Diff(mock.CompletionRequest.Prompt, "system: You are a helpful assistant.\nuser: Hello!\n"); diff != "" {
|
|
|
t.Errorf("mismatch (-got +want):\n%s", diff)
|
|
|
}
|
|
|
|
|
@@ -314,7 +324,7 @@ func TestGenerateChat(t *testing.T) {
|
|
|
t.Errorf("expected status 200, got %d", w.Code)
|
|
|
}
|
|
|
|
|
|
- if diff := cmp.Diff(mock.CompletionRequest.Prompt, "System: You can perform magic tricks. User: Hello! "); diff != "" {
|
|
|
+ if diff := cmp.Diff(mock.CompletionRequest.Prompt, "system: You can perform magic tricks.\nuser: Hello!\n"); diff != "" {
|
|
|
t.Errorf("mismatch (-got +want):\n%s", diff)
|
|
|
}
|
|
|
|
|
@@ -337,12 +347,242 @@ func TestGenerateChat(t *testing.T) {
|
|
|
t.Errorf("expected status 200, got %d", w.Code)
|
|
|
}
|
|
|
|
|
|
- if diff := cmp.Diff(mock.CompletionRequest.Prompt, "System: You are a helpful assistant. User: Hello! Assistant: I can help you with that. System: You can perform magic tricks. User: Help me write tests. "); diff != "" {
|
|
|
+ if diff := cmp.Diff(mock.CompletionRequest.Prompt, "system: You are a helpful assistant.\nuser: Hello!\nassistant: I can help you with that.\nsystem: You can perform magic tricks.\nuser: Help me write tests.\n"); diff != "" {
|
|
|
t.Errorf("mismatch (-got +want):\n%s", diff)
|
|
|
}
|
|
|
|
|
|
checkChatResponse(t, w.Body, "test-system", "Abra kadabra!")
|
|
|
})
|
|
|
+
|
|
|
+ t.Run("messages with tools (non-streaming)", func(t *testing.T) {
|
|
|
+ if w.Code != http.StatusOK {
|
|
|
+ t.Fatalf("failed to create test-system model: %d", w.Code)
|
|
|
+ }
|
|
|
+
|
|
|
+ tools := []api.Tool{
|
|
|
+ {
|
|
|
+ Type: "function",
|
|
|
+ Function: api.ToolFunction{
|
|
|
+ Name: "get_weather",
|
|
|
+ Description: "Get the current weather",
|
|
|
+ Parameters: struct {
|
|
|
+ Type string `json:"type"`
|
|
|
+ Required []string `json:"required"`
|
|
|
+ Properties map[string]struct {
|
|
|
+ Type string `json:"type"`
|
|
|
+ Description string `json:"description"`
|
|
|
+ Enum []string `json:"enum,omitempty"`
|
|
|
+ } `json:"properties"`
|
|
|
+ }{
|
|
|
+ Type: "object",
|
|
|
+ Required: []string{"location"},
|
|
|
+ Properties: map[string]struct {
|
|
|
+ Type string `json:"type"`
|
|
|
+ Description string `json:"description"`
|
|
|
+ Enum []string `json:"enum,omitempty"`
|
|
|
+ }{
|
|
|
+ "location": {
|
|
|
+ Type: "string",
|
|
|
+ Description: "The city and state",
|
|
|
+ },
|
|
|
+ "unit": {
|
|
|
+ Type: "string",
|
|
|
+ Enum: []string{"celsius", "fahrenheit"},
|
|
|
+ },
|
|
|
+ },
|
|
|
+ },
|
|
|
+ },
|
|
|
+ },
|
|
|
+ }
|
|
|
+
|
|
|
+ mock.CompletionResponse = llm.CompletionResponse{
|
|
|
+ Content: `{"name":"get_weather","arguments":{"location":"Seattle, WA","unit":"celsius"}}`,
|
|
|
+ Done: true,
|
|
|
+ DoneReason: "done",
|
|
|
+ PromptEvalCount: 1,
|
|
|
+ PromptEvalDuration: 1,
|
|
|
+ EvalCount: 1,
|
|
|
+ EvalDuration: 1,
|
|
|
+ }
|
|
|
+
|
|
|
+ streamRequest := true
|
|
|
+
|
|
|
+ w := createRequest(t, s.ChatHandler, api.ChatRequest{
|
|
|
+ Model: "test-system",
|
|
|
+ Messages: []api.Message{
|
|
|
+ {Role: "user", Content: "What's the weather in Seattle?"},
|
|
|
+ },
|
|
|
+ Tools: tools,
|
|
|
+ Stream: &streamRequest,
|
|
|
+ })
|
|
|
+
|
|
|
+ if w.Code != http.StatusOK {
|
|
|
+ var errResp struct {
|
|
|
+ Error string `json:"error"`
|
|
|
+ }
|
|
|
+ if err := json.NewDecoder(w.Body).Decode(&errResp); err != nil {
|
|
|
+ t.Logf("Failed to decode error response: %v", err)
|
|
|
+ } else {
|
|
|
+ t.Logf("Error response: %s", errResp.Error)
|
|
|
+ }
|
|
|
+ }
|
|
|
+
|
|
|
+ if w.Code != http.StatusOK {
|
|
|
+ t.Errorf("expected status 200, got %d", w.Code)
|
|
|
+ }
|
|
|
+
|
|
|
+ var resp api.ChatResponse
|
|
|
+ if err := json.NewDecoder(w.Body).Decode(&resp); err != nil {
|
|
|
+ t.Fatal(err)
|
|
|
+ }
|
|
|
+
|
|
|
+ if resp.Message.ToolCalls == nil {
|
|
|
+ t.Error("expected tool calls, got nil")
|
|
|
+ }
|
|
|
+
|
|
|
+ expectedToolCall := api.ToolCall{
|
|
|
+ Function: api.ToolCallFunction{
|
|
|
+ Name: "get_weather",
|
|
|
+ Arguments: api.ToolCallFunctionArguments{
|
|
|
+ "location": "Seattle, WA",
|
|
|
+ "unit": "celsius",
|
|
|
+ },
|
|
|
+ },
|
|
|
+ }
|
|
|
+
|
|
|
+ if diff := cmp.Diff(resp.Message.ToolCalls[0], expectedToolCall); diff != "" {
|
|
|
+ t.Errorf("tool call mismatch (-got +want):\n%s", diff)
|
|
|
+ }
|
|
|
+ })
|
|
|
+
|
|
|
+ t.Run("messages with tools (streaming)", func(t *testing.T) {
|
|
|
+ tools := []api.Tool{
|
|
|
+ {
|
|
|
+ Type: "function",
|
|
|
+ Function: api.ToolFunction{
|
|
|
+ Name: "get_weather",
|
|
|
+ Description: "Get the current weather",
|
|
|
+ Parameters: struct {
|
|
|
+ Type string `json:"type"`
|
|
|
+ Required []string `json:"required"`
|
|
|
+ Properties map[string]struct {
|
|
|
+ Type string `json:"type"`
|
|
|
+ Description string `json:"description"`
|
|
|
+ Enum []string `json:"enum,omitempty"`
|
|
|
+ } `json:"properties"`
|
|
|
+ }{
|
|
|
+ Type: "object",
|
|
|
+ Required: []string{"location"},
|
|
|
+ Properties: map[string]struct {
|
|
|
+ Type string `json:"type"`
|
|
|
+ Description string `json:"description"`
|
|
|
+ Enum []string `json:"enum,omitempty"`
|
|
|
+ }{
|
|
|
+ "location": {
|
|
|
+ Type: "string",
|
|
|
+ Description: "The city and state",
|
|
|
+ },
|
|
|
+ "unit": {
|
|
|
+ Type: "string",
|
|
|
+ Enum: []string{"celsius", "fahrenheit"},
|
|
|
+ },
|
|
|
+ },
|
|
|
+ },
|
|
|
+ },
|
|
|
+ },
|
|
|
+ }
|
|
|
+
|
|
|
+ // Simulate streaming response with multiple chunks
|
|
|
+ var wg sync.WaitGroup
|
|
|
+ wg.Add(1)
|
|
|
+
|
|
|
+ mock.CompletionFn = func(ctx context.Context, r llm.CompletionRequest, fn func(r llm.CompletionResponse)) error {
|
|
|
+ defer wg.Done()
|
|
|
+
|
|
|
+ // Send chunks with small delays to simulate streaming
|
|
|
+ responses := []llm.CompletionResponse{
|
|
|
+ {
|
|
|
+ Content: `{"name":"get_`,
|
|
|
+ Done: false,
|
|
|
+ PromptEvalCount: 1,
|
|
|
+ PromptEvalDuration: 1,
|
|
|
+ },
|
|
|
+ {
|
|
|
+ Content: `weather","arguments":{"location":"Seattle`,
|
|
|
+ Done: false,
|
|
|
+ PromptEvalCount: 2,
|
|
|
+ PromptEvalDuration: 1,
|
|
|
+ },
|
|
|
+ {
|
|
|
+ Content: `, WA","unit":"celsius"}}`,
|
|
|
+ Done: true,
|
|
|
+ DoneReason: "tool_call",
|
|
|
+ PromptEvalCount: 3,
|
|
|
+ PromptEvalDuration: 1,
|
|
|
+ },
|
|
|
+ }
|
|
|
+
|
|
|
+ for _, resp := range responses {
|
|
|
+ select {
|
|
|
+ case <-ctx.Done():
|
|
|
+ return ctx.Err()
|
|
|
+ default:
|
|
|
+ fn(resp)
|
|
|
+ time.Sleep(10 * time.Millisecond) // Small delay between chunks
|
|
|
+ }
|
|
|
+ }
|
|
|
+ return nil
|
|
|
+ }
|
|
|
+
|
|
|
+ w := createRequest(t, s.ChatHandler, api.ChatRequest{
|
|
|
+ Model: "test-system",
|
|
|
+ Messages: []api.Message{
|
|
|
+ {Role: "user", Content: "What's the weather in Seattle?"},
|
|
|
+ },
|
|
|
+ Tools: tools,
|
|
|
+ Stream: &stream,
|
|
|
+ })
|
|
|
+
|
|
|
+ wg.Wait()
|
|
|
+
|
|
|
+ if w.Code != http.StatusOK {
|
|
|
+ t.Errorf("expected status 200, got %d", w.Code)
|
|
|
+ }
|
|
|
+
|
|
|
+ // Read and validate the streamed responses
|
|
|
+ decoder := json.NewDecoder(w.Body)
|
|
|
+ var finalToolCall api.ToolCall
|
|
|
+
|
|
|
+ for {
|
|
|
+ var resp api.ChatResponse
|
|
|
+ if err := decoder.Decode(&resp); err == io.EOF {
|
|
|
+ break
|
|
|
+ } else if err != nil {
|
|
|
+ t.Fatal(err)
|
|
|
+ }
|
|
|
+
|
|
|
+ if resp.Done {
|
|
|
+ if len(resp.Message.ToolCalls) != 1 {
|
|
|
+ t.Errorf("expected 1 tool call in final response, got %d", len(resp.Message.ToolCalls))
|
|
|
+ }
|
|
|
+ finalToolCall = resp.Message.ToolCalls[0]
|
|
|
+ }
|
|
|
+ }
|
|
|
+
|
|
|
+ expectedToolCall := api.ToolCall{
|
|
|
+ Function: api.ToolCallFunction{
|
|
|
+ Name: "get_weather",
|
|
|
+ Arguments: api.ToolCallFunctionArguments{
|
|
|
+ "location": "Seattle, WA",
|
|
|
+ "unit": "celsius",
|
|
|
+ },
|
|
|
+ },
|
|
|
+ }
|
|
|
+
|
|
|
+ if diff := cmp.Diff(finalToolCall, expectedToolCall); diff != "" {
|
|
|
+ t.Errorf("final tool call mismatch (-got +want):\n%s", diff)
|
|
|
+ }
|
|
|
+ })
|
|
|
}
|
|
|
|
|
|
func TestGenerate(t *testing.T) {
|