Преглед изворни кода

llama.go: Advance though tokens when processing multiple batches

If the number of input tokens exceeds the size of the batch, multiple
batches will be submitted but they will all contain the first tokens.
This processes the input tokens as expected so that each batch has
the next set of tokens.
Jesse Gross пре 8 месеци
родитељ
комит
8aa97b5e83
1 измењених фајлова са 5 додато и 8 уклоњено
  1. 5 8
      llama/runner/runner.go

+ 5 - 8
llama/runner/runner.go

@@ -61,12 +61,6 @@ type Sequence struct {
 	n_prompt_tokens        int
 	n_prompt_tokens        int
 }
 }
 
 
-// prompt returns true if the prompt is still being processed
-// TODO (jmorganca): clean up this logic
-func (s *Sequence) prompt() bool {
-	return s.nPast < len(s.tokens)-1
-}
-
 func (s *Server) NewSequence(prompt string, numPredict int, stop []string, params *llama.SamplingParams, embedding bool) *Sequence {
 func (s *Server) NewSequence(prompt string, numPredict int, stop []string, params *llama.SamplingParams, embedding bool) *Sequence {
 	tokens, err := s.lc.Model().Tokenize(prompt, true, true)
 	tokens, err := s.lc.Model().Tokenize(prompt, true, true)
 	if err != nil {
 	if err != nil {
@@ -176,14 +170,17 @@ func (s *Server) run(ctx context.Context) {
 					seq.t_start_process_prompt = time.Now()
 					seq.t_start_process_prompt = time.Now()
 				}
 				}
 
 
+				var numTokensProcessed int
 				for j, t := range seq.tokens {
 				for j, t := range seq.tokens {
 					// todo: make this n_batch
 					// todo: make this n_batch
 					if j >= s.batchSize {
 					if j >= s.batchSize {
 						break
 						break
 					}
 					}
-					batch.Add(t, seq.nPast, []int{i}, !seq.prompt())
+					batch.Add(t, seq.nPast, []int{i}, numTokensProcessed+1 == len(seq.tokens))
 					seq.nPast++
 					seq.nPast++
+					numTokensProcessed++
 				}
 				}
+				seq.tokens = seq.tokens[numTokensProcessed:]
 				seq.iBatch = batch.NumTokens() - 1
 				seq.iBatch = batch.NumTokens() - 1
 			}
 			}
 
 
@@ -199,7 +196,7 @@ func (s *Server) run(ctx context.Context) {
 				}
 				}
 
 
 				// don't sample prompt processing
 				// don't sample prompt processing
-				if seq.prompt() {
+				if len(seq.tokens) != 0 {
 					continue
 					continue
 				}
 				}