wrapper.go 2.2 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899
  1. package kvcache
  2. import (
  3. "math"
  4. "github.com/ollama/ollama/ml"
  5. )
  6. // Wrapper cache is a container for multiple types of caches,
  7. // such as for the encoding and decoding portions of a model.
  8. type WrapperCache struct {
  9. // caches we are wrapping
  10. caches []Cache
  11. // cache to be used for this layer
  12. curType int
  13. }
  14. func NewWrapperCache(caches ...Cache) *WrapperCache {
  15. return &WrapperCache{
  16. caches: caches,
  17. }
  18. }
  19. func (c *WrapperCache) Init(backend ml.Backend, dtype ml.DType, capacity int32) {
  20. for _, cache := range c.caches {
  21. cache.Init(backend, dtype, capacity)
  22. }
  23. }
  24. func (c *WrapperCache) SetConfig(config ml.CacheConfig) {
  25. for _, cache := range c.caches {
  26. cache.SetConfig(config)
  27. }
  28. }
  29. func (c *WrapperCache) Close() {
  30. for _, cache := range c.caches {
  31. cache.Close()
  32. }
  33. }
  34. func (c *WrapperCache) StartForward(ctx ml.Context, positions []int32, seqs []int) error {
  35. for i, cache := range c.caches {
  36. err := cache.StartForward(ctx, positions, seqs)
  37. if err != nil {
  38. // unwind on error - Remove with endIndex set to math.MaxInt32 does not fail
  39. for j := i - 1; j >= 0; j-- {
  40. for k := range positions {
  41. _ = c.caches[j].Remove(seqs[k], positions[k], math.MaxInt32)
  42. }
  43. }
  44. return err
  45. }
  46. }
  47. c.curType = 0
  48. return nil
  49. }
  50. func (c *WrapperCache) SetLayer(layer int) {
  51. for _, cache := range c.caches {
  52. cache.SetLayer(layer)
  53. }
  54. }
  55. func (c *WrapperCache) SetLayerType(layerType int) {
  56. c.curType = layerType
  57. }
  58. func (c *WrapperCache) UnderlyingCache() Cache {
  59. return c.caches[c.curType]
  60. }
  61. func (c *WrapperCache) Get(ctx ml.Context) (ml.Tensor, ml.Tensor, ml.Tensor) {
  62. return c.caches[c.curType].Get(ctx)
  63. }
  64. func (c *WrapperCache) Put(ctx ml.Context, key, value ml.Tensor) {
  65. c.caches[c.curType].Put(ctx, key, value)
  66. }
  67. func (c *WrapperCache) CopyPrefix(srcSeq, dstSeq int, len int32) {
  68. for _, cache := range c.caches {
  69. cache.CopyPrefix(srcSeq, dstSeq, len)
  70. }
  71. }
  72. func (c *WrapperCache) Remove(seq int, beginIndex, endIndex int32) error {
  73. // If the one of these fails, the caller is supposed to retry with endIndex set to math.MaxInt32, which should not fail
  74. for _, cache := range c.caches {
  75. err := cache.Remove(seq, beginIndex, endIndex)
  76. if err != nil {
  77. return err
  78. }
  79. }
  80. return nil
  81. }