Browse Source

pr feedback

Bruce MacDonald 1 month ago
parent
commit
9c23f11850
3 changed files with 26 additions and 19 deletions
  1. 24 17
      runner/llamarunner/cache.go
  2. 0 1
      runner/ollamarunner/cache.go
  3. 2 1
      runner/ollamarunner/runner.go

+ 24 - 17
runner/llamarunner/cache.go

@@ -230,7 +230,8 @@ func (c *InputCache) ShiftCacheSlot(slot *InputCacheSlot, numKeep int) error {
 		return fmt.Errorf("unable to shift context - keep exceeds context (keep: %v context: %v)", numKeep, c.numCtx)
 	}
 
-	discard := c.ShiftDiscard(len(slot.Inputs), numKeep)
+	inputLen := len(slot.Inputs)
+	discard := c.ShiftDiscard(inputLen, numKeep)
 
 	if discard <= 0 {
 		return nil
@@ -239,37 +240,43 @@ func (c *InputCache) ShiftCacheSlot(slot *InputCacheSlot, numKeep int) error {
 	slog.Debug("context limit hit - shifting", "id", slot.Id, "limit", c.numCtx, "input", len(slot.Inputs),
 		"keep", numKeep, "discard", discard)
 
+	var shiftFailed bool
+
 	if c.lc.KvCacheCanShift() {
+		// For models that support shifting, attempt to shift the KV cache
 		if !c.lc.KvCacheSeqRm(slot.Id, numKeep, numKeep+discard) {
-			return fmt.Errorf("unable to remove old kv cache entries (id: %v, keep: %v discard: %v)", slot.Id, numKeep, discard)
-		}
-		c.lc.KvCacheSeqAdd(slot.Id, numKeep+discard, len(slot.Inputs), -discard)
-
-		for i := numKeep + discard; i < len(slot.Inputs); i++ {
-			slot.Inputs[i-discard] = slot.Inputs[i]
+			shiftFailed = true
+			slog.Debug("kv cache removal failed, clearing cache and returning inputs for reprocessing", "id", slot.Id)
+		} else {
+			c.lc.KvCacheSeqAdd(slot.Id, numKeep+discard, inputLen, -discard)
 		}
-		slot.Inputs = slot.Inputs[:len(slot.Inputs)-discard]
 	} else {
-		slog.Debug("kv cache cannot shift, clearing cache and truncating history")
+		// For models that don't support shifting
+		shiftFailed = true
+		slog.Debug("kv cache cannot shift, clearing cache and returning inputs for reprocessing", "id", slot.Id)
+	}
 
+	if shiftFailed {
 		// Clear the entire KV cache
-		if !c.lc.KvCacheSeqRm(slot.Id, 0, -1) {
-			return fmt.Errorf("unable to remove kv cache entries (id: %v)", slot.Id)
-		}
+		_ = c.lc.KvCacheSeqRm(slot.Id, 0, -1)
 
-		// Update the slot.Inputs to match what would happen with a shift operation
-		// Keep the first numKeep tokens and the tokens after the discard
-		newInputs := make([]input, numKeep+len(slot.Inputs)-(numKeep+discard))
+		// Create new input slice with preserved tokens (numKeep + remaining tokens after discard)
+		newInputs := make([]input, numKeep+inputLen-(numKeep+discard))
 		copy(newInputs[:numKeep], slot.Inputs[:numKeep])
 		copy(newInputs[numKeep:], slot.Inputs[numKeep+discard:])
 
 		// Reset the slot inputs since we've cleared the cache
 		slot.Inputs = []input{}
 
-		// Return the inputs that need to be reprocessed
-		// The caller will need to prepend these to the sequence's inputs queue
+		// Return error with inputs that need to be reprocessed
 		return &ErrReprocessInputs{Inputs: newInputs}
 	}
 
+	// Standard shift succeeded - update input array
+	for i := numKeep + discard; i < inputLen; i++ {
+		slot.Inputs[i-discard] = slot.Inputs[i]
+	}
+	slot.Inputs = slot.Inputs[:inputLen-discard]
+
 	return nil
 }

+ 0 - 1
runner/ollamarunner/cache.go

@@ -268,7 +268,6 @@ func (c *InputCache) ShiftCacheSlot(slot *InputCacheSlot, numKeep int32) error {
 	slog.Debug("context limit hit - shifting", "id", slot.Id, "limit", c.numCtx, "input", len(slot.Inputs),
 		"keep", numKeep, "discard", discard)
 
-	// TODO (jessegross): KV cache removal can fail for certain types of models
 	if c.cache != nil {
 		err := c.cache.Remove(slot.Id, numKeep, numKeep+discard)
 		if err != nil {

+ 2 - 1
runner/ollamarunner/runner.go

@@ -360,7 +360,8 @@ func (s *Server) processBatch() error {
 						if errors.As(err, &reprocess) {
 							// Prepend these inputs to the sequence's inputs queue for reprocessing
 							seq.inputs = append(reprocess.Inputs, seq.inputs...)
-							// Continue processing as normal
+							// Return early to restart processing with the new inputs at the beginning
+							return nil
 						} else {
 							return err
 						}