cache.go 1.5 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263
  1. package cache
  2. import (
  3. "github.com/ollama/ollama/ml"
  4. )
  5. type Options struct {
  6. Position int
  7. }
  8. type Cache interface {
  9. Sub(i int) Cache
  10. Put(ctx ml.Context, key, value ml.Tensor, opts Options) (ml.Tensor, ml.Tensor)
  11. }
  12. type Simple struct {
  13. DType ml.DType
  14. Capacity int
  15. keys, values []ml.Tensor
  16. }
  17. func (c *Simple) Sub(i int) Cache {
  18. if i >= len(c.keys) {
  19. c.keys = append(c.keys, make([]ml.Tensor, i-len(c.keys)+1)...)
  20. c.values = append(c.values, make([]ml.Tensor, i-len(c.values)+1)...)
  21. }
  22. return &Simple{
  23. keys: c.keys[i : i+1],
  24. values: c.values[i : i+1],
  25. Capacity: c.Capacity,
  26. DType: c.DType,
  27. }
  28. }
  29. func (c *Simple) Put(ctx ml.Context, key, value ml.Tensor, opts Options) (ml.Tensor, ml.Tensor) {
  30. if c.keys[0] == nil || c.values[0] == nil {
  31. c.keys[0] = ctx.Zeros(c.DType, int(key.Dim(0)*key.Dim(1))*c.Capacity)
  32. c.values[0] = ctx.Zeros(c.DType, int(value.Dim(0)*value.Dim(1))*c.Capacity)
  33. }
  34. ctx.Forward(key.Copy(ctx, c.keys[0].View(ctx, int(key.Stride(2))*opts.Position, int(key.Dim(0)*key.Dim(1)*key.Dim(2)))))
  35. ctx.Forward(value.Copy(ctx, c.values[0].View(ctx, int(value.Stride(2))*opts.Position, int(value.Dim(0)*value.Dim(1)*value.Dim(2)))))
  36. n := min(c.Capacity, int(key.Dim(2))+opts.Position)
  37. key = c.keys[0].View(ctx, 0,
  38. int(key.Dim(0)), int(key.Stride(1)),
  39. int(key.Dim(1)), int(key.Stride(2)),
  40. n,
  41. )
  42. value = c.values[0].View(ctx, 0,
  43. int(value.Dim(0)), int(value.Stride(1)),
  44. int(value.Dim(1)), int(value.Stride(2)),
  45. n,
  46. )
  47. // TODO shift context if necessary
  48. return key, value
  49. }