123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597598599600601602603604605606607608609610611612613614615616617618619620621622623624625626627628629630631632633634635636637638639640641642643644645646647648649650651652653654655656657658659660661662663664665666667668669670671672673674675676677678679680681682683684685686687688689690691692693694695696697698699700701702703704705706707708709710711712713714715716717718719720721722723724725726727728729730731732733734735736737738739740741742743744745746747748749750751752753754755756757758759760761762763764765766767768769770771772773774775776777778779780781782783784785786787788789790791792793794795796797798799800801802803804805806807808809810811812813814815816817818819820821822823824825826827828829830831832833834835836837838839840841842843844845846847848849850851852853854855856857858859860861862863864865866867868869870871872873874875876877878879880881882883884885886887888889890891892893894895896897898899900901902903904905906907908909910911912913 |
- // openai package provides middleware for partial compatibility with the OpenAI REST API
- package openai
- import (
- "bytes"
- "encoding/base64"
- "encoding/json"
- "errors"
- "fmt"
- "io"
- "log/slog"
- "math/rand"
- "net/http"
- "strings"
- "time"
- "github.com/gin-gonic/gin"
- "github.com/ollama/ollama/api"
- "github.com/ollama/ollama/types/model"
- )
- type Error struct {
- Message string `json:"message"`
- Type string `json:"type"`
- Param interface{} `json:"param"`
- Code *string `json:"code"`
- }
- type ErrorResponse struct {
- Error Error `json:"error"`
- }
- type Message struct {
- Role string `json:"role"`
- Content any `json:"content"`
- ToolCalls []ToolCall `json:"tool_calls,omitempty"`
- }
- type Choice struct {
- Index int `json:"index"`
- Message Message `json:"message"`
- FinishReason *string `json:"finish_reason"`
- }
- type ChunkChoice struct {
- Index int `json:"index"`
- Delta Message `json:"delta"`
- FinishReason *string `json:"finish_reason"`
- }
- type CompleteChunkChoice struct {
- Text string `json:"text"`
- Index int `json:"index"`
- FinishReason *string `json:"finish_reason"`
- }
- type Usage struct {
- PromptTokens int `json:"prompt_tokens"`
- CompletionTokens int `json:"completion_tokens"`
- TotalTokens int `json:"total_tokens"`
- }
- type ResponseFormat struct {
- Type string `json:"type"`
- }
- type EmbedRequest struct {
- Input any `json:"input"`
- Model string `json:"model"`
- }
- type ChatCompletionRequest struct {
- Model string `json:"model"`
- Messages []Message `json:"messages"`
- Stream bool `json:"stream"`
- MaxTokens *int `json:"max_tokens"`
- Seed *int `json:"seed"`
- Stop any `json:"stop"`
- Temperature *float64 `json:"temperature"`
- FrequencyPenalty *float64 `json:"frequency_penalty"`
- PresencePenalty *float64 `json:"presence_penalty_penalty"`
- TopP *float64 `json:"top_p"`
- ResponseFormat *ResponseFormat `json:"response_format"`
- Tools []api.Tool `json:"tools"`
- }
- type ChatCompletion struct {
- Id string `json:"id"`
- Object string `json:"object"`
- Created int64 `json:"created"`
- Model string `json:"model"`
- SystemFingerprint string `json:"system_fingerprint"`
- Choices []Choice `json:"choices"`
- Usage Usage `json:"usage,omitempty"`
- }
- type ChatCompletionChunk struct {
- Id string `json:"id"`
- Object string `json:"object"`
- Created int64 `json:"created"`
- Model string `json:"model"`
- SystemFingerprint string `json:"system_fingerprint"`
- Choices []ChunkChoice `json:"choices"`
- }
- // TODO (https://github.com/ollama/ollama/issues/5259): support []string, []int and [][]int
- type CompletionRequest struct {
- Model string `json:"model"`
- Prompt string `json:"prompt"`
- FrequencyPenalty float32 `json:"frequency_penalty"`
- MaxTokens *int `json:"max_tokens"`
- PresencePenalty float32 `json:"presence_penalty"`
- Seed *int `json:"seed"`
- Stop any `json:"stop"`
- Stream bool `json:"stream"`
- Temperature *float32 `json:"temperature"`
- TopP float32 `json:"top_p"`
- Suffix string `json:"suffix"`
- }
- type Completion struct {
- Id string `json:"id"`
- Object string `json:"object"`
- Created int64 `json:"created"`
- Model string `json:"model"`
- SystemFingerprint string `json:"system_fingerprint"`
- Choices []CompleteChunkChoice `json:"choices"`
- Usage Usage `json:"usage,omitempty"`
- }
- type CompletionChunk struct {
- Id string `json:"id"`
- Object string `json:"object"`
- Created int64 `json:"created"`
- Choices []CompleteChunkChoice `json:"choices"`
- Model string `json:"model"`
- SystemFingerprint string `json:"system_fingerprint"`
- }
- type ToolCall struct {
- ID string `json:"id"`
- Type string `json:"type"`
- Function struct {
- Name string `json:"name"`
- Arguments string `json:"arguments"`
- } `json:"function"`
- }
- type Model struct {
- Id string `json:"id"`
- Object string `json:"object"`
- Created int64 `json:"created"`
- OwnedBy string `json:"owned_by"`
- }
- type Embedding struct {
- Object string `json:"object"`
- Embedding []float32 `json:"embedding"`
- Index int `json:"index"`
- }
- type ListCompletion struct {
- Object string `json:"object"`
- Data []Model `json:"data"`
- }
- type EmbeddingList struct {
- Object string `json:"object"`
- Data []Embedding `json:"data"`
- Model string `json:"model"`
- Usage EmbeddingUsage `json:"usage,omitempty"`
- }
- type EmbeddingUsage struct {
- PromptTokens int `json:"prompt_tokens"`
- TotalTokens int `json:"total_tokens"`
- }
- func NewError(code int, message string) ErrorResponse {
- var etype string
- switch code {
- case http.StatusBadRequest:
- etype = "invalid_request_error"
- case http.StatusNotFound:
- etype = "not_found_error"
- default:
- etype = "api_error"
- }
- return ErrorResponse{Error{Type: etype, Message: message}}
- }
- func toolCallId() string {
- const letterBytes = "abcdefghijklmnopqrstuvwxyz0123456789"
- b := make([]byte, 8)
- for i := range b {
- b[i] = letterBytes[rand.Intn(len(letterBytes))]
- }
- return "call_" + strings.ToLower(string(b))
- }
- func toChatCompletion(id string, r api.ChatResponse) ChatCompletion {
- toolCalls := make([]ToolCall, len(r.Message.ToolCalls))
- for i, tc := range r.Message.ToolCalls {
- toolCalls[i].ID = toolCallId()
- toolCalls[i].Type = "function"
- toolCalls[i].Function.Name = tc.Function.Name
- args, err := json.Marshal(tc.Function.Arguments)
- if err != nil {
- slog.Error("could not marshall function arguments to json", "error", err)
- continue
- }
- toolCalls[i].Function.Arguments = string(args)
- }
- return ChatCompletion{
- Id: id,
- Object: "chat.completion",
- Created: r.CreatedAt.Unix(),
- Model: r.Model,
- SystemFingerprint: "fp_ollama",
- Choices: []Choice{{
- Index: 0,
- Message: Message{Role: r.Message.Role, Content: r.Message.Content, ToolCalls: toolCalls},
- FinishReason: func(reason string) *string {
- if len(toolCalls) > 0 {
- reason = "tool_calls"
- }
- if len(reason) > 0 {
- return &reason
- }
- return nil
- }(r.DoneReason),
- }},
- Usage: Usage{
- PromptTokens: r.PromptEvalCount,
- CompletionTokens: r.EvalCount,
- TotalTokens: r.PromptEvalCount + r.EvalCount,
- },
- }
- }
- func toChunk(id string, r api.ChatResponse) ChatCompletionChunk {
- return ChatCompletionChunk{
- Id: id,
- Object: "chat.completion.chunk",
- Created: time.Now().Unix(),
- Model: r.Model,
- SystemFingerprint: "fp_ollama",
- Choices: []ChunkChoice{{
- Index: 0,
- Delta: Message{Role: "assistant", Content: r.Message.Content},
- FinishReason: func(reason string) *string {
- if len(reason) > 0 {
- return &reason
- }
- return nil
- }(r.DoneReason),
- }},
- }
- }
- func toCompletion(id string, r api.GenerateResponse) Completion {
- return Completion{
- Id: id,
- Object: "text_completion",
- Created: r.CreatedAt.Unix(),
- Model: r.Model,
- SystemFingerprint: "fp_ollama",
- Choices: []CompleteChunkChoice{{
- Text: r.Response,
- Index: 0,
- FinishReason: func(reason string) *string {
- if len(reason) > 0 {
- return &reason
- }
- return nil
- }(r.DoneReason),
- }},
- Usage: Usage{
- PromptTokens: r.PromptEvalCount,
- CompletionTokens: r.EvalCount,
- TotalTokens: r.PromptEvalCount + r.EvalCount,
- },
- }
- }
- func toCompleteChunk(id string, r api.GenerateResponse) CompletionChunk {
- return CompletionChunk{
- Id: id,
- Object: "text_completion",
- Created: time.Now().Unix(),
- Model: r.Model,
- SystemFingerprint: "fp_ollama",
- Choices: []CompleteChunkChoice{{
- Text: r.Response,
- Index: 0,
- FinishReason: func(reason string) *string {
- if len(reason) > 0 {
- return &reason
- }
- return nil
- }(r.DoneReason),
- }},
- }
- }
- func toListCompletion(r api.ListResponse) ListCompletion {
- var data []Model
- for _, m := range r.Models {
- data = append(data, Model{
- Id: m.Name,
- Object: "model",
- Created: m.ModifiedAt.Unix(),
- OwnedBy: model.ParseName(m.Name).Namespace,
- })
- }
- return ListCompletion{
- Object: "list",
- Data: data,
- }
- }
- func toEmbeddingList(model string, r api.EmbedResponse) EmbeddingList {
- if r.Embeddings != nil {
- var data []Embedding
- for i, e := range r.Embeddings {
- data = append(data, Embedding{
- Object: "embedding",
- Embedding: e,
- Index: i,
- })
- }
- return EmbeddingList{
- Object: "list",
- Data: data,
- Model: model,
- Usage: EmbeddingUsage{
- PromptTokens: r.PromptEvalCount,
- TotalTokens: r.PromptEvalCount,
- },
- }
- }
- return EmbeddingList{}
- }
- func toModel(r api.ShowResponse, m string) Model {
- return Model{
- Id: m,
- Object: "model",
- Created: r.ModifiedAt.Unix(),
- OwnedBy: model.ParseName(m).Namespace,
- }
- }
- func fromChatRequest(r ChatCompletionRequest) (*api.ChatRequest, error) {
- var messages []api.Message
- for _, msg := range r.Messages {
- switch content := msg.Content.(type) {
- case string:
- messages = append(messages, api.Message{Role: msg.Role, Content: content})
- case []any:
- for _, c := range content {
- data, ok := c.(map[string]any)
- if !ok {
- return nil, errors.New("invalid message format")
- }
- switch data["type"] {
- case "text":
- text, ok := data["text"].(string)
- if !ok {
- return nil, errors.New("invalid message format")
- }
- messages = append(messages, api.Message{Role: msg.Role, Content: text})
- case "image_url":
- var url string
- if urlMap, ok := data["image_url"].(map[string]any); ok {
- if url, ok = urlMap["url"].(string); !ok {
- return nil, errors.New("invalid message format")
- }
- } else {
- if url, ok = data["image_url"].(string); !ok {
- return nil, errors.New("invalid message format")
- }
- }
- types := []string{"jpeg", "jpg", "png"}
- valid := false
- for _, t := range types {
- prefix := "data:image/" + t + ";base64,"
- if strings.HasPrefix(url, prefix) {
- url = strings.TrimPrefix(url, prefix)
- valid = true
- break
- }
- }
- if !valid {
- return nil, errors.New("invalid image input")
- }
- img, err := base64.StdEncoding.DecodeString(url)
- if err != nil {
- return nil, errors.New("invalid message format")
- }
- messages = append(messages, api.Message{Role: msg.Role, Images: []api.ImageData{img}})
- default:
- return nil, errors.New("invalid message format")
- }
- }
- default:
- if msg.ToolCalls == nil {
- return nil, fmt.Errorf("invalid message content type: %T", content)
- }
- toolCalls := make([]api.ToolCall, len(msg.ToolCalls))
- for i, tc := range msg.ToolCalls {
- toolCalls[i].Function.Name = tc.Function.Name
- err := json.Unmarshal([]byte(tc.Function.Arguments), &toolCalls[i].Function.Arguments)
- if err != nil {
- return nil, errors.New("invalid tool call arguments")
- }
- }
- messages = append(messages, api.Message{Role: msg.Role, ToolCalls: toolCalls})
- }
- }
- options := make(map[string]interface{})
- switch stop := r.Stop.(type) {
- case string:
- options["stop"] = []string{stop}
- case []any:
- var stops []string
- for _, s := range stop {
- if str, ok := s.(string); ok {
- stops = append(stops, str)
- }
- }
- options["stop"] = stops
- }
- if r.MaxTokens != nil {
- options["num_predict"] = *r.MaxTokens
- }
- if r.Temperature != nil {
- options["temperature"] = *r.Temperature * 2.0
- } else {
- options["temperature"] = 1.0
- }
- if r.Seed != nil {
- options["seed"] = *r.Seed
- }
- if r.FrequencyPenalty != nil {
- options["frequency_penalty"] = *r.FrequencyPenalty * 2.0
- }
- if r.PresencePenalty != nil {
- options["presence_penalty"] = *r.PresencePenalty * 2.0
- }
- if r.TopP != nil {
- options["top_p"] = *r.TopP
- } else {
- options["top_p"] = 1.0
- }
- var format string
- if r.ResponseFormat != nil && r.ResponseFormat.Type == "json_object" {
- format = "json"
- }
- return &api.ChatRequest{
- Model: r.Model,
- Messages: messages,
- Format: format,
- Options: options,
- Stream: &r.Stream,
- Tools: r.Tools,
- }, nil
- }
- func fromCompleteRequest(r CompletionRequest) (api.GenerateRequest, error) {
- options := make(map[string]any)
- switch stop := r.Stop.(type) {
- case string:
- options["stop"] = []string{stop}
- case []any:
- var stops []string
- for _, s := range stop {
- if str, ok := s.(string); ok {
- stops = append(stops, str)
- } else {
- return api.GenerateRequest{}, fmt.Errorf("invalid type for 'stop' field: %T", s)
- }
- }
- options["stop"] = stops
- }
- if r.MaxTokens != nil {
- options["num_predict"] = *r.MaxTokens
- }
- if r.Temperature != nil {
- options["temperature"] = *r.Temperature * 2.0
- } else {
- options["temperature"] = 1.0
- }
- if r.Seed != nil {
- options["seed"] = *r.Seed
- }
- options["frequency_penalty"] = r.FrequencyPenalty * 2.0
- options["presence_penalty"] = r.PresencePenalty * 2.0
- if r.TopP != 0.0 {
- options["top_p"] = r.TopP
- } else {
- options["top_p"] = 1.0
- }
- return api.GenerateRequest{
- Model: r.Model,
- Prompt: r.Prompt,
- Options: options,
- Stream: &r.Stream,
- Suffix: r.Suffix,
- }, nil
- }
- type BaseWriter struct {
- gin.ResponseWriter
- }
- type ChatWriter struct {
- stream bool
- id string
- BaseWriter
- }
- type CompleteWriter struct {
- stream bool
- id string
- BaseWriter
- }
- type ListWriter struct {
- BaseWriter
- }
- type RetrieveWriter struct {
- BaseWriter
- model string
- }
- type EmbedWriter struct {
- BaseWriter
- model string
- }
- func (w *BaseWriter) writeError(code int, data []byte) (int, error) {
- var serr api.StatusError
- err := json.Unmarshal(data, &serr)
- if err != nil {
- return 0, err
- }
- w.ResponseWriter.Header().Set("Content-Type", "application/json")
- err = json.NewEncoder(w.ResponseWriter).Encode(NewError(http.StatusInternalServerError, serr.Error()))
- if err != nil {
- return 0, err
- }
- return len(data), nil
- }
- func (w *ChatWriter) writeResponse(data []byte) (int, error) {
- var chatResponse api.ChatResponse
- err := json.Unmarshal(data, &chatResponse)
- if err != nil {
- return 0, err
- }
- // chat chunk
- if w.stream {
- d, err := json.Marshal(toChunk(w.id, chatResponse))
- if err != nil {
- return 0, err
- }
- w.ResponseWriter.Header().Set("Content-Type", "text/event-stream")
- _, err = w.ResponseWriter.Write([]byte(fmt.Sprintf("data: %s\n\n", d)))
- if err != nil {
- return 0, err
- }
- if chatResponse.Done {
- _, err = w.ResponseWriter.Write([]byte("data: [DONE]\n\n"))
- if err != nil {
- return 0, err
- }
- }
- return len(data), nil
- }
- // chat completion
- w.ResponseWriter.Header().Set("Content-Type", "application/json")
- err = json.NewEncoder(w.ResponseWriter).Encode(toChatCompletion(w.id, chatResponse))
- if err != nil {
- return 0, err
- }
- return len(data), nil
- }
- func (w *ChatWriter) Write(data []byte) (int, error) {
- code := w.ResponseWriter.Status()
- if code != http.StatusOK {
- return w.writeError(code, data)
- }
- return w.writeResponse(data)
- }
- func (w *CompleteWriter) writeResponse(data []byte) (int, error) {
- var generateResponse api.GenerateResponse
- err := json.Unmarshal(data, &generateResponse)
- if err != nil {
- return 0, err
- }
- // completion chunk
- if w.stream {
- d, err := json.Marshal(toCompleteChunk(w.id, generateResponse))
- if err != nil {
- return 0, err
- }
- w.ResponseWriter.Header().Set("Content-Type", "text/event-stream")
- _, err = w.ResponseWriter.Write([]byte(fmt.Sprintf("data: %s\n\n", d)))
- if err != nil {
- return 0, err
- }
- if generateResponse.Done {
- _, err = w.ResponseWriter.Write([]byte("data: [DONE]\n\n"))
- if err != nil {
- return 0, err
- }
- }
- return len(data), nil
- }
- // completion
- w.ResponseWriter.Header().Set("Content-Type", "application/json")
- err = json.NewEncoder(w.ResponseWriter).Encode(toCompletion(w.id, generateResponse))
- if err != nil {
- return 0, err
- }
- return len(data), nil
- }
- func (w *CompleteWriter) Write(data []byte) (int, error) {
- code := w.ResponseWriter.Status()
- if code != http.StatusOK {
- return w.writeError(code, data)
- }
- return w.writeResponse(data)
- }
- func (w *ListWriter) writeResponse(data []byte) (int, error) {
- var listResponse api.ListResponse
- err := json.Unmarshal(data, &listResponse)
- if err != nil {
- return 0, err
- }
- w.ResponseWriter.Header().Set("Content-Type", "application/json")
- err = json.NewEncoder(w.ResponseWriter).Encode(toListCompletion(listResponse))
- if err != nil {
- return 0, err
- }
- return len(data), nil
- }
- func (w *ListWriter) Write(data []byte) (int, error) {
- code := w.ResponseWriter.Status()
- if code != http.StatusOK {
- return w.writeError(code, data)
- }
- return w.writeResponse(data)
- }
- func (w *RetrieveWriter) writeResponse(data []byte) (int, error) {
- var showResponse api.ShowResponse
- err := json.Unmarshal(data, &showResponse)
- if err != nil {
- return 0, err
- }
- // retrieve completion
- w.ResponseWriter.Header().Set("Content-Type", "application/json")
- err = json.NewEncoder(w.ResponseWriter).Encode(toModel(showResponse, w.model))
- if err != nil {
- return 0, err
- }
- return len(data), nil
- }
- func (w *RetrieveWriter) Write(data []byte) (int, error) {
- code := w.ResponseWriter.Status()
- if code != http.StatusOK {
- return w.writeError(code, data)
- }
- return w.writeResponse(data)
- }
- func (w *EmbedWriter) writeResponse(data []byte) (int, error) {
- var embedResponse api.EmbedResponse
- err := json.Unmarshal(data, &embedResponse)
- if err != nil {
- return 0, err
- }
- w.ResponseWriter.Header().Set("Content-Type", "application/json")
- err = json.NewEncoder(w.ResponseWriter).Encode(toEmbeddingList(w.model, embedResponse))
- if err != nil {
- return 0, err
- }
- return len(data), nil
- }
- func (w *EmbedWriter) Write(data []byte) (int, error) {
- code := w.ResponseWriter.Status()
- if code != http.StatusOK {
- return w.writeError(code, data)
- }
- return w.writeResponse(data)
- }
- func ListMiddleware() gin.HandlerFunc {
- return func(c *gin.Context) {
- w := &ListWriter{
- BaseWriter: BaseWriter{ResponseWriter: c.Writer},
- }
- c.Writer = w
- c.Next()
- }
- }
- func RetrieveMiddleware() gin.HandlerFunc {
- return func(c *gin.Context) {
- var b bytes.Buffer
- if err := json.NewEncoder(&b).Encode(api.ShowRequest{Name: c.Param("model")}); err != nil {
- c.AbortWithStatusJSON(http.StatusInternalServerError, NewError(http.StatusInternalServerError, err.Error()))
- return
- }
- c.Request.Body = io.NopCloser(&b)
- // response writer
- w := &RetrieveWriter{
- BaseWriter: BaseWriter{ResponseWriter: c.Writer},
- model: c.Param("model"),
- }
- c.Writer = w
- c.Next()
- }
- }
- func CompletionsMiddleware() gin.HandlerFunc {
- return func(c *gin.Context) {
- var req CompletionRequest
- err := c.ShouldBindJSON(&req)
- if err != nil {
- c.AbortWithStatusJSON(http.StatusBadRequest, NewError(http.StatusBadRequest, err.Error()))
- return
- }
- var b bytes.Buffer
- genReq, err := fromCompleteRequest(req)
- if err != nil {
- c.AbortWithStatusJSON(http.StatusBadRequest, NewError(http.StatusBadRequest, err.Error()))
- return
- }
- if err := json.NewEncoder(&b).Encode(genReq); err != nil {
- c.AbortWithStatusJSON(http.StatusInternalServerError, NewError(http.StatusInternalServerError, err.Error()))
- return
- }
- c.Request.Body = io.NopCloser(&b)
- w := &CompleteWriter{
- BaseWriter: BaseWriter{ResponseWriter: c.Writer},
- stream: req.Stream,
- id: fmt.Sprintf("cmpl-%d", rand.Intn(999)),
- }
- c.Writer = w
- c.Next()
- }
- }
- func EmbeddingsMiddleware() gin.HandlerFunc {
- return func(c *gin.Context) {
- var req EmbedRequest
- err := c.ShouldBindJSON(&req)
- if err != nil {
- c.AbortWithStatusJSON(http.StatusBadRequest, NewError(http.StatusBadRequest, err.Error()))
- return
- }
- if req.Input == "" {
- req.Input = []string{""}
- }
- if req.Input == nil {
- c.AbortWithStatusJSON(http.StatusBadRequest, NewError(http.StatusBadRequest, "invalid input"))
- return
- }
- if v, ok := req.Input.([]any); ok && len(v) == 0 {
- c.AbortWithStatusJSON(http.StatusBadRequest, NewError(http.StatusBadRequest, "invalid input"))
- return
- }
- var b bytes.Buffer
- if err := json.NewEncoder(&b).Encode(api.EmbedRequest{Model: req.Model, Input: req.Input}); err != nil {
- c.AbortWithStatusJSON(http.StatusInternalServerError, NewError(http.StatusInternalServerError, err.Error()))
- return
- }
- c.Request.Body = io.NopCloser(&b)
- w := &EmbedWriter{
- BaseWriter: BaseWriter{ResponseWriter: c.Writer},
- model: req.Model,
- }
- c.Writer = w
- c.Next()
- }
- }
- func ChatMiddleware() gin.HandlerFunc {
- return func(c *gin.Context) {
- var req ChatCompletionRequest
- err := c.ShouldBindJSON(&req)
- if err != nil {
- c.AbortWithStatusJSON(http.StatusBadRequest, NewError(http.StatusBadRequest, err.Error()))
- return
- }
- if len(req.Messages) == 0 {
- c.AbortWithStatusJSON(http.StatusBadRequest, NewError(http.StatusBadRequest, "[] is too short - 'messages'"))
- return
- }
- var b bytes.Buffer
- chatReq, err := fromChatRequest(req)
- if err != nil {
- c.AbortWithStatusJSON(http.StatusBadRequest, NewError(http.StatusBadRequest, err.Error()))
- return
- }
- if err := json.NewEncoder(&b).Encode(chatReq); err != nil {
- c.AbortWithStatusJSON(http.StatusInternalServerError, NewError(http.StatusInternalServerError, err.Error()))
- return
- }
- c.Request.Body = io.NopCloser(&b)
- w := &ChatWriter{
- BaseWriter: BaseWriter{ResponseWriter: c.Writer},
- stream: req.Stream,
- id: fmt.Sprintf("chatcmpl-%d", rand.Intn(999)),
- }
- c.Writer = w
- c.Next()
- }
- }
|