|
@@ -265,29 +265,38 @@ func (s *Server) EmbedHandler(c *gin.Context) {
|
|
|
truncate = false
|
|
|
}
|
|
|
|
|
|
+ inputCheck := true
|
|
|
+
|
|
|
+ if req.Images != nil {
|
|
|
+ inputCheck = false
|
|
|
+ }
|
|
|
+
|
|
|
var input []string
|
|
|
|
|
|
- switch i := req.Input.(type) {
|
|
|
- case string:
|
|
|
- if len(i) > 0 {
|
|
|
- input = append(input, i)
|
|
|
- }
|
|
|
- case []any:
|
|
|
- for _, v := range i {
|
|
|
- if _, ok := v.(string); !ok {
|
|
|
- c.AbortWithStatusJSON(http.StatusBadRequest, gin.H{"error": "invalid input type"})
|
|
|
- return
|
|
|
+ if inputCheck {
|
|
|
+
|
|
|
+ switch i := req.Input.(type) {
|
|
|
+ case string:
|
|
|
+ if len(i) > 0 {
|
|
|
+ input = append(input, i)
|
|
|
}
|
|
|
- input = append(input, v.(string))
|
|
|
+ case []any:
|
|
|
+ for _, v := range i {
|
|
|
+ if _, ok := v.(string); !ok {
|
|
|
+ c.AbortWithStatusJSON(http.StatusBadRequest, gin.H{"error": "invalid input type"})
|
|
|
+ return
|
|
|
+ }
|
|
|
+ input = append(input, v.(string))
|
|
|
+ }
|
|
|
+ default:
|
|
|
+ c.AbortWithStatusJSON(http.StatusBadRequest, gin.H{"error": "invalid input type"})
|
|
|
+ return
|
|
|
}
|
|
|
- default:
|
|
|
- c.AbortWithStatusJSON(http.StatusBadRequest, gin.H{"error": "invalid input type"})
|
|
|
- return
|
|
|
- }
|
|
|
|
|
|
- if len(input) == 0 {
|
|
|
- c.JSON(http.StatusOK, api.EmbedResponse{Model: req.Model, Embeddings: [][]float32{}})
|
|
|
- return
|
|
|
+ if len(input) == 0 {
|
|
|
+ c.JSON(http.StatusOK, api.EmbedResponse{Model: req.Model, Embeddings: [][]float32{}})
|
|
|
+ return
|
|
|
+ }
|
|
|
}
|
|
|
|
|
|
r, m, opts, err := s.scheduleRunner(c.Request.Context(), req.Model, []Capability{}, req.Options, req.KeepAlive)
|
|
@@ -326,7 +335,13 @@ func (s *Server) EmbedHandler(c *gin.Context) {
|
|
|
|
|
|
input[i] = s
|
|
|
}
|
|
|
- embeddings, err := r.Embed(c.Request.Context(), input)
|
|
|
+
|
|
|
+ images := make([]llm.ImageData, len(req.Images))
|
|
|
+ for i := range req.Images {
|
|
|
+ images[i] = llm.ImageData{ID: i, Data: req.Images[i]}
|
|
|
+ }
|
|
|
+
|
|
|
+ embeddings, err := r.Embed(c.Request.Context(), input, images)
|
|
|
|
|
|
if err != nil {
|
|
|
slog.Error("embedding generation failed", "error", err)
|
|
@@ -384,7 +399,7 @@ func (s *Server) EmbeddingsHandler(c *gin.Context) {
|
|
|
return
|
|
|
}
|
|
|
|
|
|
- embeddings, err := r.Embed(c.Request.Context(), []string{req.Prompt})
|
|
|
+ embeddings, err := r.Embed(c.Request.Context(), []string{req.Prompt}, nil)
|
|
|
|
|
|
if err != nil {
|
|
|
slog.Info(fmt.Sprintf("embedding generation failed: %v", err))
|