wrapper.go 2.1 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293
  1. package kvcache
  2. import (
  3. "math"
  4. "github.com/ollama/ollama/ml"
  5. )
  6. // Wrapper cache is a container for multiple types of caches,
  7. // such as for the encoding and decoding portions of a model.
  8. type WrapperCache struct {
  9. // caches we are wrapping
  10. caches []Cache
  11. // cache to be used for this layer
  12. curType int
  13. }
  14. func NewWrapperCache(caches ...Cache) *WrapperCache {
  15. return &WrapperCache{
  16. caches: caches,
  17. }
  18. }
  19. func (c *WrapperCache) Init(backend ml.Backend, dtype ml.DType, capacity int32) {
  20. for _, cache := range c.caches {
  21. cache.Init(backend, dtype, capacity)
  22. }
  23. }
  24. func (c *WrapperCache) Close() {
  25. for _, cache := range c.caches {
  26. cache.Close()
  27. }
  28. }
  29. func (c *WrapperCache) StartForward(ctx ml.Context, positions []int32, seqs []int) error {
  30. for i, cache := range c.caches {
  31. err := cache.StartForward(ctx, positions, seqs)
  32. if err != nil {
  33. // unwind on error - Remove with endIndex set to math.MaxInt32 does not fail
  34. for j := i - 1; j >= 0; j-- {
  35. for k := range positions {
  36. _ = c.caches[j].Remove(seqs[k], positions[k], math.MaxInt32)
  37. }
  38. }
  39. return err
  40. }
  41. }
  42. c.curType = 0
  43. return nil
  44. }
  45. func (c *WrapperCache) SetLayer(layer int) {
  46. for _, cache := range c.caches {
  47. cache.SetLayer(layer)
  48. }
  49. }
  50. func (c *WrapperCache) SetLayerType(layerType int) {
  51. c.curType = layerType
  52. }
  53. func (c *WrapperCache) UnderlyingCache() Cache {
  54. return c.caches[c.curType]
  55. }
  56. func (c *WrapperCache) Get(ctx ml.Context) (ml.Tensor, ml.Tensor, ml.Tensor) {
  57. return c.caches[c.curType].Get(ctx)
  58. }
  59. func (c *WrapperCache) Put(ctx ml.Context, key, value ml.Tensor) {
  60. c.caches[c.curType].Put(ctx, key, value)
  61. }
  62. func (c *WrapperCache) CopyPrefix(srcSeq, dstSeq int, len int32) {
  63. for _, cache := range c.caches {
  64. cache.CopyPrefix(srcSeq, dstSeq, len)
  65. }
  66. }
  67. func (c *WrapperCache) Remove(seq int, beginIndex, endIndex int32) error {
  68. // If the one of these fails, the caller is supposed to retry with endIndex set to math.MaxInt32, which should not fail
  69. for _, cache := range c.caches {
  70. err := cache.Remove(seq, beginIndex, endIndex)
  71. if err != nil {
  72. return err
  73. }
  74. }
  75. return nil
  76. }