123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557 |
- package kvcache
- import (
- "errors"
- "fmt"
- "log/slog"
- "math"
- "slices"
- "github.com/ollama/ollama/ml"
- )
- type shiftFn func(ctx ml.Context, layer int, key, shift ml.Tensor) (ml.Tensor, error)
- // Causal cache stores K and V tensors according to their position in the
- // sequence. Returns the history and a mask for attending to past tokens
- //
- // The tensors are of shape embed dim, kv heads, batch size
- // The mask is of shape history size, batch size
- type Causal struct {
- DType ml.DType
- Capacity int32
- windowSize int32
- // config controls mostly backend-specific optimizations
- config *ml.CacheConfig
- // ** current forward pass **
- // the active layer for Get and Put
- curLayer int
- // starting location for data storage for this batch
- curLoc int
- // size of the current batch
- curBatchSize int
- // mask of the cache as used by this batch
- curMask ml.Tensor
- // locations in the cache that are needed for this batch
- curCellRange cellRange
- // ** cache metadata **
- // for each possible location in the cache, stores the position and set of sequences
- // that reference the data there
- cells []cacheCell
- // maps from sequence to the range of locations where it is stored in the cache
- cellRanges map[int]cellRange
- // ** cache data storage **
- shiftFn shiftFn
- backend ml.Backend
- cacheCtx ml.Context
- keys, values []ml.Tensor
- }
- type cacheCell struct {
- pos int32
- sequences []int
- }
- type cellRange struct {
- min int
- max int
- }
- func NewCausalCache(shift shiftFn) *Causal {
- return &Causal{windowSize: math.MaxInt32, shiftFn: shift}
- }
- func NewSWACache(windowSize int32, shift shiftFn) *Causal {
- return &Causal{windowSize: windowSize, shiftFn: shift}
- }
- func (c *Causal) Init(backend ml.Backend, dtype ml.DType, capacity int32) {
- if c.config == nil {
- var config ml.CacheConfig
- if cc, ok := backend.(ml.BackendCacheConfig); ok {
- config = cc.CacheConfig()
- }
- c.config = &config
- }
- if c.config.CachePadding == 0 {
- c.config.CachePadding = 1
- }
- c.DType = dtype
- c.Capacity = int32(roundUp(int(capacity), c.config.CachePadding))
- c.cells = make([]cacheCell, c.Capacity)
- c.cellRanges = make(map[int]cellRange)
- c.backend = backend
- c.cacheCtx = backend.NewContext()
- }
- func (c *Causal) SetConfig(config ml.CacheConfig) {
- if c.config != nil {
- panic("config cannot be changed after being previously set, either by the model or backend")
- }
- c.config = &config
- }
- func (c *Causal) Close() {
- c.cacheCtx.Close()
- }
- func (c *Causal) StartForward(ctx ml.Context, positions []int32, seqs []int) error {
- c.curBatchSize = len(positions)
- var err error
- c.curLoc, err = c.findStartLoc()
- if errors.Is(err, ErrKvCacheFull) {
- c.defrag()
- c.curLoc, err = c.findStartLoc()
- }
- if err != nil {
- return err
- }
- c.curCellRange = newRange()
- for i, pos := range positions {
- seq := seqs[i]
- c.cells[c.curLoc+i] = cacheCell{pos: pos, sequences: []int{seq}}
- seqRange, ok := c.cellRanges[seq]
- if !ok {
- seqRange = newRange()
- }
- if c.curLoc+i > seqRange.max {
- seqRange.max = c.curLoc + i
- }
- if seqRange.max > c.curCellRange.max {
- c.curCellRange.max = seqRange.max
- }
- if c.curLoc+i < seqRange.min {
- seqRange.min = c.curLoc + i
- }
- if seqRange.min < c.curCellRange.min {
- c.curCellRange.min = seqRange.min
- }
- c.cellRanges[seq] = seqRange
- }
- c.curMask, err = c.buildMask(ctx, positions, seqs)
- return err
- }
- func newRange() cellRange {
- return cellRange{
- min: math.MaxInt,
- max: 0,
- }
- }
- // Find the first contiguous block of at least curBatchSize
- func (c *Causal) findStartLoc() (int, error) {
- var start, count int
- for i := range c.cells {
- if len(c.cells[i].sequences) == 0 {
- count++
- if count >= c.curBatchSize {
- return start, nil
- }
- } else {
- start = i + 1
- count = 0
- }
- }
- return 0, fmt.Errorf("%w (length: %v)", ErrKvCacheFull, c.Capacity)
- }
- func roundDown(length, pad int) int {
- return (length / pad) * pad
- }
- func roundUp(length, pad int) int {
- return ((length + pad - 1) / pad) * pad
- }
- // Builds a mask of history x batch indicating whether for each token in the batch the
- // token in the history should apply. This is based on both the sequence and causality (the
- // position of the history is not ahead of the token in the batch).
- func (c *Causal) buildMask(ctx ml.Context, positions []int32, seqs []int) (ml.Tensor, error) {
- // TODO(jessegross): This does not do mask padding, which is required for flash attention
- // Align and pad the cache range as required by the backend
- c.curCellRange.min = roundDown(c.curCellRange.min, c.config.CachePadding)
- c.curCellRange.max = roundUp(c.curCellRange.max+1, c.config.CachePadding) - 1
- length := c.curCellRange.max - c.curCellRange.min + 1
- mask := make([]float32, c.curBatchSize*length)
- for i := range c.curBatchSize {
- for j := c.curCellRange.min; j <= c.curCellRange.max; j++ {
- if !slices.Contains(c.cells[j].sequences, seqs[i]) || c.cells[j].pos > positions[i] ||
- c.cells[j].pos < positions[i]-c.windowSize {
- mask[i*length+(j-c.curCellRange.min)] = float32(math.Inf(-1))
- }
- }
- }
- return ctx.FromFloatSlice(mask, length, c.curBatchSize)
- }
- func (c *Causal) moveCells(ctx ml.Context, src, dst, len int) {
- for i := range c.keys {
- if c.keys[i] == nil {
- continue
- }
- key := c.keys[i]
- kHeadDim := key.Dim(0)
- numKVHeads := key.Dim(1)
- rowSize := key.Stride(2)
- kSrcView := key.View(ctx, rowSize*src, kHeadDim*numKVHeads*len)
- kDstView := key.View(ctx, rowSize*dst, kHeadDim*numKVHeads*len)
- value := c.values[i]
- var vSrcView, vDstView ml.Tensor
- if c.config.PermutedV {
- vHeadDim := value.Dim(1)
- elemSize := value.Stride(0)
- vSrcView = value.View(ctx, elemSize*src, len, int(c.Capacity)*elemSize, vHeadDim*numKVHeads)
- vDstView = value.View(ctx, elemSize*dst, len, int(c.Capacity)*elemSize, vHeadDim*numKVHeads)
- } else {
- vHeadDim := value.Dim(0)
- rowSize := value.Stride(2)
- vSrcView = value.View(ctx, rowSize*src, vHeadDim*numKVHeads*len)
- vDstView = value.View(ctx, rowSize*dst, vHeadDim*numKVHeads*len)
- }
- ctx.Forward(
- kSrcView.Copy(ctx, kDstView),
- vSrcView.Copy(ctx, vDstView),
- )
- }
- }
- func (c *Causal) defrag() {
- slog.Debug("defragmenting kv cache")
- // Defrag strategy:
- // - Search for empty holes at the beginning of the cache,
- // filling them with active data starting at the end
- // - If there are contiguous elements that need to be moved,
- // combine them into a single operation by holding new moves
- // until we see that the next one is non-contiguous
- // - Fill up the context with the maximum number of operations it
- // can hold then compute that and continue with a new context
- //
- // We could try to optimize placement by grouping blocks from
- // the same sequences together but most likely the next forward
- // pass will disrupt this anyways, so the real world benefit
- // seems limited as this time.
- ctx := c.backend.NewContext()
- // For every move, 6 tensors are required per layer (2 views and a
- // copy for each of k and v).
- layers := 0
- for _, key := range c.keys {
- if key == nil {
- continue
- }
- layers++
- }
- maxMoves := ctx.MaxTensors() / (6 * layers)
- moves := 0
- var pendingSrc, pendingDst, pendingLen int
- src := len(c.cells) - 1
- for dst := 0; dst < src; dst++ {
- if len(c.cells[dst].sequences) == 0 {
- for ; src > dst; src-- {
- if len(c.cells[src].sequences) != 0 {
- c.cells[dst] = c.cells[src]
- c.cells[src] = cacheCell{}
- if pendingLen > 0 {
- if src == pendingSrc-pendingLen && dst == pendingDst+pendingLen {
- pendingSrc = src
- pendingLen++
- break
- } else {
- c.moveCells(ctx, pendingSrc, pendingDst, pendingLen)
- moves++
- }
- }
- pendingSrc = src
- pendingDst = dst
- pendingLen = 1
- break
- }
- }
- }
- if moves >= maxMoves {
- ctx.Compute()
- ctx.Close()
- ctx = c.backend.NewContext()
- moves = 0
- }
- }
- if pendingLen > 0 {
- c.moveCells(ctx, pendingSrc, pendingDst, pendingLen)
- moves++
- }
- if moves > 0 {
- ctx.Compute()
- }
- ctx.Close()
- // Reset range metadata
- for seq := range c.cellRanges {
- seqRange := newRange()
- for i, cell := range c.cells {
- if slices.Contains(cell.sequences, seq) {
- if i < seqRange.min {
- seqRange.min = i
- }
- if i > seqRange.max {
- seqRange.max = i
- }
- }
- }
- c.cellRanges[seq] = seqRange
- }
- }
- func (c *Causal) SetLayer(layer int) {
- if layer >= len(c.keys) {
- c.keys = append(c.keys, make([]ml.Tensor, layer-len(c.keys)+1)...)
- c.values = append(c.values, make([]ml.Tensor, layer-len(c.values)+1)...)
- }
- c.curLayer = layer
- }
- func (c *Causal) Get(ctx ml.Context) (ml.Tensor, ml.Tensor, ml.Tensor) {
- key := c.keys[c.curLayer]
- value := c.values[c.curLayer]
- kHeadDim := key.Dim(0)
- numKVHeads := key.Dim(1)
- rowSize := key.Stride(2)
- cachedSize := c.curMask.Dim(0)
- key = key.View(ctx, rowSize*c.curCellRange.min,
- kHeadDim, key.Stride(1),
- numKVHeads, key.Stride(2),
- cachedSize,
- )
- if c.config.PermutedV {
- vHeadDim := value.Dim(1)
- elemSize := value.Stride(0)
- value = value.View(ctx, elemSize*c.curCellRange.min,
- cachedSize, value.Stride(1),
- vHeadDim, value.Stride(2),
- numKVHeads,
- )
- } else {
- vHeadDim := value.Dim(0)
- rowSize := value.Stride(2)
- value = value.View(ctx, rowSize*c.curCellRange.min,
- vHeadDim, value.Stride(1),
- numKVHeads, value.Stride(2),
- cachedSize,
- )
- }
- return key, value, c.curMask
- }
- func (c *Causal) Put(ctx ml.Context, key, value ml.Tensor) {
- kHeadDim := key.Dim(0)
- vHeadDim := value.Dim(0)
- numKVHeads := key.Dim(1)
- batchSize := key.Dim(2)
- if c.curBatchSize != batchSize {
- panic(fmt.Errorf("inconsistent batch sizes (layer: %v, batch size: %v layer batch size: %v)", c.curLayer, c.curBatchSize, batchSize))
- }
- if c.keys[c.curLayer] == nil || c.values[c.curLayer] == nil {
- c.keys[c.curLayer] = c.cacheCtx.Zeros(c.DType, kHeadDim, numKVHeads, int(c.Capacity))
- if c.config.PermutedV {
- c.values[c.curLayer] = c.cacheCtx.Zeros(c.DType, int(c.Capacity), vHeadDim, numKVHeads)
- } else {
- c.values[c.curLayer] = c.cacheCtx.Zeros(c.DType, vHeadDim, numKVHeads, int(c.Capacity))
- }
- }
- rowSize := c.keys[c.curLayer].Stride(2)
- ctx.Forward(key.Copy(ctx, c.keys[c.curLayer].View(ctx, rowSize*c.curLoc, kHeadDim*numKVHeads*batchSize)))
- if c.config.PermutedV {
- elemSize := c.values[c.curLayer].Stride(0)
- value = value.Permute(ctx, 1, 2, 0, 3)
- ctx.Forward(value.Copy(ctx, c.values[c.curLayer].View(ctx, elemSize*c.curLoc, batchSize, int(c.Capacity)*elemSize, vHeadDim*numKVHeads)))
- } else {
- rowSize := c.values[c.curLayer].Stride(2)
- ctx.Forward(value.Copy(ctx, c.values[c.curLayer].View(ctx, rowSize*c.curLoc, vHeadDim*numKVHeads*batchSize)))
- }
- }
- func (c *Causal) CopyPrefix(srcSeq, dstSeq int, len int32) {
- seqRange := newRange()
- for i := range c.cells {
- // Remove the contents of dstSeq so that we only have the copied prefix, metadata will be reset at the end
- if slices.Contains(c.cells[i].sequences, dstSeq) {
- c.cells[i].sequences = slices.DeleteFunc(c.cells[i].sequences, func(s int) bool { return s == dstSeq })
- }
- if slices.Contains(c.cells[i].sequences, srcSeq) && c.cells[i].pos < len {
- c.cells[i].sequences = append(c.cells[i].sequences, dstSeq)
- if i < seqRange.min {
- seqRange.min = i
- }
- if i > seqRange.max {
- seqRange.max = i
- }
- }
- }
- c.cellRanges[dstSeq] = seqRange
- }
- func (c *Causal) shift(seq int, beginIndex, offset int32) error {
- if c.shiftFn == nil {
- return ErrNotSupported
- }
- ctx := c.backend.NewContext()
- defer ctx.Close()
- seqRange := c.cellRanges[seq]
- size := seqRange.max - seqRange.min + 1
- offsets := make([]int32, size)
- for i := range offsets {
- cell := c.cells[seqRange.min+i]
- if slices.Contains(cell.sequences, seq) && cell.pos >= beginIndex {
- offsets[i] = offset
- }
- }
- kShift, err := ctx.FromIntSlice(offsets, len(offsets))
- if err != nil {
- return err
- }
- for i, key := range c.keys {
- if key == nil {
- continue
- }
- kHeadDim := key.Dim(0)
- numKVHeads := key.Dim(1)
- rowSize := key.Stride(2)
- key = key.View(ctx, rowSize*seqRange.min,
- kHeadDim, key.Stride(1),
- numKVHeads, key.Stride(2),
- size,
- )
- roped, err := c.shiftFn(ctx, i, key, kShift)
- if err != nil {
- return err
- }
- ctx.Forward(roped.Copy(ctx, key))
- }
- ctx.Compute()
- return nil
- }
- func (c *Causal) Remove(seq int, beginIndex, endIndex int32) error {
- var offset int32
- if endIndex != math.MaxInt32 {
- offset = beginIndex - endIndex
- }
- seqRange := newRange()
- for i := range c.cells {
- if slices.Contains(c.cells[i].sequences, seq) {
- if c.cells[i].pos >= beginIndex && c.cells[i].pos < endIndex {
- c.cells[i].sequences = slices.DeleteFunc(c.cells[i].sequences, func(s int) bool { return s == seq })
- } else {
- if c.cells[i].pos >= endIndex {
- if slices.ContainsFunc(c.cells[i].sequences, func(s int) bool { return s != seq }) {
- // TODO(jessegross): Need to be careful about data shared between sequences
- return errors.New("shifting on cells shared by multiple sequences not yet implemented")
- }
- c.cells[i].pos += offset
- }
- if i < seqRange.min {
- seqRange.min = i
- }
- if i > seqRange.max {
- seqRange.max = i
- }
- }
- }
- }
- if seqRange == newRange() {
- delete(c.cellRanges, seq)
- return nil
- }
- c.cellRanges[seq] = seqRange
- if endIndex != math.MaxInt32 {
- err := c.shift(seq, endIndex+offset, offset)
- if err != nil {
- return err
- }
- }
- return nil
- }
|