|
@@ -21,9 +21,10 @@ type shiftFn func(ctx ml.Context, layer int, key, shift ml.Tensor) (ml.Tensor, e
|
|
type Causal struct {
|
|
type Causal struct {
|
|
DType ml.DType
|
|
DType ml.DType
|
|
Capacity int32
|
|
Capacity int32
|
|
- causal bool
|
|
|
|
windowSize int32
|
|
windowSize int32
|
|
|
|
|
|
|
|
+ opts CausalOptions
|
|
|
|
+
|
|
// config controls mostly backend-specific optimizations
|
|
// config controls mostly backend-specific optimizations
|
|
config *ml.CacheConfig
|
|
config *ml.CacheConfig
|
|
|
|
|
|
@@ -79,7 +80,6 @@ type cellRange struct {
|
|
|
|
|
|
func NewCausalCache(shift shiftFn) *Causal {
|
|
func NewCausalCache(shift shiftFn) *Causal {
|
|
return &Causal{
|
|
return &Causal{
|
|
- causal: true,
|
|
|
|
windowSize: math.MaxInt32,
|
|
windowSize: math.MaxInt32,
|
|
shiftFn: shift,
|
|
shiftFn: shift,
|
|
ctxs: make(map[int]ml.Context),
|
|
ctxs: make(map[int]ml.Context),
|
|
@@ -90,7 +90,6 @@ func NewCausalCache(shift shiftFn) *Causal {
|
|
|
|
|
|
func NewSWACache(windowSize int32, shift shiftFn) *Causal {
|
|
func NewSWACache(windowSize int32, shift shiftFn) *Causal {
|
|
return &Causal{
|
|
return &Causal{
|
|
- causal: true,
|
|
|
|
windowSize: windowSize,
|
|
windowSize: windowSize,
|
|
shiftFn: shift,
|
|
shiftFn: shift,
|
|
ctxs: make(map[int]ml.Context),
|
|
ctxs: make(map[int]ml.Context),
|
|
@@ -235,9 +234,10 @@ func (c *Causal) buildMask(ctx ml.Context) (ml.Tensor, error) {
|
|
mask := make([]float32, batchSize*length)
|
|
mask := make([]float32, batchSize*length)
|
|
|
|
|
|
for i := range c.curBatchSize {
|
|
for i := range c.curBatchSize {
|
|
|
|
+ enabled := !slices.Contains(c.opts.Except, c.curPositions[i])
|
|
for j := c.curCellRange.min; j <= c.curCellRange.max; j++ {
|
|
for j := c.curCellRange.min; j <= c.curCellRange.max; j++ {
|
|
if !slices.Contains(c.cells[j].sequences, c.curSequences[i]) ||
|
|
if !slices.Contains(c.cells[j].sequences, c.curSequences[i]) ||
|
|
- (c.causal && c.cells[j].pos > c.curPositions[i]) ||
|
|
|
|
|
|
+ (enabled && c.cells[j].pos > c.curPositions[i]) ||
|
|
c.cells[j].pos < c.curPositions[i]-c.windowSize {
|
|
c.cells[j].pos < c.curPositions[i]-c.windowSize {
|
|
mask[i*length+(j-c.curCellRange.min)] = float32(math.Inf(-1))
|
|
mask[i*length+(j-c.curCellRange.min)] = float32(math.Inf(-1))
|
|
}
|
|
}
|
|
@@ -404,15 +404,19 @@ func (c *Causal) SetLayer(layer int) {
|
|
c.curLayer = layer
|
|
c.curLayer = layer
|
|
}
|
|
}
|
|
|
|
|
|
|
|
+type CausalOptions struct {
|
|
|
|
+ // Enabled controls whether the causal mask is generated for a particular position.
|
|
|
|
+ Except []int32
|
|
|
|
+}
|
|
|
|
+
|
|
// SetCausal enables or disables causal mask generation for subsequent calls to Get.
|
|
// SetCausal enables or disables causal mask generation for subsequent calls to Get.
|
|
// This state carries over to future forward passes. The default value is true.
|
|
// This state carries over to future forward passes. The default value is true.
|
|
//
|
|
//
|
|
// ctx may be set to nil if this is called from outside of a forward pass, for
|
|
// ctx may be set to nil if this is called from outside of a forward pass, for
|
|
// example, when initializing the cache.
|
|
// example, when initializing the cache.
|
|
-func (c *Causal) SetCausal(ctx ml.Context, causal bool) {
|
|
|
|
- if c.causal != causal {
|
|
|
|
- c.causal = causal
|
|
|
|
-
|
|
|
|
|
|
+func (c *Causal) SetCausal(ctx ml.Context, opts CausalOptions) {
|
|
|
|
+ if !slices.Equal(c.opts.Except, opts.Except) {
|
|
|
|
+ c.opts = opts
|
|
if ctx != nil {
|
|
if ctx != nil {
|
|
var err error
|
|
var err error
|
|
c.curMask, err = c.buildMask(ctx)
|
|
c.curMask, err = c.buildMask(ctx)
|