Browse Source

use non-causal mask for inputs with images

Michael Yang 1 month ago
parent
commit
9d2a20a763
1 changed files with 5 additions and 0 deletions
  1. 5 0
      model/models/gemma3/model_text.go

+ 5 - 0
model/models/gemma3/model_text.go

@@ -181,6 +181,11 @@ func (m *TextModel) Forward(ctx ml.Context, inputs, positions, outputs ml.Tensor
 		visionOutputs := multimodal[0].Multimodal.(ml.Tensor)
 		offset := multimodal[0].Index - 1 - visionOutputs.Dim(1)
 		hiddenState = hiddenState.Set(ctx, visionOutputs, offset*hiddenState.Stride(1))
+
+		if causal, ok := cache.(*kvcache.WrapperCache).UnderlyingCache().(*kvcache.Causal); ok {
+			causal.SetCausal(ctx, false)
+			defer causal.SetCausal(ctx, true)
+		}
 	}
 
 	for i, layer := range m.Layers {