소스 검색

Get embeddings working

Truncation doesn't pass, but the other embeddings tests pass
Daniel Hiltgen 9 달 전
부모
커밋
b2f8a6120c
1개의 변경된 파일29개의 추가작업 그리고 14개의 파일을 삭제
  1. 29 14
      llama/runner/runner.go

+ 29 - 14
llama/runner/runner.go

@@ -56,7 +56,7 @@ func (s *Sequence) prompt() bool {
 }
 }
 
 
 func (s *Server) NewSequence(prompt string, numPredict int, stop []string, params *llama.SamplingParams, embedding bool) *Sequence {
 func (s *Server) NewSequence(prompt string, numPredict int, stop []string, params *llama.SamplingParams, embedding bool) *Sequence {
-	tokens, err := s.lc.Model().Tokenize(prompt, false, true)
+	tokens, err := s.lc.Model().Tokenize(prompt, embedding, true)
 	if err != nil {
 	if err != nil {
 		panic(err)
 		panic(err)
 	}
 	}
@@ -353,11 +353,11 @@ func (s *Server) completion(w http.ResponseWriter, r *http.Request) {
 }
 }
 
 
 type EmbeddingRequest struct {
 type EmbeddingRequest struct {
-	Prompt string `json:"prompt"`
+	Content []string `json:"content"`
 }
 }
 
 
 type EmbeddingResponse struct {
 type EmbeddingResponse struct {
-	Embedding []float32 `json:"embedding"`
+	Embedding [][]float32 `json:"embedding"`
 }
 }
 
 
 // TODO (jmorganca): is it safe to do this concurrently with decoding?
 // TODO (jmorganca): is it safe to do this concurrently with decoding?
@@ -370,22 +370,37 @@ func (s *Server) embeddings(w http.ResponseWriter, r *http.Request) {
 
 
 	w.Header().Set("Content-Type", "application/json")
 	w.Header().Set("Content-Type", "application/json")
 
 
-	seq := s.NewSequence(req.Prompt, 0, nil, nil, true)
+	slog.Debug("embedding request", "content", req.Content)
+	seqs := make([]*Sequence, len(req.Content))
+	embeddings := make([][]float32, len(req.Content))
+	var processed int
+	for i, content := range req.Content {
+		seqs[i] = s.NewSequence(content, 0, nil, nil, true)
+	}
 
 
-	s.mu.Lock()
-	for i, sq := range s.seqs {
-		if sq == nil {
-			s.seqs[i] = seq
-			s.cond.Signal()
-			break
+	// TODO - refactor to go routines to add seq's and drain the responses
+	// so we don't stall until each set is iterated through
+	for processed < len(seqs) {
+		s.mu.Lock()
+		for i, sq := range s.seqs {
+			if processed >= len(seqs) {
+				break
+			}
+			if sq == nil {
+				s.seqs[i] = seqs[processed]
+				processed += 1
+			}
 		}
 		}
-	}
-	s.mu.Unlock()
+		s.cond.Signal()
+		s.mu.Unlock()
 
 
-	embedding := <-seq.embedding
+		for i := range processed {
+			embeddings[i] = <-seqs[i].embedding
+		}
+	}
 
 
 	if err := json.NewEncoder(w).Encode(&EmbeddingResponse{
 	if err := json.NewEncoder(w).Encode(&EmbeddingResponse{
-		Embedding: embedding,
+		Embedding: embeddings,
 	}); err != nil {
 	}); err != nil {
 		log.Println("Failed to encode result:", err)
 		log.Println("Failed to encode result:", err)
 		return
 		return