|
@@ -45,6 +45,9 @@ type Sequence struct {
|
|
|
// prompt inputs left to evaluate
|
|
|
inputs []input
|
|
|
|
|
|
+ // inputs that have been added to a batch but not yet submitted to Decode
|
|
|
+ pendingInputs []input
|
|
|
+
|
|
|
// tokens that have been generated but not returned yet (e.g. for stop sequences)
|
|
|
pendingResponses []string
|
|
|
|
|
@@ -367,14 +370,13 @@ func (s *Server) processBatch(tokenBatch *llama.Batch, embedBatch *llama.Batch)
|
|
|
continue
|
|
|
}
|
|
|
|
|
|
- var numInputsProcessed int
|
|
|
- shifted := false
|
|
|
-
|
|
|
for i, input := range seq.inputs {
|
|
|
- if len(seq.cache.Inputs)+1 > s.cache.numCtx {
|
|
|
- if !shifted {
|
|
|
- s.cache.ShiftCacheSlot(seq.cache, seq.numKeep)
|
|
|
- shifted = true
|
|
|
+ if len(seq.cache.Inputs)+len(seq.pendingInputs)+1 > s.cache.numCtx {
|
|
|
+ if len(seq.pendingInputs) == 0 {
|
|
|
+ err := s.cache.ShiftCacheSlot(seq.cache, seq.numKeep)
|
|
|
+ if err != nil {
|
|
|
+ return err
|
|
|
+ }
|
|
|
} else {
|
|
|
break
|
|
|
}
|
|
@@ -403,15 +405,12 @@ func (s *Server) processBatch(tokenBatch *llama.Batch, embedBatch *llama.Batch)
|
|
|
}
|
|
|
|
|
|
crossAttention = seq.crossAttention
|
|
|
- batch.Add(input.token, input.embed, len(seq.cache.Inputs), i+1 == len(seq.inputs), seq.cache.Id)
|
|
|
- seq.cache.Inputs = append(seq.cache.Inputs, input)
|
|
|
- numInputsProcessed++
|
|
|
- }
|
|
|
-
|
|
|
- if numInputsProcessed > 0 {
|
|
|
- seq.inputs = seq.inputs[numInputsProcessed:]
|
|
|
+ batch.Add(input.token, input.embed, len(seq.cache.Inputs)+len(seq.pendingInputs), i+1 == len(seq.inputs), seq.cache.Id)
|
|
|
+ seq.pendingInputs = append(seq.pendingInputs, input)
|
|
|
seq.iBatch = batch.NumTokens() - 1
|
|
|
}
|
|
|
+
|
|
|
+ seq.inputs = seq.inputs[len(seq.pendingInputs):]
|
|
|
}
|
|
|
|
|
|
if batch == nil || batch.NumTokens() == 0 {
|
|
@@ -444,6 +443,12 @@ func (s *Server) processBatch(tokenBatch *llama.Batch, embedBatch *llama.Batch)
|
|
|
continue
|
|
|
}
|
|
|
|
|
|
+ // After calling Decode, pending inputs are now in the cache
|
|
|
+ if len(seq.pendingInputs) > 0 {
|
|
|
+ seq.cache.Inputs = append(seq.cache.Inputs, seq.pendingInputs...)
|
|
|
+ seq.pendingInputs = []input{}
|
|
|
+ }
|
|
|
+
|
|
|
// don't sample prompt processing
|
|
|
if len(seq.inputs) != 0 {
|
|
|
continue
|