Ver Fonte

set num_ctx through extra body

ParthSareen há 3 meses atrás
pai
commit
35e97db03b
1 ficheiros alterados com 32 adições e 14 exclusões
  1. 32 14
      openai/openai.go

+ 32 - 14
openai/openai.go

@@ -80,19 +80,21 @@ type StreamOptions struct {
 }
 
 type ChatCompletionRequest struct {
-	Model            string          `json:"model"`
-	Messages         []Message       `json:"messages"`
-	Stream           bool            `json:"stream"`
-	StreamOptions    *StreamOptions  `json:"stream_options"`
-	MaxTokens        *int            `json:"max_tokens"`
-	Seed             *int            `json:"seed"`
-	Stop             any             `json:"stop"`
-	Temperature      *float64        `json:"temperature"`
-	FrequencyPenalty *float64        `json:"frequency_penalty"`
-	PresencePenalty  *float64        `json:"presence_penalty"`
-	TopP             *float64        `json:"top_p"`
-	ResponseFormat   *ResponseFormat `json:"response_format"`
-	Tools            []api.Tool      `json:"tools"`
+	Model               string          `json:"model"`
+	Messages            []Message       `json:"messages"`
+	Stream              bool            `json:"stream"`
+	StreamOptions       *StreamOptions  `json:"stream_options"`
+	MaxCompletionTokens *int            `json:"max_completion_tokens"`
+	MaxTokens           *int            `json:"max_tokens" deprecated:"use max_completion_tokens instead"`
+	Seed                *int            `json:"seed"`
+	Stop                any             `json:"stop"`
+	Temperature         *float64        `json:"temperature"`
+	FrequencyPenalty    *float64        `json:"frequency_penalty"`
+	PresencePenalty     *float64        `json:"presence_penalty"`
+	TopP                *float64        `json:"top_p"`
+	ResponseFormat      *ResponseFormat `json:"response_format"`
+	Tools               []api.Tool      `json:"tools"`
+	NumCtx              *int            `json:"num_ctx"`
 }
 
 type ChatCompletion struct {
@@ -475,8 +477,24 @@ func fromChatRequest(r ChatCompletionRequest) (*api.ChatRequest, error) {
 		options["stop"] = stops
 	}
 
+	// Deprecated: MaxTokens is deprecated, use MaxCompletionTokens instead
 	if r.MaxTokens != nil {
-		options["num_predict"] = *r.MaxTokens
+		r.MaxCompletionTokens = r.MaxTokens
+	}
+
+	if r.NumCtx != nil {
+		options["num_ctx"] = *r.NumCtx
+	}
+
+	DEFAULT_NUM_CTX := 2048
+	if r.MaxCompletionTokens != nil {
+		options["num_predict"] = *r.MaxCompletionTokens
+
+		if numCtx, ok := options["num_ctx"].(int); ok && *r.MaxCompletionTokens > numCtx {
+			options["num_ctx"] = *r.MaxCompletionTokens
+		} else if *r.MaxCompletionTokens > DEFAULT_NUM_CTX {
+			options["num_ctx"] = DEFAULT_NUM_CTX
+		}
 	}
 
 	if r.Temperature != nil {