cache.go 2.5 KB

1234567891011121314151617181920212223242526272829303132333435363738394041424344454647484950515253545556575859606162636465666768697071
  1. package kvcache
  2. import (
  3. "errors"
  4. "github.com/ollama/ollama/ml"
  5. "github.com/ollama/ollama/model/input"
  6. )
  7. var (
  8. ErrKvCacheFull = errors.New("could not find a kv cache slot")
  9. ErrNotSupported = errors.New("model does not support operation")
  10. )
  11. type Cache interface {
  12. // ** used by model implementations **
  13. // SetLayer sets the active layer of the cache
  14. SetLayer(layer int)
  15. // Get returns the history of key and value tensors plus a mask
  16. //
  17. // The shape of the tensors is documented in the specific
  18. // cache implementation used.
  19. Get(ctx ml.Context) (ml.Tensor, ml.Tensor, ml.Tensor)
  20. // Put stores a batch of key and value in the cache
  21. //
  22. // The shape of the tensors is documented in the specific
  23. // cache implementation used.
  24. Put(ctx ml.Context, key, value ml.Tensor)
  25. // SetConfig controls optimizations (mostly backend-specific) that may transform
  26. // the output of the cache to work better with specific kernels. If not called,
  27. // the backend settings will be used. This works well when calling Attention.
  28. //
  29. // The config can be overridden by models, especially if they require vanilla
  30. // output when implementing their own version of attention. To do this, pass
  31. // an empty ml.CacheConfig.
  32. //
  33. // Most models will not need to use this.
  34. SetConfig(ml.CacheConfig)
  35. // ** cache management **
  36. // Init sets up runtime parameters.
  37. // backend: Used to allocate cache data storage and execute management operations (such as defrag)
  38. // dtype: The data type for storing cache entries
  39. // maxSequences: The maximum number of sequences stored in the cache - across all batches
  40. // capacity: The number of cache entries to store, per sequence
  41. // maxBatch: The maximum number of tokens that can occur in a single batch
  42. Init(backend ml.Backend, dtype ml.DType, maxSequences, capacity, maxBatch int)
  43. // Close closes the cache and frees resources associated with it
  44. Close()
  45. // StartForward is called before the start of the model's forward pass.
  46. // For each token in the coming batch, there must be a corresponding
  47. // entry in positions and seqs.
  48. StartForward(ctx ml.Context, batch input.Batch) error
  49. // CopyPrefix copies tokens in the range [0, len) from srcSeq to dstSeq
  50. CopyPrefix(srcSeq, dstSeq int, len int32)
  51. // Remove deletes tokens in the range [beginIndex, endIndex) from seq. Set
  52. // endIndex to math.MaxInt32 to remove everything starting at beginIndex.
  53. //
  54. // If an error occurs, the entire context for the sequence should be
  55. // removed by calling Remove(seq, 0, math.MaxInt32)
  56. Remove(seq int, beginIndex, endIndex int32) error
  57. }