Pārlūkot izejas kodu

OpenAI v1/completions: allow stop token list (#5551)

* stop token parsing fix

* add stop test
royjhan 9 mēneši atpakaļ
vecāks
revīzija
4918fae535
2 mainītis faili ar 20 papildinājumiem un 5 dzēšanām
  1. 9 5
      openai/openai.go
  2. 11 0
      openai/openai_test.go

+ 9 - 5
openai/openai.go

@@ -338,12 +338,16 @@ func fromCompleteRequest(r CompletionRequest) (api.GenerateRequest, error) {
 	switch stop := r.Stop.(type) {
 	case string:
 		options["stop"] = []string{stop}
-	case []string:
-		options["stop"] = stop
-	default:
-		if r.Stop != nil {
-			return api.GenerateRequest{}, fmt.Errorf("invalid type for 'stop' field: %T", r.Stop)
+	case []any:
+		var stops []string
+		for _, s := range stop {
+			if str, ok := s.(string); ok {
+				stops = append(stops, str)
+			} else {
+				return api.GenerateRequest{}, fmt.Errorf("invalid type for 'stop' field: %T", s)
+			}
 		}
+		options["stop"] = stops
 	}
 
 	if r.MaxTokens != nil {

+ 11 - 0
openai/openai_test.go

@@ -79,6 +79,7 @@ func TestMiddlewareRequests(t *testing.T) {
 					Model:       "test-model",
 					Prompt:      "Hello",
 					Temperature: &temp,
+					Stop:        []string{"\n", "stop"},
 				}
 
 				bodyBytes, _ := json.Marshal(body)
@@ -99,6 +100,16 @@ func TestMiddlewareRequests(t *testing.T) {
 				if genReq.Options["temperature"] != 1.6 {
 					t.Fatalf("expected 1.6, got %f", genReq.Options["temperature"])
 				}
+
+				stopTokens, ok := genReq.Options["stop"].([]any)
+
+				if !ok {
+					t.Fatalf("expected stop tokens to be a list")
+				}
+
+				if stopTokens[0] != "\n" || stopTokens[1] != "stop" {
+					t.Fatalf("expected ['\\n', 'stop'], got %v", stopTokens)
+				}
 			},
 		},
 	}