123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263 |
- package cache
- import (
- "github.com/ollama/ollama/ml"
- )
- type Options struct {
- Position int
- }
- type Cache interface {
- Sub(i int) Cache
- Put(ctx ml.Context, key, value ml.Tensor, opts Options) (ml.Tensor, ml.Tensor)
- }
- type Simple struct {
- DType ml.DType
- Capacity int
- keys, values []ml.Tensor
- }
- func (c *Simple) Sub(i int) Cache {
- if i >= len(c.keys) {
- c.keys = append(c.keys, make([]ml.Tensor, i-len(c.keys)+1)...)
- c.values = append(c.values, make([]ml.Tensor, i-len(c.values)+1)...)
- }
- return &Simple{
- keys: c.keys[i : i+1],
- values: c.values[i : i+1],
- Capacity: c.Capacity,
- DType: c.DType,
- }
- }
- func (c *Simple) Put(ctx ml.Context, key, value ml.Tensor, opts Options) (ml.Tensor, ml.Tensor) {
- if c.keys[0] == nil || c.values[0] == nil {
- c.keys[0] = ctx.Zeros(c.DType, int(key.Dim(0)*key.Dim(1))*c.Capacity)
- c.values[0] = ctx.Zeros(c.DType, int(value.Dim(0)*value.Dim(1))*c.Capacity)
- }
- ctx.Forward(key.Copy(ctx, c.keys[0].View(ctx, int(key.Stride(2))*opts.Position, int(key.Dim(0)*key.Dim(1)*key.Dim(2)))))
- ctx.Forward(value.Copy(ctx, c.values[0].View(ctx, int(value.Stride(2))*opts.Position, int(value.Dim(0)*value.Dim(1)*value.Dim(2)))))
- n := min(c.Capacity, int(key.Dim(2))+opts.Position)
- key = c.keys[0].View(ctx, 0,
- int(key.Dim(0)), int(key.Stride(1)),
- int(key.Dim(1)), int(key.Stride(2)),
- n,
- )
- value = c.values[0].View(ctx, 0,
- int(value.Dim(0)), int(value.Stride(1)),
- int(value.Dim(1)), int(value.Stride(2)),
- n,
- )
- // TODO shift context if necessary
- return key, value
- }
|