cache.go 2.1 KB

1234567891011121314151617181920212223242526272829303132333435363738394041424344454647484950515253545556575859606162636465
  1. package kvcache
  2. import (
  3. "errors"
  4. "github.com/ollama/ollama/ml"
  5. )
  6. var (
  7. ErrKvCacheFull = errors.New("could not find a kv cache slot")
  8. ErrNotSupported = errors.New("model does not support operation")
  9. )
  10. type Cache interface {
  11. // ** used by model implementations **
  12. // SetLayer sets the active layer of the cache
  13. SetLayer(layer int)
  14. // Get returns the history of key and value tensors plus a mask
  15. //
  16. // The shape of the tensors is documented in the specific
  17. // cache implementation used.
  18. Get(ctx ml.Context) (ml.Tensor, ml.Tensor, ml.Tensor)
  19. // Put stores a batch of key and value in the cache
  20. //
  21. // The shape of the tensors is documented in the specific
  22. // cache implementation used.
  23. Put(ctx ml.Context, key, value ml.Tensor)
  24. // SetConfig controls optimizations (mostly backend-specific) that may transform
  25. // the output of the cache to work better with specific kernels. If not called,
  26. // the backend settings will be used. This works well when calling Attention.
  27. //
  28. // The config can be overridden by models, especially if they require vanilla
  29. // output when implementing their own version of attention. To do this, pass
  30. // an empty ml.CacheConfig.
  31. //
  32. // Most models will not need to use this.
  33. SetConfig(ml.CacheConfig)
  34. // ** cache management **
  35. // Init sets up runtime parameters
  36. Init(backend ml.Backend, dtype ml.DType, capacity int32)
  37. // Close closes the cache and frees resources associated with it
  38. Close()
  39. // StartForward is called before the start of the model's forward pass.
  40. // For each token in the coming batch, there must be a corresponding
  41. // entry in positions and seqs.
  42. StartForward(ctx ml.Context, positions []int32, seqs []int) error
  43. // CopyPrefix copies tokens in the range [0, len) from srcSeq to dstSeq
  44. CopyPrefix(srcSeq, dstSeq int, len int32)
  45. // Remove deletes tokens in the range [beginIndex, endIndex) from seq. Set
  46. // endIndex to math.MaxInt32 to remove everything starting at beginIndex.
  47. //
  48. // If an error occurs, the entire context for the sequence should be
  49. // removed by calling Remove(seq, 0, math.MaxInt32)
  50. Remove(seq int, beginIndex, endIndex int32) error
  51. }