|
@@ -61,6 +61,11 @@ type ResponseFormat struct {
|
|
|
Type string `json:"type"`
|
|
|
}
|
|
|
|
|
|
+type EmbedRequest struct {
|
|
|
+ Input any `json:"input"`
|
|
|
+ Model string `json:"model"`
|
|
|
+}
|
|
|
+
|
|
|
type ChatCompletionRequest struct {
|
|
|
Model string `json:"model"`
|
|
|
Messages []Message `json:"messages"`
|
|
@@ -134,11 +139,23 @@ type Model struct {
|
|
|
OwnedBy string `json:"owned_by"`
|
|
|
}
|
|
|
|
|
|
+type Embedding struct {
|
|
|
+ Object string `json:"object"`
|
|
|
+ Embedding []float32 `json:"embedding"`
|
|
|
+ Index int `json:"index"`
|
|
|
+}
|
|
|
+
|
|
|
type ListCompletion struct {
|
|
|
Object string `json:"object"`
|
|
|
Data []Model `json:"data"`
|
|
|
}
|
|
|
|
|
|
+type EmbeddingList struct {
|
|
|
+ Object string `json:"object"`
|
|
|
+ Data []Embedding `json:"data"`
|
|
|
+ Model string `json:"model"`
|
|
|
+}
|
|
|
+
|
|
|
func NewError(code int, message string) ErrorResponse {
|
|
|
var etype string
|
|
|
switch code {
|
|
@@ -262,6 +279,27 @@ func toListCompletion(r api.ListResponse) ListCompletion {
|
|
|
}
|
|
|
}
|
|
|
|
|
|
+func toEmbeddingList(model string, r api.EmbedResponse) EmbeddingList {
|
|
|
+ if r.Embeddings != nil {
|
|
|
+ var data []Embedding
|
|
|
+ for i, e := range r.Embeddings {
|
|
|
+ data = append(data, Embedding{
|
|
|
+ Object: "embedding",
|
|
|
+ Embedding: e,
|
|
|
+ Index: i,
|
|
|
+ })
|
|
|
+ }
|
|
|
+
|
|
|
+ return EmbeddingList{
|
|
|
+ Object: "list",
|
|
|
+ Data: data,
|
|
|
+ Model: model,
|
|
|
+ }
|
|
|
+ }
|
|
|
+
|
|
|
+ return EmbeddingList{}
|
|
|
+}
|
|
|
+
|
|
|
func toModel(r api.ShowResponse, m string) Model {
|
|
|
return Model{
|
|
|
Id: m,
|
|
@@ -465,6 +503,11 @@ type RetrieveWriter struct {
|
|
|
model string
|
|
|
}
|
|
|
|
|
|
+type EmbedWriter struct {
|
|
|
+ BaseWriter
|
|
|
+ model string
|
|
|
+}
|
|
|
+
|
|
|
func (w *BaseWriter) writeError(code int, data []byte) (int, error) {
|
|
|
var serr api.StatusError
|
|
|
err := json.Unmarshal(data, &serr)
|
|
@@ -630,6 +673,33 @@ func (w *RetrieveWriter) Write(data []byte) (int, error) {
|
|
|
return w.writeResponse(data)
|
|
|
}
|
|
|
|
|
|
+func (w *EmbedWriter) writeResponse(data []byte) (int, error) {
|
|
|
+ var embedResponse api.EmbedResponse
|
|
|
+ err := json.Unmarshal(data, &embedResponse)
|
|
|
+
|
|
|
+ if err != nil {
|
|
|
+ return 0, err
|
|
|
+ }
|
|
|
+
|
|
|
+ w.ResponseWriter.Header().Set("Content-Type", "application/json")
|
|
|
+ err = json.NewEncoder(w.ResponseWriter).Encode(toEmbeddingList(w.model, embedResponse))
|
|
|
+
|
|
|
+ if err != nil {
|
|
|
+ return 0, err
|
|
|
+ }
|
|
|
+
|
|
|
+ return len(data), nil
|
|
|
+}
|
|
|
+
|
|
|
+func (w *EmbedWriter) Write(data []byte) (int, error) {
|
|
|
+ code := w.ResponseWriter.Status()
|
|
|
+ if code != http.StatusOK {
|
|
|
+ return w.writeError(code, data)
|
|
|
+ }
|
|
|
+
|
|
|
+ return w.writeResponse(data)
|
|
|
+}
|
|
|
+
|
|
|
func ListMiddleware() gin.HandlerFunc {
|
|
|
return func(c *gin.Context) {
|
|
|
w := &ListWriter{
|
|
@@ -693,6 +763,47 @@ func CompletionsMiddleware() gin.HandlerFunc {
|
|
|
id: fmt.Sprintf("cmpl-%d", rand.Intn(999)),
|
|
|
}
|
|
|
|
|
|
+ c.Writer = w
|
|
|
+ c.Next()
|
|
|
+ }
|
|
|
+}
|
|
|
+
|
|
|
+func EmbeddingsMiddleware() gin.HandlerFunc {
|
|
|
+ return func(c *gin.Context) {
|
|
|
+ var req EmbedRequest
|
|
|
+ err := c.ShouldBindJSON(&req)
|
|
|
+ if err != nil {
|
|
|
+ c.AbortWithStatusJSON(http.StatusBadRequest, NewError(http.StatusBadRequest, err.Error()))
|
|
|
+ return
|
|
|
+ }
|
|
|
+
|
|
|
+ if req.Input == "" {
|
|
|
+ req.Input = []string{""}
|
|
|
+ }
|
|
|
+
|
|
|
+ if req.Input == nil {
|
|
|
+ c.AbortWithStatusJSON(http.StatusBadRequest, NewError(http.StatusBadRequest, "invalid input"))
|
|
|
+ return
|
|
|
+ }
|
|
|
+
|
|
|
+ if v, ok := req.Input.([]any); ok && len(v) == 0 {
|
|
|
+ c.AbortWithStatusJSON(http.StatusBadRequest, NewError(http.StatusBadRequest, "invalid input"))
|
|
|
+ return
|
|
|
+ }
|
|
|
+
|
|
|
+ var b bytes.Buffer
|
|
|
+ if err := json.NewEncoder(&b).Encode(api.EmbedRequest{Model: req.Model, Input: req.Input}); err != nil {
|
|
|
+ c.AbortWithStatusJSON(http.StatusInternalServerError, NewError(http.StatusInternalServerError, err.Error()))
|
|
|
+ return
|
|
|
+ }
|
|
|
+
|
|
|
+ c.Request.Body = io.NopCloser(&b)
|
|
|
+
|
|
|
+ w := &EmbedWriter{
|
|
|
+ BaseWriter: BaseWriter{ResponseWriter: c.Writer},
|
|
|
+ model: req.Model,
|
|
|
+ }
|
|
|
+
|
|
|
c.Writer = w
|
|
|
|
|
|
c.Next()
|