|
@@ -392,6 +392,29 @@ func (s *Server) EmbedHandler(c *gin.Context) {
|
|
sessionDuration = req.KeepAlive.Duration
|
|
sessionDuration = req.KeepAlive.Duration
|
|
}
|
|
}
|
|
|
|
|
|
|
|
+ switch reqEmbed := req.Input.(type) {
|
|
|
|
+ case string:
|
|
|
|
+ if reqEmbed == "" {
|
|
|
|
+ c.JSON(http.StatusOK, api.EmbedResponse{Model: req.Model, Embeddings: [][]float32{}})
|
|
|
|
+ return
|
|
|
|
+ }
|
|
|
|
+ case []any:
|
|
|
|
+ if reqEmbed == nil {
|
|
|
|
+ c.JSON(http.StatusOK, api.EmbedResponse{Model: req.Model, Embeddings: [][]float32{}})
|
|
|
|
+ return
|
|
|
|
+ }
|
|
|
|
+
|
|
|
|
+ for _, v := range reqEmbed {
|
|
|
|
+ if _, ok := v.(string); !ok {
|
|
|
|
+ c.AbortWithStatusJSON(http.StatusBadRequest, gin.H{"error": "invalid input type"})
|
|
|
|
+ return
|
|
|
|
+ }
|
|
|
|
+ }
|
|
|
|
+ default:
|
|
|
|
+ c.AbortWithStatusJSON(http.StatusBadRequest, gin.H{"error": "invalid input type"})
|
|
|
|
+ return
|
|
|
|
+ }
|
|
|
|
+
|
|
rCh, eCh := s.sched.GetRunner(c.Request.Context(), model, opts, sessionDuration)
|
|
rCh, eCh := s.sched.GetRunner(c.Request.Context(), model, opts, sessionDuration)
|
|
var runner *runnerRef
|
|
var runner *runnerRef
|
|
select {
|
|
select {
|
|
@@ -424,10 +447,6 @@ func (s *Server) EmbedHandler(c *gin.Context) {
|
|
|
|
|
|
switch reqEmbed := req.Input.(type) {
|
|
switch reqEmbed := req.Input.(type) {
|
|
case string:
|
|
case string:
|
|
- if reqEmbed == "" {
|
|
|
|
- c.JSON(http.StatusOK, api.EmbedResponse{Embeddings: [][]float32{}})
|
|
|
|
- return
|
|
|
|
- }
|
|
|
|
reqEmbed, err = checkFit(reqEmbed, *req.Truncate)
|
|
reqEmbed, err = checkFit(reqEmbed, *req.Truncate)
|
|
if err != nil {
|
|
if err != nil {
|
|
c.JSON(http.StatusInternalServerError, gin.H{"error": err.Error()})
|
|
c.JSON(http.StatusInternalServerError, gin.H{"error": err.Error()})
|
|
@@ -435,24 +454,14 @@ func (s *Server) EmbedHandler(c *gin.Context) {
|
|
}
|
|
}
|
|
embeddings, err = runner.llama.Embed(c.Request.Context(), []string{reqEmbed})
|
|
embeddings, err = runner.llama.Embed(c.Request.Context(), []string{reqEmbed})
|
|
case []any:
|
|
case []any:
|
|
- if reqEmbed == nil {
|
|
|
|
- c.JSON(http.StatusOK, api.EmbedResponse{Embeddings: [][]float32{}})
|
|
|
|
- return
|
|
|
|
- }
|
|
|
|
-
|
|
|
|
reqEmbedArray := make([]string, len(reqEmbed))
|
|
reqEmbedArray := make([]string, len(reqEmbed))
|
|
for i, v := range reqEmbed {
|
|
for i, v := range reqEmbed {
|
|
- if s, ok := v.(string); ok {
|
|
|
|
- s, err = checkFit(s, *req.Truncate)
|
|
|
|
- if err != nil {
|
|
|
|
- c.JSON(http.StatusInternalServerError, gin.H{"error": err.Error()})
|
|
|
|
- return
|
|
|
|
- }
|
|
|
|
- reqEmbedArray[i] = s
|
|
|
|
- } else {
|
|
|
|
- c.AbortWithStatusJSON(http.StatusBadRequest, gin.H{"error": "invalid input type"})
|
|
|
|
|
|
+ s, err := checkFit(v.(string), *req.Truncate)
|
|
|
|
+ if err != nil {
|
|
|
|
+ c.JSON(http.StatusInternalServerError, gin.H{"error": err.Error()})
|
|
return
|
|
return
|
|
}
|
|
}
|
|
|
|
+ reqEmbedArray[i] = s
|
|
}
|
|
}
|
|
embeddings, err = runner.llama.Embed(c.Request.Context(), reqEmbedArray)
|
|
embeddings, err = runner.llama.Embed(c.Request.Context(), reqEmbedArray)
|
|
default:
|
|
default:
|