cache.go 2.1 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566
  1. package kvcache
  2. import (
  3. "errors"
  4. "github.com/ollama/ollama/ml"
  5. "github.com/ollama/ollama/model/input"
  6. )
  7. var (
  8. ErrKvCacheFull = errors.New("could not find a kv cache slot")
  9. ErrNotSupported = errors.New("model does not support operation")
  10. )
  11. type Cache interface {
  12. // ** used by model implementations **
  13. // SetLayer sets the active layer of the cache
  14. SetLayer(layer int)
  15. // Get returns the history of key and value tensors plus a mask
  16. //
  17. // The shape of the tensors is documented in the specific
  18. // cache implementation used.
  19. Get(ctx ml.Context) (ml.Tensor, ml.Tensor, ml.Tensor)
  20. // Put stores a batch of key and value in the cache
  21. //
  22. // The shape of the tensors is documented in the specific
  23. // cache implementation used.
  24. Put(ctx ml.Context, key, value ml.Tensor)
  25. // SetConfig controls optimizations (mostly backend-specific) that may transform
  26. // the output of the cache to work better with specific kernels. If not called,
  27. // the backend settings will be used. This works well when calling Attention.
  28. //
  29. // The config can be overridden by models, especially if they require vanilla
  30. // output when implementing their own version of attention. To do this, pass
  31. // an empty ml.CacheConfig.
  32. //
  33. // Most models will not need to use this.
  34. SetConfig(ml.CacheConfig)
  35. // ** cache management **
  36. // Init sets up runtime parameters
  37. Init(backend ml.Backend, dtype ml.DType, capacity int32)
  38. // Close closes the cache and frees resources associated with it
  39. Close()
  40. // StartForward is called before the start of the model's forward pass.
  41. // For each token in the coming batch, there must be a corresponding
  42. // entry in positions and seqs.
  43. StartForward(ctx ml.Context, batch input.Batch) error
  44. // CopyPrefix copies tokens in the range [0, len) from srcSeq to dstSeq
  45. CopyPrefix(srcSeq, dstSeq int, len int32)
  46. // Remove deletes tokens in the range [beginIndex, endIndex) from seq. Set
  47. // endIndex to math.MaxInt32 to remove everything starting at beginIndex.
  48. //
  49. // If an error occurs, the entire context for the sequence should be
  50. // removed by calling Remove(seq, 0, math.MaxInt32)
  51. Remove(seq int, beginIndex, endIndex int32) error
  52. }