|
@@ -3,7 +3,6 @@ package openai
|
|
|
import (
|
|
|
"bytes"
|
|
|
"encoding/json"
|
|
|
- "fmt"
|
|
|
"io"
|
|
|
"net/http"
|
|
|
"net/http/httptest"
|
|
@@ -16,49 +15,33 @@ import (
|
|
|
"github.com/stretchr/testify/assert"
|
|
|
)
|
|
|
|
|
|
-func TestMiddleware(t *testing.T) {
|
|
|
+func TestMiddlewareRequests(t *testing.T) {
|
|
|
type testCase struct {
|
|
|
Name string
|
|
|
Method string
|
|
|
Path string
|
|
|
- TestPath string
|
|
|
Handler func() gin.HandlerFunc
|
|
|
- Endpoint func(c *gin.Context)
|
|
|
Setup func(t *testing.T, req *http.Request)
|
|
|
- Expected func(t *testing.T, resp *httptest.ResponseRecorder)
|
|
|
+ Expected func(t *testing.T, req *http.Request)
|
|
|
}
|
|
|
|
|
|
- testCases := []testCase{
|
|
|
- {
|
|
|
- Name: "chat handler",
|
|
|
- Method: http.MethodPost,
|
|
|
- Path: "/api/chat",
|
|
|
- TestPath: "/api/chat",
|
|
|
- Handler: ChatMiddleware,
|
|
|
- Endpoint: func(c *gin.Context) {
|
|
|
- var chatReq api.ChatRequest
|
|
|
- if err := c.ShouldBindJSON(&chatReq); err != nil {
|
|
|
- c.JSON(http.StatusBadRequest, gin.H{"error": "invalid request"})
|
|
|
- return
|
|
|
- }
|
|
|
-
|
|
|
- userMessage := chatReq.Messages[0].Content
|
|
|
- var assistantMessage string
|
|
|
+ var capturedRequest *http.Request
|
|
|
|
|
|
- switch userMessage {
|
|
|
- case "Hello":
|
|
|
- assistantMessage = "Hello!"
|
|
|
- default:
|
|
|
- assistantMessage = "I'm not sure how to respond to that."
|
|
|
- }
|
|
|
+ captureRequestMiddleware := func() gin.HandlerFunc {
|
|
|
+ return func(c *gin.Context) {
|
|
|
+ bodyBytes, _ := io.ReadAll(c.Request.Body)
|
|
|
+ c.Request.Body = io.NopCloser(bytes.NewReader(bodyBytes))
|
|
|
+ capturedRequest = c.Request
|
|
|
+ c.Next()
|
|
|
+ }
|
|
|
+ }
|
|
|
|
|
|
- c.JSON(http.StatusOK, api.ChatResponse{
|
|
|
- Message: api.Message{
|
|
|
- Role: "assistant",
|
|
|
- Content: assistantMessage,
|
|
|
- },
|
|
|
- })
|
|
|
- },
|
|
|
+ testCases := []testCase{
|
|
|
+ {
|
|
|
+ Name: "chat handler",
|
|
|
+ Method: http.MethodPost,
|
|
|
+ Path: "/api/chat",
|
|
|
+ Handler: ChatMiddleware,
|
|
|
Setup: func(t *testing.T, req *http.Request) {
|
|
|
body := ChatCompletionRequest{
|
|
|
Model: "test-model",
|
|
@@ -70,38 +53,32 @@ func TestMiddleware(t *testing.T) {
|
|
|
req.Body = io.NopCloser(bytes.NewReader(bodyBytes))
|
|
|
req.Header.Set("Content-Type", "application/json")
|
|
|
},
|
|
|
- Expected: func(t *testing.T, resp *httptest.ResponseRecorder) {
|
|
|
- assert.Equal(t, http.StatusOK, resp.Code)
|
|
|
-
|
|
|
- var chatResp ChatCompletion
|
|
|
- if err := json.NewDecoder(resp.Body).Decode(&chatResp); err != nil {
|
|
|
+ Expected: func(t *testing.T, req *http.Request) {
|
|
|
+ var chatReq api.ChatRequest
|
|
|
+ if err := json.NewDecoder(req.Body).Decode(&chatReq); err != nil {
|
|
|
t.Fatal(err)
|
|
|
}
|
|
|
|
|
|
- if chatResp.Object != "chat.completion" {
|
|
|
- t.Fatalf("expected chat.completion, got %s", chatResp.Object)
|
|
|
+ if chatReq.Messages[0].Role != "user" {
|
|
|
+ t.Fatalf("expected 'user', got %s", chatReq.Messages[0].Role)
|
|
|
}
|
|
|
|
|
|
- if chatResp.Choices[0].Message.Content != "Hello!" {
|
|
|
- t.Fatalf("expected Hello!, got %s", chatResp.Choices[0].Message.Content)
|
|
|
+ if chatReq.Messages[0].Content != "Hello" {
|
|
|
+ t.Fatalf("expected 'Hello', got %s", chatReq.Messages[0].Content)
|
|
|
}
|
|
|
},
|
|
|
},
|
|
|
{
|
|
|
- Name: "completions handler",
|
|
|
- Method: http.MethodPost,
|
|
|
- Path: "/api/generate",
|
|
|
- TestPath: "/api/generate",
|
|
|
- Handler: CompletionsMiddleware,
|
|
|
- Endpoint: func(c *gin.Context) {
|
|
|
- c.JSON(http.StatusOK, api.GenerateResponse{
|
|
|
- Response: "Hello!",
|
|
|
- })
|
|
|
- },
|
|
|
+ Name: "completions handler",
|
|
|
+ Method: http.MethodPost,
|
|
|
+ Path: "/api/generate",
|
|
|
+ Handler: CompletionsMiddleware,
|
|
|
Setup: func(t *testing.T, req *http.Request) {
|
|
|
+ temp := float32(0.8)
|
|
|
body := CompletionRequest{
|
|
|
- Model: "test-model",
|
|
|
- Prompt: "Hello",
|
|
|
+ Model: "test-model",
|
|
|
+ Prompt: "Hello",
|
|
|
+ Temperature: &temp,
|
|
|
}
|
|
|
|
|
|
bodyBytes, _ := json.Marshal(body)
|
|
@@ -109,80 +86,65 @@ func TestMiddleware(t *testing.T) {
|
|
|
req.Body = io.NopCloser(bytes.NewReader(bodyBytes))
|
|
|
req.Header.Set("Content-Type", "application/json")
|
|
|
},
|
|
|
- Expected: func(t *testing.T, resp *httptest.ResponseRecorder) {
|
|
|
- assert.Equal(t, http.StatusOK, resp.Code)
|
|
|
- var completionResp Completion
|
|
|
- if err := json.NewDecoder(resp.Body).Decode(&completionResp); err != nil {
|
|
|
+ Expected: func(t *testing.T, req *http.Request) {
|
|
|
+ var genReq api.GenerateRequest
|
|
|
+ if err := json.NewDecoder(req.Body).Decode(&genReq); err != nil {
|
|
|
t.Fatal(err)
|
|
|
}
|
|
|
|
|
|
- if completionResp.Object != "text_completion" {
|
|
|
- t.Fatalf("expected text_completion, got %s", completionResp.Object)
|
|
|
+ if genReq.Prompt != "Hello" {
|
|
|
+ t.Fatalf("expected 'Hello', got %s", genReq.Prompt)
|
|
|
}
|
|
|
|
|
|
- if completionResp.Choices[0].Text != "Hello!" {
|
|
|
- t.Fatalf("expected Hello!, got %s", completionResp.Choices[0].Text)
|
|
|
+ if genReq.Options["temperature"] != 1.6 {
|
|
|
+ t.Fatalf("expected 1.6, got %f", genReq.Options["temperature"])
|
|
|
}
|
|
|
},
|
|
|
},
|
|
|
- {
|
|
|
- Name: "completions handler with params",
|
|
|
- Method: http.MethodPost,
|
|
|
- Path: "/api/generate",
|
|
|
- TestPath: "/api/generate",
|
|
|
- Handler: CompletionsMiddleware,
|
|
|
- Endpoint: func(c *gin.Context) {
|
|
|
- var generateReq api.GenerateRequest
|
|
|
- if err := c.ShouldBindJSON(&generateReq); err != nil {
|
|
|
- c.JSON(http.StatusBadRequest, gin.H{"error": "invalid request"})
|
|
|
- return
|
|
|
- }
|
|
|
+ }
|
|
|
|
|
|
- temperature := generateReq.Options["temperature"].(float64)
|
|
|
- var assistantMessage string
|
|
|
+ gin.SetMode(gin.TestMode)
|
|
|
+ router := gin.New()
|
|
|
|
|
|
- switch temperature {
|
|
|
- case 1.6:
|
|
|
- assistantMessage = "Received temperature of 1.6"
|
|
|
- default:
|
|
|
- assistantMessage = fmt.Sprintf("Received temperature of %f", temperature)
|
|
|
- }
|
|
|
+ endpoint := func(c *gin.Context) {
|
|
|
+ c.Status(http.StatusOK)
|
|
|
+ }
|
|
|
|
|
|
- c.JSON(http.StatusOK, api.GenerateResponse{
|
|
|
- Response: assistantMessage,
|
|
|
- })
|
|
|
- },
|
|
|
- Setup: func(t *testing.T, req *http.Request) {
|
|
|
- temp := float32(0.8)
|
|
|
- body := CompletionRequest{
|
|
|
- Model: "test-model",
|
|
|
- Prompt: "Hello",
|
|
|
- Temperature: &temp,
|
|
|
- }
|
|
|
+ for _, tc := range testCases {
|
|
|
+ t.Run(tc.Name, func(t *testing.T) {
|
|
|
+ router = gin.New()
|
|
|
+ router.Use(captureRequestMiddleware())
|
|
|
+ router.Use(tc.Handler())
|
|
|
+ router.Handle(tc.Method, tc.Path, endpoint)
|
|
|
+ req, _ := http.NewRequest(tc.Method, tc.Path, nil)
|
|
|
|
|
|
- bodyBytes, _ := json.Marshal(body)
|
|
|
+ if tc.Setup != nil {
|
|
|
+ tc.Setup(t, req)
|
|
|
+ }
|
|
|
|
|
|
- req.Body = io.NopCloser(bytes.NewReader(bodyBytes))
|
|
|
- req.Header.Set("Content-Type", "application/json")
|
|
|
- },
|
|
|
- Expected: func(t *testing.T, resp *httptest.ResponseRecorder) {
|
|
|
- assert.Equal(t, http.StatusOK, resp.Code)
|
|
|
- var completionResp Completion
|
|
|
- if err := json.NewDecoder(resp.Body).Decode(&completionResp); err != nil {
|
|
|
- t.Fatal(err)
|
|
|
- }
|
|
|
+ resp := httptest.NewRecorder()
|
|
|
+ router.ServeHTTP(resp, req)
|
|
|
|
|
|
- if completionResp.Object != "text_completion" {
|
|
|
- t.Fatalf("expected text_completion, got %s", completionResp.Object)
|
|
|
- }
|
|
|
+ tc.Expected(t, capturedRequest)
|
|
|
+ })
|
|
|
+ }
|
|
|
+}
|
|
|
|
|
|
- if completionResp.Choices[0].Text != "Received temperature of 1.6" {
|
|
|
- t.Fatalf("expected Received temperature of 1.6, got %s", completionResp.Choices[0].Text)
|
|
|
- }
|
|
|
- },
|
|
|
- },
|
|
|
+func TestMiddlewareResponses(t *testing.T) {
|
|
|
+ type testCase struct {
|
|
|
+ Name string
|
|
|
+ Method string
|
|
|
+ Path string
|
|
|
+ TestPath string
|
|
|
+ Handler func() gin.HandlerFunc
|
|
|
+ Endpoint func(c *gin.Context)
|
|
|
+ Setup func(t *testing.T, req *http.Request)
|
|
|
+ Expected func(t *testing.T, resp *httptest.ResponseRecorder)
|
|
|
+ }
|
|
|
+
|
|
|
+ testCases := []testCase{
|
|
|
{
|
|
|
- Name: "completions handler with error",
|
|
|
+ Name: "completions handler error forwarding",
|
|
|
Method: http.MethodPost,
|
|
|
Path: "/api/generate",
|
|
|
TestPath: "/api/generate",
|