|
@@ -55,8 +55,8 @@ type Causal struct {
|
|
|
|
|
|
shiftFn shiftFn
|
|
shiftFn shiftFn
|
|
backend ml.Backend
|
|
backend ml.Backend
|
|
- cacheCtx ml.Context
|
|
|
|
- keys, values []ml.Tensor
|
|
|
|
|
|
+ ctxs map[int]ml.Context
|
|
|
|
+ keys, values map[int]ml.Tensor
|
|
}
|
|
}
|
|
|
|
|
|
type cacheCell struct {
|
|
type cacheCell struct {
|
|
@@ -70,11 +70,23 @@ type cellRange struct {
|
|
}
|
|
}
|
|
|
|
|
|
func NewCausalCache(shift shiftFn) *Causal {
|
|
func NewCausalCache(shift shiftFn) *Causal {
|
|
- return &Causal{windowSize: math.MaxInt32, shiftFn: shift}
|
|
|
|
|
|
+ return &Causal{
|
|
|
|
+ windowSize: math.MaxInt32,
|
|
|
|
+ shiftFn: shift,
|
|
|
|
+ ctxs: make(map[int]ml.Context),
|
|
|
|
+ keys: make(map[int]ml.Tensor),
|
|
|
|
+ values: make(map[int]ml.Tensor),
|
|
|
|
+ }
|
|
}
|
|
}
|
|
|
|
|
|
func NewSWACache(windowSize int32, shift shiftFn) *Causal {
|
|
func NewSWACache(windowSize int32, shift shiftFn) *Causal {
|
|
- return &Causal{windowSize: windowSize, shiftFn: shift}
|
|
|
|
|
|
+ return &Causal{
|
|
|
|
+ windowSize: windowSize,
|
|
|
|
+ shiftFn: shift,
|
|
|
|
+ ctxs: make(map[int]ml.Context),
|
|
|
|
+ keys: make(map[int]ml.Tensor),
|
|
|
|
+ values: make(map[int]ml.Tensor),
|
|
|
|
+ }
|
|
}
|
|
}
|
|
|
|
|
|
func (c *Causal) Init(backend ml.Backend, dtype ml.DType, capacity int32) {
|
|
func (c *Causal) Init(backend ml.Backend, dtype ml.DType, capacity int32) {
|
|
@@ -103,7 +115,6 @@ func (c *Causal) Init(backend ml.Backend, dtype ml.DType, capacity int32) {
|
|
c.cells = make([]cacheCell, c.Capacity)
|
|
c.cells = make([]cacheCell, c.Capacity)
|
|
c.cellRanges = make(map[int]cellRange)
|
|
c.cellRanges = make(map[int]cellRange)
|
|
c.backend = backend
|
|
c.backend = backend
|
|
- c.cacheCtx = backend.NewContext()
|
|
|
|
}
|
|
}
|
|
|
|
|
|
func (c *Causal) SetConfig(config ml.CacheConfig) {
|
|
func (c *Causal) SetConfig(config ml.CacheConfig) {
|
|
@@ -115,7 +126,9 @@ func (c *Causal) SetConfig(config ml.CacheConfig) {
|
|
}
|
|
}
|
|
|
|
|
|
func (c *Causal) Close() {
|
|
func (c *Causal) Close() {
|
|
- c.cacheCtx.Close()
|
|
|
|
|
|
+ for _, ctx := range c.ctxs {
|
|
|
|
+ ctx.Close()
|
|
|
|
+ }
|
|
}
|
|
}
|
|
|
|
|
|
func (c *Causal) StartForward(ctx ml.Context, positions []int32, seqs []int) error {
|
|
func (c *Causal) StartForward(ctx ml.Context, positions []int32, seqs []int) error {
|
|
@@ -239,13 +252,11 @@ func (c *Causal) buildMask(ctx ml.Context, positions []int32, seqs []int) (ml.Te
|
|
}
|
|
}
|
|
|
|
|
|
func (c *Causal) moveCells(ctx ml.Context, src, dst, len int) {
|
|
func (c *Causal) moveCells(ctx ml.Context, src, dst, len int) {
|
|
- for i := range c.keys {
|
|
|
|
- if c.keys[i] == nil {
|
|
|
|
|
|
+ for i, key := range c.keys {
|
|
|
|
+ if key == nil {
|
|
continue
|
|
continue
|
|
}
|
|
}
|
|
|
|
|
|
- key := c.keys[i]
|
|
|
|
-
|
|
|
|
kHeadDim := key.Dim(0)
|
|
kHeadDim := key.Dim(0)
|
|
numKVHeads := key.Dim(1)
|
|
numKVHeads := key.Dim(1)
|
|
rowSize := key.Stride(2)
|
|
rowSize := key.Stride(2)
|
|
@@ -305,7 +316,7 @@ func (c *Causal) defrag() {
|
|
layers++
|
|
layers++
|
|
}
|
|
}
|
|
|
|
|
|
- maxMoves := ctx.MaxTensors() / (6 * layers)
|
|
|
|
|
|
+ maxMoves := ctx.MaxGraphNodes() / (6 * layers)
|
|
moves := 0
|
|
moves := 0
|
|
|
|
|
|
var pendingSrc, pendingDst, pendingLen int
|
|
var pendingSrc, pendingDst, pendingLen int
|
|
@@ -377,11 +388,6 @@ func (c *Causal) defrag() {
|
|
}
|
|
}
|
|
|
|
|
|
func (c *Causal) SetLayer(layer int) {
|
|
func (c *Causal) SetLayer(layer int) {
|
|
- if layer >= len(c.keys) {
|
|
|
|
- c.keys = append(c.keys, make([]ml.Tensor, layer-len(c.keys)+1)...)
|
|
|
|
- c.values = append(c.values, make([]ml.Tensor, layer-len(c.values)+1)...)
|
|
|
|
- }
|
|
|
|
-
|
|
|
|
c.curLayer = layer
|
|
c.curLayer = layer
|
|
}
|
|
}
|
|
|
|
|
|
@@ -433,13 +439,19 @@ func (c *Causal) Put(ctx ml.Context, key, value ml.Tensor) {
|
|
panic(fmt.Errorf("inconsistent batch sizes (layer: %v, batch size: %v layer batch size: %v)", c.curLayer, c.curBatchSize, batchSize))
|
|
panic(fmt.Errorf("inconsistent batch sizes (layer: %v, batch size: %v layer batch size: %v)", c.curLayer, c.curBatchSize, batchSize))
|
|
}
|
|
}
|
|
|
|
|
|
- if c.keys[c.curLayer] == nil || c.values[c.curLayer] == nil {
|
|
|
|
- c.keys[c.curLayer] = c.cacheCtx.Zeros(c.DType, kHeadDim, numKVHeads, int(c.Capacity))
|
|
|
|
|
|
+ if _, ok := c.ctxs[c.curLayer]; !ok {
|
|
|
|
+ c.ctxs[c.curLayer] = c.backend.NewContext()
|
|
|
|
+ }
|
|
|
|
+
|
|
|
|
+ if _, ok := c.keys[c.curLayer]; !ok {
|
|
|
|
+ c.keys[c.curLayer] = c.ctxs[c.curLayer].Zeros(c.DType, kHeadDim, numKVHeads, int(c.Capacity))
|
|
|
|
+ }
|
|
|
|
|
|
|
|
+ if _, ok := c.values[c.curLayer]; !ok {
|
|
if c.config.PermutedV {
|
|
if c.config.PermutedV {
|
|
- c.values[c.curLayer] = c.cacheCtx.Zeros(c.DType, int(c.Capacity), vHeadDim, numKVHeads)
|
|
|
|
|
|
+ c.values[c.curLayer] = c.ctxs[c.curLayer].Zeros(c.DType, int(c.Capacity), vHeadDim, numKVHeads)
|
|
} else {
|
|
} else {
|
|
- c.values[c.curLayer] = c.cacheCtx.Zeros(c.DType, vHeadDim, numKVHeads, int(c.Capacity))
|
|
|
|
|
|
+ c.values[c.curLayer] = c.ctxs[c.curLayer].Zeros(c.DType, vHeadDim, numKVHeads, int(c.Capacity))
|
|
}
|
|
}
|
|
}
|
|
}
|
|
|
|
|