cache.go 1.6 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354
  1. package kvcache
  2. import (
  3. "errors"
  4. "github.com/ollama/ollama/ml"
  5. )
  6. var (
  7. ErrKvCacheFull = errors.New("could not find a kv cache slot")
  8. ErrNotSupported = errors.New("model does not support operation")
  9. )
  10. type Cache interface {
  11. // ** used by model implementations **
  12. // SetLayer sets the active layer of the cache
  13. SetLayer(layer int)
  14. // Get returns the history of key and value tensors plus a mask
  15. //
  16. // The shape of the tensors is documented in the specific
  17. // cache implementation used.
  18. Get(ctx ml.Context) (ml.Tensor, ml.Tensor, ml.Tensor)
  19. // Put stores a batch of key and value in the cache
  20. //
  21. // The shape of the tensors is documented in the specific
  22. // cache implementation used.
  23. Put(ctx ml.Context, key, value ml.Tensor)
  24. // ** cache management **
  25. // Init sets up runtime parameters
  26. Init(backend ml.Backend, dtype ml.DType, capacity int32)
  27. // Close closes the cache and frees resources associated with it
  28. Close()
  29. // StartForward is called before the start of the model's forward pass.
  30. // For each token in the coming batch, there must be a corresponding
  31. // entry in positions and seqs.
  32. StartForward(ctx ml.Context, positions []int32, seqs []int) error
  33. // CopyPrefix copies tokens in the range [0, len) from srcSeq to dstSeq
  34. CopyPrefix(srcSeq, dstSeq int, len int32)
  35. // Remove deletes tokens in the range [beginIndex, endIndex) from seq. Set
  36. // endIndex to math.MaxInt32 to remove everything starting at beginIndex.
  37. //
  38. // If an error occurs, the entire context for the sequence should be
  39. // removed by calling Remove(seq, 0, math.MaxInt32)
  40. Remove(seq int, beginIndex, endIndex int32) error
  41. }