|
@@ -144,6 +144,7 @@ func (c *Causal) StartForward(ctx ml.Context, opts input.Options) error {
|
|
c.curBatchSize = len(opts.Positions)
|
|
c.curBatchSize = len(opts.Positions)
|
|
c.curSequences = opts.Sequences
|
|
c.curSequences = opts.Sequences
|
|
c.curPositions = opts.Positions
|
|
c.curPositions = opts.Positions
|
|
|
|
+ c.opts.Except = nil
|
|
|
|
|
|
var err error
|
|
var err error
|
|
c.curLoc, err = c.findStartLoc()
|
|
c.curLoc, err = c.findStartLoc()
|
|
@@ -234,7 +235,7 @@ func (c *Causal) buildMask(ctx ml.Context) (ml.Tensor, error) {
|
|
mask := make([]float32, batchSize*length)
|
|
mask := make([]float32, batchSize*length)
|
|
|
|
|
|
for i := range c.curBatchSize {
|
|
for i := range c.curBatchSize {
|
|
- enabled := !slices.Contains(c.opts.Except, c.curPositions[i])
|
|
|
|
|
|
+ enabled := !slices.Contains(c.opts.Except, i)
|
|
for j := c.curCellRange.min; j <= c.curCellRange.max; j++ {
|
|
for j := c.curCellRange.min; j <= c.curCellRange.max; j++ {
|
|
if !slices.Contains(c.cells[j].sequences, c.curSequences[i]) ||
|
|
if !slices.Contains(c.cells[j].sequences, c.curSequences[i]) ||
|
|
(enabled && c.cells[j].pos > c.curPositions[i]) ||
|
|
(enabled && c.cells[j].pos > c.curPositions[i]) ||
|
|
@@ -405,15 +406,12 @@ func (c *Causal) SetLayer(layer int) {
|
|
}
|
|
}
|
|
|
|
|
|
type CausalOptions struct {
|
|
type CausalOptions struct {
|
|
- // Enabled controls whether the causal mask is generated for a particular position.
|
|
|
|
- Except []int32
|
|
|
|
|
|
+ // Enabled controls whether the causal mask is generated for a particular index in a batch
|
|
|
|
+ Except []int
|
|
}
|
|
}
|
|
|
|
|
|
-// SetCausal enables or disables causal mask generation for subsequent calls to Get.
|
|
|
|
-// This state carries over to future forward passes. The default value is true.
|
|
|
|
-//
|
|
|
|
-// ctx may be set to nil if this is called from outside of a forward pass, for
|
|
|
|
-// example, when initializing the cache.
|
|
|
|
|
|
+// SetCausal disables causal mask generation for a particular range of indicies in
|
|
|
|
+// the current batch for subsequent calls to Get. The state resets for the next forward pass.
|
|
func (c *Causal) SetCausal(ctx ml.Context, opts CausalOptions) {
|
|
func (c *Causal) SetCausal(ctx ml.Context, opts CausalOptions) {
|
|
if !slices.Equal(c.opts.Except, opts.Except) {
|
|
if !slices.Equal(c.opts.Except, opts.Except) {
|
|
c.opts = opts
|
|
c.opts = opts
|