Procházet zdrojové kódy

separate request tests (#5578)

royjhan před 9 měsíci
rodič
revize
0aff67877e
1 změnil soubory, kde provedl 75 přidání a 113 odebrání
  1. 75 113
      openai/openai_test.go

+ 75 - 113
openai/openai_test.go

@@ -3,7 +3,6 @@ package openai
 import (
 import (
 	"bytes"
 	"bytes"
 	"encoding/json"
 	"encoding/json"
-	"fmt"
 	"io"
 	"io"
 	"net/http"
 	"net/http"
 	"net/http/httptest"
 	"net/http/httptest"
@@ -16,49 +15,33 @@ import (
 	"github.com/stretchr/testify/assert"
 	"github.com/stretchr/testify/assert"
 )
 )
 
 
-func TestMiddleware(t *testing.T) {
+func TestMiddlewareRequests(t *testing.T) {
 	type testCase struct {
 	type testCase struct {
 		Name     string
 		Name     string
 		Method   string
 		Method   string
 		Path     string
 		Path     string
-		TestPath string
 		Handler  func() gin.HandlerFunc
 		Handler  func() gin.HandlerFunc
-		Endpoint func(c *gin.Context)
 		Setup    func(t *testing.T, req *http.Request)
 		Setup    func(t *testing.T, req *http.Request)
-		Expected func(t *testing.T, resp *httptest.ResponseRecorder)
+		Expected func(t *testing.T, req *http.Request)
 	}
 	}
 
 
-	testCases := []testCase{
-		{
-			Name:     "chat handler",
-			Method:   http.MethodPost,
-			Path:     "/api/chat",
-			TestPath: "/api/chat",
-			Handler:  ChatMiddleware,
-			Endpoint: func(c *gin.Context) {
-				var chatReq api.ChatRequest
-				if err := c.ShouldBindJSON(&chatReq); err != nil {
-					c.JSON(http.StatusBadRequest, gin.H{"error": "invalid request"})
-					return
-				}
-
-				userMessage := chatReq.Messages[0].Content
-				var assistantMessage string
+	var capturedRequest *http.Request
 
 
-				switch userMessage {
-				case "Hello":
-					assistantMessage = "Hello!"
-				default:
-					assistantMessage = "I'm not sure how to respond to that."
-				}
+	captureRequestMiddleware := func() gin.HandlerFunc {
+		return func(c *gin.Context) {
+			bodyBytes, _ := io.ReadAll(c.Request.Body)
+			c.Request.Body = io.NopCloser(bytes.NewReader(bodyBytes))
+			capturedRequest = c.Request
+			c.Next()
+		}
+	}
 
 
-				c.JSON(http.StatusOK, api.ChatResponse{
-					Message: api.Message{
-						Role:    "assistant",
-						Content: assistantMessage,
-					},
-				})
-			},
+	testCases := []testCase{
+		{
+			Name:    "chat handler",
+			Method:  http.MethodPost,
+			Path:    "/api/chat",
+			Handler: ChatMiddleware,
 			Setup: func(t *testing.T, req *http.Request) {
 			Setup: func(t *testing.T, req *http.Request) {
 				body := ChatCompletionRequest{
 				body := ChatCompletionRequest{
 					Model:    "test-model",
 					Model:    "test-model",
@@ -70,38 +53,32 @@ func TestMiddleware(t *testing.T) {
 				req.Body = io.NopCloser(bytes.NewReader(bodyBytes))
 				req.Body = io.NopCloser(bytes.NewReader(bodyBytes))
 				req.Header.Set("Content-Type", "application/json")
 				req.Header.Set("Content-Type", "application/json")
 			},
 			},
-			Expected: func(t *testing.T, resp *httptest.ResponseRecorder) {
-				assert.Equal(t, http.StatusOK, resp.Code)
-
-				var chatResp ChatCompletion
-				if err := json.NewDecoder(resp.Body).Decode(&chatResp); err != nil {
+			Expected: func(t *testing.T, req *http.Request) {
+				var chatReq api.ChatRequest
+				if err := json.NewDecoder(req.Body).Decode(&chatReq); err != nil {
 					t.Fatal(err)
 					t.Fatal(err)
 				}
 				}
 
 
-				if chatResp.Object != "chat.completion" {
-					t.Fatalf("expected chat.completion, got %s", chatResp.Object)
+				if chatReq.Messages[0].Role != "user" {
+					t.Fatalf("expected 'user', got %s", chatReq.Messages[0].Role)
 				}
 				}
 
 
-				if chatResp.Choices[0].Message.Content != "Hello!" {
-					t.Fatalf("expected Hello!, got %s", chatResp.Choices[0].Message.Content)
+				if chatReq.Messages[0].Content != "Hello" {
+					t.Fatalf("expected 'Hello', got %s", chatReq.Messages[0].Content)
 				}
 				}
 			},
 			},
 		},
 		},
 		{
 		{
-			Name:     "completions handler",
-			Method:   http.MethodPost,
-			Path:     "/api/generate",
-			TestPath: "/api/generate",
-			Handler:  CompletionsMiddleware,
-			Endpoint: func(c *gin.Context) {
-				c.JSON(http.StatusOK, api.GenerateResponse{
-					Response: "Hello!",
-				})
-			},
+			Name:    "completions handler",
+			Method:  http.MethodPost,
+			Path:    "/api/generate",
+			Handler: CompletionsMiddleware,
 			Setup: func(t *testing.T, req *http.Request) {
 			Setup: func(t *testing.T, req *http.Request) {
+				temp := float32(0.8)
 				body := CompletionRequest{
 				body := CompletionRequest{
-					Model:  "test-model",
-					Prompt: "Hello",
+					Model:       "test-model",
+					Prompt:      "Hello",
+					Temperature: &temp,
 				}
 				}
 
 
 				bodyBytes, _ := json.Marshal(body)
 				bodyBytes, _ := json.Marshal(body)
@@ -109,80 +86,65 @@ func TestMiddleware(t *testing.T) {
 				req.Body = io.NopCloser(bytes.NewReader(bodyBytes))
 				req.Body = io.NopCloser(bytes.NewReader(bodyBytes))
 				req.Header.Set("Content-Type", "application/json")
 				req.Header.Set("Content-Type", "application/json")
 			},
 			},
-			Expected: func(t *testing.T, resp *httptest.ResponseRecorder) {
-				assert.Equal(t, http.StatusOK, resp.Code)
-				var completionResp Completion
-				if err := json.NewDecoder(resp.Body).Decode(&completionResp); err != nil {
+			Expected: func(t *testing.T, req *http.Request) {
+				var genReq api.GenerateRequest
+				if err := json.NewDecoder(req.Body).Decode(&genReq); err != nil {
 					t.Fatal(err)
 					t.Fatal(err)
 				}
 				}
 
 
-				if completionResp.Object != "text_completion" {
-					t.Fatalf("expected text_completion, got %s", completionResp.Object)
+				if genReq.Prompt != "Hello" {
+					t.Fatalf("expected 'Hello', got %s", genReq.Prompt)
 				}
 				}
 
 
-				if completionResp.Choices[0].Text != "Hello!" {
-					t.Fatalf("expected Hello!, got %s", completionResp.Choices[0].Text)
+				if genReq.Options["temperature"] != 1.6 {
+					t.Fatalf("expected 1.6, got %f", genReq.Options["temperature"])
 				}
 				}
 			},
 			},
 		},
 		},
-		{
-			Name:     "completions handler with params",
-			Method:   http.MethodPost,
-			Path:     "/api/generate",
-			TestPath: "/api/generate",
-			Handler:  CompletionsMiddleware,
-			Endpoint: func(c *gin.Context) {
-				var generateReq api.GenerateRequest
-				if err := c.ShouldBindJSON(&generateReq); err != nil {
-					c.JSON(http.StatusBadRequest, gin.H{"error": "invalid request"})
-					return
-				}
+	}
 
 
-				temperature := generateReq.Options["temperature"].(float64)
-				var assistantMessage string
+	gin.SetMode(gin.TestMode)
+	router := gin.New()
 
 
-				switch temperature {
-				case 1.6:
-					assistantMessage = "Received temperature of 1.6"
-				default:
-					assistantMessage = fmt.Sprintf("Received temperature of %f", temperature)
-				}
+	endpoint := func(c *gin.Context) {
+		c.Status(http.StatusOK)
+	}
 
 
-				c.JSON(http.StatusOK, api.GenerateResponse{
-					Response: assistantMessage,
-				})
-			},
-			Setup: func(t *testing.T, req *http.Request) {
-				temp := float32(0.8)
-				body := CompletionRequest{
-					Model:       "test-model",
-					Prompt:      "Hello",
-					Temperature: &temp,
-				}
+	for _, tc := range testCases {
+		t.Run(tc.Name, func(t *testing.T) {
+			router = gin.New()
+			router.Use(captureRequestMiddleware())
+			router.Use(tc.Handler())
+			router.Handle(tc.Method, tc.Path, endpoint)
+			req, _ := http.NewRequest(tc.Method, tc.Path, nil)
 
 
-				bodyBytes, _ := json.Marshal(body)
+			if tc.Setup != nil {
+				tc.Setup(t, req)
+			}
 
 
-				req.Body = io.NopCloser(bytes.NewReader(bodyBytes))
-				req.Header.Set("Content-Type", "application/json")
-			},
-			Expected: func(t *testing.T, resp *httptest.ResponseRecorder) {
-				assert.Equal(t, http.StatusOK, resp.Code)
-				var completionResp Completion
-				if err := json.NewDecoder(resp.Body).Decode(&completionResp); err != nil {
-					t.Fatal(err)
-				}
+			resp := httptest.NewRecorder()
+			router.ServeHTTP(resp, req)
 
 
-				if completionResp.Object != "text_completion" {
-					t.Fatalf("expected text_completion, got %s", completionResp.Object)
-				}
+			tc.Expected(t, capturedRequest)
+		})
+	}
+}
 
 
-				if completionResp.Choices[0].Text != "Received temperature of 1.6" {
-					t.Fatalf("expected Received temperature of 1.6, got %s", completionResp.Choices[0].Text)
-				}
-			},
-		},
+func TestMiddlewareResponses(t *testing.T) {
+	type testCase struct {
+		Name     string
+		Method   string
+		Path     string
+		TestPath string
+		Handler  func() gin.HandlerFunc
+		Endpoint func(c *gin.Context)
+		Setup    func(t *testing.T, req *http.Request)
+		Expected func(t *testing.T, resp *httptest.ResponseRecorder)
+	}
+
+	testCases := []testCase{
 		{
 		{
-			Name:     "completions handler with error",
+			Name:     "completions handler error forwarding",
 			Method:   http.MethodPost,
 			Method:   http.MethodPost,
 			Path:     "/api/generate",
 			Path:     "/api/generate",
 			TestPath: "/api/generate",
 			TestPath: "/api/generate",