소스 검색

use non-causal mask for inputs with images

Michael Yang 2 달 전
부모
커밋
9d2a20a763
1개의 변경된 파일5개의 추가작업 그리고 0개의 파일을 삭제
  1. 5 0
      model/models/gemma3/model_text.go

+ 5 - 0
model/models/gemma3/model_text.go

@@ -181,6 +181,11 @@ func (m *TextModel) Forward(ctx ml.Context, inputs, positions, outputs ml.Tensor
 		visionOutputs := multimodal[0].Multimodal.(ml.Tensor)
 		offset := multimodal[0].Index - 1 - visionOutputs.Dim(1)
 		hiddenState = hiddenState.Set(ctx, visionOutputs, offset*hiddenState.Stride(1))
+
+		if causal, ok := cache.(*kvcache.WrapperCache).UnderlyingCache().(*kvcache.Causal); ok {
+			causal.SetCausal(ctx, false)
+			defer causal.SetCausal(ctx, true)
+		}
 	}
 
 	for i, layer := range m.Layers {