浏览代码

use non-causal mask for inputs with images

Michael Yang 1 月之前
父节点
当前提交
9d2a20a763
共有 1 个文件被更改,包括 5 次插入0 次删除
  1. 5 0
      model/models/gemma3/model_text.go

+ 5 - 0
model/models/gemma3/model_text.go

@@ -181,6 +181,11 @@ func (m *TextModel) Forward(ctx ml.Context, inputs, positions, outputs ml.Tensor
 		visionOutputs := multimodal[0].Multimodal.(ml.Tensor)
 		visionOutputs := multimodal[0].Multimodal.(ml.Tensor)
 		offset := multimodal[0].Index - 1 - visionOutputs.Dim(1)
 		offset := multimodal[0].Index - 1 - visionOutputs.Dim(1)
 		hiddenState = hiddenState.Set(ctx, visionOutputs, offset*hiddenState.Stride(1))
 		hiddenState = hiddenState.Set(ctx, visionOutputs, offset*hiddenState.Stride(1))
+
+		if causal, ok := cache.(*kvcache.WrapperCache).UnderlyingCache().(*kvcache.Causal); ok {
+			causal.SetCausal(ctx, false)
+			defer causal.SetCausal(ctx, true)
+		}
 	}
 	}
 
 
 	for i, layer := range m.Layers {
 	for i, layer := range m.Layers {