wrapper.go 2.3 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100
  1. package kvcache
  2. import (
  3. "math"
  4. "github.com/ollama/ollama/ml"
  5. "github.com/ollama/ollama/model/input"
  6. )
  7. // Wrapper cache is a container for multiple types of caches,
  8. // such as for the encoding and decoding portions of a model.
  9. type WrapperCache struct {
  10. // caches we are wrapping
  11. caches []Cache
  12. // cache to be used for this layer
  13. curType int
  14. }
  15. func NewWrapperCache(caches ...Cache) *WrapperCache {
  16. return &WrapperCache{
  17. caches: caches,
  18. }
  19. }
  20. func (c *WrapperCache) Init(backend ml.Backend, dtype ml.DType, maxSequences, capacity, maxBatch int) {
  21. for _, cache := range c.caches {
  22. cache.Init(backend, dtype, maxSequences, capacity, maxBatch)
  23. }
  24. }
  25. func (c *WrapperCache) SetConfig(config ml.CacheConfig) {
  26. for _, cache := range c.caches {
  27. cache.SetConfig(config)
  28. }
  29. }
  30. func (c *WrapperCache) Close() {
  31. for _, cache := range c.caches {
  32. cache.Close()
  33. }
  34. }
  35. func (c *WrapperCache) StartForward(ctx ml.Context, batch input.Batch) error {
  36. for i, cache := range c.caches {
  37. err := cache.StartForward(ctx, batch)
  38. if err != nil {
  39. // unwind on error - Remove with endIndex set to math.MaxInt32 does not fail
  40. for j := i - 1; j >= 0; j-- {
  41. for k := range batch.Positions {
  42. _ = c.caches[j].Remove(batch.Sequences[k], batch.Positions[k], math.MaxInt32)
  43. }
  44. }
  45. return err
  46. }
  47. }
  48. c.curType = 0
  49. return nil
  50. }
  51. func (c *WrapperCache) SetLayer(layer int) {
  52. for _, cache := range c.caches {
  53. cache.SetLayer(layer)
  54. }
  55. }
  56. func (c *WrapperCache) SetLayerType(layerType int) {
  57. c.curType = layerType
  58. }
  59. func (c *WrapperCache) UnderlyingCache() Cache {
  60. return c.caches[c.curType]
  61. }
  62. func (c *WrapperCache) Get(ctx ml.Context) (ml.Tensor, ml.Tensor, ml.Tensor) {
  63. return c.caches[c.curType].Get(ctx)
  64. }
  65. func (c *WrapperCache) Put(ctx ml.Context, key, value ml.Tensor) {
  66. c.caches[c.curType].Put(ctx, key, value)
  67. }
  68. func (c *WrapperCache) CopyPrefix(srcSeq, dstSeq int, len int32) {
  69. for _, cache := range c.caches {
  70. cache.CopyPrefix(srcSeq, dstSeq, len)
  71. }
  72. }
  73. func (c *WrapperCache) Remove(seq int, beginIndex, endIndex int32) error {
  74. // If the one of these fails, the caller is supposed to retry with endIndex set to math.MaxInt32, which should not fail
  75. for _, cache := range c.caches {
  76. err := cache.Remove(seq, beginIndex, endIndex)
  77. if err != nil {
  78. return err
  79. }
  80. }
  81. return nil
  82. }