Browse Source

engine: error on embeddings; not currently implemented

Michael Yang 1 month ago
parent
commit
ec46f3286c
2 changed files with 9 additions and 66 deletions
  1. 7 62
      runner/ollamarunner/runner.go
  2. 2 4
      server/routes.go

+ 7 - 62
runner/ollamarunner/runner.go

@@ -691,65 +691,6 @@ type EmbeddingResponse struct {
 	Embedding []float32 `json:"embedding"`
 }
 
-func (s *Server) embeddings(w http.ResponseWriter, r *http.Request) {
-	var req EmbeddingRequest
-	if err := json.NewDecoder(r.Body).Decode(&req); err != nil {
-		http.Error(w, fmt.Sprintf("bad request: %s", err), http.StatusBadRequest)
-		return
-	}
-
-	w.Header().Set("Content-Type", "application/json")
-
-	slog.Debug("embedding request", "content", req.Content)
-
-	seq, err := s.NewSequence(req.Content, nil, NewSequenceParams{embedding: true})
-	if err != nil {
-		http.Error(w, fmt.Sprintf("Failed to create new sequence: %v", err), http.StatusInternalServerError)
-		return
-	}
-
-	// Ensure there is a place to put the sequence, released when removed from s.seqs
-	if err := s.seqsSem.Acquire(r.Context(), 1); err != nil {
-		if errors.Is(err, context.Canceled) {
-			slog.Info("aborting embeddings request due to client closing the connection")
-		} else {
-			slog.Error("Failed to acquire semaphore", "error", err)
-		}
-		return
-	}
-
-	s.mu.Lock()
-	found := false
-	for i, sq := range s.seqs {
-		if sq == nil {
-			seq.cache, seq.inputs, err = s.cache.LoadCacheSlot(seq.inputs, req.CachePrompt)
-			if err != nil {
-				s.mu.Unlock()
-				http.Error(w, fmt.Sprintf("Failed to load cache: %v", err), http.StatusInternalServerError)
-				return
-			}
-			s.seqs[i] = seq
-			s.cond.Signal()
-			found = true
-			break
-		}
-	}
-	s.mu.Unlock()
-
-	if !found {
-		http.Error(w, "could not find an available sequence", http.StatusInternalServerError)
-		return
-	}
-
-	embedding := <-seq.embedding
-
-	if err := json.NewEncoder(w).Encode(&EmbeddingResponse{
-		Embedding: embedding,
-	}); err != nil {
-		http.Error(w, fmt.Sprintf("failed to encode response: %v", err), http.StatusInternalServerError)
-	}
-}
-
 type HealthResponse struct {
 	Status   string  `json:"status"`
 	Progress float32 `json:"progress"`
@@ -927,9 +868,13 @@ func Execute(args []string) error {
 	defer listener.Close()
 
 	mux := http.NewServeMux()
-	mux.HandleFunc("/embedding", server.embeddings)
-	mux.HandleFunc("/completion", server.completion)
-	mux.HandleFunc("/health", server.health)
+	// TODO: support embeddings
+	mux.HandleFunc("POST /embedding", func(w http.ResponseWriter, r *http.Request) {
+		http.Error(w, "this model does not support embeddings", http.StatusNotImplemented)
+	})
+
+	mux.HandleFunc("POST /completion", server.completion)
+	mux.HandleFunc("GET /health", server.health)
 
 	httpServer := http.Server{
 		Handler: mux,

+ 2 - 4
server/routes.go

@@ -483,8 +483,7 @@ func (s *Server) EmbedHandler(c *gin.Context) {
 	}
 
 	if err := g.Wait(); err != nil {
-		slog.Error("embedding generation failed", "error", err)
-		c.JSON(http.StatusInternalServerError, gin.H{"error": fmt.Errorf("failed to generate embeddings: %v", err)})
+		c.AbortWithStatusJSON(http.StatusInternalServerError, gin.H{"error": strings.TrimSpace(err.Error())})
 		return
 	}
 
@@ -545,8 +544,7 @@ func (s *Server) EmbeddingsHandler(c *gin.Context) {
 
 	embedding, err := r.Embedding(c.Request.Context(), req.Prompt)
 	if err != nil {
-		slog.Info(fmt.Sprintf("embedding generation failed: %v", err))
-		c.JSON(http.StatusInternalServerError, gin.H{"error": fmt.Errorf("failed to generate embedding: %v", err)})
+		c.AbortWithStatusJSON(http.StatusInternalServerError, gin.H{"error": strings.TrimSpace(err.Error())})
 		return
 	}