Преглед изворни кода

ollamarunner: Check for minBatch of context space when shifting

Models can specify that a group of inputs need to be handled a single
batch. However, context shifting didn't respect this and could trigger
a break anyways. In this case, we should instead trigger a context
shift earlier so that it occurs before the grouped batch.

Note that there still some corner cases:
 - A long prompt that exceeds the context window can get truncated
   in the middle of an image. With the current models, this will
   result in the model not recognizing the image at all, which is
   pretty much the expected result with truncation.
 - The context window is set less than the minimum batch size. The
   only solution to this is to refuse to load the model with these
   settings. However, this can never occur with current models and
   default settings.

Since users are unlikely to run into these scenarios, fixing them is
left as a follow up.
Jesse Gross пре 1 месец
родитељ
комит
bf24498b1e
1 измењених фајлова са 17 додато и 11 уклоњено
  1. 17 11
      runner/ollamarunner/runner.go

+ 17 - 11
runner/ollamarunner/runner.go

@@ -115,6 +115,9 @@ func (s *Server) NewSequence(prompt string, images []llm.ImageData, params NewSe
 		params.numKeep = int32(len(inputs))
 	}
 
+	// TODO(jessegross): We should ensure that we always leave minBatch of context space to shift,
+	// otherwise we might truncate or split the batch against the model's wishes
+
 	// Ensure that at least 1 input can be discarded during shift
 	params.numKeep = min(params.numKeep, s.cache.numCtx-1)
 
@@ -366,17 +369,6 @@ func (s *Server) processBatch() error {
 		batchSize := s.batchSize
 
 		for j, inp := range seq.inputs {
-			if int32(len(seq.cache.Inputs)+len(seq.pendingInputs)+1) > s.cache.numCtx {
-				if len(seq.pendingInputs) == 0 {
-					err := s.cache.ShiftCacheSlot(seq.cache, seq.numKeep)
-					if err != nil {
-						return err
-					}
-				} else {
-					break
-				}
-			}
-
 			// If we are required to put following inputs into a single batch then extend the
 			// batch size. Since we are only extending the size the minimum amount possible, this
 			// will cause a break if we have pending inputs.
@@ -389,6 +381,20 @@ func (s *Server) processBatch() error {
 				break
 			}
 
+			// If the sum of our working set (already processed tokens, tokens we added to this
+			// batch, required following tokens) exceeds the context size, then trigger a shift
+			// now so we don't have to do one later when we can't break the batch.
+			if int32(len(seq.cache.Inputs)+len(seq.pendingInputs)+minBatch) > s.cache.numCtx {
+				if len(seq.pendingInputs) != 0 {
+					break
+				}
+
+				err := s.cache.ShiftCacheSlot(seq.cache, seq.numKeep)
+				if err != nil {
+					return err
+				}
+			}
+
 			options.Inputs = append(options.Inputs, inp.Token)
 			if inp.Multimodal != nil {
 				options.Multimodal = append(options.Multimodal, input.MultimodalIndex{Index: len(options.Inputs) - 1, Multimodal: inp.Multimodal})