|
@@ -7,27 +7,22 @@ import (
|
|
|
"io"
|
|
|
"net/http"
|
|
|
"net/http/httptest"
|
|
|
+ "reflect"
|
|
|
"strings"
|
|
|
"testing"
|
|
|
"time"
|
|
|
|
|
|
"github.com/gin-gonic/gin"
|
|
|
- "github.com/stretchr/testify/assert"
|
|
|
|
|
|
"github.com/ollama/ollama/api"
|
|
|
)
|
|
|
|
|
|
const (
|
|
|
- prefix = `data:image/jpeg;base64,`
|
|
|
- image = `iVBORw0KGgoAAAANSUhEUgAAAAEAAAABCAQAAAC1HAwCAAAAC0lEQVR42mNk+A8AAQUBAScY42YAAAAASUVORK5CYII=`
|
|
|
- imageURL = prefix + image
|
|
|
+ prefix = `data:image/jpeg;base64,`
|
|
|
+ image = `iVBORw0KGgoAAAANSUhEUgAAAAEAAAABCAQAAAC1HAwCAAAAC0lEQVR42mNk+A8AAQUBAScY42YAAAAASUVORK5CYII=`
|
|
|
)
|
|
|
|
|
|
-func prepareRequest(req *http.Request, body any) {
|
|
|
- bodyBytes, _ := json.Marshal(body)
|
|
|
- req.Body = io.NopCloser(bytes.NewReader(bodyBytes))
|
|
|
- req.Header.Set("Content-Type", "application/json")
|
|
|
-}
|
|
|
+var False = false
|
|
|
|
|
|
func captureRequestMiddleware(capturedRequest any) gin.HandlerFunc {
|
|
|
return func(c *gin.Context) {
|
|
@@ -43,134 +38,136 @@ func captureRequestMiddleware(capturedRequest any) gin.HandlerFunc {
|
|
|
|
|
|
func TestChatMiddleware(t *testing.T) {
|
|
|
type testCase struct {
|
|
|
- Name string
|
|
|
- Setup func(t *testing.T, req *http.Request)
|
|
|
- Expected func(t *testing.T, req *api.ChatRequest, resp *httptest.ResponseRecorder)
|
|
|
+ name string
|
|
|
+ body string
|
|
|
+ req api.ChatRequest
|
|
|
+ err ErrorResponse
|
|
|
}
|
|
|
|
|
|
var capturedRequest *api.ChatRequest
|
|
|
|
|
|
testCases := []testCase{
|
|
|
{
|
|
|
- Name: "chat handler",
|
|
|
- Setup: func(t *testing.T, req *http.Request) {
|
|
|
- body := ChatCompletionRequest{
|
|
|
- Model: "test-model",
|
|
|
- Messages: []Message{{Role: "user", Content: "Hello"}},
|
|
|
- }
|
|
|
- prepareRequest(req, body)
|
|
|
- },
|
|
|
- Expected: func(t *testing.T, req *api.ChatRequest, resp *httptest.ResponseRecorder) {
|
|
|
- if resp.Code != http.StatusOK {
|
|
|
- t.Fatalf("expected 200, got %d", resp.Code)
|
|
|
- }
|
|
|
-
|
|
|
- if req.Messages[0].Role != "user" {
|
|
|
- t.Fatalf("expected 'user', got %s", req.Messages[0].Role)
|
|
|
- }
|
|
|
-
|
|
|
- if req.Messages[0].Content != "Hello" {
|
|
|
- t.Fatalf("expected 'Hello', got %s", req.Messages[0].Content)
|
|
|
- }
|
|
|
+ name: "chat handler",
|
|
|
+ body: `{
|
|
|
+ "model": "test-model",
|
|
|
+ "messages": [
|
|
|
+ {"role": "user", "content": "Hello"}
|
|
|
+ ]
|
|
|
+ }`,
|
|
|
+ req: api.ChatRequest{
|
|
|
+ Model: "test-model",
|
|
|
+ Messages: []api.Message{
|
|
|
+ {
|
|
|
+ Role: "user",
|
|
|
+ Content: "Hello",
|
|
|
+ },
|
|
|
+ },
|
|
|
+ Options: map[string]any{
|
|
|
+ "temperature": 1.0,
|
|
|
+ "top_p": 1.0,
|
|
|
+ },
|
|
|
+ Stream: &False,
|
|
|
},
|
|
|
},
|
|
|
{
|
|
|
- Name: "chat handler with image content",
|
|
|
- Setup: func(t *testing.T, req *http.Request) {
|
|
|
- body := ChatCompletionRequest{
|
|
|
- Model: "test-model",
|
|
|
- Messages: []Message{
|
|
|
- {
|
|
|
- Role: "user", Content: []map[string]any{
|
|
|
- {"type": "text", "text": "Hello"},
|
|
|
- {"type": "image_url", "image_url": map[string]string{"url": imageURL}},
|
|
|
+ name: "chat handler with image content",
|
|
|
+ body: `{
|
|
|
+ "model": "test-model",
|
|
|
+ "messages": [
|
|
|
+ {
|
|
|
+ "role": "user",
|
|
|
+ "content": [
|
|
|
+ {
|
|
|
+ "type": "text",
|
|
|
+ "text": "Hello"
|
|
|
},
|
|
|
+ {
|
|
|
+ "type": "image_url",
|
|
|
+ "image_url": {
|
|
|
+ "url": "` + prefix + image + `"
|
|
|
+ }
|
|
|
+ }
|
|
|
+ ]
|
|
|
+ }
|
|
|
+ ]
|
|
|
+ }`,
|
|
|
+ req: api.ChatRequest{
|
|
|
+ Model: "test-model",
|
|
|
+ Messages: []api.Message{
|
|
|
+ {
|
|
|
+ Role: "user",
|
|
|
+ Content: "Hello",
|
|
|
+ },
|
|
|
+ {
|
|
|
+ Role: "user",
|
|
|
+ Images: []api.ImageData{
|
|
|
+ func() []byte {
|
|
|
+ img, _ := base64.StdEncoding.DecodeString(image)
|
|
|
+ return img
|
|
|
+ }(),
|
|
|
},
|
|
|
},
|
|
|
- }
|
|
|
- prepareRequest(req, body)
|
|
|
- },
|
|
|
- Expected: func(t *testing.T, req *api.ChatRequest, resp *httptest.ResponseRecorder) {
|
|
|
- if resp.Code != http.StatusOK {
|
|
|
- t.Fatalf("expected 200, got %d", resp.Code)
|
|
|
- }
|
|
|
-
|
|
|
- if req.Messages[0].Role != "user" {
|
|
|
- t.Fatalf("expected 'user', got %s", req.Messages[0].Role)
|
|
|
- }
|
|
|
-
|
|
|
- if req.Messages[0].Content != "Hello" {
|
|
|
- t.Fatalf("expected 'Hello', got %s", req.Messages[0].Content)
|
|
|
- }
|
|
|
-
|
|
|
- img, _ := base64.StdEncoding.DecodeString(imageURL[len(prefix):])
|
|
|
-
|
|
|
- if req.Messages[1].Role != "user" {
|
|
|
- t.Fatalf("expected 'user', got %s", req.Messages[1].Role)
|
|
|
- }
|
|
|
-
|
|
|
- if !bytes.Equal(req.Messages[1].Images[0], img) {
|
|
|
- t.Fatalf("expected image encoding, got %s", req.Messages[1].Images[0])
|
|
|
- }
|
|
|
+ },
|
|
|
+ Options: map[string]any{
|
|
|
+ "temperature": 1.0,
|
|
|
+ "top_p": 1.0,
|
|
|
+ },
|
|
|
+ Stream: &False,
|
|
|
},
|
|
|
},
|
|
|
{
|
|
|
- Name: "chat handler with tools",
|
|
|
- Setup: func(t *testing.T, req *http.Request) {
|
|
|
- body := ChatCompletionRequest{
|
|
|
- Model: "test-model",
|
|
|
- Messages: []Message{
|
|
|
- {Role: "user", Content: "What's the weather like in Paris Today?"},
|
|
|
- {Role: "assistant", ToolCalls: []ToolCall{{
|
|
|
- ID: "id",
|
|
|
- Type: "function",
|
|
|
- Function: struct {
|
|
|
- Name string `json:"name"`
|
|
|
- Arguments string `json:"arguments"`
|
|
|
- }{
|
|
|
- Name: "get_current_weather",
|
|
|
- Arguments: "{\"location\": \"Paris, France\", \"format\": \"celsius\"}",
|
|
|
+ name: "chat handler with tools",
|
|
|
+ body: `{
|
|
|
+ "model": "test-model",
|
|
|
+ "messages": [
|
|
|
+ {"role": "user", "content": "What's the weather like in Paris Today?"},
|
|
|
+ {"role": "assistant", "tool_calls": [{"id": "id", "type": "function", "function": {"name": "get_current_weather", "arguments": "{\"location\": \"Paris, France\", \"format\": \"celsius\"}"}}]}
|
|
|
+ ]
|
|
|
+ }`,
|
|
|
+ req: api.ChatRequest{
|
|
|
+ Model: "test-model",
|
|
|
+ Messages: []api.Message{
|
|
|
+ {
|
|
|
+ Role: "user",
|
|
|
+ Content: "What's the weather like in Paris Today?",
|
|
|
+ },
|
|
|
+ {
|
|
|
+ Role: "assistant",
|
|
|
+ ToolCalls: []api.ToolCall{
|
|
|
+ {
|
|
|
+ Function: api.ToolCallFunction{
|
|
|
+ Name: "get_current_weather",
|
|
|
+ Arguments: map[string]interface{}{
|
|
|
+ "location": "Paris, France",
|
|
|
+ "format": "celsius",
|
|
|
+ },
|
|
|
+ },
|
|
|
},
|
|
|
- }}},
|
|
|
+ },
|
|
|
},
|
|
|
- }
|
|
|
- prepareRequest(req, body)
|
|
|
- },
|
|
|
- Expected: func(t *testing.T, req *api.ChatRequest, resp *httptest.ResponseRecorder) {
|
|
|
- if resp.Code != 200 {
|
|
|
- t.Fatalf("expected 200, got %d", resp.Code)
|
|
|
- }
|
|
|
-
|
|
|
- if req.Messages[0].Content != "What's the weather like in Paris Today?" {
|
|
|
- t.Fatalf("expected What's the weather like in Paris Today?, got %s", req.Messages[0].Content)
|
|
|
- }
|
|
|
-
|
|
|
- if req.Messages[1].ToolCalls[0].Function.Arguments["location"] != "Paris, France" {
|
|
|
- t.Fatalf("expected 'Paris, France', got %v", req.Messages[1].ToolCalls[0].Function.Arguments["location"])
|
|
|
- }
|
|
|
-
|
|
|
- if req.Messages[1].ToolCalls[0].Function.Arguments["format"] != "celsius" {
|
|
|
- t.Fatalf("expected celsius, got %v", req.Messages[1].ToolCalls[0].Function.Arguments["format"])
|
|
|
- }
|
|
|
+ },
|
|
|
+ Options: map[string]any{
|
|
|
+ "temperature": 1.0,
|
|
|
+ "top_p": 1.0,
|
|
|
+ },
|
|
|
+ Stream: &False,
|
|
|
},
|
|
|
},
|
|
|
- {
|
|
|
- Name: "chat handler error forwarding",
|
|
|
- Setup: func(t *testing.T, req *http.Request) {
|
|
|
- body := ChatCompletionRequest{
|
|
|
- Model: "test-model",
|
|
|
- Messages: []Message{{Role: "user", Content: 2}},
|
|
|
- }
|
|
|
- prepareRequest(req, body)
|
|
|
- },
|
|
|
- Expected: func(t *testing.T, req *api.ChatRequest, resp *httptest.ResponseRecorder) {
|
|
|
- if resp.Code != http.StatusBadRequest {
|
|
|
- t.Fatalf("expected 400, got %d", resp.Code)
|
|
|
- }
|
|
|
|
|
|
- if !strings.Contains(resp.Body.String(), "invalid message content type") {
|
|
|
- t.Fatalf("error was not forwarded")
|
|
|
- }
|
|
|
+ {
|
|
|
+ name: "chat handler error forwarding",
|
|
|
+ body: `{
|
|
|
+ "model": "test-model",
|
|
|
+ "messages": [
|
|
|
+ {"role": "user", "content": 2}
|
|
|
+ ]
|
|
|
+ }`,
|
|
|
+ err: ErrorResponse{
|
|
|
+ Error: Error{
|
|
|
+ Message: "invalid message content type: float64",
|
|
|
+ Type: "invalid_request_error",
|
|
|
+ },
|
|
|
},
|
|
|
},
|
|
|
}
|
|
@@ -185,16 +182,26 @@ func TestChatMiddleware(t *testing.T) {
|
|
|
router.Handle(http.MethodPost, "/api/chat", endpoint)
|
|
|
|
|
|
for _, tc := range testCases {
|
|
|
- t.Run(tc.Name, func(t *testing.T) {
|
|
|
- req, _ := http.NewRequest(http.MethodPost, "/api/chat", nil)
|
|
|
-
|
|
|
- tc.Setup(t, req)
|
|
|
+ t.Run(tc.name, func(t *testing.T) {
|
|
|
+ req, _ := http.NewRequest(http.MethodPost, "/api/chat", strings.NewReader(tc.body))
|
|
|
+ req.Header.Set("Content-Type", "application/json")
|
|
|
|
|
|
resp := httptest.NewRecorder()
|
|
|
router.ServeHTTP(resp, req)
|
|
|
|
|
|
- tc.Expected(t, capturedRequest, resp)
|
|
|
+ var errResp ErrorResponse
|
|
|
+ if resp.Code != http.StatusOK {
|
|
|
+ if err := json.Unmarshal(resp.Body.Bytes(), &errResp); err != nil {
|
|
|
+ t.Fatal(err)
|
|
|
+ }
|
|
|
+ }
|
|
|
+ if capturedRequest != nil && !reflect.DeepEqual(tc.req, *capturedRequest) {
|
|
|
+ t.Fatal("requests did not match")
|
|
|
+ }
|
|
|
|
|
|
+ if !reflect.DeepEqual(tc.err, errResp) {
|
|
|
+ t.Fatal("errors did not match")
|
|
|
+ }
|
|
|
capturedRequest = nil
|
|
|
})
|
|
|
}
|
|
@@ -202,71 +209,52 @@ func TestChatMiddleware(t *testing.T) {
|
|
|
|
|
|
func TestCompletionsMiddleware(t *testing.T) {
|
|
|
type testCase struct {
|
|
|
- Name string
|
|
|
- Setup func(t *testing.T, req *http.Request)
|
|
|
- Expected func(t *testing.T, req *api.GenerateRequest, resp *httptest.ResponseRecorder)
|
|
|
+ name string
|
|
|
+ body string
|
|
|
+ req api.GenerateRequest
|
|
|
+ err ErrorResponse
|
|
|
}
|
|
|
|
|
|
var capturedRequest *api.GenerateRequest
|
|
|
|
|
|
testCases := []testCase{
|
|
|
{
|
|
|
- Name: "completions handler",
|
|
|
- Setup: func(t *testing.T, req *http.Request) {
|
|
|
- temp := float32(0.8)
|
|
|
- body := CompletionRequest{
|
|
|
- Model: "test-model",
|
|
|
- Prompt: "Hello",
|
|
|
- Temperature: &temp,
|
|
|
- Stop: []string{"\n", "stop"},
|
|
|
- Suffix: "suffix",
|
|
|
- }
|
|
|
- prepareRequest(req, body)
|
|
|
- },
|
|
|
- Expected: func(t *testing.T, req *api.GenerateRequest, resp *httptest.ResponseRecorder) {
|
|
|
- if req.Prompt != "Hello" {
|
|
|
- t.Fatalf("expected 'Hello', got %s", req.Prompt)
|
|
|
- }
|
|
|
-
|
|
|
- if req.Options["temperature"] != 1.6 {
|
|
|
- t.Fatalf("expected 1.6, got %f", req.Options["temperature"])
|
|
|
- }
|
|
|
-
|
|
|
- stopTokens, ok := req.Options["stop"].([]any)
|
|
|
-
|
|
|
- if !ok {
|
|
|
- t.Fatalf("expected stop tokens to be a list")
|
|
|
- }
|
|
|
-
|
|
|
- if stopTokens[0] != "\n" || stopTokens[1] != "stop" {
|
|
|
- t.Fatalf("expected ['\\n', 'stop'], got %v", stopTokens)
|
|
|
- }
|
|
|
-
|
|
|
- if req.Suffix != "suffix" {
|
|
|
- t.Fatalf("expected 'suffix', got %s", req.Suffix)
|
|
|
- }
|
|
|
+ name: "completions handler",
|
|
|
+ body: `{
|
|
|
+ "model": "test-model",
|
|
|
+ "prompt": "Hello",
|
|
|
+ "temperature": 0.8,
|
|
|
+ "stop": ["\n", "stop"],
|
|
|
+ "suffix": "suffix"
|
|
|
+ }`,
|
|
|
+ req: api.GenerateRequest{
|
|
|
+ Model: "test-model",
|
|
|
+ Prompt: "Hello",
|
|
|
+ Options: map[string]any{
|
|
|
+ "frequency_penalty": 0.0,
|
|
|
+ "presence_penalty": 0.0,
|
|
|
+ "temperature": 1.6,
|
|
|
+ "top_p": 1.0,
|
|
|
+ "stop": []any{"\n", "stop"},
|
|
|
+ },
|
|
|
+ Suffix: "suffix",
|
|
|
+ Stream: &False,
|
|
|
},
|
|
|
},
|
|
|
{
|
|
|
- Name: "completions handler error forwarding",
|
|
|
- Setup: func(t *testing.T, req *http.Request) {
|
|
|
- body := CompletionRequest{
|
|
|
- Model: "test-model",
|
|
|
- Prompt: "Hello",
|
|
|
- Temperature: nil,
|
|
|
- Stop: []int{1, 2},
|
|
|
- Suffix: "suffix",
|
|
|
- }
|
|
|
- prepareRequest(req, body)
|
|
|
- },
|
|
|
- Expected: func(t *testing.T, req *api.GenerateRequest, resp *httptest.ResponseRecorder) {
|
|
|
- if resp.Code != http.StatusBadRequest {
|
|
|
- t.Fatalf("expected 400, got %d", resp.Code)
|
|
|
- }
|
|
|
-
|
|
|
- if !strings.Contains(resp.Body.String(), "invalid type for 'stop' field") {
|
|
|
- t.Fatalf("error was not forwarded")
|
|
|
- }
|
|
|
+ name: "completions handler error forwarding",
|
|
|
+ body: `{
|
|
|
+ "model": "test-model",
|
|
|
+ "prompt": "Hello",
|
|
|
+ "temperature": null,
|
|
|
+ "stop": [1, 2],
|
|
|
+ "suffix": "suffix"
|
|
|
+ }`,
|
|
|
+ err: ErrorResponse{
|
|
|
+ Error: Error{
|
|
|
+ Message: "invalid type for 'stop' field: float64",
|
|
|
+ Type: "invalid_request_error",
|
|
|
+ },
|
|
|
},
|
|
|
},
|
|
|
}
|
|
@@ -281,15 +269,27 @@ func TestCompletionsMiddleware(t *testing.T) {
|
|
|
router.Handle(http.MethodPost, "/api/generate", endpoint)
|
|
|
|
|
|
for _, tc := range testCases {
|
|
|
- t.Run(tc.Name, func(t *testing.T) {
|
|
|
- req, _ := http.NewRequest(http.MethodPost, "/api/generate", nil)
|
|
|
-
|
|
|
- tc.Setup(t, req)
|
|
|
+ t.Run(tc.name, func(t *testing.T) {
|
|
|
+ req, _ := http.NewRequest(http.MethodPost, "/api/generate", strings.NewReader(tc.body))
|
|
|
+ req.Header.Set("Content-Type", "application/json")
|
|
|
|
|
|
resp := httptest.NewRecorder()
|
|
|
router.ServeHTTP(resp, req)
|
|
|
|
|
|
- tc.Expected(t, capturedRequest, resp)
|
|
|
+ var errResp ErrorResponse
|
|
|
+ if resp.Code != http.StatusOK {
|
|
|
+ if err := json.Unmarshal(resp.Body.Bytes(), &errResp); err != nil {
|
|
|
+ t.Fatal(err)
|
|
|
+ }
|
|
|
+ }
|
|
|
+
|
|
|
+ if capturedRequest != nil && !reflect.DeepEqual(tc.req, *capturedRequest) {
|
|
|
+ t.Fatal("requests did not match")
|
|
|
+ }
|
|
|
+
|
|
|
+ if !reflect.DeepEqual(tc.err, errResp) {
|
|
|
+ t.Fatal("errors did not match")
|
|
|
+ }
|
|
|
|
|
|
capturedRequest = nil
|
|
|
})
|
|
@@ -298,78 +298,47 @@ func TestCompletionsMiddleware(t *testing.T) {
|
|
|
|
|
|
func TestEmbeddingsMiddleware(t *testing.T) {
|
|
|
type testCase struct {
|
|
|
- Name string
|
|
|
- Setup func(t *testing.T, req *http.Request)
|
|
|
- Expected func(t *testing.T, req *api.EmbedRequest, resp *httptest.ResponseRecorder)
|
|
|
+ name string
|
|
|
+ body string
|
|
|
+ req api.EmbedRequest
|
|
|
+ err ErrorResponse
|
|
|
}
|
|
|
|
|
|
var capturedRequest *api.EmbedRequest
|
|
|
|
|
|
testCases := []testCase{
|
|
|
{
|
|
|
- Name: "embed handler single input",
|
|
|
- Setup: func(t *testing.T, req *http.Request) {
|
|
|
- body := EmbedRequest{
|
|
|
- Input: "Hello",
|
|
|
- Model: "test-model",
|
|
|
- }
|
|
|
- prepareRequest(req, body)
|
|
|
- },
|
|
|
- Expected: func(t *testing.T, req *api.EmbedRequest, resp *httptest.ResponseRecorder) {
|
|
|
- if req.Input != "Hello" {
|
|
|
- t.Fatalf("expected 'Hello', got %s", req.Input)
|
|
|
- }
|
|
|
-
|
|
|
- if req.Model != "test-model" {
|
|
|
- t.Fatalf("expected 'test-model', got %s", req.Model)
|
|
|
- }
|
|
|
+ name: "embed handler single input",
|
|
|
+ body: `{
|
|
|
+ "input": "Hello",
|
|
|
+ "model": "test-model"
|
|
|
+ }`,
|
|
|
+ req: api.EmbedRequest{
|
|
|
+ Input: "Hello",
|
|
|
+ Model: "test-model",
|
|
|
},
|
|
|
},
|
|
|
{
|
|
|
- Name: "embed handler batch input",
|
|
|
- Setup: func(t *testing.T, req *http.Request) {
|
|
|
- body := EmbedRequest{
|
|
|
- Input: []string{"Hello", "World"},
|
|
|
- Model: "test-model",
|
|
|
- }
|
|
|
- prepareRequest(req, body)
|
|
|
- },
|
|
|
- Expected: func(t *testing.T, req *api.EmbedRequest, resp *httptest.ResponseRecorder) {
|
|
|
- input, ok := req.Input.([]any)
|
|
|
-
|
|
|
- if !ok {
|
|
|
- t.Fatalf("expected input to be a list")
|
|
|
- }
|
|
|
-
|
|
|
- if input[0].(string) != "Hello" {
|
|
|
- t.Fatalf("expected 'Hello', got %s", input[0])
|
|
|
- }
|
|
|
-
|
|
|
- if input[1].(string) != "World" {
|
|
|
- t.Fatalf("expected 'World', got %s", input[1])
|
|
|
- }
|
|
|
-
|
|
|
- if req.Model != "test-model" {
|
|
|
- t.Fatalf("expected 'test-model', got %s", req.Model)
|
|
|
- }
|
|
|
+ name: "embed handler batch input",
|
|
|
+ body: `{
|
|
|
+ "input": ["Hello", "World"],
|
|
|
+ "model": "test-model"
|
|
|
+ }`,
|
|
|
+ req: api.EmbedRequest{
|
|
|
+ Input: []any{"Hello", "World"},
|
|
|
+ Model: "test-model",
|
|
|
},
|
|
|
},
|
|
|
{
|
|
|
- Name: "embed handler error forwarding",
|
|
|
- Setup: func(t *testing.T, req *http.Request) {
|
|
|
- body := EmbedRequest{
|
|
|
- Model: "test-model",
|
|
|
- }
|
|
|
- prepareRequest(req, body)
|
|
|
- },
|
|
|
- Expected: func(t *testing.T, req *api.EmbedRequest, resp *httptest.ResponseRecorder) {
|
|
|
- if resp.Code != http.StatusBadRequest {
|
|
|
- t.Fatalf("expected 400, got %d", resp.Code)
|
|
|
- }
|
|
|
-
|
|
|
- if !strings.Contains(resp.Body.String(), "invalid input") {
|
|
|
- t.Fatalf("error was not forwarded")
|
|
|
- }
|
|
|
+ name: "embed handler error forwarding",
|
|
|
+ body: `{
|
|
|
+ "model": "test-model"
|
|
|
+ }`,
|
|
|
+ err: ErrorResponse{
|
|
|
+ Error: Error{
|
|
|
+ Message: "invalid input",
|
|
|
+ Type: "invalid_request_error",
|
|
|
+ },
|
|
|
},
|
|
|
},
|
|
|
}
|
|
@@ -384,116 +353,167 @@ func TestEmbeddingsMiddleware(t *testing.T) {
|
|
|
router.Handle(http.MethodPost, "/api/embed", endpoint)
|
|
|
|
|
|
for _, tc := range testCases {
|
|
|
- t.Run(tc.Name, func(t *testing.T) {
|
|
|
- req, _ := http.NewRequest(http.MethodPost, "/api/embed", nil)
|
|
|
-
|
|
|
- tc.Setup(t, req)
|
|
|
+ t.Run(tc.name, func(t *testing.T) {
|
|
|
+ req, _ := http.NewRequest(http.MethodPost, "/api/embed", strings.NewReader(tc.body))
|
|
|
+ req.Header.Set("Content-Type", "application/json")
|
|
|
|
|
|
resp := httptest.NewRecorder()
|
|
|
router.ServeHTTP(resp, req)
|
|
|
|
|
|
- tc.Expected(t, capturedRequest, resp)
|
|
|
+ var errResp ErrorResponse
|
|
|
+ if resp.Code != http.StatusOK {
|
|
|
+ if err := json.Unmarshal(resp.Body.Bytes(), &errResp); err != nil {
|
|
|
+ t.Fatal(err)
|
|
|
+ }
|
|
|
+ }
|
|
|
+
|
|
|
+ if capturedRequest != nil && !reflect.DeepEqual(tc.req, *capturedRequest) {
|
|
|
+ t.Fatal("requests did not match")
|
|
|
+ }
|
|
|
+
|
|
|
+ if !reflect.DeepEqual(tc.err, errResp) {
|
|
|
+ t.Fatal("errors did not match")
|
|
|
+ }
|
|
|
|
|
|
capturedRequest = nil
|
|
|
})
|
|
|
}
|
|
|
}
|
|
|
|
|
|
-func TestMiddlewareResponses(t *testing.T) {
|
|
|
+func TestListMiddleware(t *testing.T) {
|
|
|
type testCase struct {
|
|
|
- Name string
|
|
|
- Method string
|
|
|
- Path string
|
|
|
- TestPath string
|
|
|
- Handler func() gin.HandlerFunc
|
|
|
- Endpoint func(c *gin.Context)
|
|
|
- Setup func(t *testing.T, req *http.Request)
|
|
|
- Expected func(t *testing.T, resp *httptest.ResponseRecorder)
|
|
|
+ name string
|
|
|
+ endpoint func(c *gin.Context)
|
|
|
+ resp string
|
|
|
}
|
|
|
|
|
|
testCases := []testCase{
|
|
|
{
|
|
|
- Name: "list handler",
|
|
|
- Method: http.MethodGet,
|
|
|
- Path: "/api/tags",
|
|
|
- TestPath: "/api/tags",
|
|
|
- Handler: ListMiddleware,
|
|
|
- Endpoint: func(c *gin.Context) {
|
|
|
+ name: "list handler",
|
|
|
+ endpoint: func(c *gin.Context) {
|
|
|
c.JSON(http.StatusOK, api.ListResponse{
|
|
|
Models: []api.ListModelResponse{
|
|
|
{
|
|
|
- Name: "Test Model",
|
|
|
+ Name: "test-model",
|
|
|
+ ModifiedAt: time.Unix(int64(1686935002), 0).UTC(),
|
|
|
},
|
|
|
},
|
|
|
})
|
|
|
},
|
|
|
- Expected: func(t *testing.T, resp *httptest.ResponseRecorder) {
|
|
|
- var listResp ListCompletion
|
|
|
- if err := json.NewDecoder(resp.Body).Decode(&listResp); err != nil {
|
|
|
- t.Fatal(err)
|
|
|
- }
|
|
|
+ resp: `{
|
|
|
+ "object": "list",
|
|
|
+ "data": [
|
|
|
+ {
|
|
|
+ "id": "test-model",
|
|
|
+ "object": "model",
|
|
|
+ "created": 1686935002,
|
|
|
+ "owned_by": "library"
|
|
|
+ }
|
|
|
+ ]
|
|
|
+ }`,
|
|
|
+ },
|
|
|
+ {
|
|
|
+ name: "list handler empty output",
|
|
|
+ endpoint: func(c *gin.Context) {
|
|
|
+ c.JSON(http.StatusOK, api.ListResponse{})
|
|
|
+ },
|
|
|
+ resp: `{
|
|
|
+ "object": "list",
|
|
|
+ "data": null
|
|
|
+ }`,
|
|
|
+ },
|
|
|
+ }
|
|
|
|
|
|
- if listResp.Object != "list" {
|
|
|
- t.Fatalf("expected list, got %s", listResp.Object)
|
|
|
- }
|
|
|
+ gin.SetMode(gin.TestMode)
|
|
|
|
|
|
- if len(listResp.Data) != 1 {
|
|
|
- t.Fatalf("expected 1, got %d", len(listResp.Data))
|
|
|
- }
|
|
|
+ for _, tc := range testCases {
|
|
|
+ router := gin.New()
|
|
|
+ router.Use(ListMiddleware())
|
|
|
+ router.Handle(http.MethodGet, "/api/tags", tc.endpoint)
|
|
|
+ req, _ := http.NewRequest(http.MethodGet, "/api/tags", nil)
|
|
|
|
|
|
- if listResp.Data[0].Id != "Test Model" {
|
|
|
- t.Fatalf("expected Test Model, got %s", listResp.Data[0].Id)
|
|
|
- }
|
|
|
- },
|
|
|
- },
|
|
|
+ resp := httptest.NewRecorder()
|
|
|
+ router.ServeHTTP(resp, req)
|
|
|
+
|
|
|
+ var expected, actual map[string]any
|
|
|
+ err := json.Unmarshal([]byte(tc.resp), &expected)
|
|
|
+ if err != nil {
|
|
|
+ t.Fatalf("failed to unmarshal expected response: %v", err)
|
|
|
+ }
|
|
|
+
|
|
|
+ err = json.Unmarshal(resp.Body.Bytes(), &actual)
|
|
|
+ if err != nil {
|
|
|
+ t.Fatalf("failed to unmarshal actual response: %v", err)
|
|
|
+ }
|
|
|
+
|
|
|
+ if !reflect.DeepEqual(expected, actual) {
|
|
|
+ t.Errorf("responses did not match\nExpected: %+v\nActual: %+v", expected, actual)
|
|
|
+ }
|
|
|
+ }
|
|
|
+}
|
|
|
+
|
|
|
+func TestRetrieveMiddleware(t *testing.T) {
|
|
|
+ type testCase struct {
|
|
|
+ name string
|
|
|
+ endpoint func(c *gin.Context)
|
|
|
+ resp string
|
|
|
+ }
|
|
|
+
|
|
|
+ testCases := []testCase{
|
|
|
{
|
|
|
- Name: "retrieve model",
|
|
|
- Method: http.MethodGet,
|
|
|
- Path: "/api/show/:model",
|
|
|
- TestPath: "/api/show/test-model",
|
|
|
- Handler: RetrieveMiddleware,
|
|
|
- Endpoint: func(c *gin.Context) {
|
|
|
+ name: "retrieve handler",
|
|
|
+ endpoint: func(c *gin.Context) {
|
|
|
c.JSON(http.StatusOK, api.ShowResponse{
|
|
|
- ModifiedAt: time.Date(2024, 6, 17, 13, 45, 0, 0, time.UTC),
|
|
|
+ ModifiedAt: time.Unix(int64(1686935002), 0).UTC(),
|
|
|
})
|
|
|
},
|
|
|
- Expected: func(t *testing.T, resp *httptest.ResponseRecorder) {
|
|
|
- var retrieveResp Model
|
|
|
- if err := json.NewDecoder(resp.Body).Decode(&retrieveResp); err != nil {
|
|
|
- t.Fatal(err)
|
|
|
- }
|
|
|
-
|
|
|
- if retrieveResp.Object != "model" {
|
|
|
- t.Fatalf("Expected object to be model, got %s", retrieveResp.Object)
|
|
|
- }
|
|
|
-
|
|
|
- if retrieveResp.Id != "test-model" {
|
|
|
- t.Fatalf("Expected id to be test-model, got %s", retrieveResp.Id)
|
|
|
- }
|
|
|
+ resp: `{
|
|
|
+ "id":"test-model",
|
|
|
+ "object":"model",
|
|
|
+ "created":1686935002,
|
|
|
+ "owned_by":"library"}
|
|
|
+ `,
|
|
|
+ },
|
|
|
+ {
|
|
|
+ name: "retrieve handler error forwarding",
|
|
|
+ endpoint: func(c *gin.Context) {
|
|
|
+ c.JSON(http.StatusBadRequest, gin.H{"error": "model not found"})
|
|
|
},
|
|
|
+ resp: `{
|
|
|
+ "error": {
|
|
|
+ "code": null,
|
|
|
+ "message": "model not found",
|
|
|
+ "param": null,
|
|
|
+ "type": "api_error"
|
|
|
+ }
|
|
|
+ }`,
|
|
|
},
|
|
|
}
|
|
|
|
|
|
gin.SetMode(gin.TestMode)
|
|
|
- router := gin.New()
|
|
|
|
|
|
for _, tc := range testCases {
|
|
|
- t.Run(tc.Name, func(t *testing.T) {
|
|
|
- router = gin.New()
|
|
|
- router.Use(tc.Handler())
|
|
|
- router.Handle(tc.Method, tc.Path, tc.Endpoint)
|
|
|
- req, _ := http.NewRequest(tc.Method, tc.TestPath, nil)
|
|
|
-
|
|
|
- if tc.Setup != nil {
|
|
|
- tc.Setup(t, req)
|
|
|
- }
|
|
|
+ router := gin.New()
|
|
|
+ router.Use(RetrieveMiddleware())
|
|
|
+ router.Handle(http.MethodGet, "/api/show/:model", tc.endpoint)
|
|
|
+ req, _ := http.NewRequest(http.MethodGet, "/api/show/test-model", nil)
|
|
|
|
|
|
- resp := httptest.NewRecorder()
|
|
|
- router.ServeHTTP(resp, req)
|
|
|
+ resp := httptest.NewRecorder()
|
|
|
+ router.ServeHTTP(resp, req)
|
|
|
|
|
|
- assert.Equal(t, http.StatusOK, resp.Code)
|
|
|
+ var expected, actual map[string]any
|
|
|
+ err := json.Unmarshal([]byte(tc.resp), &expected)
|
|
|
+ if err != nil {
|
|
|
+ t.Fatalf("failed to unmarshal expected response: %v", err)
|
|
|
+ }
|
|
|
|
|
|
- tc.Expected(t, resp)
|
|
|
- })
|
|
|
+ err = json.Unmarshal(resp.Body.Bytes(), &actual)
|
|
|
+ if err != nil {
|
|
|
+ t.Fatalf("failed to unmarshal actual response: %v", err)
|
|
|
+ }
|
|
|
+
|
|
|
+ if !reflect.DeepEqual(expected, actual) {
|
|
|
+ t.Errorf("responses did not match\nExpected: %+v\nActual: %+v", expected, actual)
|
|
|
+ }
|
|
|
}
|
|
|
}
|