|
@@ -1,10 +1,13 @@
|
|
package ollamarunner
|
|
package ollamarunner
|
|
|
|
|
|
import (
|
|
import (
|
|
|
|
+ "errors"
|
|
|
|
+ "fmt"
|
|
"image"
|
|
"image"
|
|
"testing"
|
|
"testing"
|
|
"time"
|
|
"time"
|
|
|
|
|
|
|
|
+ "github.com/ollama/ollama/ml"
|
|
"github.com/ollama/ollama/model/input"
|
|
"github.com/ollama/ollama/model/input"
|
|
)
|
|
)
|
|
|
|
|
|
@@ -297,3 +300,96 @@ func TestShiftDiscard(t *testing.T) {
|
|
})
|
|
})
|
|
}
|
|
}
|
|
}
|
|
}
|
|
|
|
+
|
|
|
|
+// Mock implementation of the Cache interface
|
|
|
|
+type mockCache struct {
|
|
|
|
+ shouldFail bool
|
|
|
|
+}
|
|
|
|
+
|
|
|
|
+// Implement only the methods needed for the test
|
|
|
|
+func (m *mockCache) Remove(seq int, beginIndex, endIndex int32) error {
|
|
|
|
+ if m.shouldFail {
|
|
|
|
+ return fmt.Errorf("mock cache removal error")
|
|
|
|
+ }
|
|
|
|
+ return nil
|
|
|
|
+}
|
|
|
|
+
|
|
|
|
+// Stub implementations for other interface methods
|
|
|
|
+func (m *mockCache) SetLayer(layer int) {}
|
|
|
|
+func (m *mockCache) Get(ctx ml.Context) (ml.Tensor, ml.Tensor, ml.Tensor) { return nil, nil, nil }
|
|
|
|
+func (m *mockCache) Put(ctx ml.Context, key, value ml.Tensor) {}
|
|
|
|
+func (m *mockCache) Init(backend ml.Backend, dtype ml.DType, capacity int32) {}
|
|
|
|
+func (m *mockCache) Close() {}
|
|
|
|
+func (m *mockCache) StartForward(ctx ml.Context, positions []int32, seqs []int) error { return nil }
|
|
|
|
+func (m *mockCache) CopyPrefix(srcSeq, dstSeq int, len int32) {}
|
|
|
|
+
|
|
|
|
+func TestShiftCacheSlot(t *testing.T) {
|
|
|
|
+ tests := []struct {
|
|
|
|
+ name string
|
|
|
|
+ numCtx int32
|
|
|
|
+ inputs []input
|
|
|
|
+ numKeep int32
|
|
|
|
+ cacheErr bool
|
|
|
|
+ wantErr any
|
|
|
|
+ wantInputsLen int
|
|
|
|
+ }{
|
|
|
|
+ {
|
|
|
|
+ name: "Normal shift",
|
|
|
|
+ numCtx: 10,
|
|
|
|
+ inputs: []input{{token: 1}, {token: 2}, {token: 3}, {token: 4}, {token: 5}, {token: 6}, {token: 7}, {token: 8}, {token: 9}, {token: 10}},
|
|
|
|
+ numKeep: 2,
|
|
|
|
+ cacheErr: false, // No error
|
|
|
|
+ wantErr: nil,
|
|
|
|
+ wantInputsLen: 6, // After discarding 4 tokens
|
|
|
|
+ },
|
|
|
|
+ {
|
|
|
|
+ name: "Cache removal fails",
|
|
|
|
+ numCtx: 10,
|
|
|
|
+ inputs: []input{{token: 1}, {token: 2}, {token: 3}, {token: 4}, {token: 5}, {token: 6}, {token: 7}, {token: 8}, {token: 9}, {token: 10}},
|
|
|
|
+ numKeep: 2,
|
|
|
|
+ cacheErr: true,
|
|
|
|
+ wantErr: &ErrReprocessInputs{},
|
|
|
|
+ wantInputsLen: 0, // Original inputs should be cleared
|
|
|
|
+ },
|
|
|
|
+ }
|
|
|
|
+
|
|
|
|
+ for _, tt := range tests {
|
|
|
|
+ t.Run(tt.name, func(t *testing.T) {
|
|
|
|
+ mock := &mockCache{shouldFail: tt.cacheErr}
|
|
|
|
+ c := InputCache{
|
|
|
|
+ numCtx: tt.numCtx,
|
|
|
|
+ cache: mock,
|
|
|
|
+ }
|
|
|
|
+ slot := &InputCacheSlot{
|
|
|
|
+ Id: 123,
|
|
|
|
+ Inputs: make([]input, len(tt.inputs)),
|
|
|
|
+ }
|
|
|
|
+ copy(slot.Inputs, tt.inputs)
|
|
|
|
+
|
|
|
|
+ err := c.ShiftCacheSlot(slot, tt.numKeep)
|
|
|
|
+
|
|
|
|
+ if tt.wantErr != nil {
|
|
|
|
+ if err == nil {
|
|
|
|
+ t.Errorf("Expected error but got nil")
|
|
|
|
+ return
|
|
|
|
+ }
|
|
|
|
+
|
|
|
|
+ if !errors.As(err, &tt.wantErr) {
|
|
|
|
+ t.Errorf("Expected error of type %T but got %T: %v", tt.wantErr, err, err)
|
|
|
|
+ }
|
|
|
|
+
|
|
|
|
+ if errReproc, ok := err.(*ErrReprocessInputs); ok {
|
|
|
|
+ if errReproc.SlotId != slot.Id {
|
|
|
|
+ t.Errorf("ErrReprocessInputs has wrong SlotId: got %v, want %v", errReproc.SlotId, slot.Id)
|
|
|
|
+ }
|
|
|
|
+ }
|
|
|
|
+ } else if err != nil {
|
|
|
|
+ t.Errorf("Unexpected error: %v", err)
|
|
|
|
+ }
|
|
|
|
+
|
|
|
|
+ if len(slot.Inputs) != tt.wantInputsLen {
|
|
|
|
+ t.Errorf("Slot inputs length after operation: got %v, want %v", len(slot.Inputs), tt.wantInputsLen)
|
|
|
|
+ }
|
|
|
|
+ })
|
|
|
|
+ }
|
|
|
|
+}
|