Przeglądaj źródła

OpenAI: Add Suffix to `v1/completions` (#5611)

* add suffix

* remove todo

* remove TODO

* add to test

* rm outdated prompt tokens info md

* fix test

* fix test
royjhan 9 miesięcy temu
rodzic
commit
0d41623b52
3 zmienionych plików z 7 dodań i 6 usunięć
  1. 0 4
      docs/openai.md
  2. 2 2
      openai/openai.go
  3. 5 0
      openai/openai_test.go

+ 0 - 4
docs/openai.md

@@ -103,10 +103,6 @@ curl http://localhost:11434/v1/chat/completions \
 - [ ] `user`
 - [ ] `n`
 
-#### Notes
-
-- `usage.prompt_tokens` will be 0 for completions where prompt evaluation is cached
-
 ## Models
 
 Before using a model, pull it locally `ollama pull`:

+ 2 - 2
openai/openai.go

@@ -111,6 +111,7 @@ type CompletionRequest struct {
 	Stream           bool     `json:"stream"`
 	Temperature      *float32 `json:"temperature"`
 	TopP             float32  `json:"top_p"`
+	Suffix           string   `json:"suffix"`
 }
 
 type Completion struct {
@@ -188,7 +189,6 @@ func toChatCompletion(id string, r api.ChatResponse) ChatCompletion {
 			}(r.DoneReason),
 		}},
 		Usage: Usage{
-			// TODO: ollama returns 0 for prompt eval if the prompt was cached, but openai returns the actual count
 			PromptTokens:     r.PromptEvalCount,
 			CompletionTokens: r.EvalCount,
 			TotalTokens:      r.PromptEvalCount + r.EvalCount,
@@ -234,7 +234,6 @@ func toCompletion(id string, r api.GenerateResponse) Completion {
 			}(r.DoneReason),
 		}},
 		Usage: Usage{
-			// TODO: ollama returns 0 for prompt eval if the prompt was cached, but openai returns the actual count
 			PromptTokens:     r.PromptEvalCount,
 			CompletionTokens: r.EvalCount,
 			TotalTokens:      r.PromptEvalCount + r.EvalCount,
@@ -475,6 +474,7 @@ func fromCompleteRequest(r CompletionRequest) (api.GenerateRequest, error) {
 		Prompt:  r.Prompt,
 		Options: options,
 		Stream:  &r.Stream,
+		Suffix:  r.Suffix,
 	}, nil
 }
 

+ 5 - 0
openai/openai_test.go

@@ -85,6 +85,7 @@ func TestMiddlewareRequests(t *testing.T) {
 					Prompt:      "Hello",
 					Temperature: &temp,
 					Stop:        []string{"\n", "stop"},
+					Suffix:      "suffix",
 				}
 
 				bodyBytes, _ := json.Marshal(body)
@@ -115,6 +116,10 @@ func TestMiddlewareRequests(t *testing.T) {
 				if stopTokens[0] != "\n" || stopTokens[1] != "stop" {
 					t.Fatalf("expected ['\\n', 'stop'], got %v", stopTokens)
 				}
+
+				if genReq.Suffix != "suffix" {
+					t.Fatalf("expected 'suffix', got %s", genReq.Suffix)
+				}
 			},
 		},
 		{