Procházet zdrojové kódy

runner.go: Move pieces[] into sequence

pieces[] is used to cache pending responses and is currently being
passed around to different functions. Move it into the sequences
where it logically belongs.
Jesse Gross před 8 měsíci
rodič
revize
d022cfc9e6
1 změnil soubory, kde provedl 30 přidání a 31 odebrání
  1. 30 31
      llama/runner/runner.go

+ 30 - 31
llama/runner/runner.go

@@ -35,6 +35,10 @@ type Sequence struct {
 	// tokens left to evaluate
 	// tokens left to evaluate
 	tokens []int
 	tokens []int
 
 
+	// tokens that have been generated but not returned yet (e.g. for stop sequences)
+	// TODO (jmorganca): simplify this
+	pendingResponses []string
+
 	// channel to send responses over
 	// channel to send responses over
 	responses chan string
 	responses chan string
 
 
@@ -105,16 +109,17 @@ func (s *Server) NewSequence(prompt string, params NewSequenceParams) *Sequence
 	}
 	}
 
 
 	return &Sequence{
 	return &Sequence{
-		tokens:          tokens,
-		n_prompt_tokens: len(tokens),
-		numPredict:      params.numPredict,
-		responses:       make(chan string, 1),
-		quit:            make(chan bool, 1),
-		embedding:       make(chan []float32, 1),
-		samplingCtx:     sc,
-		embeddingOnly:   params.embedding,
-		stop:            params.stop,
-		numKeep:         params.numKeep,
+		tokens:           tokens,
+		n_prompt_tokens:  len(tokens),
+		numPredict:       params.numPredict,
+		pendingResponses: make([]string, 0),
+		responses:        make(chan string, 1),
+		quit:             make(chan bool, 1),
+		embedding:        make(chan []float32, 1),
+		samplingCtx:      sc,
+		embeddingOnly:    params.embedding,
+		stop:             params.stop,
+		numKeep:          params.numKeep,
 	}
 	}
 }
 }
 
 
@@ -201,34 +206,30 @@ func incompleteUnicode(token string) bool {
 	return incomplete
 	return incomplete
 }
 }
 
 
-func (s *Server) removeSequence(seqIndex int, pieces *[][]string, reason string) {
+func (s *Server) removeSequence(seqIndex int, reason string) {
 	seq := s.seqs[seqIndex]
 	seq := s.seqs[seqIndex]
 
 
 	seq.doneReason = reason
 	seq.doneReason = reason
 	close(seq.responses)
 	close(seq.responses)
 	close(seq.embedding)
 	close(seq.embedding)
-	(*pieces)[seqIndex] = []string{}
+	seq.pendingResponses = []string{}
 	seq.samplingCtx.Free()
 	seq.samplingCtx.Free()
 	s.lc.KvCacheSeqRm(seqIndex, 0, -1)
 	s.lc.KvCacheSeqRm(seqIndex, 0, -1)
 	s.seqs[seqIndex] = nil
 	s.seqs[seqIndex] = nil
 }
 }
 
 
 func (s *Server) run(ctx context.Context) {
 func (s *Server) run(ctx context.Context) {
-	// build up stop sequences as we recognize them
-	// TODO (jmorganca): simplify this
-	pieces := make([][]string, s.parallel)
-
 	for {
 	for {
 		select {
 		select {
 		case <-ctx.Done():
 		case <-ctx.Done():
 			return
 			return
 		default:
 		default:
-			pieces = s.processBatch(pieces)
+			s.processBatch()
 		}
 		}
 	}
 	}
 }
 }
 
 
-func (s *Server) processBatch(pieces [][]string) [][]string {
+func (s *Server) processBatch() {
 	batch := llama.NewBatch(s.batchSize*len(s.seqs), 0, len(s.seqs))
 	batch := llama.NewBatch(s.batchSize*len(s.seqs), 0, len(s.seqs))
 	defer batch.Free()
 	defer batch.Free()
 
 
@@ -247,7 +248,7 @@ func (s *Server) processBatch(pieces [][]string) [][]string {
 
 
 		// if past the num predict limit
 		// if past the num predict limit
 		if seq.numPredict > 0 && seq.numPredicted > seq.numPredict {
 		if seq.numPredict > 0 && seq.numPredicted > seq.numPredict {
-			s.removeSequence(i, &pieces, "limit")
+			s.removeSequence(i, "limit")
 			continue
 			continue
 		}
 		}
 
 
@@ -274,7 +275,7 @@ func (s *Server) processBatch(pieces [][]string) [][]string {
 	}
 	}
 
 
 	if batch.NumTokens() == 0 {
 	if batch.NumTokens() == 0 {
-		return pieces
+		return
 	}
 	}
 
 
 	err := s.lc.Decode(batch)
 	err := s.lc.Decode(batch)
@@ -301,7 +302,7 @@ func (s *Server) processBatch(pieces [][]string) [][]string {
 			}
 			}
 
 
 			seq.embedding <- embd
 			seq.embedding <- embd
-			s.removeSequence(i, &pieces, "")
+			s.removeSequence(i, "")
 			continue
 			continue
 		}
 		}
 
 
@@ -329,14 +330,14 @@ func (s *Server) processBatch(pieces [][]string) [][]string {
 			// seq.responses <- piece
 			// seq.responses <- piece
 
 
 			// TODO: end the sequence instead of quitting the pool
 			// TODO: end the sequence instead of quitting the pool
-			s.removeSequence(i, &pieces, "stop")
+			s.removeSequence(i, "stop")
 			continue
 			continue
 		}
 		}
 
 
 		seq.tokens = []int{token}
 		seq.tokens = []int{token}
 
 
-		pieces[i] = append(pieces[i], piece)
-		sequence := strings.Join(pieces[i], "")
+		seq.pendingResponses = append(seq.pendingResponses, piece)
+		sequence := strings.Join(seq.pendingResponses, "")
 
 
 		if incompleteUnicode(sequence) {
 		if incompleteUnicode(sequence) {
 			continue
 			continue
@@ -345,7 +346,7 @@ func (s *Server) processBatch(pieces [][]string) [][]string {
 		if ok, stop := findStop(sequence, seq.stop); ok {
 		if ok, stop := findStop(sequence, seq.stop); ok {
 			slog.Info("hit stop token", "stop", seq.stop)
 			slog.Info("hit stop token", "stop", seq.stop)
 
 
-			truncated := truncateStop(pieces[i], stop)
+			truncated := truncateStop(seq.pendingResponses, stop)
 
 
 			for _, p := range truncated {
 			for _, p := range truncated {
 				select {
 				select {
@@ -355,7 +356,7 @@ func (s *Server) processBatch(pieces [][]string) [][]string {
 				}
 				}
 			}
 			}
 
 
-			s.removeSequence(i, &pieces, "stop")
+			s.removeSequence(i, "stop")
 			continue
 			continue
 		}
 		}
 
 
@@ -363,19 +364,17 @@ func (s *Server) processBatch(pieces [][]string) [][]string {
 			continue
 			continue
 		}
 		}
 
 
-		for _, p := range pieces[i] {
+		for _, p := range seq.pendingResponses {
 			select {
 			select {
 			case seq.responses <- p:
 			case seq.responses <- p:
 			case <-seq.quit:
 			case <-seq.quit:
-				s.removeSequence(i, &pieces, "connection")
+				s.removeSequence(i, "connection")
 				break
 				break
 			}
 			}
 		}
 		}
 
 
-		pieces[i] = []string{}
+		seq.pendingResponses = []string{}
 	}
 	}
-
-	return pieces
 }
 }
 
 
 type Options struct {
 type Options struct {