tensor.go 1.1 KB

1234567891011121314151617181920212223242526272829303132333435363738394041424344454647
  1. package cache
  2. import (
  3. "github.com/ollama/ollama/ml"
  4. )
  5. type TensorCache struct {
  6. curLayer int
  7. cacheCtx ml.Context
  8. keys, values []ml.Tensor
  9. }
  10. func NewTensorCache(backend ml.Backend) *TensorCache {
  11. return &TensorCache{
  12. cacheCtx: backend.NewContext(),
  13. }
  14. }
  15. func (c *TensorCache) Close() {
  16. c.cacheCtx.Close()
  17. }
  18. func (c *TensorCache) Sub(i int) *TensorCache {
  19. if i >= len(c.keys) {
  20. c.keys = append(c.keys, make([]ml.Tensor, i-len(c.keys)+1)...)
  21. c.values = append(c.values, make([]ml.Tensor, i-len(c.values)+1)...)
  22. }
  23. c.curLayer = i
  24. return c
  25. }
  26. func (c *TensorCache) Get(ctx ml.Context) (ml.Tensor, ml.Tensor, ml.Tensor) {
  27. return c.keys[c.curLayer], c.values[c.curLayer], nil
  28. }
  29. func (c *TensorCache) Put(ctx ml.Context, key, value ml.Tensor) {
  30. if c.keys[c.curLayer] == nil || c.values[c.curLayer] == nil {
  31. c.keys[c.curLayer] = c.cacheCtx.Zeros(key.DType(), key.Shape()...)
  32. c.values[c.curLayer] = c.cacheCtx.Zeros(value.DType(), value.Shape()...)
  33. }
  34. ctx.Forward(key.Copy(ctx, c.keys[c.curLayer]))
  35. ctx.Forward(value.Copy(ctx, c.values[c.curLayer]))
  36. }