|
@@ -37,7 +37,7 @@ func (sa *TextSelfAttention) Forward(ctx ml.Context, hiddenState, positions, mas
|
|
|
key = key.Permute(ctx, 0, 2, 1, 3).Contiguous(ctx)
|
|
|
value = value.Permute(ctx, 1, 2, 0, 3).Contiguous(ctx)
|
|
|
|
|
|
- scores := key.Mulmat(ctx, query)
|
|
|
+ scores := key.MulmatFullPrec(ctx, query)
|
|
|
scores = scores.Scale(ctx, 1.0/math.Sqrt(float64(headDim)))
|
|
|
|
|
|
if mask != nil {
|