Browse Source

OpenAI: Function Based Testing (#5752)

* distinguish error forwarding

* more coverage

* rm comment
royjhan 9 months ago
parent
commit
c57317cbf0
2 changed files with 266 additions and 170 deletions
  1. 1 0
      openai/openai.go
  2. 265 170
      openai/openai_test.go

+ 1 - 0
openai/openai.go

@@ -877,6 +877,7 @@ func ChatMiddleware() gin.HandlerFunc {
 		chatReq, err := fromChatRequest(req)
 		if err != nil {
 			c.AbortWithStatusJSON(http.StatusBadRequest, NewError(http.StatusBadRequest, err.Error()))
+			return
 		}
 
 		if err := json.NewEncoder(&b).Encode(chatReq); err != nil {

+ 265 - 170
openai/openai_test.go

@@ -20,64 +20,195 @@ const prefix = `data:image/jpeg;base64,`
 const image = `iVBORw0KGgoAAAANSUhEUgAAAAEAAAABCAQAAAC1HAwCAAAAC0lEQVR42mNk+A8AAQUBAScY42YAAAAASUVORK5CYII=`
 const imageURL = prefix + image
 
-func TestMiddlewareRequests(t *testing.T) {
+func prepareRequest(req *http.Request, body any) {
+	bodyBytes, _ := json.Marshal(body)
+	req.Body = io.NopCloser(bytes.NewReader(bodyBytes))
+	req.Header.Set("Content-Type", "application/json")
+}
+
+func captureRequestMiddleware(capturedRequest any) gin.HandlerFunc {
+	return func(c *gin.Context) {
+		bodyBytes, _ := io.ReadAll(c.Request.Body)
+		c.Request.Body = io.NopCloser(bytes.NewReader(bodyBytes))
+		err := json.Unmarshal(bodyBytes, capturedRequest)
+		if err != nil {
+			c.AbortWithStatusJSON(http.StatusInternalServerError, "failed to unmarshal request")
+		}
+		c.Next()
+	}
+}
+
+func TestChatMiddleware(t *testing.T) {
 	type testCase struct {
 		Name     string
-		Method   string
-		Path     string
-		Handler  func() gin.HandlerFunc
 		Setup    func(t *testing.T, req *http.Request)
-		Expected func(t *testing.T, req *http.Request)
+		Expected func(t *testing.T, req *api.ChatRequest, resp *httptest.ResponseRecorder)
 	}
 
-	var capturedRequest *http.Request
-
-	captureRequestMiddleware := func() gin.HandlerFunc {
-		return func(c *gin.Context) {
-			bodyBytes, _ := io.ReadAll(c.Request.Body)
-			c.Request.Body = io.NopCloser(bytes.NewReader(bodyBytes))
-			capturedRequest = c.Request
-			c.Next()
-		}
-	}
+	var capturedRequest *api.ChatRequest
 
 	testCases := []testCase{
 		{
-			Name:    "chat handler",
-			Method:  http.MethodPost,
-			Path:    "/api/chat",
-			Handler: ChatMiddleware,
+			Name: "chat handler",
 			Setup: func(t *testing.T, req *http.Request) {
 				body := ChatCompletionRequest{
 					Model:    "test-model",
 					Messages: []Message{{Role: "user", Content: "Hello"}},
 				}
+				prepareRequest(req, body)
+			},
+			Expected: func(t *testing.T, req *api.ChatRequest, resp *httptest.ResponseRecorder) {
+				if resp.Code != http.StatusOK {
+					t.Fatalf("expected 200, got %d", resp.Code)
+				}
 
-				bodyBytes, _ := json.Marshal(body)
+				if req.Messages[0].Role != "user" {
+					t.Fatalf("expected 'user', got %s", req.Messages[0].Role)
+				}
 
-				req.Body = io.NopCloser(bytes.NewReader(bodyBytes))
-				req.Header.Set("Content-Type", "application/json")
+				if req.Messages[0].Content != "Hello" {
+					t.Fatalf("expected 'Hello', got %s", req.Messages[0].Content)
+				}
 			},
-			Expected: func(t *testing.T, req *http.Request) {
-				var chatReq api.ChatRequest
-				if err := json.NewDecoder(req.Body).Decode(&chatReq); err != nil {
-					t.Fatal(err)
+		},
+		{
+			Name: "chat handler with image content",
+			Setup: func(t *testing.T, req *http.Request) {
+				body := ChatCompletionRequest{
+					Model: "test-model",
+					Messages: []Message{
+						{
+							Role: "user", Content: []map[string]any{
+								{"type": "text", "text": "Hello"},
+								{"type": "image_url", "image_url": map[string]string{"url": imageURL}},
+							},
+						},
+					},
+				}
+				prepareRequest(req, body)
+			},
+			Expected: func(t *testing.T, req *api.ChatRequest, resp *httptest.ResponseRecorder) {
+				if resp.Code != http.StatusOK {
+					t.Fatalf("expected 200, got %d", resp.Code)
+				}
+
+				if req.Messages[0].Role != "user" {
+					t.Fatalf("expected 'user', got %s", req.Messages[0].Role)
+				}
+
+				if req.Messages[0].Content != "Hello" {
+					t.Fatalf("expected 'Hello', got %s", req.Messages[0].Content)
+				}
+
+				img, _ := base64.StdEncoding.DecodeString(imageURL[len(prefix):])
+
+				if req.Messages[1].Role != "user" {
+					t.Fatalf("expected 'user', got %s", req.Messages[1].Role)
+				}
+
+				if !bytes.Equal(req.Messages[1].Images[0], img) {
+					t.Fatalf("expected image encoding, got %s", req.Messages[1].Images[0])
+				}
+			},
+		},
+		{
+			Name: "chat handler with tools",
+			Setup: func(t *testing.T, req *http.Request) {
+				body := ChatCompletionRequest{
+					Model: "test-model",
+					Messages: []Message{
+						{Role: "user", Content: "What's the weather like in Paris Today?"},
+						{Role: "assistant", ToolCalls: []ToolCall{{
+							ID:   "id",
+							Type: "function",
+							Function: struct {
+								Name      string `json:"name"`
+								Arguments string `json:"arguments"`
+							}{
+								Name:      "get_current_weather",
+								Arguments: "{\"location\": \"Paris, France\", \"format\": \"celsius\"}",
+							},
+						}}},
+					},
+				}
+				prepareRequest(req, body)
+			},
+			Expected: func(t *testing.T, req *api.ChatRequest, resp *httptest.ResponseRecorder) {
+				if resp.Code != 200 {
+					t.Fatalf("expected 200, got %d", resp.Code)
+				}
+
+				if req.Messages[0].Content != "What's the weather like in Paris Today?" {
+					t.Fatalf("expected What's the weather like in Paris Today?, got %s", req.Messages[0].Content)
 				}
 
-				if chatReq.Messages[0].Role != "user" {
-					t.Fatalf("expected 'user', got %s", chatReq.Messages[0].Role)
+				if req.Messages[1].ToolCalls[0].Function.Arguments["location"] != "Paris, France" {
+					t.Fatalf("expected 'Paris, France', got %v", req.Messages[1].ToolCalls[0].Function.Arguments["location"])
 				}
 
-				if chatReq.Messages[0].Content != "Hello" {
-					t.Fatalf("expected 'Hello', got %s", chatReq.Messages[0].Content)
+				if req.Messages[1].ToolCalls[0].Function.Arguments["format"] != "celsius" {
+					t.Fatalf("expected celsius, got %v", req.Messages[1].ToolCalls[0].Function.Arguments["format"])
 				}
 			},
 		},
 		{
-			Name:    "completions handler",
-			Method:  http.MethodPost,
-			Path:    "/api/generate",
-			Handler: CompletionsMiddleware,
+			Name: "chat handler error forwarding",
+			Setup: func(t *testing.T, req *http.Request) {
+				body := ChatCompletionRequest{
+					Model:    "test-model",
+					Messages: []Message{{Role: "user", Content: 2}},
+				}
+				prepareRequest(req, body)
+			},
+			Expected: func(t *testing.T, req *api.ChatRequest, resp *httptest.ResponseRecorder) {
+				if resp.Code != http.StatusBadRequest {
+					t.Fatalf("expected 400, got %d", resp.Code)
+				}
+
+				if !strings.Contains(resp.Body.String(), "invalid message content type") {
+					t.Fatalf("error was not forwarded")
+				}
+			},
+		},
+	}
+
+	endpoint := func(c *gin.Context) {
+		c.Status(http.StatusOK)
+	}
+
+	gin.SetMode(gin.TestMode)
+	router := gin.New()
+	router.Use(ChatMiddleware(), captureRequestMiddleware(&capturedRequest))
+	router.Handle(http.MethodPost, "/api/chat", endpoint)
+
+	for _, tc := range testCases {
+		t.Run(tc.Name, func(t *testing.T) {
+			req, _ := http.NewRequest(http.MethodPost, "/api/chat", nil)
+
+			tc.Setup(t, req)
+
+			resp := httptest.NewRecorder()
+			router.ServeHTTP(resp, req)
+
+			tc.Expected(t, capturedRequest, resp)
+
+			capturedRequest = nil
+		})
+	}
+}
+
+func TestCompletionsMiddleware(t *testing.T) {
+	type testCase struct {
+		Name     string
+		Setup    func(t *testing.T, req *http.Request)
+		Expected func(t *testing.T, req *api.GenerateRequest, resp *httptest.ResponseRecorder)
+	}
+
+	var capturedRequest *api.GenerateRequest
+
+	testCases := []testCase{
+		{
+			Name: "completions handler",
 			Setup: func(t *testing.T, req *http.Request) {
 				temp := float32(0.8)
 				body := CompletionRequest{
@@ -87,27 +218,18 @@ func TestMiddlewareRequests(t *testing.T) {
 					Stop:        []string{"\n", "stop"},
 					Suffix:      "suffix",
 				}
-
-				bodyBytes, _ := json.Marshal(body)
-
-				req.Body = io.NopCloser(bytes.NewReader(bodyBytes))
-				req.Header.Set("Content-Type", "application/json")
+				prepareRequest(req, body)
 			},
-			Expected: func(t *testing.T, req *http.Request) {
-				var genReq api.GenerateRequest
-				if err := json.NewDecoder(req.Body).Decode(&genReq); err != nil {
-					t.Fatal(err)
-				}
-
-				if genReq.Prompt != "Hello" {
-					t.Fatalf("expected 'Hello', got %s", genReq.Prompt)
+			Expected: func(t *testing.T, req *api.GenerateRequest, resp *httptest.ResponseRecorder) {
+				if req.Prompt != "Hello" {
+					t.Fatalf("expected 'Hello', got %s", req.Prompt)
 				}
 
-				if genReq.Options["temperature"] != 1.6 {
-					t.Fatalf("expected 1.6, got %f", genReq.Options["temperature"])
+				if req.Options["temperature"] != 1.6 {
+					t.Fatalf("expected 1.6, got %f", req.Options["temperature"])
 				}
 
-				stopTokens, ok := genReq.Options["stop"].([]any)
+				stopTokens, ok := req.Options["stop"].([]any)
 
 				if !ok {
 					t.Fatalf("expected stop tokens to be a list")
@@ -117,113 +239,100 @@ func TestMiddlewareRequests(t *testing.T) {
 					t.Fatalf("expected ['\\n', 'stop'], got %v", stopTokens)
 				}
 
-				if genReq.Suffix != "suffix" {
-					t.Fatalf("expected 'suffix', got %s", genReq.Suffix)
+				if req.Suffix != "suffix" {
+					t.Fatalf("expected 'suffix', got %s", req.Suffix)
 				}
 			},
 		},
 		{
-			Name:    "chat handler with image content",
-			Method:  http.MethodPost,
-			Path:    "/api/chat",
-			Handler: ChatMiddleware,
+			Name: "completions handler error forwarding",
 			Setup: func(t *testing.T, req *http.Request) {
-				body := ChatCompletionRequest{
-					Model: "test-model",
-					Messages: []Message{
-						{
-							Role: "user", Content: []map[string]any{
-								{"type": "text", "text": "Hello"},
-								{"type": "image_url", "image_url": map[string]string{"url": imageURL}},
-							},
-						},
-					},
+				body := CompletionRequest{
+					Model:       "test-model",
+					Prompt:      "Hello",
+					Temperature: nil,
+					Stop:        []int{1, 2},
+					Suffix:      "suffix",
 				}
-
-				bodyBytes, _ := json.Marshal(body)
-
-				req.Body = io.NopCloser(bytes.NewReader(bodyBytes))
-				req.Header.Set("Content-Type", "application/json")
+				prepareRequest(req, body)
 			},
-			Expected: func(t *testing.T, req *http.Request) {
-				var chatReq api.ChatRequest
-				if err := json.NewDecoder(req.Body).Decode(&chatReq); err != nil {
-					t.Fatal(err)
+			Expected: func(t *testing.T, req *api.GenerateRequest, resp *httptest.ResponseRecorder) {
+				if resp.Code != http.StatusBadRequest {
+					t.Fatalf("expected 400, got %d", resp.Code)
 				}
 
-				if chatReq.Messages[0].Role != "user" {
-					t.Fatalf("expected 'user', got %s", chatReq.Messages[0].Role)
+				if !strings.Contains(resp.Body.String(), "invalid type for 'stop' field") {
+					t.Fatalf("error was not forwarded")
 				}
+			},
+		},
+	}
 
-				if chatReq.Messages[0].Content != "Hello" {
-					t.Fatalf("expected 'Hello', got %s", chatReq.Messages[0].Content)
-				}
+	endpoint := func(c *gin.Context) {
+		c.Status(http.StatusOK)
+	}
 
-				img, _ := base64.StdEncoding.DecodeString(imageURL[len(prefix):])
+	gin.SetMode(gin.TestMode)
+	router := gin.New()
+	router.Use(CompletionsMiddleware(), captureRequestMiddleware(&capturedRequest))
+	router.Handle(http.MethodPost, "/api/generate", endpoint)
 
-				if chatReq.Messages[1].Role != "user" {
-					t.Fatalf("expected 'user', got %s", chatReq.Messages[1].Role)
-				}
+	for _, tc := range testCases {
+		t.Run(tc.Name, func(t *testing.T) {
+			req, _ := http.NewRequest(http.MethodPost, "/api/generate", nil)
 
-				if !bytes.Equal(chatReq.Messages[1].Images[0], img) {
-					t.Fatalf("expected image encoding, got %s", chatReq.Messages[1].Images[0])
-				}
-			},
-		},
+			tc.Setup(t, req)
+
+			resp := httptest.NewRecorder()
+			router.ServeHTTP(resp, req)
+
+			tc.Expected(t, capturedRequest, resp)
+
+			capturedRequest = nil
+		})
+	}
+}
+
+func TestEmbeddingsMiddleware(t *testing.T) {
+	type testCase struct {
+		Name     string
+		Setup    func(t *testing.T, req *http.Request)
+		Expected func(t *testing.T, req *api.EmbedRequest, resp *httptest.ResponseRecorder)
+	}
+
+	var capturedRequest *api.EmbedRequest
+
+	testCases := []testCase{
 		{
-			Name:    "embed handler single input",
-			Method:  http.MethodPost,
-			Path:    "/api/embed",
-			Handler: EmbeddingsMiddleware,
+			Name: "embed handler single input",
 			Setup: func(t *testing.T, req *http.Request) {
 				body := EmbedRequest{
 					Input: "Hello",
 					Model: "test-model",
 				}
-
-				bodyBytes, _ := json.Marshal(body)
-
-				req.Body = io.NopCloser(bytes.NewReader(bodyBytes))
-				req.Header.Set("Content-Type", "application/json")
+				prepareRequest(req, body)
 			},
-			Expected: func(t *testing.T, req *http.Request) {
-				var embedReq api.EmbedRequest
-				if err := json.NewDecoder(req.Body).Decode(&embedReq); err != nil {
-					t.Fatal(err)
+			Expected: func(t *testing.T, req *api.EmbedRequest, resp *httptest.ResponseRecorder) {
+				if req.Input != "Hello" {
+					t.Fatalf("expected 'Hello', got %s", req.Input)
 				}
 
-				if embedReq.Input != "Hello" {
-					t.Fatalf("expected 'Hello', got %s", embedReq.Input)
-				}
-
-				if embedReq.Model != "test-model" {
-					t.Fatalf("expected 'test-model', got %s", embedReq.Model)
+				if req.Model != "test-model" {
+					t.Fatalf("expected 'test-model', got %s", req.Model)
 				}
 			},
 		},
 		{
-			Name:    "embed handler batch input",
-			Method:  http.MethodPost,
-			Path:    "/api/embed",
-			Handler: EmbeddingsMiddleware,
+			Name: "embed handler batch input",
 			Setup: func(t *testing.T, req *http.Request) {
 				body := EmbedRequest{
 					Input: []string{"Hello", "World"},
 					Model: "test-model",
 				}
-
-				bodyBytes, _ := json.Marshal(body)
-
-				req.Body = io.NopCloser(bytes.NewReader(bodyBytes))
-				req.Header.Set("Content-Type", "application/json")
+				prepareRequest(req, body)
 			},
-			Expected: func(t *testing.T, req *http.Request) {
-				var embedReq api.EmbedRequest
-				if err := json.NewDecoder(req.Body).Decode(&embedReq); err != nil {
-					t.Fatal(err)
-				}
-
-				input, ok := embedReq.Input.([]any)
+			Expected: func(t *testing.T, req *api.EmbedRequest, resp *httptest.ResponseRecorder) {
+				input, ok := req.Input.([]any)
 
 				if !ok {
 					t.Fatalf("expected input to be a list")
@@ -237,36 +346,52 @@ func TestMiddlewareRequests(t *testing.T) {
 					t.Fatalf("expected 'World', got %s", input[1])
 				}
 
-				if embedReq.Model != "test-model" {
-					t.Fatalf("expected 'test-model', got %s", embedReq.Model)
+				if req.Model != "test-model" {
+					t.Fatalf("expected 'test-model', got %s", req.Model)
 				}
 			},
 		},
-	}
+		{
+			Name: "embed handler error forwarding",
+			Setup: func(t *testing.T, req *http.Request) {
+				body := EmbedRequest{
+					Model: "test-model",
+				}
+				prepareRequest(req, body)
+			},
+			Expected: func(t *testing.T, req *api.EmbedRequest, resp *httptest.ResponseRecorder) {
+				if resp.Code != http.StatusBadRequest {
+					t.Fatalf("expected 400, got %d", resp.Code)
+				}
 
-	gin.SetMode(gin.TestMode)
-	router := gin.New()
+				if !strings.Contains(resp.Body.String(), "invalid input") {
+					t.Fatalf("error was not forwarded")
+				}
+			},
+		},
+	}
 
 	endpoint := func(c *gin.Context) {
 		c.Status(http.StatusOK)
 	}
 
+	gin.SetMode(gin.TestMode)
+	router := gin.New()
+	router.Use(EmbeddingsMiddleware(), captureRequestMiddleware(&capturedRequest))
+	router.Handle(http.MethodPost, "/api/embed", endpoint)
+
 	for _, tc := range testCases {
 		t.Run(tc.Name, func(t *testing.T) {
-			router = gin.New()
-			router.Use(captureRequestMiddleware())
-			router.Use(tc.Handler())
-			router.Handle(tc.Method, tc.Path, endpoint)
-			req, _ := http.NewRequest(tc.Method, tc.Path, nil)
+			req, _ := http.NewRequest(http.MethodPost, "/api/embed", nil)
 
-			if tc.Setup != nil {
-				tc.Setup(t, req)
-			}
+			tc.Setup(t, req)
 
 			resp := httptest.NewRecorder()
 			router.ServeHTTP(resp, req)
 
-			tc.Expected(t, capturedRequest)
+			tc.Expected(t, capturedRequest, resp)
+
+			capturedRequest = nil
 		})
 	}
 }
@@ -284,36 +409,6 @@ func TestMiddlewareResponses(t *testing.T) {
 	}
 
 	testCases := []testCase{
-		{
-			Name:     "completions handler error forwarding",
-			Method:   http.MethodPost,
-			Path:     "/api/generate",
-			TestPath: "/api/generate",
-			Handler:  CompletionsMiddleware,
-			Endpoint: func(c *gin.Context) {
-				c.JSON(http.StatusBadRequest, gin.H{"error": "invalid request"})
-			},
-			Setup: func(t *testing.T, req *http.Request) {
-				body := CompletionRequest{
-					Model:  "test-model",
-					Prompt: "Hello",
-				}
-
-				bodyBytes, _ := json.Marshal(body)
-
-				req.Body = io.NopCloser(bytes.NewReader(bodyBytes))
-				req.Header.Set("Content-Type", "application/json")
-			},
-			Expected: func(t *testing.T, resp *httptest.ResponseRecorder) {
-				if resp.Code != http.StatusBadRequest {
-					t.Fatalf("expected 400, got %d", resp.Code)
-				}
-
-				if !strings.Contains(resp.Body.String(), `"invalid request"`) {
-					t.Fatalf("error was not forwarded")
-				}
-			},
-		},
 		{
 			Name:     "list handler",
 			Method:   http.MethodGet,
@@ -330,8 +425,6 @@ func TestMiddlewareResponses(t *testing.T) {
 				})
 			},
 			Expected: func(t *testing.T, resp *httptest.ResponseRecorder) {
-				assert.Equal(t, http.StatusOK, resp.Code)
-
 				var listResp ListCompletion
 				if err := json.NewDecoder(resp.Body).Decode(&listResp); err != nil {
 					t.Fatal(err)
@@ -395,6 +488,8 @@ func TestMiddlewareResponses(t *testing.T) {
 			resp := httptest.NewRecorder()
 			router.ServeHTTP(resp, req)
 
+			assert.Equal(t, http.StatusOK, resp.Code)
+
 			tc.Expected(t, resp)
 		})
 	}