소스 검색

openai: fix "presence_penalty" typo and add test (#6665)

frob 7 달 전
부모
커밋
fe91d7fff1
2개의 변경된 파일45개의 추가작업 그리고 3개의 파일을 삭제
  1. 1 1
      openai/openai.go
  2. 44 2
      openai/openai_test.go

+ 1 - 1
openai/openai.go

@@ -79,7 +79,7 @@ type ChatCompletionRequest struct {
 	Stop             any             `json:"stop"`
 	Temperature      *float64        `json:"temperature"`
 	FrequencyPenalty *float64        `json:"frequency_penalty"`
-	PresencePenalty  *float64        `json:"presence_penalty_penalty"`
+	PresencePenalty  *float64        `json:"presence_penalty"`
 	TopP             *float64        `json:"top_p"`
 	ResponseFormat   *ResponseFormat `json:"response_format"`
 	Tools            []api.Tool      `json:"tools"`

+ 44 - 2
openai/openai_test.go

@@ -22,7 +22,10 @@ const (
 	image  = `iVBORw0KGgoAAAANSUhEUgAAAAEAAAABCAQAAAC1HAwCAAAAC0lEQVR42mNk+A8AAQUBAScY42YAAAAASUVORK5CYII=`
 )
 
-var False = false
+var (
+	False = false
+	True  = true
+)
 
 func captureRequestMiddleware(capturedRequest any) gin.HandlerFunc {
 	return func(c *gin.Context) {
@@ -70,6 +73,44 @@ func TestChatMiddleware(t *testing.T) {
 				Stream: &False,
 			},
 		},
+		{
+			name: "chat handler with options",
+			body: `{
+				"model": "test-model",
+				"messages": [
+					{"role": "user", "content": "Hello"}
+				],
+				"stream":            true,
+				"max_tokens":        999,
+				"seed":              123,
+				"stop":              ["\n", "stop"],
+				"temperature":       3.0,
+				"frequency_penalty": 4.0,
+				"presence_penalty":  5.0,
+				"top_p":             6.0,
+				"response_format":   {"type": "json_object"}
+			}`,
+			req: api.ChatRequest{
+				Model: "test-model",
+				Messages: []api.Message{
+					{
+						Role:    "user",
+						Content: "Hello",
+					},
+				},
+				Options: map[string]any{
+					"num_predict":       999.0, // float because JSON doesn't distinguish between float and int
+					"seed":              123.0,
+					"stop":              []any{"\n", "stop"},
+					"temperature":       6.0,
+					"frequency_penalty": 8.0,
+					"presence_penalty":  10.0,
+					"top_p":             6.0,
+				},
+				Format: "json",
+				Stream: &True,
+			},
+		},
 		{
 			name: "chat handler with image content",
 			body: `{
@@ -186,6 +227,8 @@ func TestChatMiddleware(t *testing.T) {
 			req, _ := http.NewRequest(http.MethodPost, "/api/chat", strings.NewReader(tc.body))
 			req.Header.Set("Content-Type", "application/json")
 
+			defer func() { capturedRequest = nil }()
+
 			resp := httptest.NewRecorder()
 			router.ServeHTTP(resp, req)
 
@@ -202,7 +245,6 @@ func TestChatMiddleware(t *testing.T) {
 			if !reflect.DeepEqual(tc.err, errResp) {
 				t.Fatal("errors did not match")
 			}
-			capturedRequest = nil
 		})
 	}
 }