浏览代码

add cache clear to the ollama runner

Bruce MacDonald 2 月之前
父节点
当前提交
05372c724b

+ 3 - 9
runner/llamarunner/cache.go

@@ -215,12 +215,10 @@ func (c *InputCache) ShiftDiscard(inputLen int, numKeep int) int {
 
 type ErrReprocessInputs struct {
 	Inputs []input
-	SlotId int
 }
 
 func (e *ErrReprocessInputs) Error() string {
-	return fmt.Sprintf("kv cache shift not supported, inputs need reprocessing (slot: %v, input count: %v)",
-		e.SlotId, len(e.Inputs))
+	return fmt.Sprintf("kv cache shift not supported, inputs need reprocessing (input count: %v)", len(e.Inputs))
 }
 
 // ShiftCacheSlot frees up space in the KV cache by deleting the oldest half of history
@@ -265,16 +263,12 @@ func (c *InputCache) ShiftCacheSlot(slot *InputCacheSlot, numKeep int) error {
 		copy(newInputs[:numKeep], slot.Inputs[:numKeep])
 		copy(newInputs[numKeep:], slot.Inputs[numKeep+discard:])
 
-		// Update the slot.Inputs to be empty since we've cleared the cache
-		// The transformer will rebuild these as the inputs are processed
+		// Reset the slot inputs since we've cleared the cache
 		slot.Inputs = []input{}
 
 		// Return the inputs that need to be reprocessed
 		// The caller will need to prepend these to the sequence's inputs queue
-		return &ErrReprocessInputs{
-			Inputs: newInputs,
-			SlotId: slot.Id,
-		}
+		return &ErrReprocessInputs{Inputs: newInputs}
 	}
 
 	return nil

+ 3 - 2
runner/llamarunner/runner.go

@@ -388,9 +388,10 @@ func (s *Server) processBatch(tokenBatch *llama.Batch, embedBatch *llama.Batch)
 				if len(seq.pendingInputs) == 0 {
 					err := s.cache.ShiftCacheSlot(seq.cache, seq.numKeep)
 					if err != nil {
-						if inr, ok := err.(*ErrReprocessInputs); ok {
+						var reprocess *ErrReprocessInputs
+						if errors.As(err, &reprocess) {
 							// Prepend these inputs to the sequence's inputs queue for reprocessing
-							seq.inputs = append(inr.Inputs, seq.inputs...)
+							seq.inputs = append(reprocess.Inputs, seq.inputs...)
 							// Continue processing as normal
 						} else {
 							return err

+ 29 - 1
runner/ollamarunner/cache.go

@@ -241,6 +241,16 @@ func (c *InputCache) ShiftDiscard(inputLen int32, numKeep int32) int32 {
 	return discard
 }
 
+type ErrReprocessInputs struct {
+	Inputs []input
+	SlotId int
+}
+
+func (e *ErrReprocessInputs) Error() string {
+	return fmt.Sprintf("kv cache shift not supported, inputs need reprocessing (slot: %v, input count: %v)",
+		e.SlotId, len(e.Inputs))
+}
+
 // Frees up space in the KV cache by deleting the oldest half of history and shifting
 // the newest half into that space (saving numKeep inputs at the beginning).
 //
@@ -264,7 +274,25 @@ func (c *InputCache) ShiftCacheSlot(slot *InputCacheSlot, numKeep int32) error {
 	if c.cache != nil {
 		err := c.cache.Remove(slot.Id, numKeep, numKeep+discard)
 		if err != nil {
-			return fmt.Errorf("unable to remove old kv cache entries (id: %v, keep: %v discard: %v): %w", slot.Id, numKeep, discard, err)
+			slog.Debug("kv cache removal failed, clearing cache and returning inputs for reprocessing",
+				"id", slot.Id, "error", err)
+
+			// Clear the entire KV cache
+			_ = c.cache.Remove(slot.Id, 0, -1)
+
+			// Create new input slice with preserved tokens (numKeep + remaining tokens after discard)
+			newInputs := make([]input, numKeep+inputLen-(numKeep+discard))
+			copy(newInputs[:numKeep], slot.Inputs[:numKeep])
+			copy(newInputs[numKeep:], slot.Inputs[numKeep+discard:])
+
+			// Reset the slot inputs since we've cleared the cache
+			slot.Inputs = []input{}
+
+			// Return error with inputs that need to be reprocessed
+			return &ErrReprocessInputs{
+				Inputs: newInputs,
+				SlotId: slot.Id,
+			}
 		}
 	}
 

+ 96 - 0
runner/ollamarunner/cache_test.go

@@ -1,10 +1,13 @@
 package ollamarunner
 
 import (
+	"errors"
+	"fmt"
 	"image"
 	"testing"
 	"time"
 
+	"github.com/ollama/ollama/ml"
 	"github.com/ollama/ollama/model/input"
 )
 
@@ -297,3 +300,96 @@ func TestShiftDiscard(t *testing.T) {
 		})
 	}
 }
+
+// Mock implementation of the Cache interface
+type mockCache struct {
+	shouldFail bool
+}
+
+// Implement only the methods needed for the test
+func (m *mockCache) Remove(seq int, beginIndex, endIndex int32) error {
+	if m.shouldFail {
+		return fmt.Errorf("mock cache removal error")
+	}
+	return nil
+}
+
+// Stub implementations for other interface methods
+func (m *mockCache) SetLayer(layer int)                                               {}
+func (m *mockCache) Get(ctx ml.Context) (ml.Tensor, ml.Tensor, ml.Tensor)             { return nil, nil, nil }
+func (m *mockCache) Put(ctx ml.Context, key, value ml.Tensor)                         {}
+func (m *mockCache) Init(backend ml.Backend, dtype ml.DType, capacity int32)          {}
+func (m *mockCache) Close()                                                           {}
+func (m *mockCache) StartForward(ctx ml.Context, positions []int32, seqs []int) error { return nil }
+func (m *mockCache) CopyPrefix(srcSeq, dstSeq int, len int32)                         {}
+
+func TestShiftCacheSlot(t *testing.T) {
+	tests := []struct {
+		name          string
+		numCtx        int32
+		inputs        []input
+		numKeep       int32
+		cacheErr      bool
+		wantErr       any
+		wantInputsLen int
+	}{
+		{
+			name:          "Normal shift",
+			numCtx:        10,
+			inputs:        []input{{token: 1}, {token: 2}, {token: 3}, {token: 4}, {token: 5}, {token: 6}, {token: 7}, {token: 8}, {token: 9}, {token: 10}},
+			numKeep:       2,
+			cacheErr:      false, // No error
+			wantErr:       nil,
+			wantInputsLen: 6, // After discarding 4 tokens
+		},
+		{
+			name:          "Cache removal fails",
+			numCtx:        10,
+			inputs:        []input{{token: 1}, {token: 2}, {token: 3}, {token: 4}, {token: 5}, {token: 6}, {token: 7}, {token: 8}, {token: 9}, {token: 10}},
+			numKeep:       2,
+			cacheErr:      true,
+			wantErr:       &ErrReprocessInputs{},
+			wantInputsLen: 0, // Original inputs should be cleared
+		},
+	}
+
+	for _, tt := range tests {
+		t.Run(tt.name, func(t *testing.T) {
+			mock := &mockCache{shouldFail: tt.cacheErr}
+			c := InputCache{
+				numCtx: tt.numCtx,
+				cache:  mock,
+			}
+			slot := &InputCacheSlot{
+				Id:     123,
+				Inputs: make([]input, len(tt.inputs)),
+			}
+			copy(slot.Inputs, tt.inputs)
+
+			err := c.ShiftCacheSlot(slot, tt.numKeep)
+
+			if tt.wantErr != nil {
+				if err == nil {
+					t.Errorf("Expected error but got nil")
+					return
+				}
+
+				if !errors.As(err, &tt.wantErr) {
+					t.Errorf("Expected error of type %T but got %T: %v", tt.wantErr, err, err)
+				}
+
+				if errReproc, ok := err.(*ErrReprocessInputs); ok {
+					if errReproc.SlotId != slot.Id {
+						t.Errorf("ErrReprocessInputs has wrong SlotId: got %v, want %v", errReproc.SlotId, slot.Id)
+					}
+				}
+			} else if err != nil {
+				t.Errorf("Unexpected error: %v", err)
+			}
+
+			if len(slot.Inputs) != tt.wantInputsLen {
+				t.Errorf("Slot inputs length after operation: got %v, want %v", len(slot.Inputs), tt.wantInputsLen)
+			}
+		})
+	}
+}

+ 8 - 1
runner/ollamarunner/runner.go

@@ -356,7 +356,14 @@ func (s *Server) processBatch() error {
 				if len(seq.pendingInputs) == 0 {
 					err := s.cache.ShiftCacheSlot(seq.cache, seq.numKeep)
 					if err != nil {
-						return err
+						var reprocess *ErrReprocessInputs
+						if errors.As(err, &reprocess) {
+							// Prepend these inputs to the sequence's inputs queue for reprocessing
+							seq.inputs = append(reprocess.Inputs, seq.inputs...)
+							// Continue processing as normal
+						} else {
+							return err
+						}
 					}
 				} else {
 					break