|
@@ -315,20 +315,20 @@ func (m *mockCache) Remove(seq int, beginIndex, endIndex int32) error {
|
|
|
}
|
|
|
|
|
|
// Stub implementations for other interface methods
|
|
|
-func (m *mockCache) SetLayer(layer int) {}
|
|
|
-func (m *mockCache) Get(ctx ml.Context) (ml.Tensor, ml.Tensor, ml.Tensor) { return nil, nil, nil }
|
|
|
-func (m *mockCache) Put(ctx ml.Context, key, value ml.Tensor) {}
|
|
|
-func (m *mockCache) Init(backend ml.Backend, dtype ml.DType, capacity int32) {}
|
|
|
-func (m *mockCache) Close() {}
|
|
|
-func (m *mockCache) StartForward(ctx ml.Context, positions []int32, seqs []int) error { return nil }
|
|
|
-func (m *mockCache) CopyPrefix(srcSeq, dstSeq int, len int32) {}
|
|
|
-func (m *mockCache) SetConfig(ml.CacheConfig) {}
|
|
|
+func (m *mockCache) SetLayer(layer int) {}
|
|
|
+func (m *mockCache) Get(ctx ml.Context) (ml.Tensor, ml.Tensor, ml.Tensor) { return nil, nil, nil }
|
|
|
+func (m *mockCache) Put(ctx ml.Context, key, value ml.Tensor) {}
|
|
|
+func (m *mockCache) Init(backend ml.Backend, dtype ml.DType, capacity int32) {}
|
|
|
+func (m *mockCache) Close() {}
|
|
|
+func (m *mockCache) StartForward(ctx ml.Context, opts input.Options) error { return nil }
|
|
|
+func (m *mockCache) CopyPrefix(srcSeq, dstSeq int, len int32) {}
|
|
|
+func (m *mockCache) SetConfig(ml.CacheConfig) {}
|
|
|
|
|
|
func TestShiftCacheSlot(t *testing.T) {
|
|
|
tests := []struct {
|
|
|
name string
|
|
|
numCtx int32
|
|
|
- inputs []input
|
|
|
+ inputs []input.Input
|
|
|
numKeep int32
|
|
|
cacheErr bool
|
|
|
wantErr any
|
|
@@ -337,7 +337,7 @@ func TestShiftCacheSlot(t *testing.T) {
|
|
|
{
|
|
|
name: "Normal shift",
|
|
|
numCtx: 10,
|
|
|
- inputs: []input{{token: 1}, {token: 2}, {token: 3}, {token: 4}, {token: 5}, {token: 6}, {token: 7}, {token: 8}, {token: 9}, {token: 10}},
|
|
|
+ inputs: []input.Input{{Token: 1}, {Token: 2}, {Token: 3}, {Token: 4}, {Token: 5}, {Token: 6}, {Token: 7}, {Token: 8}, {Token: 9}, {Token: 10}},
|
|
|
numKeep: 2,
|
|
|
cacheErr: false, // No error
|
|
|
wantErr: nil,
|
|
@@ -346,7 +346,7 @@ func TestShiftCacheSlot(t *testing.T) {
|
|
|
{
|
|
|
name: "Cache removal fails",
|
|
|
numCtx: 10,
|
|
|
- inputs: []input{{token: 1}, {token: 2}, {token: 3}, {token: 4}, {token: 5}, {token: 6}, {token: 7}, {token: 8}, {token: 9}, {token: 10}},
|
|
|
+ inputs: []input.Input{{Token: 1}, {Token: 2}, {Token: 3}, {Token: 4}, {Token: 5}, {Token: 6}, {Token: 7}, {Token: 8}, {Token: 9}, {Token: 10}},
|
|
|
numKeep: 2,
|
|
|
cacheErr: true,
|
|
|
wantErr: &ErrReprocessInputs{},
|
|
@@ -363,7 +363,7 @@ func TestShiftCacheSlot(t *testing.T) {
|
|
|
}
|
|
|
slot := &InputCacheSlot{
|
|
|
Id: 123,
|
|
|
- Inputs: make([]input, len(tt.inputs)),
|
|
|
+ Inputs: make([]input.Input, len(tt.inputs)),
|
|
|
}
|
|
|
copy(slot.Inputs, tt.inputs)
|
|
|
|