tensor.go 1.1 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748
  1. package cache
  2. import (
  3. "github.com/ollama/ollama/ml"
  4. )
  5. type TensorCache struct {
  6. curLayer int
  7. cacheCtx ml.Context
  8. keys, values []ml.Tensor
  9. }
  10. func NewTensorCache(backend ml.Backend) *TensorCache {
  11. return &TensorCache{
  12. // TODO(jessegross): This context is not sized appropriately
  13. cacheCtx: backend.NewContext(),
  14. }
  15. }
  16. func (c *TensorCache) Close() {
  17. c.cacheCtx.Close()
  18. }
  19. func (c *TensorCache) Sub(i int) *TensorCache {
  20. if i >= len(c.keys) {
  21. c.keys = append(c.keys, make([]ml.Tensor, i-len(c.keys)+1)...)
  22. c.values = append(c.values, make([]ml.Tensor, i-len(c.values)+1)...)
  23. }
  24. c.curLayer = i
  25. return c
  26. }
  27. func (c *TensorCache) Get(ctx ml.Context) (ml.Tensor, ml.Tensor, ml.Tensor) {
  28. return c.keys[c.curLayer], c.values[c.curLayer], nil
  29. }
  30. func (c *TensorCache) Put(ctx ml.Context, key, value ml.Tensor) {
  31. if c.keys[c.curLayer] == nil || c.values[c.curLayer] == nil {
  32. c.keys[c.curLayer] = c.cacheCtx.Zeros(key.DType(), key.Shape()...)
  33. c.values[c.curLayer] = c.cacheCtx.Zeros(value.DType(), value.Shape()...)
  34. }
  35. ctx.Forward(key.Copy(ctx, c.keys[c.curLayer]))
  36. ctx.Forward(value.Copy(ctx, c.values[c.curLayer]))
  37. }