|
@@ -7,7 +7,6 @@ import (
|
|
|
"io"
|
|
|
"net/http"
|
|
|
"net/http/httptest"
|
|
|
- "reflect"
|
|
|
"strings"
|
|
|
"testing"
|
|
|
"time"
|
|
@@ -82,7 +81,7 @@ func TestChatMiddleware(t *testing.T) {
|
|
|
{"role": "user", "content": "Hello"}
|
|
|
],
|
|
|
"stream": true,
|
|
|
- "max_tokens": 999,
|
|
|
+ "max_completion_tokens": 999,
|
|
|
"seed": 123,
|
|
|
"stop": ["\n", "stop"],
|
|
|
"temperature": 3.0,
|
|
@@ -315,6 +314,61 @@ func TestChatMiddleware(t *testing.T) {
|
|
|
Stream: &True,
|
|
|
},
|
|
|
},
|
|
|
+ {
|
|
|
+ name: "chat handler with num_ctx",
|
|
|
+ body: `{
|
|
|
+ "model": "test-model",
|
|
|
+ "messages": [{"role": "user", "content": "Hello"}],
|
|
|
+ "num_ctx": 4096
|
|
|
+ }`,
|
|
|
+ req: api.ChatRequest{
|
|
|
+ Model: "test-model",
|
|
|
+ Messages: []api.Message{{Role: "user", Content: "Hello"}},
|
|
|
+ Options: map[string]any{
|
|
|
+ "num_ctx": 4096.0, // float because JSON doesn't distinguish between float and int
|
|
|
+ "temperature": 1.0,
|
|
|
+ "top_p": 1.0,
|
|
|
+ },
|
|
|
+ Stream: &False,
|
|
|
+ },
|
|
|
+ },
|
|
|
+ {
|
|
|
+ name: "chat handler with max_completion_tokens < num_ctx",
|
|
|
+ body: `{
|
|
|
+ "model": "test-model",
|
|
|
+ "messages": [{"role": "user", "content": "Hello"}],
|
|
|
+ "max_completion_tokens": 2
|
|
|
+ }`,
|
|
|
+ req: api.ChatRequest{
|
|
|
+ Model: "test-model",
|
|
|
+ Messages: []api.Message{{Role: "user", Content: "Hello"}},
|
|
|
+ Options: map[string]any{
|
|
|
+ "num_predict": 2.0, // float because JSON doesn't distinguish between float and int
|
|
|
+ "temperature": 1.0,
|
|
|
+ "top_p": 1.0,
|
|
|
+ },
|
|
|
+ Stream: &False,
|
|
|
+ },
|
|
|
+ },
|
|
|
+ {
|
|
|
+ name: "chat handler with max_completion_tokens > num_ctx",
|
|
|
+ body: `{
|
|
|
+ "model": "test-model",
|
|
|
+ "messages": [{"role": "user", "content": "Hello"}],
|
|
|
+ "max_completion_tokens": 4096
|
|
|
+ }`,
|
|
|
+ req: api.ChatRequest{
|
|
|
+ Model: "test-model",
|
|
|
+ Messages: []api.Message{{Role: "user", Content: "Hello"}},
|
|
|
+ Options: map[string]any{
|
|
|
+ "num_predict": 4096.0, // float because JSON doesn't distinguish between float and int
|
|
|
+ "num_ctx": 4096.0,
|
|
|
+ "temperature": 1.0,
|
|
|
+ "top_p": 1.0,
|
|
|
+ },
|
|
|
+ Stream: &False,
|
|
|
+ },
|
|
|
+ },
|
|
|
{
|
|
|
name: "chat handler error forwarding",
|
|
|
body: `{
|
|
@@ -359,7 +413,7 @@ func TestChatMiddleware(t *testing.T) {
|
|
|
return
|
|
|
}
|
|
|
if diff := cmp.Diff(&tc.req, capturedRequest); diff != "" {
|
|
|
- t.Fatalf("requests did not match: %+v", diff)
|
|
|
+ t.Fatalf("requests did not match (-want +got):\n%s", diff)
|
|
|
}
|
|
|
if diff := cmp.Diff(tc.err, errResp); diff != "" {
|
|
|
t.Fatalf("errors did not match for %s:\n%s", tc.name, diff)
|
|
@@ -493,12 +547,14 @@ func TestCompletionsMiddleware(t *testing.T) {
|
|
|
}
|
|
|
}
|
|
|
|
|
|
- if capturedRequest != nil && !reflect.DeepEqual(tc.req, *capturedRequest) {
|
|
|
- t.Fatal("requests did not match")
|
|
|
+ if capturedRequest != nil {
|
|
|
+ if diff := cmp.Diff(tc.req, *capturedRequest); diff != "" {
|
|
|
+ t.Fatalf("requests did not match (-want +got):\n%s", diff)
|
|
|
+ }
|
|
|
}
|
|
|
|
|
|
- if !reflect.DeepEqual(tc.err, errResp) {
|
|
|
- t.Fatal("errors did not match")
|
|
|
+ if diff := cmp.Diff(tc.err, errResp); diff != "" {
|
|
|
+ t.Fatalf("errors did not match (-want +got):\n%s", diff)
|
|
|
}
|
|
|
|
|
|
capturedRequest = nil
|
|
@@ -577,12 +633,14 @@ func TestEmbeddingsMiddleware(t *testing.T) {
|
|
|
}
|
|
|
}
|
|
|
|
|
|
- if capturedRequest != nil && !reflect.DeepEqual(tc.req, *capturedRequest) {
|
|
|
- t.Fatal("requests did not match")
|
|
|
+ if capturedRequest != nil {
|
|
|
+ if diff := cmp.Diff(tc.req, *capturedRequest); diff != "" {
|
|
|
+ t.Fatalf("requests did not match (-want +got):\n%s", diff)
|
|
|
+ }
|
|
|
}
|
|
|
|
|
|
- if !reflect.DeepEqual(tc.err, errResp) {
|
|
|
- t.Fatal("errors did not match")
|
|
|
+ if diff := cmp.Diff(tc.err, errResp); diff != "" {
|
|
|
+ t.Fatalf("errors did not match (-want +got):\n%s", diff)
|
|
|
}
|
|
|
|
|
|
capturedRequest = nil
|
|
@@ -656,8 +714,8 @@ func TestListMiddleware(t *testing.T) {
|
|
|
t.Fatalf("failed to unmarshal actual response: %v", err)
|
|
|
}
|
|
|
|
|
|
- if !reflect.DeepEqual(expected, actual) {
|
|
|
- t.Errorf("responses did not match\nExpected: %+v\nActual: %+v", expected, actual)
|
|
|
+ if diff := cmp.Diff(expected, actual); diff != "" {
|
|
|
+ t.Errorf("responses did not match (-want +got):\n%s", diff)
|
|
|
}
|
|
|
}
|
|
|
}
|
|
@@ -722,8 +780,8 @@ func TestRetrieveMiddleware(t *testing.T) {
|
|
|
t.Fatalf("failed to unmarshal actual response: %v", err)
|
|
|
}
|
|
|
|
|
|
- if !reflect.DeepEqual(expected, actual) {
|
|
|
- t.Errorf("responses did not match\nExpected: %+v\nActual: %+v", expected, actual)
|
|
|
+ if diff := cmp.Diff(expected, actual); diff != "" {
|
|
|
+ t.Errorf("responses did not match (-want +got):\n%s", diff)
|
|
|
}
|
|
|
}
|
|
|
}
|