|
@@ -61,12 +61,6 @@ type Sequence struct {
|
|
|
n_prompt_tokens int
|
|
|
}
|
|
|
|
|
|
-// prompt returns true if the prompt is still being processed
|
|
|
-// TODO (jmorganca): clean up this logic
|
|
|
-func (s *Sequence) prompt() bool {
|
|
|
- return s.nPast < len(s.tokens)-1
|
|
|
-}
|
|
|
-
|
|
|
func (s *Server) NewSequence(prompt string, numPredict int, stop []string, params *llama.SamplingParams, embedding bool) *Sequence {
|
|
|
tokens, err := s.lc.Model().Tokenize(prompt, true, true)
|
|
|
if err != nil {
|
|
@@ -176,14 +170,17 @@ func (s *Server) run(ctx context.Context) {
|
|
|
seq.t_start_process_prompt = time.Now()
|
|
|
}
|
|
|
|
|
|
+ var numTokensProcessed int
|
|
|
for j, t := range seq.tokens {
|
|
|
// todo: make this n_batch
|
|
|
if j >= s.batchSize {
|
|
|
break
|
|
|
}
|
|
|
- batch.Add(t, seq.nPast, []int{i}, !seq.prompt())
|
|
|
+ batch.Add(t, seq.nPast, []int{i}, numTokensProcessed+1 == len(seq.tokens))
|
|
|
seq.nPast++
|
|
|
+ numTokensProcessed++
|
|
|
}
|
|
|
+ seq.tokens = seq.tokens[numTokensProcessed:]
|
|
|
seq.iBatch = batch.NumTokens() - 1
|
|
|
}
|
|
|
|
|
@@ -199,7 +196,7 @@ func (s *Server) run(ctx context.Context) {
|
|
|
}
|
|
|
|
|
|
// don't sample prompt processing
|
|
|
- if seq.prompt() {
|
|
|
+ if len(seq.tokens) != 0 {
|
|
|
continue
|
|
|
}
|
|
|
|