|
@@ -20,7 +20,6 @@ type shiftFn func(ctx ml.Context, layer int, key, shift ml.Tensor) (ml.Tensor, e
|
|
// The mask is of shape history size, batch size
|
|
// The mask is of shape history size, batch size
|
|
type Causal struct {
|
|
type Causal struct {
|
|
DType ml.DType
|
|
DType ml.DType
|
|
- Capacity int32
|
|
|
|
windowSize int32
|
|
windowSize int32
|
|
|
|
|
|
opts CausalOptions
|
|
opts CausalOptions
|
|
@@ -98,7 +97,7 @@ func NewSWACache(windowSize int32, shift shiftFn) *Causal {
|
|
}
|
|
}
|
|
}
|
|
}
|
|
|
|
|
|
-func (c *Causal) Init(backend ml.Backend, dtype ml.DType, capacity int32) {
|
|
|
|
|
|
+func (c *Causal) Init(backend ml.Backend, dtype ml.DType, maxSequences, capacity, maxBatch int) {
|
|
if c.config == nil {
|
|
if c.config == nil {
|
|
var config ml.CacheConfig
|
|
var config ml.CacheConfig
|
|
if cc, ok := backend.(ml.BackendCacheConfig); ok {
|
|
if cc, ok := backend.(ml.BackendCacheConfig); ok {
|
|
@@ -119,9 +118,11 @@ func (c *Causal) Init(backend ml.Backend, dtype ml.DType, capacity int32) {
|
|
c.config.MaskDType = ml.DTypeF32
|
|
c.config.MaskDType = ml.DTypeF32
|
|
}
|
|
}
|
|
|
|
|
|
|
|
+ cacheSize := maxSequences * capacity
|
|
|
|
+ cacheSize = roundUp(cacheSize, c.config.CachePadding)
|
|
|
|
+ c.cells = make([]cacheCell, cacheSize)
|
|
|
|
+
|
|
c.DType = dtype
|
|
c.DType = dtype
|
|
- c.Capacity = int32(roundUp(int(capacity), c.config.CachePadding))
|
|
|
|
- c.cells = make([]cacheCell, c.Capacity)
|
|
|
|
c.cellRanges = make(map[int]cellRange)
|
|
c.cellRanges = make(map[int]cellRange)
|
|
c.backend = backend
|
|
c.backend = backend
|
|
}
|
|
}
|
|
@@ -210,7 +211,7 @@ func (c *Causal) findStartLoc() (int, error) {
|
|
}
|
|
}
|
|
}
|
|
}
|
|
|
|
|
|
- return 0, fmt.Errorf("%w (length: %v)", ErrKvCacheFull, c.Capacity)
|
|
|
|
|
|
+ return 0, fmt.Errorf("%w (length: %v)", ErrKvCacheFull, len(c.cells))
|
|
}
|
|
}
|
|
|
|
|
|
func roundDown(length, pad int) int {
|
|
func roundDown(length, pad int) int {
|
|
@@ -265,7 +266,7 @@ func (c *Causal) buildMask(ctx ml.Context) (ml.Tensor, error) {
|
|
return maskTensor, nil
|
|
return maskTensor, nil
|
|
}
|
|
}
|
|
|
|
|
|
-func (c *Causal) moveCells(ctx ml.Context, src, dst, len int) {
|
|
|
|
|
|
+func (c *Causal) moveCells(ctx ml.Context, src, dst, length int) {
|
|
for i, key := range c.keys {
|
|
for i, key := range c.keys {
|
|
if key == nil {
|
|
if key == nil {
|
|
continue
|
|
continue
|
|
@@ -275,8 +276,8 @@ func (c *Causal) moveCells(ctx ml.Context, src, dst, len int) {
|
|
numKVHeads := key.Dim(1)
|
|
numKVHeads := key.Dim(1)
|
|
rowSize := key.Stride(2)
|
|
rowSize := key.Stride(2)
|
|
|
|
|
|
- kSrcView := key.View(ctx, rowSize*src, kHeadDim*numKVHeads*len)
|
|
|
|
- kDstView := key.View(ctx, rowSize*dst, kHeadDim*numKVHeads*len)
|
|
|
|
|
|
+ kSrcView := key.View(ctx, rowSize*src, kHeadDim*numKVHeads*length)
|
|
|
|
+ kDstView := key.View(ctx, rowSize*dst, kHeadDim*numKVHeads*length)
|
|
|
|
|
|
value := c.values[i]
|
|
value := c.values[i]
|
|
var vSrcView, vDstView ml.Tensor
|
|
var vSrcView, vDstView ml.Tensor
|
|
@@ -284,14 +285,14 @@ func (c *Causal) moveCells(ctx ml.Context, src, dst, len int) {
|
|
vHeadDim := value.Dim(1)
|
|
vHeadDim := value.Dim(1)
|
|
elemSize := value.Stride(0)
|
|
elemSize := value.Stride(0)
|
|
|
|
|
|
- vSrcView = value.View(ctx, elemSize*src, len, int(c.Capacity)*elemSize, vHeadDim*numKVHeads)
|
|
|
|
- vDstView = value.View(ctx, elemSize*dst, len, int(c.Capacity)*elemSize, vHeadDim*numKVHeads)
|
|
|
|
|
|
+ vSrcView = value.View(ctx, elemSize*src, length, len(c.cells)*elemSize, vHeadDim*numKVHeads)
|
|
|
|
+ vDstView = value.View(ctx, elemSize*dst, length, len(c.cells)*elemSize, vHeadDim*numKVHeads)
|
|
} else {
|
|
} else {
|
|
vHeadDim := value.Dim(0)
|
|
vHeadDim := value.Dim(0)
|
|
rowSize := value.Stride(2)
|
|
rowSize := value.Stride(2)
|
|
|
|
|
|
- vSrcView = value.View(ctx, rowSize*src, vHeadDim*numKVHeads*len)
|
|
|
|
- vDstView = value.View(ctx, rowSize*dst, vHeadDim*numKVHeads*len)
|
|
|
|
|
|
+ vSrcView = value.View(ctx, rowSize*src, vHeadDim*numKVHeads*length)
|
|
|
|
+ vDstView = value.View(ctx, rowSize*dst, vHeadDim*numKVHeads*length)
|
|
}
|
|
}
|
|
|
|
|
|
ctx.Forward(
|
|
ctx.Forward(
|
|
@@ -480,14 +481,14 @@ func (c *Causal) Put(ctx ml.Context, key, value ml.Tensor) {
|
|
}
|
|
}
|
|
|
|
|
|
if _, ok := c.keys[c.curLayer]; !ok {
|
|
if _, ok := c.keys[c.curLayer]; !ok {
|
|
- c.keys[c.curLayer] = c.ctxs[c.curLayer].Zeros(c.DType, kHeadDim, numKVHeads, int(c.Capacity))
|
|
|
|
|
|
+ c.keys[c.curLayer] = c.ctxs[c.curLayer].Zeros(c.DType, kHeadDim, numKVHeads, len(c.cells))
|
|
}
|
|
}
|
|
|
|
|
|
if _, ok := c.values[c.curLayer]; !ok {
|
|
if _, ok := c.values[c.curLayer]; !ok {
|
|
if c.config.PermutedV {
|
|
if c.config.PermutedV {
|
|
- c.values[c.curLayer] = c.ctxs[c.curLayer].Zeros(c.DType, int(c.Capacity), vHeadDim, numKVHeads)
|
|
|
|
|
|
+ c.values[c.curLayer] = c.ctxs[c.curLayer].Zeros(c.DType, len(c.cells), vHeadDim, numKVHeads)
|
|
} else {
|
|
} else {
|
|
- c.values[c.curLayer] = c.ctxs[c.curLayer].Zeros(c.DType, vHeadDim, numKVHeads, int(c.Capacity))
|
|
|
|
|
|
+ c.values[c.curLayer] = c.ctxs[c.curLayer].Zeros(c.DType, vHeadDim, numKVHeads, len(c.cells))
|
|
}
|
|
}
|
|
}
|
|
}
|
|
|
|
|
|
@@ -498,7 +499,7 @@ func (c *Causal) Put(ctx ml.Context, key, value ml.Tensor) {
|
|
elemSize := c.values[c.curLayer].Stride(0)
|
|
elemSize := c.values[c.curLayer].Stride(0)
|
|
|
|
|
|
value = value.Permute(ctx, 1, 2, 0, 3)
|
|
value = value.Permute(ctx, 1, 2, 0, 3)
|
|
- ctx.Forward(value.Copy(ctx, c.values[c.curLayer].View(ctx, elemSize*c.curLoc, batchSize, int(c.Capacity)*elemSize, vHeadDim*numKVHeads)))
|
|
|
|
|
|
+ ctx.Forward(value.Copy(ctx, c.values[c.curLayer].View(ctx, elemSize*c.curLoc, batchSize, len(c.cells)*elemSize, vHeadDim*numKVHeads)))
|
|
} else {
|
|
} else {
|
|
rowSize := c.values[c.curLayer].Stride(2)
|
|
rowSize := c.values[c.curLayer].Stride(2)
|
|
|
|
|