Forráskód Böngészése

OpenAI: Simplify input output in testing (#5858)

* simplify input output

* direct comp

* in line image

* rm error pointer type

* update response testing

* lint
royjhan 8 hónapja
szülő
commit
01d544d373
1 módosított fájl, 335 hozzáadás és 315 törlés
  1. 335 315
      openai/openai_test.go

+ 335 - 315
openai/openai_test.go

@@ -7,27 +7,22 @@ import (
 	"io"
 	"net/http"
 	"net/http/httptest"
+	"reflect"
 	"strings"
 	"testing"
 	"time"
 
 	"github.com/gin-gonic/gin"
-	"github.com/stretchr/testify/assert"
 
 	"github.com/ollama/ollama/api"
 )
 
 const (
-	prefix   = `data:image/jpeg;base64,`
-	image    = `iVBORw0KGgoAAAANSUhEUgAAAAEAAAABCAQAAAC1HAwCAAAAC0lEQVR42mNk+A8AAQUBAScY42YAAAAASUVORK5CYII=`
-	imageURL = prefix + image
+	prefix = `data:image/jpeg;base64,`
+	image  = `iVBORw0KGgoAAAANSUhEUgAAAAEAAAABCAQAAAC1HAwCAAAAC0lEQVR42mNk+A8AAQUBAScY42YAAAAASUVORK5CYII=`
 )
 
-func prepareRequest(req *http.Request, body any) {
-	bodyBytes, _ := json.Marshal(body)
-	req.Body = io.NopCloser(bytes.NewReader(bodyBytes))
-	req.Header.Set("Content-Type", "application/json")
-}
+var False = false
 
 func captureRequestMiddleware(capturedRequest any) gin.HandlerFunc {
 	return func(c *gin.Context) {
@@ -43,134 +38,136 @@ func captureRequestMiddleware(capturedRequest any) gin.HandlerFunc {
 
 func TestChatMiddleware(t *testing.T) {
 	type testCase struct {
-		Name     string
-		Setup    func(t *testing.T, req *http.Request)
-		Expected func(t *testing.T, req *api.ChatRequest, resp *httptest.ResponseRecorder)
+		name string
+		body string
+		req  api.ChatRequest
+		err  ErrorResponse
 	}
 
 	var capturedRequest *api.ChatRequest
 
 	testCases := []testCase{
 		{
-			Name: "chat handler",
-			Setup: func(t *testing.T, req *http.Request) {
-				body := ChatCompletionRequest{
-					Model:    "test-model",
-					Messages: []Message{{Role: "user", Content: "Hello"}},
-				}
-				prepareRequest(req, body)
-			},
-			Expected: func(t *testing.T, req *api.ChatRequest, resp *httptest.ResponseRecorder) {
-				if resp.Code != http.StatusOK {
-					t.Fatalf("expected 200, got %d", resp.Code)
-				}
-
-				if req.Messages[0].Role != "user" {
-					t.Fatalf("expected 'user', got %s", req.Messages[0].Role)
-				}
-
-				if req.Messages[0].Content != "Hello" {
-					t.Fatalf("expected 'Hello', got %s", req.Messages[0].Content)
-				}
+			name: "chat handler",
+			body: `{
+				"model": "test-model",
+				"messages": [
+					{"role": "user", "content": "Hello"}
+				]
+			}`,
+			req: api.ChatRequest{
+				Model: "test-model",
+				Messages: []api.Message{
+					{
+						Role:    "user",
+						Content: "Hello",
+					},
+				},
+				Options: map[string]any{
+					"temperature": 1.0,
+					"top_p":       1.0,
+				},
+				Stream: &False,
 			},
 		},
 		{
-			Name: "chat handler with image content",
-			Setup: func(t *testing.T, req *http.Request) {
-				body := ChatCompletionRequest{
-					Model: "test-model",
-					Messages: []Message{
-						{
-							Role: "user", Content: []map[string]any{
-								{"type": "text", "text": "Hello"},
-								{"type": "image_url", "image_url": map[string]string{"url": imageURL}},
+			name: "chat handler with image content",
+			body: `{
+				"model": "test-model",
+				"messages": [
+					{
+						"role": "user",
+						"content": [
+							{
+								"type": "text",
+								"text": "Hello"
 							},
+							{
+								"type": "image_url",
+								"image_url": {
+									"url": "` + prefix + image + `"
+								}
+							}
+						]
+					}
+				]
+			}`,
+			req: api.ChatRequest{
+				Model: "test-model",
+				Messages: []api.Message{
+					{
+						Role:    "user",
+						Content: "Hello",
+					},
+					{
+						Role: "user",
+						Images: []api.ImageData{
+							func() []byte {
+								img, _ := base64.StdEncoding.DecodeString(image)
+								return img
+							}(),
 						},
 					},
-				}
-				prepareRequest(req, body)
-			},
-			Expected: func(t *testing.T, req *api.ChatRequest, resp *httptest.ResponseRecorder) {
-				if resp.Code != http.StatusOK {
-					t.Fatalf("expected 200, got %d", resp.Code)
-				}
-
-				if req.Messages[0].Role != "user" {
-					t.Fatalf("expected 'user', got %s", req.Messages[0].Role)
-				}
-
-				if req.Messages[0].Content != "Hello" {
-					t.Fatalf("expected 'Hello', got %s", req.Messages[0].Content)
-				}
-
-				img, _ := base64.StdEncoding.DecodeString(imageURL[len(prefix):])
-
-				if req.Messages[1].Role != "user" {
-					t.Fatalf("expected 'user', got %s", req.Messages[1].Role)
-				}
-
-				if !bytes.Equal(req.Messages[1].Images[0], img) {
-					t.Fatalf("expected image encoding, got %s", req.Messages[1].Images[0])
-				}
+				},
+				Options: map[string]any{
+					"temperature": 1.0,
+					"top_p":       1.0,
+				},
+				Stream: &False,
 			},
 		},
 		{
-			Name: "chat handler with tools",
-			Setup: func(t *testing.T, req *http.Request) {
-				body := ChatCompletionRequest{
-					Model: "test-model",
-					Messages: []Message{
-						{Role: "user", Content: "What's the weather like in Paris Today?"},
-						{Role: "assistant", ToolCalls: []ToolCall{{
-							ID:   "id",
-							Type: "function",
-							Function: struct {
-								Name      string `json:"name"`
-								Arguments string `json:"arguments"`
-							}{
-								Name:      "get_current_weather",
-								Arguments: "{\"location\": \"Paris, France\", \"format\": \"celsius\"}",
+			name: "chat handler with tools",
+			body: `{
+				"model": "test-model",
+				"messages": [
+					{"role": "user", "content": "What's the weather like in Paris Today?"},
+					{"role": "assistant", "tool_calls": [{"id": "id", "type": "function", "function": {"name": "get_current_weather", "arguments": "{\"location\": \"Paris, France\", \"format\": \"celsius\"}"}}]}
+				]
+			}`,
+			req: api.ChatRequest{
+				Model: "test-model",
+				Messages: []api.Message{
+					{
+						Role:    "user",
+						Content: "What's the weather like in Paris Today?",
+					},
+					{
+						Role: "assistant",
+						ToolCalls: []api.ToolCall{
+							{
+								Function: api.ToolCallFunction{
+									Name: "get_current_weather",
+									Arguments: map[string]interface{}{
+										"location": "Paris, France",
+										"format":   "celsius",
+									},
+								},
 							},
-						}}},
+						},
 					},
-				}
-				prepareRequest(req, body)
-			},
-			Expected: func(t *testing.T, req *api.ChatRequest, resp *httptest.ResponseRecorder) {
-				if resp.Code != 200 {
-					t.Fatalf("expected 200, got %d", resp.Code)
-				}
-
-				if req.Messages[0].Content != "What's the weather like in Paris Today?" {
-					t.Fatalf("expected What's the weather like in Paris Today?, got %s", req.Messages[0].Content)
-				}
-
-				if req.Messages[1].ToolCalls[0].Function.Arguments["location"] != "Paris, France" {
-					t.Fatalf("expected 'Paris, France', got %v", req.Messages[1].ToolCalls[0].Function.Arguments["location"])
-				}
-
-				if req.Messages[1].ToolCalls[0].Function.Arguments["format"] != "celsius" {
-					t.Fatalf("expected celsius, got %v", req.Messages[1].ToolCalls[0].Function.Arguments["format"])
-				}
+				},
+				Options: map[string]any{
+					"temperature": 1.0,
+					"top_p":       1.0,
+				},
+				Stream: &False,
 			},
 		},
-		{
-			Name: "chat handler error forwarding",
-			Setup: func(t *testing.T, req *http.Request) {
-				body := ChatCompletionRequest{
-					Model:    "test-model",
-					Messages: []Message{{Role: "user", Content: 2}},
-				}
-				prepareRequest(req, body)
-			},
-			Expected: func(t *testing.T, req *api.ChatRequest, resp *httptest.ResponseRecorder) {
-				if resp.Code != http.StatusBadRequest {
-					t.Fatalf("expected 400, got %d", resp.Code)
-				}
 
-				if !strings.Contains(resp.Body.String(), "invalid message content type") {
-					t.Fatalf("error was not forwarded")
-				}
+		{
+			name: "chat handler error forwarding",
+			body: `{
+				"model": "test-model",
+				"messages": [
+					{"role": "user", "content": 2}
+				]
+			}`,
+			err: ErrorResponse{
+				Error: Error{
+					Message: "invalid message content type: float64",
+					Type:    "invalid_request_error",
+				},
 			},
 		},
 	}
@@ -185,16 +182,26 @@ func TestChatMiddleware(t *testing.T) {
 	router.Handle(http.MethodPost, "/api/chat", endpoint)
 
 	for _, tc := range testCases {
-		t.Run(tc.Name, func(t *testing.T) {
-			req, _ := http.NewRequest(http.MethodPost, "/api/chat", nil)
-
-			tc.Setup(t, req)
+		t.Run(tc.name, func(t *testing.T) {
+			req, _ := http.NewRequest(http.MethodPost, "/api/chat", strings.NewReader(tc.body))
+			req.Header.Set("Content-Type", "application/json")
 
 			resp := httptest.NewRecorder()
 			router.ServeHTTP(resp, req)
 
-			tc.Expected(t, capturedRequest, resp)
+			var errResp ErrorResponse
+			if resp.Code != http.StatusOK {
+				if err := json.Unmarshal(resp.Body.Bytes(), &errResp); err != nil {
+					t.Fatal(err)
+				}
+			}
+			if capturedRequest != nil && !reflect.DeepEqual(tc.req, *capturedRequest) {
+				t.Fatal("requests did not match")
+			}
 
+			if !reflect.DeepEqual(tc.err, errResp) {
+				t.Fatal("errors did not match")
+			}
 			capturedRequest = nil
 		})
 	}
@@ -202,71 +209,52 @@ func TestChatMiddleware(t *testing.T) {
 
 func TestCompletionsMiddleware(t *testing.T) {
 	type testCase struct {
-		Name     string
-		Setup    func(t *testing.T, req *http.Request)
-		Expected func(t *testing.T, req *api.GenerateRequest, resp *httptest.ResponseRecorder)
+		name string
+		body string
+		req  api.GenerateRequest
+		err  ErrorResponse
 	}
 
 	var capturedRequest *api.GenerateRequest
 
 	testCases := []testCase{
 		{
-			Name: "completions handler",
-			Setup: func(t *testing.T, req *http.Request) {
-				temp := float32(0.8)
-				body := CompletionRequest{
-					Model:       "test-model",
-					Prompt:      "Hello",
-					Temperature: &temp,
-					Stop:        []string{"\n", "stop"},
-					Suffix:      "suffix",
-				}
-				prepareRequest(req, body)
-			},
-			Expected: func(t *testing.T, req *api.GenerateRequest, resp *httptest.ResponseRecorder) {
-				if req.Prompt != "Hello" {
-					t.Fatalf("expected 'Hello', got %s", req.Prompt)
-				}
-
-				if req.Options["temperature"] != 1.6 {
-					t.Fatalf("expected 1.6, got %f", req.Options["temperature"])
-				}
-
-				stopTokens, ok := req.Options["stop"].([]any)
-
-				if !ok {
-					t.Fatalf("expected stop tokens to be a list")
-				}
-
-				if stopTokens[0] != "\n" || stopTokens[1] != "stop" {
-					t.Fatalf("expected ['\\n', 'stop'], got %v", stopTokens)
-				}
-
-				if req.Suffix != "suffix" {
-					t.Fatalf("expected 'suffix', got %s", req.Suffix)
-				}
+			name: "completions handler",
+			body: `{
+				"model": "test-model",
+				"prompt": "Hello",
+				"temperature": 0.8,
+				"stop": ["\n", "stop"],
+				"suffix": "suffix"
+			}`,
+			req: api.GenerateRequest{
+				Model:  "test-model",
+				Prompt: "Hello",
+				Options: map[string]any{
+					"frequency_penalty": 0.0,
+					"presence_penalty":  0.0,
+					"temperature":       1.6,
+					"top_p":             1.0,
+					"stop":              []any{"\n", "stop"},
+				},
+				Suffix: "suffix",
+				Stream: &False,
 			},
 		},
 		{
-			Name: "completions handler error forwarding",
-			Setup: func(t *testing.T, req *http.Request) {
-				body := CompletionRequest{
-					Model:       "test-model",
-					Prompt:      "Hello",
-					Temperature: nil,
-					Stop:        []int{1, 2},
-					Suffix:      "suffix",
-				}
-				prepareRequest(req, body)
-			},
-			Expected: func(t *testing.T, req *api.GenerateRequest, resp *httptest.ResponseRecorder) {
-				if resp.Code != http.StatusBadRequest {
-					t.Fatalf("expected 400, got %d", resp.Code)
-				}
-
-				if !strings.Contains(resp.Body.String(), "invalid type for 'stop' field") {
-					t.Fatalf("error was not forwarded")
-				}
+			name: "completions handler error forwarding",
+			body: `{
+				"model": "test-model",
+				"prompt": "Hello",
+				"temperature": null,
+				"stop": [1, 2],
+				"suffix": "suffix"
+			}`,
+			err: ErrorResponse{
+				Error: Error{
+					Message: "invalid type for 'stop' field: float64",
+					Type:    "invalid_request_error",
+				},
 			},
 		},
 	}
@@ -281,15 +269,27 @@ func TestCompletionsMiddleware(t *testing.T) {
 	router.Handle(http.MethodPost, "/api/generate", endpoint)
 
 	for _, tc := range testCases {
-		t.Run(tc.Name, func(t *testing.T) {
-			req, _ := http.NewRequest(http.MethodPost, "/api/generate", nil)
-
-			tc.Setup(t, req)
+		t.Run(tc.name, func(t *testing.T) {
+			req, _ := http.NewRequest(http.MethodPost, "/api/generate", strings.NewReader(tc.body))
+			req.Header.Set("Content-Type", "application/json")
 
 			resp := httptest.NewRecorder()
 			router.ServeHTTP(resp, req)
 
-			tc.Expected(t, capturedRequest, resp)
+			var errResp ErrorResponse
+			if resp.Code != http.StatusOK {
+				if err := json.Unmarshal(resp.Body.Bytes(), &errResp); err != nil {
+					t.Fatal(err)
+				}
+			}
+
+			if capturedRequest != nil && !reflect.DeepEqual(tc.req, *capturedRequest) {
+				t.Fatal("requests did not match")
+			}
+
+			if !reflect.DeepEqual(tc.err, errResp) {
+				t.Fatal("errors did not match")
+			}
 
 			capturedRequest = nil
 		})
@@ -298,78 +298,47 @@ func TestCompletionsMiddleware(t *testing.T) {
 
 func TestEmbeddingsMiddleware(t *testing.T) {
 	type testCase struct {
-		Name     string
-		Setup    func(t *testing.T, req *http.Request)
-		Expected func(t *testing.T, req *api.EmbedRequest, resp *httptest.ResponseRecorder)
+		name string
+		body string
+		req  api.EmbedRequest
+		err  ErrorResponse
 	}
 
 	var capturedRequest *api.EmbedRequest
 
 	testCases := []testCase{
 		{
-			Name: "embed handler single input",
-			Setup: func(t *testing.T, req *http.Request) {
-				body := EmbedRequest{
-					Input: "Hello",
-					Model: "test-model",
-				}
-				prepareRequest(req, body)
-			},
-			Expected: func(t *testing.T, req *api.EmbedRequest, resp *httptest.ResponseRecorder) {
-				if req.Input != "Hello" {
-					t.Fatalf("expected 'Hello', got %s", req.Input)
-				}
-
-				if req.Model != "test-model" {
-					t.Fatalf("expected 'test-model', got %s", req.Model)
-				}
+			name: "embed handler single input",
+			body: `{
+				"input": "Hello",
+				"model": "test-model"
+			}`,
+			req: api.EmbedRequest{
+				Input: "Hello",
+				Model: "test-model",
 			},
 		},
 		{
-			Name: "embed handler batch input",
-			Setup: func(t *testing.T, req *http.Request) {
-				body := EmbedRequest{
-					Input: []string{"Hello", "World"},
-					Model: "test-model",
-				}
-				prepareRequest(req, body)
-			},
-			Expected: func(t *testing.T, req *api.EmbedRequest, resp *httptest.ResponseRecorder) {
-				input, ok := req.Input.([]any)
-
-				if !ok {
-					t.Fatalf("expected input to be a list")
-				}
-
-				if input[0].(string) != "Hello" {
-					t.Fatalf("expected 'Hello', got %s", input[0])
-				}
-
-				if input[1].(string) != "World" {
-					t.Fatalf("expected 'World', got %s", input[1])
-				}
-
-				if req.Model != "test-model" {
-					t.Fatalf("expected 'test-model', got %s", req.Model)
-				}
+			name: "embed handler batch input",
+			body: `{
+				"input": ["Hello", "World"],
+				"model": "test-model"
+			}`,
+			req: api.EmbedRequest{
+				Input: []any{"Hello", "World"},
+				Model: "test-model",
 			},
 		},
 		{
-			Name: "embed handler error forwarding",
-			Setup: func(t *testing.T, req *http.Request) {
-				body := EmbedRequest{
-					Model: "test-model",
-				}
-				prepareRequest(req, body)
-			},
-			Expected: func(t *testing.T, req *api.EmbedRequest, resp *httptest.ResponseRecorder) {
-				if resp.Code != http.StatusBadRequest {
-					t.Fatalf("expected 400, got %d", resp.Code)
-				}
-
-				if !strings.Contains(resp.Body.String(), "invalid input") {
-					t.Fatalf("error was not forwarded")
-				}
+			name: "embed handler error forwarding",
+			body: `{
+				"model": "test-model"
+			}`,
+			err: ErrorResponse{
+				Error: Error{
+					Message: "invalid input",
+					Type:    "invalid_request_error",
+				},
 			},
 		},
 	}
@@ -384,116 +353,167 @@ func TestEmbeddingsMiddleware(t *testing.T) {
 	router.Handle(http.MethodPost, "/api/embed", endpoint)
 
 	for _, tc := range testCases {
-		t.Run(tc.Name, func(t *testing.T) {
-			req, _ := http.NewRequest(http.MethodPost, "/api/embed", nil)
-
-			tc.Setup(t, req)
+		t.Run(tc.name, func(t *testing.T) {
+			req, _ := http.NewRequest(http.MethodPost, "/api/embed", strings.NewReader(tc.body))
+			req.Header.Set("Content-Type", "application/json")
 
 			resp := httptest.NewRecorder()
 			router.ServeHTTP(resp, req)
 
-			tc.Expected(t, capturedRequest, resp)
+			var errResp ErrorResponse
+			if resp.Code != http.StatusOK {
+				if err := json.Unmarshal(resp.Body.Bytes(), &errResp); err != nil {
+					t.Fatal(err)
+				}
+			}
+
+			if capturedRequest != nil && !reflect.DeepEqual(tc.req, *capturedRequest) {
+				t.Fatal("requests did not match")
+			}
+
+			if !reflect.DeepEqual(tc.err, errResp) {
+				t.Fatal("errors did not match")
+			}
 
 			capturedRequest = nil
 		})
 	}
 }
 
-func TestMiddlewareResponses(t *testing.T) {
+func TestListMiddleware(t *testing.T) {
 	type testCase struct {
-		Name     string
-		Method   string
-		Path     string
-		TestPath string
-		Handler  func() gin.HandlerFunc
-		Endpoint func(c *gin.Context)
-		Setup    func(t *testing.T, req *http.Request)
-		Expected func(t *testing.T, resp *httptest.ResponseRecorder)
+		name     string
+		endpoint func(c *gin.Context)
+		resp     string
 	}
 
 	testCases := []testCase{
 		{
-			Name:     "list handler",
-			Method:   http.MethodGet,
-			Path:     "/api/tags",
-			TestPath: "/api/tags",
-			Handler:  ListMiddleware,
-			Endpoint: func(c *gin.Context) {
+			name: "list handler",
+			endpoint: func(c *gin.Context) {
 				c.JSON(http.StatusOK, api.ListResponse{
 					Models: []api.ListModelResponse{
 						{
-							Name: "Test Model",
+							Name:       "test-model",
+							ModifiedAt: time.Unix(int64(1686935002), 0).UTC(),
 						},
 					},
 				})
 			},
-			Expected: func(t *testing.T, resp *httptest.ResponseRecorder) {
-				var listResp ListCompletion
-				if err := json.NewDecoder(resp.Body).Decode(&listResp); err != nil {
-					t.Fatal(err)
-				}
+			resp: `{
+				"object": "list",
+				"data": [
+					{
+						"id": "test-model",
+						"object": "model",
+						"created": 1686935002,
+						"owned_by": "library"
+					}
+				]
+			}`,
+		},
+		{
+			name: "list handler empty output",
+			endpoint: func(c *gin.Context) {
+				c.JSON(http.StatusOK, api.ListResponse{})
+			},
+			resp: `{
+				"object": "list",
+				"data": null
+			}`,
+		},
+	}
 
-				if listResp.Object != "list" {
-					t.Fatalf("expected list, got %s", listResp.Object)
-				}
+	gin.SetMode(gin.TestMode)
 
-				if len(listResp.Data) != 1 {
-					t.Fatalf("expected 1, got %d", len(listResp.Data))
-				}
+	for _, tc := range testCases {
+		router := gin.New()
+		router.Use(ListMiddleware())
+		router.Handle(http.MethodGet, "/api/tags", tc.endpoint)
+		req, _ := http.NewRequest(http.MethodGet, "/api/tags", nil)
 
-				if listResp.Data[0].Id != "Test Model" {
-					t.Fatalf("expected Test Model, got %s", listResp.Data[0].Id)
-				}
-			},
-		},
+		resp := httptest.NewRecorder()
+		router.ServeHTTP(resp, req)
+
+		var expected, actual map[string]any
+		err := json.Unmarshal([]byte(tc.resp), &expected)
+		if err != nil {
+			t.Fatalf("failed to unmarshal expected response: %v", err)
+		}
+
+		err = json.Unmarshal(resp.Body.Bytes(), &actual)
+		if err != nil {
+			t.Fatalf("failed to unmarshal actual response: %v", err)
+		}
+
+		if !reflect.DeepEqual(expected, actual) {
+			t.Errorf("responses did not match\nExpected: %+v\nActual: %+v", expected, actual)
+		}
+	}
+}
+
+func TestRetrieveMiddleware(t *testing.T) {
+	type testCase struct {
+		name     string
+		endpoint func(c *gin.Context)
+		resp     string
+	}
+
+	testCases := []testCase{
 		{
-			Name:     "retrieve model",
-			Method:   http.MethodGet,
-			Path:     "/api/show/:model",
-			TestPath: "/api/show/test-model",
-			Handler:  RetrieveMiddleware,
-			Endpoint: func(c *gin.Context) {
+			name: "retrieve handler",
+			endpoint: func(c *gin.Context) {
 				c.JSON(http.StatusOK, api.ShowResponse{
-					ModifiedAt: time.Date(2024, 6, 17, 13, 45, 0, 0, time.UTC),
+					ModifiedAt: time.Unix(int64(1686935002), 0).UTC(),
 				})
 			},
-			Expected: func(t *testing.T, resp *httptest.ResponseRecorder) {
-				var retrieveResp Model
-				if err := json.NewDecoder(resp.Body).Decode(&retrieveResp); err != nil {
-					t.Fatal(err)
-				}
-
-				if retrieveResp.Object != "model" {
-					t.Fatalf("Expected object to be model, got %s", retrieveResp.Object)
-				}
-
-				if retrieveResp.Id != "test-model" {
-					t.Fatalf("Expected id to be test-model, got %s", retrieveResp.Id)
-				}
+			resp: `{
+				"id":"test-model",
+				"object":"model",
+				"created":1686935002,
+				"owned_by":"library"}
+			`,
+		},
+		{
+			name: "retrieve handler error forwarding",
+			endpoint: func(c *gin.Context) {
+				c.JSON(http.StatusBadRequest, gin.H{"error": "model not found"})
 			},
+			resp: `{
+				"error": {
+				  "code": null,
+				  "message": "model not found",
+				  "param": null,
+				  "type": "api_error"
+				}
+			}`,
 		},
 	}
 
 	gin.SetMode(gin.TestMode)
-	router := gin.New()
 
 	for _, tc := range testCases {
-		t.Run(tc.Name, func(t *testing.T) {
-			router = gin.New()
-			router.Use(tc.Handler())
-			router.Handle(tc.Method, tc.Path, tc.Endpoint)
-			req, _ := http.NewRequest(tc.Method, tc.TestPath, nil)
-
-			if tc.Setup != nil {
-				tc.Setup(t, req)
-			}
+		router := gin.New()
+		router.Use(RetrieveMiddleware())
+		router.Handle(http.MethodGet, "/api/show/:model", tc.endpoint)
+		req, _ := http.NewRequest(http.MethodGet, "/api/show/test-model", nil)
 
-			resp := httptest.NewRecorder()
-			router.ServeHTTP(resp, req)
+		resp := httptest.NewRecorder()
+		router.ServeHTTP(resp, req)
 
-			assert.Equal(t, http.StatusOK, resp.Code)
+		var expected, actual map[string]any
+		err := json.Unmarshal([]byte(tc.resp), &expected)
+		if err != nil {
+			t.Fatalf("failed to unmarshal expected response: %v", err)
+		}
 
-			tc.Expected(t, resp)
-		})
+		err = json.Unmarshal(resp.Body.Bytes(), &actual)
+		if err != nil {
+			t.Fatalf("failed to unmarshal actual response: %v", err)
+		}
+
+		if !reflect.DeepEqual(expected, actual) {
+			t.Errorf("responses did not match\nExpected: %+v\nActual: %+v", expected, actual)
+		}
 	}
 }