Forráskód Böngészése

runner.go: Fix embeddings endpoint

The embeddings endpoint only takes a single input and provides a
single output, instead of multiple as the current implementation
expected. Fixing this also allows the implementation to be simplified
and a few embedding-specific issues to be addressed.
Jesse Gross 8 hónapja
szülő
commit
46a7c682f2
2 módosított fájl, 25 hozzáadás és 32 törlés
  1. 1 1
      llama/llama.go
  2. 24 31
      llama/runner/runner.go

+ 1 - 1
llama/llama.go

@@ -429,7 +429,7 @@ func NewSamplingContext(params SamplingParams) *SamplingContext {
 }
 }
 
 
 func (s *SamplingContext) Free() {
 func (s *SamplingContext) Free() {
-	if s.c != nil {
+	if s != nil {
 		C.llama_sampling_cfree(s.c)
 		C.llama_sampling_cfree(s.c)
 	}
 	}
 }
 }

+ 24 - 31
llama/runner/runner.go

@@ -88,9 +88,15 @@ func (s *Server) NewSequence(prompt string, params NewSequenceParams) *Sequence
 	if params.numKeep < 0 {
 	if params.numKeep < 0 {
 		params.numKeep = len(tokens)
 		params.numKeep = len(tokens)
 	}
 	}
-	// Subtracting 4 ensures that at least 1 token can be discarded during shift
-	params.numKeep = min(params.numKeep, s.numCtx-4)
-	params.numKeep += s.bosToken
+
+	if !params.embedding {
+		// Subtracting 4 ensures that at least 1 token can be discarded during shift
+		params.numKeep = min(params.numKeep, s.numCtx-4)
+		params.numKeep += s.bosToken
+	} else {
+		// Embeddings are 1 shot - just truncate to the context window, without ever shifting
+		params.numKeep = min(params.numKeep, s.numCtx)
+	}
 
 
 	// truncate to fit in context window
 	// truncate to fit in context window
 	if len(tokens) > s.numCtx {
 	if len(tokens) > s.numCtx {
@@ -523,14 +529,13 @@ func (s *Server) completion(w http.ResponseWriter, r *http.Request) {
 }
 }
 
 
 type EmbeddingRequest struct {
 type EmbeddingRequest struct {
-	Content []string `json:"content"`
+	Content string `json:"content"`
 }
 }
 
 
 type EmbeddingResponse struct {
 type EmbeddingResponse struct {
-	Embedding [][]float32 `json:"embedding"`
+	Embedding []float32 `json:"embedding"`
 }
 }
 
 
-// TODO (jmorganca): is it safe to do this concurrently with decoding?
 func (s *Server) embeddings(w http.ResponseWriter, r *http.Request) {
 func (s *Server) embeddings(w http.ResponseWriter, r *http.Request) {
 	var req EmbeddingRequest
 	var req EmbeddingRequest
 	if err := json.NewDecoder(r.Body).Decode(&req); err != nil {
 	if err := json.NewDecoder(r.Body).Decode(&req); err != nil {
@@ -541,36 +546,24 @@ func (s *Server) embeddings(w http.ResponseWriter, r *http.Request) {
 	w.Header().Set("Content-Type", "application/json")
 	w.Header().Set("Content-Type", "application/json")
 
 
 	slog.Debug("embedding request", "content", req.Content)
 	slog.Debug("embedding request", "content", req.Content)
-	seqs := make([]*Sequence, len(req.Content))
-	embeddings := make([][]float32, len(req.Content))
-	var processed int
-	for i, content := range req.Content {
-		seqs[i] = s.NewSequence(content, NewSequenceParams{embedding: true})
-	}
-
-	// TODO - refactor to go routines to add seq's and drain the responses
-	// so we don't stall until each set is iterated through
-	for processed < len(seqs) {
-		s.mu.Lock()
-		for i, sq := range s.seqs {
-			if processed >= len(seqs) {
-				break
-			}
-			if sq == nil {
-				s.seqs[i] = seqs[processed]
-				processed += 1
-			}
-		}
-		s.cond.Signal()
-		s.mu.Unlock()
 
 
-		for i := range processed {
-			embeddings[i] = <-seqs[i].embedding
+	seq := s.NewSequence(req.Content, NewSequenceParams{embedding: true})
+
+	// TODO (jessegross): Wait for a free slot instead of failing and blocking forever
+	s.mu.Lock()
+	for i, sq := range s.seqs {
+		if sq == nil {
+			s.seqs[i] = seq
+			s.cond.Signal()
+			break
 		}
 		}
 	}
 	}
+	s.mu.Unlock()
+
+	embedding := <-seq.embedding
 
 
 	if err := json.NewEncoder(w).Encode(&EmbeddingResponse{
 	if err := json.NewEncoder(w).Encode(&EmbeddingResponse{
-		Embedding: embeddings,
+		Embedding: embedding,
 	}); err != nil {
 	}); err != nil {
 		log.Println("Failed to encode result:", err)
 		log.Println("Failed to encode result:", err)
 		return
 		return