|
@@ -20,6 +20,7 @@ type shiftFn func(ctx ml.Context, layer int, key, shift ml.Tensor) (ml.Tensor, e
|
|
|
type Causal struct {
|
|
|
DType ml.DType
|
|
|
Capacity int32
|
|
|
+ causal bool
|
|
|
windowSize int32
|
|
|
|
|
|
// config controls mostly backend-specific optimizations
|
|
@@ -42,6 +43,12 @@ type Causal struct {
|
|
|
// locations in the cache that are needed for this batch
|
|
|
curCellRange cellRange
|
|
|
|
|
|
+ // curSequences is the sequences corresponding to this pass's entries in the cache
|
|
|
+ curSequences []int
|
|
|
+
|
|
|
+ // curPositions is the positions corresponding to this pass's entries in the cache
|
|
|
+ curPositions []int32
|
|
|
+
|
|
|
// ** cache metadata **
|
|
|
|
|
|
// for each possible location in the cache, stores the position and set of sequences
|
|
@@ -71,6 +78,7 @@ type cellRange struct {
|
|
|
|
|
|
func NewCausalCache(shift shiftFn) *Causal {
|
|
|
return &Causal{
|
|
|
+ causal: true,
|
|
|
windowSize: math.MaxInt32,
|
|
|
shiftFn: shift,
|
|
|
ctxs: make(map[int]ml.Context),
|
|
@@ -81,6 +89,7 @@ func NewCausalCache(shift shiftFn) *Causal {
|
|
|
|
|
|
func NewSWACache(windowSize int32, shift shiftFn) *Causal {
|
|
|
return &Causal{
|
|
|
+ causal: true,
|
|
|
windowSize: windowSize,
|
|
|
shiftFn: shift,
|
|
|
ctxs: make(map[int]ml.Context),
|
|
@@ -133,6 +142,8 @@ func (c *Causal) Close() {
|
|
|
|
|
|
func (c *Causal) StartForward(ctx ml.Context, positions []int32, seqs []int) error {
|
|
|
c.curBatchSize = len(positions)
|
|
|
+ c.curSequences = seqs
|
|
|
+ c.curPositions = positions
|
|
|
|
|
|
var err error
|
|
|
c.curLoc, err = c.findStartLoc()
|
|
@@ -171,7 +182,7 @@ func (c *Causal) StartForward(ctx ml.Context, positions []int32, seqs []int) err
|
|
|
c.cellRanges[seq] = seqRange
|
|
|
}
|
|
|
|
|
|
- c.curMask, err = c.buildMask(ctx, positions, seqs)
|
|
|
+ c.curMask, err = c.buildMask(ctx)
|
|
|
|
|
|
return err
|
|
|
}
|
|
@@ -212,7 +223,7 @@ func roundUp(length, pad int) int {
|
|
|
// Builds a mask of history x batch indicating whether for each token in the batch the
|
|
|
// token in the history should apply. This is based on both the sequence and causality (the
|
|
|
// position of the history is not ahead of the token in the batch).
|
|
|
-func (c *Causal) buildMask(ctx ml.Context, positions []int32, seqs []int) (ml.Tensor, error) {
|
|
|
+func (c *Causal) buildMask(ctx ml.Context) (ml.Tensor, error) {
|
|
|
// Align and pad the two dimensions as required by the backend
|
|
|
batchSize := roundUp(c.curBatchSize, c.config.MaskBatchPadding)
|
|
|
|
|
@@ -224,8 +235,9 @@ func (c *Causal) buildMask(ctx ml.Context, positions []int32, seqs []int) (ml.Te
|
|
|
|
|
|
for i := range c.curBatchSize {
|
|
|
for j := c.curCellRange.min; j <= c.curCellRange.max; j++ {
|
|
|
- if !slices.Contains(c.cells[j].sequences, seqs[i]) || c.cells[j].pos > positions[i] ||
|
|
|
- c.cells[j].pos < positions[i]-c.windowSize {
|
|
|
+ if !slices.Contains(c.cells[j].sequences, c.curSequences[i]) ||
|
|
|
+ (c.causal && c.cells[j].pos > c.curPositions[i]) ||
|
|
|
+ c.cells[j].pos < c.curPositions[i]-c.windowSize {
|
|
|
mask[i*length+(j-c.curCellRange.min)] = float32(math.Inf(-1))
|
|
|
}
|
|
|
}
|
|
@@ -391,6 +403,26 @@ func (c *Causal) SetLayer(layer int) {
|
|
|
c.curLayer = layer
|
|
|
}
|
|
|
|
|
|
+// SetCausal enables or disables causal mask generation for subsequent calls to Get.
|
|
|
+// This state carries over to future forward passes. The default value is true.
|
|
|
+//
|
|
|
+// ctx may be set to nil if this is called from outside of a forward pass, for
|
|
|
+// example, when initializing the cache.
|
|
|
+func (c *Causal) SetCausal(ctx ml.Context, causal bool) {
|
|
|
+ if c.causal != causal {
|
|
|
+ c.causal = causal
|
|
|
+
|
|
|
+ if ctx != nil {
|
|
|
+ var err error
|
|
|
+ c.curMask, err = c.buildMask(ctx)
|
|
|
+ if err != nil {
|
|
|
+ // This error should never occur because we have previously built a mask with the same shape
|
|
|
+ panic(fmt.Errorf("SetCausal: %w", err))
|
|
|
+ }
|
|
|
+ }
|
|
|
+ }
|
|
|
+}
|
|
|
+
|
|
|
func (c *Causal) Get(ctx ml.Context) (ml.Tensor, ml.Tensor, ml.Tensor) {
|
|
|
key := c.keys[c.curLayer]
|
|
|
value := c.values[c.curLayer]
|