|
@@ -22,7 +22,10 @@ const (
|
|
image = `iVBORw0KGgoAAAANSUhEUgAAAAEAAAABCAQAAAC1HAwCAAAAC0lEQVR42mNk+A8AAQUBAScY42YAAAAASUVORK5CYII=`
|
|
image = `iVBORw0KGgoAAAANSUhEUgAAAAEAAAABCAQAAAC1HAwCAAAAC0lEQVR42mNk+A8AAQUBAScY42YAAAAASUVORK5CYII=`
|
|
)
|
|
)
|
|
|
|
|
|
-var False = false
|
|
|
|
|
|
+var (
|
|
|
|
+ False = false
|
|
|
|
+ True = true
|
|
|
|
+)
|
|
|
|
|
|
func captureRequestMiddleware(capturedRequest any) gin.HandlerFunc {
|
|
func captureRequestMiddleware(capturedRequest any) gin.HandlerFunc {
|
|
return func(c *gin.Context) {
|
|
return func(c *gin.Context) {
|
|
@@ -70,6 +73,44 @@ func TestChatMiddleware(t *testing.T) {
|
|
Stream: &False,
|
|
Stream: &False,
|
|
},
|
|
},
|
|
},
|
|
},
|
|
|
|
+ {
|
|
|
|
+ name: "chat handler with options",
|
|
|
|
+ body: `{
|
|
|
|
+ "model": "test-model",
|
|
|
|
+ "messages": [
|
|
|
|
+ {"role": "user", "content": "Hello"}
|
|
|
|
+ ],
|
|
|
|
+ "stream": true,
|
|
|
|
+ "max_tokens": 999,
|
|
|
|
+ "seed": 123,
|
|
|
|
+ "stop": ["\n", "stop"],
|
|
|
|
+ "temperature": 3.0,
|
|
|
|
+ "frequency_penalty": 4.0,
|
|
|
|
+ "presence_penalty": 5.0,
|
|
|
|
+ "top_p": 6.0,
|
|
|
|
+ "response_format": {"type": "json_object"}
|
|
|
|
+ }`,
|
|
|
|
+ req: api.ChatRequest{
|
|
|
|
+ Model: "test-model",
|
|
|
|
+ Messages: []api.Message{
|
|
|
|
+ {
|
|
|
|
+ Role: "user",
|
|
|
|
+ Content: "Hello",
|
|
|
|
+ },
|
|
|
|
+ },
|
|
|
|
+ Options: map[string]any{
|
|
|
|
+ "num_predict": 999.0, // float because JSON doesn't distinguish between float and int
|
|
|
|
+ "seed": 123.0,
|
|
|
|
+ "stop": []any{"\n", "stop"},
|
|
|
|
+ "temperature": 6.0,
|
|
|
|
+ "frequency_penalty": 8.0,
|
|
|
|
+ "presence_penalty": 10.0,
|
|
|
|
+ "top_p": 6.0,
|
|
|
|
+ },
|
|
|
|
+ Format: "json",
|
|
|
|
+ Stream: &True,
|
|
|
|
+ },
|
|
|
|
+ },
|
|
{
|
|
{
|
|
name: "chat handler with image content",
|
|
name: "chat handler with image content",
|
|
body: `{
|
|
body: `{
|
|
@@ -186,6 +227,8 @@ func TestChatMiddleware(t *testing.T) {
|
|
req, _ := http.NewRequest(http.MethodPost, "/api/chat", strings.NewReader(tc.body))
|
|
req, _ := http.NewRequest(http.MethodPost, "/api/chat", strings.NewReader(tc.body))
|
|
req.Header.Set("Content-Type", "application/json")
|
|
req.Header.Set("Content-Type", "application/json")
|
|
|
|
|
|
|
|
+ defer func() { capturedRequest = nil }()
|
|
|
|
+
|
|
resp := httptest.NewRecorder()
|
|
resp := httptest.NewRecorder()
|
|
router.ServeHTTP(resp, req)
|
|
router.ServeHTTP(resp, req)
|
|
|
|
|
|
@@ -202,7 +245,6 @@ func TestChatMiddleware(t *testing.T) {
|
|
if !reflect.DeepEqual(tc.err, errResp) {
|
|
if !reflect.DeepEqual(tc.err, errResp) {
|
|
t.Fatal("errors did not match")
|
|
t.Fatal("errors did not match")
|
|
}
|
|
}
|
|
- capturedRequest = nil
|
|
|
|
})
|
|
})
|
|
}
|
|
}
|
|
}
|
|
}
|