|
@@ -115,6 +115,9 @@ func (s *Server) NewSequence(prompt string, images []llm.ImageData, params NewSe
|
|
|
params.numKeep = int32(len(inputs))
|
|
|
}
|
|
|
|
|
|
+ // TODO(jessegross): We should ensure that we always leave minBatch of context space to shift,
|
|
|
+ // otherwise we might truncate or split the batch against the model's wishes
|
|
|
+
|
|
|
// Ensure that at least 1 input can be discarded during shift
|
|
|
params.numKeep = min(params.numKeep, s.cache.numCtx-1)
|
|
|
|
|
@@ -366,17 +369,6 @@ func (s *Server) processBatch() error {
|
|
|
batchSize := s.batchSize
|
|
|
|
|
|
for j, inp := range seq.inputs {
|
|
|
- if int32(len(seq.cache.Inputs)+len(seq.pendingInputs)+1) > s.cache.numCtx {
|
|
|
- if len(seq.pendingInputs) == 0 {
|
|
|
- err := s.cache.ShiftCacheSlot(seq.cache, seq.numKeep)
|
|
|
- if err != nil {
|
|
|
- return err
|
|
|
- }
|
|
|
- } else {
|
|
|
- break
|
|
|
- }
|
|
|
- }
|
|
|
-
|
|
|
// If we are required to put following inputs into a single batch then extend the
|
|
|
// batch size. Since we are only extending the size the minimum amount possible, this
|
|
|
// will cause a break if we have pending inputs.
|
|
@@ -389,6 +381,20 @@ func (s *Server) processBatch() error {
|
|
|
break
|
|
|
}
|
|
|
|
|
|
+ // If the sum of our working set (already processed tokens, tokens we added to this
|
|
|
+ // batch, required following tokens) exceeds the context size, then trigger a shift
|
|
|
+ // now so we don't have to do one later when we can't break the batch.
|
|
|
+ if int32(len(seq.cache.Inputs)+len(seq.pendingInputs)+minBatch) > s.cache.numCtx {
|
|
|
+ if len(seq.pendingInputs) != 0 {
|
|
|
+ break
|
|
|
+ }
|
|
|
+
|
|
|
+ err := s.cache.ShiftCacheSlot(seq.cache, seq.numKeep)
|
|
|
+ if err != nil {
|
|
|
+ return err
|
|
|
+ }
|
|
|
+ }
|
|
|
+
|
|
|
options.Inputs = append(options.Inputs, inp.Token)
|
|
|
if inp.Multimodal != nil {
|
|
|
options.Multimodal = append(options.Multimodal, input.MultimodalIndex{Index: len(options.Inputs) - 1, Multimodal: inp.Multimodal})
|