|
@@ -339,8 +339,8 @@ func getDefaultSessionDuration() time.Duration {
|
|
|
return defaultSessionDuration
|
|
|
}
|
|
|
|
|
|
-func (s *Server) EmbeddingsHandler(c *gin.Context) {
|
|
|
- var req api.EmbeddingRequest
|
|
|
+func (s *Server) EmbedHandler(c *gin.Context) {
|
|
|
+ var req api.EmbedRequest
|
|
|
err := c.ShouldBindJSON(&req)
|
|
|
switch {
|
|
|
case errors.Is(err, io.EOF):
|
|
@@ -389,39 +389,101 @@ func (s *Server) EmbeddingsHandler(c *gin.Context) {
|
|
|
return
|
|
|
}
|
|
|
|
|
|
- switch {
|
|
|
- // single embedding
|
|
|
- case len(req.Prompt) > 0:
|
|
|
- slog.Info("embedding request", "prompt", req.Prompt)
|
|
|
- embeddings, err := runner.llama.Embedding(c.Request.Context(), []string{req.Prompt})
|
|
|
- if err != nil {
|
|
|
- slog.Info(fmt.Sprintf("embedding generation failed: %v", err))
|
|
|
- c.JSON(http.StatusInternalServerError, gin.H{"error": "failed to generate embedding"})
|
|
|
+ embeddings := [][]float64{}
|
|
|
+
|
|
|
+ switch reqEmbed := req.Input.(type) {
|
|
|
+ case string:
|
|
|
+ if reqEmbed == "" {
|
|
|
+ c.JSON(http.StatusOK, api.EmbedResponse{Embeddings: [][]float64{}})
|
|
|
return
|
|
|
}
|
|
|
-
|
|
|
- resp := api.EmbeddingResponse{Embedding: embeddings[0]}
|
|
|
- c.JSON(http.StatusOK, resp)
|
|
|
- // batch embeddings
|
|
|
- case len(req.PromptBatch) > 0:
|
|
|
- embeddings, err := runner.llama.Embedding(c.Request.Context(), req.PromptBatch)
|
|
|
- if err != nil {
|
|
|
- slog.Info(fmt.Sprintf("batch embedding generation failed: %v", err))
|
|
|
- c.JSON(http.StatusInternalServerError, gin.H{"error": "failed to generate embedding"})
|
|
|
+ embeddings, err = runner.llama.Embedding(c.Request.Context(), []string{reqEmbed})
|
|
|
+ case []string:
|
|
|
+ if reqEmbed == nil {
|
|
|
+ c.JSON(http.StatusOK, api.EmbedResponse{Embeddings: [][]float64{}})
|
|
|
return
|
|
|
}
|
|
|
+ embeddings, err = runner.llama.Embedding(c.Request.Context(), reqEmbed)
|
|
|
+ default:
|
|
|
+ c.AbortWithStatusJSON(http.StatusBadRequest, gin.H{"error": "invalid input type"})
|
|
|
+ }
|
|
|
|
|
|
- resp := api.EmbeddingResponse{EmbeddingBatch: embeddings}
|
|
|
- c.JSON(http.StatusOK, resp)
|
|
|
+ if err != nil {
|
|
|
+ slog.Info(fmt.Sprintf("embedding generation failed: %v", err))
|
|
|
+ c.JSON(http.StatusInternalServerError, gin.H{"error": "failed to generate embedding"})
|
|
|
+ return
|
|
|
+ }
|
|
|
|
|
|
- // empty prompt loads the model
|
|
|
- default:
|
|
|
- if req.PromptBatch != nil {
|
|
|
- c.JSON(http.StatusOK, api.EmbeddingResponse{EmbeddingBatch: [][]float64{}})
|
|
|
- } else {
|
|
|
- c.JSON(http.StatusOK, api.EmbeddingResponse{Embedding: []float64{}})
|
|
|
+ resp := api.EmbedResponse{Embeddings: embeddings}
|
|
|
+ c.JSON(http.StatusOK, resp)
|
|
|
+}
|
|
|
+
|
|
|
+func (s *Server) EmbeddingsHandler(c *gin.Context) {
|
|
|
+ var req api.EmbeddingRequest
|
|
|
+ err := c.ShouldBindJSON(&req)
|
|
|
+ switch {
|
|
|
+ case errors.Is(err, io.EOF):
|
|
|
+ c.AbortWithStatusJSON(http.StatusBadRequest, gin.H{"error": "missing request body"})
|
|
|
+ return
|
|
|
+ case err != nil:
|
|
|
+ c.AbortWithStatusJSON(http.StatusBadRequest, gin.H{"error": err.Error()})
|
|
|
+ return
|
|
|
+ }
|
|
|
+
|
|
|
+ if req.Model == "" {
|
|
|
+ c.AbortWithStatusJSON(http.StatusBadRequest, gin.H{"error": "model is required"})
|
|
|
+ return
|
|
|
+ }
|
|
|
+
|
|
|
+ model, err := GetModel(req.Model)
|
|
|
+ if err != nil {
|
|
|
+ var pErr *fs.PathError
|
|
|
+ if errors.As(err, &pErr) {
|
|
|
+ c.JSON(http.StatusNotFound, gin.H{"error": fmt.Sprintf("model '%s' not found, try pulling it first", req.Model)})
|
|
|
+ return
|
|
|
}
|
|
|
+ c.JSON(http.StatusInternalServerError, gin.H{"error": err.Error()})
|
|
|
+ return
|
|
|
+ }
|
|
|
+
|
|
|
+ opts, err := modelOptions(model, req.Options)
|
|
|
+ if err != nil {
|
|
|
+ c.JSON(http.StatusInternalServerError, gin.H{"error": err.Error()})
|
|
|
+ return
|
|
|
+ }
|
|
|
+
|
|
|
+ var sessionDuration time.Duration
|
|
|
+ if req.KeepAlive == nil {
|
|
|
+ sessionDuration = getDefaultSessionDuration()
|
|
|
+ } else {
|
|
|
+ sessionDuration = req.KeepAlive.Duration
|
|
|
}
|
|
|
+
|
|
|
+ rCh, eCh := s.sched.GetRunner(c.Request.Context(), model, opts, sessionDuration)
|
|
|
+ var runner *runnerRef
|
|
|
+ select {
|
|
|
+ case runner = <-rCh:
|
|
|
+ case err = <-eCh:
|
|
|
+ handleErrorResponse(c, err)
|
|
|
+ return
|
|
|
+ }
|
|
|
+
|
|
|
+ // an empty request loads the model
|
|
|
+ if req.Prompt == "" {
|
|
|
+ c.JSON(http.StatusOK, api.EmbeddingResponse{Embedding: []float64{}})
|
|
|
+ return
|
|
|
+ }
|
|
|
+ embedding, err := runner.llama.Embedding(c.Request.Context(), req.Prompt)
|
|
|
+ if err != nil {
|
|
|
+ slog.Info(fmt.Sprintf("embedding generation failed: %v", err))
|
|
|
+ c.JSON(http.StatusInternalServerError, gin.H{"error": "failed to generate embedding"})
|
|
|
+ return
|
|
|
+ }
|
|
|
+
|
|
|
+ resp := api.EmbeddingResponse{
|
|
|
+ Embedding: embedding,
|
|
|
+ }
|
|
|
+ c.JSON(http.StatusOK, resp)
|
|
|
}
|
|
|
|
|
|
func (s *Server) PullModelHandler(c *gin.Context) {
|
|
@@ -1005,7 +1067,8 @@ func (s *Server) GenerateRoutes() http.Handler {
|
|
|
r.POST("/api/pull", s.PullModelHandler)
|
|
|
r.POST("/api/generate", s.GenerateHandler)
|
|
|
r.POST("/api/chat", s.ChatHandler)
|
|
|
- r.POST("/api/embeddings", s.EmbeddingsHandler)
|
|
|
+ r.POST("/api/embed", s.EmbedHandler)
|
|
|
+ r.POST("/api/embeddings", s.EmbeddingsHandler) // legacy
|
|
|
r.POST("/api/create", s.CreateModelHandler)
|
|
|
r.POST("/api/push", s.PushModelHandler)
|
|
|
r.POST("/api/copy", s.CopyModelHandler)
|