|
@@ -22,6 +22,9 @@ type Causal struct {
|
|
Capacity int32
|
|
Capacity int32
|
|
windowSize int32
|
|
windowSize int32
|
|
|
|
|
|
|
|
+ // config controls mostly backend-specific optimizations
|
|
|
|
+ config *ml.CacheConfig
|
|
|
|
+
|
|
// ** current forward pass **
|
|
// ** current forward pass **
|
|
|
|
|
|
// the active layer for Get and Put
|
|
// the active layer for Get and Put
|
|
@@ -75,14 +78,34 @@ func NewSWACache(windowSize int32, shift shiftFn) *Causal {
|
|
}
|
|
}
|
|
|
|
|
|
func (c *Causal) Init(backend ml.Backend, dtype ml.DType, capacity int32) {
|
|
func (c *Causal) Init(backend ml.Backend, dtype ml.DType, capacity int32) {
|
|
|
|
+ if c.config == nil {
|
|
|
|
+ var config ml.CacheConfig
|
|
|
|
+ if cc, ok := backend.(ml.BackendCacheConfig); ok {
|
|
|
|
+ config = cc.CacheConfig()
|
|
|
|
+ }
|
|
|
|
+ c.config = &config
|
|
|
|
+ }
|
|
|
|
+
|
|
|
|
+ if c.config.CachePadding == 0 {
|
|
|
|
+ c.config.CachePadding = 1
|
|
|
|
+ }
|
|
|
|
+
|
|
c.DType = dtype
|
|
c.DType = dtype
|
|
- c.Capacity = capacity
|
|
|
|
- c.cells = make([]cacheCell, capacity)
|
|
|
|
|
|
+ c.Capacity = int32(roundUp(int(capacity), c.config.CachePadding))
|
|
|
|
+ c.cells = make([]cacheCell, c.Capacity)
|
|
c.cellRanges = make(map[int]cellRange)
|
|
c.cellRanges = make(map[int]cellRange)
|
|
c.backend = backend
|
|
c.backend = backend
|
|
c.cacheCtx = backend.NewContext()
|
|
c.cacheCtx = backend.NewContext()
|
|
}
|
|
}
|
|
|
|
|
|
|
|
+func (c *Causal) SetConfig(config ml.CacheConfig) {
|
|
|
|
+ if c.config != nil {
|
|
|
|
+ panic("config cannot be changed after being previously set, either by the model or backend")
|
|
|
|
+ }
|
|
|
|
+
|
|
|
|
+ c.config = &config
|
|
|
|
+}
|
|
|
|
+
|
|
func (c *Causal) Close() {
|
|
func (c *Causal) Close() {
|
|
c.cacheCtx.Close()
|
|
c.cacheCtx.Close()
|
|
}
|
|
}
|
|
@@ -157,36 +180,73 @@ func (c *Causal) findStartLoc() (int, error) {
|
|
return 0, fmt.Errorf("%w (length: %v)", ErrKvCacheFull, c.Capacity)
|
|
return 0, fmt.Errorf("%w (length: %v)", ErrKvCacheFull, c.Capacity)
|
|
}
|
|
}
|
|
|
|
|
|
|
|
+func roundDown(length, pad int) int {
|
|
|
|
+ return (length / pad) * pad
|
|
|
|
+}
|
|
|
|
+
|
|
|
|
+func roundUp(length, pad int) int {
|
|
|
|
+ return ((length + pad - 1) / pad) * pad
|
|
|
|
+}
|
|
|
|
+
|
|
// Builds a mask of history x batch indicating whether for each token in the batch the
|
|
// Builds a mask of history x batch indicating whether for each token in the batch the
|
|
// token in the history should apply. This is based on both the sequence and causality (the
|
|
// token in the history should apply. This is based on both the sequence and causality (the
|
|
// position of the history is not ahead of the token in the batch).
|
|
// position of the history is not ahead of the token in the batch).
|
|
func (c *Causal) buildMask(ctx ml.Context, positions []int32, seqs []int) (ml.Tensor, error) {
|
|
func (c *Causal) buildMask(ctx ml.Context, positions []int32, seqs []int) (ml.Tensor, error) {
|
|
- // TODO(jessegross): This does not do padding, which is required for flash attention
|
|
|
|
- len := c.curCellRange.max - c.curCellRange.min + 1
|
|
|
|
- mask := make([]float32, c.curBatchSize*len)
|
|
|
|
|
|
+ // TODO(jessegross): This does not do mask padding, which is required for flash attention
|
|
|
|
+ // Align and pad the cache range as required by the backend
|
|
|
|
+ c.curCellRange.min = roundDown(c.curCellRange.min, c.config.CachePadding)
|
|
|
|
+ c.curCellRange.max = roundUp(c.curCellRange.max+1, c.config.CachePadding) - 1
|
|
|
|
+
|
|
|
|
+ length := c.curCellRange.max - c.curCellRange.min + 1
|
|
|
|
+ mask := make([]float32, c.curBatchSize*length)
|
|
|
|
|
|
for i := range c.curBatchSize {
|
|
for i := range c.curBatchSize {
|
|
for j := c.curCellRange.min; j <= c.curCellRange.max; j++ {
|
|
for j := c.curCellRange.min; j <= c.curCellRange.max; j++ {
|
|
if !slices.Contains(c.cells[j].sequences, seqs[i]) || c.cells[j].pos > positions[i] ||
|
|
if !slices.Contains(c.cells[j].sequences, seqs[i]) || c.cells[j].pos > positions[i] ||
|
|
c.cells[j].pos < positions[i]-c.windowSize {
|
|
c.cells[j].pos < positions[i]-c.windowSize {
|
|
- mask[i*len+(j-c.curCellRange.min)] = float32(math.Inf(-1))
|
|
|
|
|
|
+ mask[i*length+(j-c.curCellRange.min)] = float32(math.Inf(-1))
|
|
}
|
|
}
|
|
}
|
|
}
|
|
}
|
|
}
|
|
|
|
|
|
- return ctx.FromFloatSlice(mask, len, c.curBatchSize)
|
|
|
|
|
|
+ return ctx.FromFloatSlice(mask, length, c.curBatchSize)
|
|
}
|
|
}
|
|
|
|
|
|
-func moveCell(ctx ml.Context, objs []ml.Tensor, src, dst, len int) {
|
|
|
|
- for _, obj := range objs {
|
|
|
|
- if obj == nil {
|
|
|
|
|
|
+func (c *Causal) moveCells(ctx ml.Context, src, dst, len int) {
|
|
|
|
+ for i := range c.keys {
|
|
|
|
+ if c.keys[i] == nil {
|
|
continue
|
|
continue
|
|
}
|
|
}
|
|
|
|
|
|
- srcView := obj.View(ctx, obj.Stride(2)*src, obj.Dim(0)*obj.Dim(1)*len)
|
|
|
|
- dstView := obj.View(ctx, obj.Stride(2)*dst, obj.Dim(0)*obj.Dim(1)*len)
|
|
|
|
|
|
+ key := c.keys[i]
|
|
|
|
+
|
|
|
|
+ kHeadDim := key.Dim(0)
|
|
|
|
+ numKVHeads := key.Dim(1)
|
|
|
|
+ rowSize := key.Stride(2)
|
|
|
|
+
|
|
|
|
+ kSrcView := key.View(ctx, rowSize*src, kHeadDim*numKVHeads*len)
|
|
|
|
+ kDstView := key.View(ctx, rowSize*dst, kHeadDim*numKVHeads*len)
|
|
|
|
+
|
|
|
|
+ value := c.values[i]
|
|
|
|
+ var vSrcView, vDstView ml.Tensor
|
|
|
|
+ if c.config.PermutedV {
|
|
|
|
+ vHeadDim := value.Dim(1)
|
|
|
|
+ elemSize := value.Stride(0)
|
|
|
|
+
|
|
|
|
+ vSrcView = value.View(ctx, elemSize*src, len, int(c.Capacity)*elemSize, vHeadDim*numKVHeads)
|
|
|
|
+ vDstView = value.View(ctx, elemSize*dst, len, int(c.Capacity)*elemSize, vHeadDim*numKVHeads)
|
|
|
|
+ } else {
|
|
|
|
+ vHeadDim := value.Dim(0)
|
|
|
|
+ rowSize := value.Stride(2)
|
|
|
|
|
|
- ctx.Forward(srcView.Copy(ctx, dstView))
|
|
|
|
|
|
+ vSrcView = value.View(ctx, rowSize*src, vHeadDim*numKVHeads*len)
|
|
|
|
+ vDstView = value.View(ctx, rowSize*dst, vHeadDim*numKVHeads*len)
|
|
|
|
+ }
|
|
|
|
+
|
|
|
|
+ ctx.Forward(
|
|
|
|
+ kSrcView.Copy(ctx, kDstView),
|
|
|
|
+ vSrcView.Copy(ctx, vDstView),
|
|
|
|
+ )
|
|
}
|
|
}
|
|
}
|
|
}
|
|
|
|
|
|
@@ -238,8 +298,7 @@ func (c *Causal) defrag() {
|
|
pendingLen++
|
|
pendingLen++
|
|
break
|
|
break
|
|
} else {
|
|
} else {
|
|
- moveCell(ctx, c.keys, pendingSrc, pendingDst, pendingLen)
|
|
|
|
- moveCell(ctx, c.values, pendingSrc, pendingDst, pendingLen)
|
|
|
|
|
|
+ c.moveCells(ctx, pendingSrc, pendingDst, pendingLen)
|
|
moves++
|
|
moves++
|
|
}
|
|
}
|
|
}
|
|
}
|
|
@@ -263,8 +322,7 @@ func (c *Causal) defrag() {
|
|
}
|
|
}
|
|
|
|
|
|
if pendingLen > 0 {
|
|
if pendingLen > 0 {
|
|
- moveCell(ctx, c.keys, pendingSrc, pendingDst, pendingLen)
|
|
|
|
- moveCell(ctx, c.values, pendingSrc, pendingDst, pendingLen)
|
|
|
|
|
|
+ c.moveCells(ctx, pendingSrc, pendingDst, pendingLen)
|
|
moves++
|
|
moves++
|
|
}
|
|
}
|
|
|
|
|
|
@@ -305,35 +363,73 @@ func (c *Causal) Get(ctx ml.Context) (ml.Tensor, ml.Tensor, ml.Tensor) {
|
|
key := c.keys[c.curLayer]
|
|
key := c.keys[c.curLayer]
|
|
value := c.values[c.curLayer]
|
|
value := c.values[c.curLayer]
|
|
|
|
|
|
- key = key.View(ctx, key.Stride(2)*c.curCellRange.min,
|
|
|
|
- key.Dim(0), key.Stride(1),
|
|
|
|
- key.Dim(1), key.Stride(2),
|
|
|
|
- c.curMask.Dim(0),
|
|
|
|
- )
|
|
|
|
|
|
+ kHeadDim := key.Dim(0)
|
|
|
|
+ numKVHeads := key.Dim(1)
|
|
|
|
+ rowSize := key.Stride(2)
|
|
|
|
+ cachedSize := c.curMask.Dim(0)
|
|
|
|
|
|
- value = value.View(ctx, key.Stride(2)*c.curCellRange.min,
|
|
|
|
- value.Dim(0), value.Stride(1),
|
|
|
|
- value.Dim(1), value.Stride(2),
|
|
|
|
- c.curMask.Dim(0),
|
|
|
|
|
|
+ key = key.View(ctx, rowSize*c.curCellRange.min,
|
|
|
|
+ kHeadDim, key.Stride(1),
|
|
|
|
+ numKVHeads, key.Stride(2),
|
|
|
|
+ cachedSize,
|
|
)
|
|
)
|
|
|
|
|
|
|
|
+ if c.config.PermutedV {
|
|
|
|
+ vHeadDim := value.Dim(1)
|
|
|
|
+ elemSize := value.Stride(0)
|
|
|
|
+
|
|
|
|
+ value = value.View(ctx, elemSize*c.curCellRange.min,
|
|
|
|
+ cachedSize, value.Stride(1),
|
|
|
|
+ vHeadDim, value.Stride(2),
|
|
|
|
+ numKVHeads,
|
|
|
|
+ )
|
|
|
|
+ } else {
|
|
|
|
+ vHeadDim := value.Dim(0)
|
|
|
|
+ rowSize := value.Stride(2)
|
|
|
|
+
|
|
|
|
+ value = value.View(ctx, rowSize*c.curCellRange.min,
|
|
|
|
+ vHeadDim, value.Stride(1),
|
|
|
|
+ numKVHeads, value.Stride(2),
|
|
|
|
+ cachedSize,
|
|
|
|
+ )
|
|
|
|
+ }
|
|
|
|
+
|
|
return key, value, c.curMask
|
|
return key, value, c.curMask
|
|
}
|
|
}
|
|
|
|
|
|
func (c *Causal) Put(ctx ml.Context, key, value ml.Tensor) {
|
|
func (c *Causal) Put(ctx ml.Context, key, value ml.Tensor) {
|
|
- if c.curBatchSize != key.Dim(2) {
|
|
|
|
- panic(fmt.Errorf("inconsistent batch sizes (layer: %v, batch size: %v layer batch size: %v)", c.curLayer, c.curBatchSize, key.Dim(2)))
|
|
|
|
|
|
+ kHeadDim := key.Dim(0)
|
|
|
|
+ vHeadDim := value.Dim(0)
|
|
|
|
+ numKVHeads := key.Dim(1)
|
|
|
|
+ batchSize := key.Dim(2)
|
|
|
|
+
|
|
|
|
+ if c.curBatchSize != batchSize {
|
|
|
|
+ panic(fmt.Errorf("inconsistent batch sizes (layer: %v, batch size: %v layer batch size: %v)", c.curLayer, c.curBatchSize, batchSize))
|
|
}
|
|
}
|
|
|
|
|
|
if c.keys[c.curLayer] == nil || c.values[c.curLayer] == nil {
|
|
if c.keys[c.curLayer] == nil || c.values[c.curLayer] == nil {
|
|
- c.keys[c.curLayer] = c.cacheCtx.Zeros(c.DType, key.Dim(0), key.Dim(1), int(c.Capacity))
|
|
|
|
- c.values[c.curLayer] = c.cacheCtx.Zeros(c.DType, value.Dim(0), value.Dim(1), int(c.Capacity))
|
|
|
|
|
|
+ c.keys[c.curLayer] = c.cacheCtx.Zeros(c.DType, kHeadDim, numKVHeads, int(c.Capacity))
|
|
|
|
+
|
|
|
|
+ if c.config.PermutedV {
|
|
|
|
+ c.values[c.curLayer] = c.cacheCtx.Zeros(c.DType, int(c.Capacity), vHeadDim, numKVHeads)
|
|
|
|
+ } else {
|
|
|
|
+ c.values[c.curLayer] = c.cacheCtx.Zeros(c.DType, vHeadDim, numKVHeads, int(c.Capacity))
|
|
|
|
+ }
|
|
}
|
|
}
|
|
|
|
|
|
- ctx.Forward(
|
|
|
|
- key.Copy(ctx, c.keys[c.curLayer].View(ctx, c.keys[c.curLayer].Stride(2)*c.curLoc, key.Dim(0)*key.Dim(1)*key.Dim(2))),
|
|
|
|
- value.Copy(ctx, c.values[c.curLayer].View(ctx, c.values[c.curLayer].Stride(2)*c.curLoc, value.Dim(0)*value.Dim(1)*value.Dim(2))),
|
|
|
|
- )
|
|
|
|
|
|
+ rowSize := c.keys[c.curLayer].Stride(2)
|
|
|
|
+ ctx.Forward(key.Copy(ctx, c.keys[c.curLayer].View(ctx, rowSize*c.curLoc, kHeadDim*numKVHeads*batchSize)))
|
|
|
|
+
|
|
|
|
+ if c.config.PermutedV {
|
|
|
|
+ elemSize := c.values[c.curLayer].Stride(0)
|
|
|
|
+
|
|
|
|
+ value = value.Permute(ctx, 1, 2, 0, 3)
|
|
|
|
+ ctx.Forward(value.Copy(ctx, c.values[c.curLayer].View(ctx, elemSize*c.curLoc, batchSize, int(c.Capacity)*elemSize, vHeadDim*numKVHeads)))
|
|
|
|
+ } else {
|
|
|
|
+ rowSize := c.values[c.curLayer].Stride(2)
|
|
|
|
+
|
|
|
|
+ ctx.Forward(value.Copy(ctx, c.values[c.curLayer].View(ctx, rowSize*c.curLoc, vHeadDim*numKVHeads*batchSize)))
|
|
|
|
+ }
|
|
}
|
|
}
|
|
|
|
|
|
func (c *Causal) CopyPrefix(srcSeq, dstSeq int, len int32) {
|
|
func (c *Causal) CopyPrefix(srcSeq, dstSeq int, len int32) {
|
|
@@ -389,9 +485,13 @@ func (c *Causal) shift(seq int, beginIndex, offset int32) error {
|
|
continue
|
|
continue
|
|
}
|
|
}
|
|
|
|
|
|
- key = key.View(ctx, key.Stride(2)*seqRange.min,
|
|
|
|
- key.Dim(0), key.Stride(1),
|
|
|
|
- key.Dim(1), key.Stride(2),
|
|
|
|
|
|
+ kHeadDim := key.Dim(0)
|
|
|
|
+ numKVHeads := key.Dim(1)
|
|
|
|
+ rowSize := key.Stride(2)
|
|
|
|
+
|
|
|
|
+ key = key.View(ctx, rowSize*seqRange.min,
|
|
|
|
+ kHeadDim, key.Stride(1),
|
|
|
|
+ numKVHeads, key.Stride(2),
|
|
size,
|
|
size,
|
|
)
|
|
)
|
|
|
|
|