|
@@ -43,6 +43,12 @@ type ChunkChoice struct {
|
|
|
FinishReason *string `json:"finish_reason"`
|
|
|
}
|
|
|
|
|
|
+type CompleteChunkChoice struct {
|
|
|
+ Text string `json:"text"`
|
|
|
+ Index int `json:"index"`
|
|
|
+ FinishReason *string `json:"finish_reason"`
|
|
|
+}
|
|
|
+
|
|
|
type Usage struct {
|
|
|
PromptTokens int `json:"prompt_tokens"`
|
|
|
CompletionTokens int `json:"completion_tokens"`
|
|
@@ -86,6 +92,39 @@ type ChatCompletionChunk struct {
|
|
|
Choices []ChunkChoice `json:"choices"`
|
|
|
}
|
|
|
|
|
|
+// TODO (https://github.com/ollama/ollama/issues/5259): support []string, []int and [][]int
|
|
|
+type CompletionRequest struct {
|
|
|
+ Model string `json:"model"`
|
|
|
+ Prompt string `json:"prompt"`
|
|
|
+ FrequencyPenalty float32 `json:"frequency_penalty"`
|
|
|
+ MaxTokens *int `json:"max_tokens"`
|
|
|
+ PresencePenalty float32 `json:"presence_penalty"`
|
|
|
+ Seed *int `json:"seed"`
|
|
|
+ Stop any `json:"stop"`
|
|
|
+ Stream bool `json:"stream"`
|
|
|
+ Temperature *float32 `json:"temperature"`
|
|
|
+ TopP float32 `json:"top_p"`
|
|
|
+}
|
|
|
+
|
|
|
+type Completion struct {
|
|
|
+ Id string `json:"id"`
|
|
|
+ Object string `json:"object"`
|
|
|
+ Created int64 `json:"created"`
|
|
|
+ Model string `json:"model"`
|
|
|
+ SystemFingerprint string `json:"system_fingerprint"`
|
|
|
+ Choices []CompleteChunkChoice `json:"choices"`
|
|
|
+ Usage Usage `json:"usage,omitempty"`
|
|
|
+}
|
|
|
+
|
|
|
+type CompletionChunk struct {
|
|
|
+ Id string `json:"id"`
|
|
|
+ Object string `json:"object"`
|
|
|
+ Created int64 `json:"created"`
|
|
|
+ Choices []CompleteChunkChoice `json:"choices"`
|
|
|
+ Model string `json:"model"`
|
|
|
+ SystemFingerprint string `json:"system_fingerprint"`
|
|
|
+}
|
|
|
+
|
|
|
type Model struct {
|
|
|
Id string `json:"id"`
|
|
|
Object string `json:"object"`
|
|
@@ -158,6 +197,52 @@ func toChunk(id string, r api.ChatResponse) ChatCompletionChunk {
|
|
|
}
|
|
|
}
|
|
|
|
|
|
+func toCompletion(id string, r api.GenerateResponse) Completion {
|
|
|
+ return Completion{
|
|
|
+ Id: id,
|
|
|
+ Object: "text_completion",
|
|
|
+ Created: r.CreatedAt.Unix(),
|
|
|
+ Model: r.Model,
|
|
|
+ SystemFingerprint: "fp_ollama",
|
|
|
+ Choices: []CompleteChunkChoice{{
|
|
|
+ Text: r.Response,
|
|
|
+ Index: 0,
|
|
|
+ FinishReason: func(reason string) *string {
|
|
|
+ if len(reason) > 0 {
|
|
|
+ return &reason
|
|
|
+ }
|
|
|
+ return nil
|
|
|
+ }(r.DoneReason),
|
|
|
+ }},
|
|
|
+ Usage: Usage{
|
|
|
+ // TODO: ollama returns 0 for prompt eval if the prompt was cached, but openai returns the actual count
|
|
|
+ PromptTokens: r.PromptEvalCount,
|
|
|
+ CompletionTokens: r.EvalCount,
|
|
|
+ TotalTokens: r.PromptEvalCount + r.EvalCount,
|
|
|
+ },
|
|
|
+ }
|
|
|
+}
|
|
|
+
|
|
|
+func toCompleteChunk(id string, r api.GenerateResponse) CompletionChunk {
|
|
|
+ return CompletionChunk{
|
|
|
+ Id: id,
|
|
|
+ Object: "text_completion",
|
|
|
+ Created: time.Now().Unix(),
|
|
|
+ Model: r.Model,
|
|
|
+ SystemFingerprint: "fp_ollama",
|
|
|
+ Choices: []CompleteChunkChoice{{
|
|
|
+ Text: r.Response,
|
|
|
+ Index: 0,
|
|
|
+ FinishReason: func(reason string) *string {
|
|
|
+ if len(reason) > 0 {
|
|
|
+ return &reason
|
|
|
+ }
|
|
|
+ return nil
|
|
|
+ }(r.DoneReason),
|
|
|
+ }},
|
|
|
+ }
|
|
|
+}
|
|
|
+
|
|
|
func toListCompletion(r api.ListResponse) ListCompletion {
|
|
|
var data []Model
|
|
|
for _, m := range r.Models {
|
|
@@ -195,7 +280,7 @@ func fromChatRequest(r ChatCompletionRequest) api.ChatRequest {
|
|
|
switch stop := r.Stop.(type) {
|
|
|
case string:
|
|
|
options["stop"] = []string{stop}
|
|
|
- case []interface{}:
|
|
|
+ case []any:
|
|
|
var stops []string
|
|
|
for _, s := range stop {
|
|
|
if str, ok := s.(string); ok {
|
|
@@ -247,6 +332,52 @@ func fromChatRequest(r ChatCompletionRequest) api.ChatRequest {
|
|
|
}
|
|
|
}
|
|
|
|
|
|
+func fromCompleteRequest(r CompletionRequest) (api.GenerateRequest, error) {
|
|
|
+ options := make(map[string]any)
|
|
|
+
|
|
|
+ switch stop := r.Stop.(type) {
|
|
|
+ case string:
|
|
|
+ options["stop"] = []string{stop}
|
|
|
+ case []string:
|
|
|
+ options["stop"] = stop
|
|
|
+ default:
|
|
|
+ if r.Stop != nil {
|
|
|
+ return api.GenerateRequest{}, fmt.Errorf("invalid type for 'stop' field: %T", r.Stop)
|
|
|
+ }
|
|
|
+ }
|
|
|
+
|
|
|
+ if r.MaxTokens != nil {
|
|
|
+ options["num_predict"] = *r.MaxTokens
|
|
|
+ }
|
|
|
+
|
|
|
+ if r.Temperature != nil {
|
|
|
+ options["temperature"] = *r.Temperature * 2.0
|
|
|
+ } else {
|
|
|
+ options["temperature"] = 1.0
|
|
|
+ }
|
|
|
+
|
|
|
+ if r.Seed != nil {
|
|
|
+ options["seed"] = *r.Seed
|
|
|
+ }
|
|
|
+
|
|
|
+ options["frequency_penalty"] = r.FrequencyPenalty * 2.0
|
|
|
+
|
|
|
+ options["presence_penalty"] = r.PresencePenalty * 2.0
|
|
|
+
|
|
|
+ if r.TopP != 0.0 {
|
|
|
+ options["top_p"] = r.TopP
|
|
|
+ } else {
|
|
|
+ options["top_p"] = 1.0
|
|
|
+ }
|
|
|
+
|
|
|
+ return api.GenerateRequest{
|
|
|
+ Model: r.Model,
|
|
|
+ Prompt: r.Prompt,
|
|
|
+ Options: options,
|
|
|
+ Stream: &r.Stream,
|
|
|
+ }, nil
|
|
|
+}
|
|
|
+
|
|
|
type BaseWriter struct {
|
|
|
gin.ResponseWriter
|
|
|
}
|
|
@@ -257,6 +388,12 @@ type ChatWriter struct {
|
|
|
BaseWriter
|
|
|
}
|
|
|
|
|
|
+type CompleteWriter struct {
|
|
|
+ stream bool
|
|
|
+ id string
|
|
|
+ BaseWriter
|
|
|
+}
|
|
|
+
|
|
|
type ListWriter struct {
|
|
|
BaseWriter
|
|
|
}
|
|
@@ -331,6 +468,55 @@ func (w *ChatWriter) Write(data []byte) (int, error) {
|
|
|
return w.writeResponse(data)
|
|
|
}
|
|
|
|
|
|
+func (w *CompleteWriter) writeResponse(data []byte) (int, error) {
|
|
|
+ var generateResponse api.GenerateResponse
|
|
|
+ err := json.Unmarshal(data, &generateResponse)
|
|
|
+ if err != nil {
|
|
|
+ return 0, err
|
|
|
+ }
|
|
|
+
|
|
|
+ // completion chunk
|
|
|
+ if w.stream {
|
|
|
+ d, err := json.Marshal(toCompleteChunk(w.id, generateResponse))
|
|
|
+ if err != nil {
|
|
|
+ return 0, err
|
|
|
+ }
|
|
|
+
|
|
|
+ w.ResponseWriter.Header().Set("Content-Type", "text/event-stream")
|
|
|
+ _, err = w.ResponseWriter.Write([]byte(fmt.Sprintf("data: %s\n\n", d)))
|
|
|
+ if err != nil {
|
|
|
+ return 0, err
|
|
|
+ }
|
|
|
+
|
|
|
+ if generateResponse.Done {
|
|
|
+ _, err = w.ResponseWriter.Write([]byte("data: [DONE]\n\n"))
|
|
|
+ if err != nil {
|
|
|
+ return 0, err
|
|
|
+ }
|
|
|
+ }
|
|
|
+
|
|
|
+ return len(data), nil
|
|
|
+ }
|
|
|
+
|
|
|
+ // completion
|
|
|
+ w.ResponseWriter.Header().Set("Content-Type", "application/json")
|
|
|
+ err = json.NewEncoder(w.ResponseWriter).Encode(toCompletion(w.id, generateResponse))
|
|
|
+ if err != nil {
|
|
|
+ return 0, err
|
|
|
+ }
|
|
|
+
|
|
|
+ return len(data), nil
|
|
|
+}
|
|
|
+
|
|
|
+func (w *CompleteWriter) Write(data []byte) (int, error) {
|
|
|
+ code := w.ResponseWriter.Status()
|
|
|
+ if code != http.StatusOK {
|
|
|
+ return w.writeError(code, data)
|
|
|
+ }
|
|
|
+
|
|
|
+ return w.writeResponse(data)
|
|
|
+}
|
|
|
+
|
|
|
func (w *ListWriter) writeResponse(data []byte) (int, error) {
|
|
|
var listResponse api.ListResponse
|
|
|
err := json.Unmarshal(data, &listResponse)
|
|
@@ -416,6 +602,41 @@ func RetrieveMiddleware() gin.HandlerFunc {
|
|
|
}
|
|
|
}
|
|
|
|
|
|
+func CompletionsMiddleware() gin.HandlerFunc {
|
|
|
+ return func(c *gin.Context) {
|
|
|
+ var req CompletionRequest
|
|
|
+ err := c.ShouldBindJSON(&req)
|
|
|
+ if err != nil {
|
|
|
+ c.AbortWithStatusJSON(http.StatusBadRequest, NewError(http.StatusBadRequest, err.Error()))
|
|
|
+ return
|
|
|
+ }
|
|
|
+
|
|
|
+ var b bytes.Buffer
|
|
|
+ genReq, err := fromCompleteRequest(req)
|
|
|
+ if err != nil {
|
|
|
+ c.AbortWithStatusJSON(http.StatusBadRequest, NewError(http.StatusBadRequest, err.Error()))
|
|
|
+ return
|
|
|
+ }
|
|
|
+
|
|
|
+ if err := json.NewEncoder(&b).Encode(genReq); err != nil {
|
|
|
+ c.AbortWithStatusJSON(http.StatusInternalServerError, NewError(http.StatusInternalServerError, err.Error()))
|
|
|
+ return
|
|
|
+ }
|
|
|
+
|
|
|
+ c.Request.Body = io.NopCloser(&b)
|
|
|
+
|
|
|
+ w := &CompleteWriter{
|
|
|
+ BaseWriter: BaseWriter{ResponseWriter: c.Writer},
|
|
|
+ stream: req.Stream,
|
|
|
+ id: fmt.Sprintf("cmpl-%d", rand.Intn(999)),
|
|
|
+ }
|
|
|
+
|
|
|
+ c.Writer = w
|
|
|
+
|
|
|
+ c.Next()
|
|
|
+ }
|
|
|
+}
|
|
|
+
|
|
|
func ChatMiddleware() gin.HandlerFunc {
|
|
|
return func(c *gin.Context) {
|
|
|
var req ChatCompletionRequest
|