Bruce MacDonald 1 month ago
parent
commit
cb10c99297
2 changed files with 15 additions and 15 deletions
  1. 3 3
      runner/ollamarunner/cache.go
  2. 12 12
      runner/ollamarunner/cache_test.go

+ 3 - 3
runner/ollamarunner/cache.go

@@ -242,7 +242,7 @@ func (c *InputCache) ShiftDiscard(inputLen int32, numKeep int32) int32 {
 }
 
 type ErrReprocessInputs struct {
-	Inputs []input
+	Inputs []input.Input
 }
 
 func (e *ErrReprocessInputs) Error() string {
@@ -279,12 +279,12 @@ func (c *InputCache) ShiftCacheSlot(slot *InputCacheSlot, numKeep int32) error {
 			_ = c.cache.Remove(slot.Id, 0, -1)
 
 			// Create new input slice with preserved tokens (numKeep + remaining tokens after discard)
-			newInputs := make([]input, numKeep+inputLen-(numKeep+discard))
+			newInputs := make([]input.Input, numKeep+inputLen-(numKeep+discard))
 			copy(newInputs[:numKeep], slot.Inputs[:numKeep])
 			copy(newInputs[numKeep:], slot.Inputs[numKeep+discard:])
 
 			// Reset the slot inputs since we've cleared the cache
-			slot.Inputs = []input{}
+			slot.Inputs = []input.Input{}
 
 			// Return error with inputs that need to be reprocessed
 			return &ErrReprocessInputs{Inputs: newInputs}

+ 12 - 12
runner/ollamarunner/cache_test.go

@@ -315,20 +315,20 @@ func (m *mockCache) Remove(seq int, beginIndex, endIndex int32) error {
 }
 
 // Stub implementations for other interface methods
-func (m *mockCache) SetLayer(layer int)                                               {}
-func (m *mockCache) Get(ctx ml.Context) (ml.Tensor, ml.Tensor, ml.Tensor)             { return nil, nil, nil }
-func (m *mockCache) Put(ctx ml.Context, key, value ml.Tensor)                         {}
-func (m *mockCache) Init(backend ml.Backend, dtype ml.DType, capacity int32)          {}
-func (m *mockCache) Close()                                                           {}
-func (m *mockCache) StartForward(ctx ml.Context, positions []int32, seqs []int) error { return nil }
-func (m *mockCache) CopyPrefix(srcSeq, dstSeq int, len int32)                         {}
-func (m *mockCache) SetConfig(ml.CacheConfig)                                         {}
+func (m *mockCache) SetLayer(layer int)                                      {}
+func (m *mockCache) Get(ctx ml.Context) (ml.Tensor, ml.Tensor, ml.Tensor)    { return nil, nil, nil }
+func (m *mockCache) Put(ctx ml.Context, key, value ml.Tensor)                {}
+func (m *mockCache) Init(backend ml.Backend, dtype ml.DType, capacity int32) {}
+func (m *mockCache) Close()                                                  {}
+func (m *mockCache) StartForward(ctx ml.Context, opts input.Options) error   { return nil }
+func (m *mockCache) CopyPrefix(srcSeq, dstSeq int, len int32)                {}
+func (m *mockCache) SetConfig(ml.CacheConfig)                                {}
 
 func TestShiftCacheSlot(t *testing.T) {
 	tests := []struct {
 		name          string
 		numCtx        int32
-		inputs        []input
+		inputs        []input.Input
 		numKeep       int32
 		cacheErr      bool
 		wantErr       any
@@ -337,7 +337,7 @@ func TestShiftCacheSlot(t *testing.T) {
 		{
 			name:          "Normal shift",
 			numCtx:        10,
-			inputs:        []input{{token: 1}, {token: 2}, {token: 3}, {token: 4}, {token: 5}, {token: 6}, {token: 7}, {token: 8}, {token: 9}, {token: 10}},
+			inputs:        []input.Input{{Token: 1}, {Token: 2}, {Token: 3}, {Token: 4}, {Token: 5}, {Token: 6}, {Token: 7}, {Token: 8}, {Token: 9}, {Token: 10}},
 			numKeep:       2,
 			cacheErr:      false, // No error
 			wantErr:       nil,
@@ -346,7 +346,7 @@ func TestShiftCacheSlot(t *testing.T) {
 		{
 			name:          "Cache removal fails",
 			numCtx:        10,
-			inputs:        []input{{token: 1}, {token: 2}, {token: 3}, {token: 4}, {token: 5}, {token: 6}, {token: 7}, {token: 8}, {token: 9}, {token: 10}},
+			inputs:        []input.Input{{Token: 1}, {Token: 2}, {Token: 3}, {Token: 4}, {Token: 5}, {Token: 6}, {Token: 7}, {Token: 8}, {Token: 9}, {Token: 10}},
 			numKeep:       2,
 			cacheErr:      true,
 			wantErr:       &ErrReprocessInputs{},
@@ -363,7 +363,7 @@ func TestShiftCacheSlot(t *testing.T) {
 			}
 			slot := &InputCacheSlot{
 				Id:     123,
-				Inputs: make([]input, len(tt.inputs)),
+				Inputs: make([]input.Input, len(tt.inputs)),
 			}
 			copy(slot.Inputs, tt.inputs)