浏览代码

use non-causal mask only for image positions

Michael Yang 1 月之前
父节点
当前提交
e95278932b
共有 2 个文件被更改,包括 18 次插入10 次删除
  1. 12 8
      kvcache/causal.go
  2. 6 2
      model/models/gemma3/model_text.go

+ 12 - 8
kvcache/causal.go

@@ -21,9 +21,10 @@ type shiftFn func(ctx ml.Context, layer int, key, shift ml.Tensor) (ml.Tensor, e
 type Causal struct {
 	DType      ml.DType
 	Capacity   int32
-	causal     bool
 	windowSize int32
 
+	opts CausalOptions
+
 	// config controls mostly backend-specific optimizations
 	config *ml.CacheConfig
 
@@ -79,7 +80,6 @@ type cellRange struct {
 
 func NewCausalCache(shift shiftFn) *Causal {
 	return &Causal{
-		causal:     true,
 		windowSize: math.MaxInt32,
 		shiftFn:    shift,
 		ctxs:       make(map[int]ml.Context),
@@ -90,7 +90,6 @@ func NewCausalCache(shift shiftFn) *Causal {
 
 func NewSWACache(windowSize int32, shift shiftFn) *Causal {
 	return &Causal{
-		causal:     true,
 		windowSize: windowSize,
 		shiftFn:    shift,
 		ctxs:       make(map[int]ml.Context),
@@ -235,9 +234,10 @@ func (c *Causal) buildMask(ctx ml.Context) (ml.Tensor, error) {
 	mask := make([]float32, batchSize*length)
 
 	for i := range c.curBatchSize {
+		enabled := !slices.Contains(c.opts.Except, c.curPositions[i])
 		for j := c.curCellRange.min; j <= c.curCellRange.max; j++ {
 			if !slices.Contains(c.cells[j].sequences, c.curSequences[i]) ||
-				(c.causal && c.cells[j].pos > c.curPositions[i]) ||
+				(enabled && c.cells[j].pos > c.curPositions[i]) ||
 				c.cells[j].pos < c.curPositions[i]-c.windowSize {
 				mask[i*length+(j-c.curCellRange.min)] = float32(math.Inf(-1))
 			}
@@ -404,15 +404,19 @@ func (c *Causal) SetLayer(layer int) {
 	c.curLayer = layer
 }
 
+type CausalOptions struct {
+	// Enabled controls whether the causal mask is generated for a particular position.
+	Except []int32
+}
+
 // SetCausal enables or disables causal mask generation for subsequent calls to Get.
 // This state carries over to future forward passes. The default value is true.
 //
 // ctx may be set to nil if this is called from outside of a forward pass, for
 // example, when initializing the cache.
-func (c *Causal) SetCausal(ctx ml.Context, causal bool) {
-	if c.causal != causal {
-		c.causal = causal
-
+func (c *Causal) SetCausal(ctx ml.Context, opts CausalOptions) {
+	if !slices.Equal(c.opts.Except, opts.Except) {
+		c.opts = opts
 		if ctx != nil {
 			var err error
 			c.curMask, err = c.buildMask(ctx)

+ 6 - 2
model/models/gemma3/model_text.go

@@ -183,8 +183,12 @@ func (m *TextModel) Forward(ctx ml.Context, inputs, positions, outputs ml.Tensor
 		hiddenState = hiddenState.Set(ctx, visionOutputs, offset*hiddenState.Stride(1))
 
 		if causal, ok := cache.(*kvcache.WrapperCache).UnderlyingCache().(*kvcache.Causal); ok {
-			causal.SetCausal(ctx, false)
-			defer causal.SetCausal(ctx, true)
+			except := make([]int32, visionOutputs.Dim(1))
+			for i := 0; i < visionOutputs.Dim(1); i++ {
+				except[i] = int32(offset + i)
+			}
+
+			causal.SetCausal(ctx, kvcache.CausalOptions{Except: except})
 		}
 	}