ソースを参照

Disable causal attention based on batch index

Currently we are using positions, which are relative to a
sequence and may not be unique.
Jesse Gross 1 ヶ月 前
コミット
a8e83a7654
2 ファイル変更10 行追加12 行削除
  1. 6 8
      kvcache/causal.go
  2. 4 4
      model/models/gemma3/model_text.go

+ 6 - 8
kvcache/causal.go

@@ -144,6 +144,7 @@ func (c *Causal) StartForward(ctx ml.Context, opts input.Options) error {
 	c.curBatchSize = len(opts.Positions)
 	c.curSequences = opts.Sequences
 	c.curPositions = opts.Positions
+	c.opts.Except = nil
 
 	var err error
 	c.curLoc, err = c.findStartLoc()
@@ -234,7 +235,7 @@ func (c *Causal) buildMask(ctx ml.Context) (ml.Tensor, error) {
 	mask := make([]float32, batchSize*length)
 
 	for i := range c.curBatchSize {
-		enabled := !slices.Contains(c.opts.Except, c.curPositions[i])
+		enabled := !slices.Contains(c.opts.Except, i)
 		for j := c.curCellRange.min; j <= c.curCellRange.max; j++ {
 			if !slices.Contains(c.cells[j].sequences, c.curSequences[i]) ||
 				(enabled && c.cells[j].pos > c.curPositions[i]) ||
@@ -405,15 +406,12 @@ func (c *Causal) SetLayer(layer int) {
 }
 
 type CausalOptions struct {
-	// Enabled controls whether the causal mask is generated for a particular position.
-	Except []int32
+	// Enabled controls whether the causal mask is generated for a particular index in a batch
+	Except []int
 }
 
-// SetCausal enables or disables causal mask generation for subsequent calls to Get.
-// This state carries over to future forward passes. The default value is true.
-//
-// ctx may be set to nil if this is called from outside of a forward pass, for
-// example, when initializing the cache.
+// SetCausal disables causal mask generation for a particular range of indicies in
+// the current batch for subsequent calls to Get. The state resets for the next forward pass.
 func (c *Causal) SetCausal(ctx ml.Context, opts CausalOptions) {
 	if !slices.Equal(c.opts.Except, opts.Except) {
 		c.opts = opts

+ 4 - 4
model/models/gemma3/model_text.go

@@ -173,10 +173,10 @@ func (l *TextLayer) Forward(ctx ml.Context, layer int, hiddenState, positionIDs,
 	return hiddenState.Add(ctx, residual)
 }
 
-func setImageEmbeddings(ctx ml.Context, hiddenState ml.Tensor, multimodal []input.MultimodalIndex, positions []int32) []int32 {
+func setImageEmbeddings(ctx ml.Context, hiddenState ml.Tensor, multimodal []input.MultimodalIndex) []int {
 	var embedding ml.Tensor
 	var src, dst, length int
-	var except []int32
+	var except []int
 
 	for _, image := range multimodal {
 		imageToken := image.Multimodal.(imageToken)
@@ -204,7 +204,7 @@ func setImageEmbeddings(ctx ml.Context, hiddenState ml.Tensor, multimodal []inpu
 			length = 1
 		}
 
-		except = append(except, positions[imageDst])
+		except = append(except, imageDst)
 	}
 
 	if embedding != nil {
@@ -219,7 +219,7 @@ func (m *TextModel) Forward(ctx ml.Context, inputs, positions, outputs ml.Tensor
 	hiddenState := m.TokenEmbedding.Forward(ctx, inputs)
 	hiddenState = hiddenState.Scale(ctx, math.Sqrt(float64(m.TextOptions.hiddenSize)))
 
-	except := setImageEmbeddings(ctx, hiddenState, opts.Multimodal, opts.Positions)
+	except := setImageEmbeddings(ctx, hiddenState, opts.Multimodal)
 
 	for i, layer := range m.Layers {
 		// gemma alternates between the sliding window (local) and causal (global)