|
@@ -38,13 +38,8 @@ func (sa *TextSelfAttention) Forward(ctx ml.Context, hiddenState, positions, _ m
|
|
|
key = key.Permute(ctx, 0, 2, 1, 3).Contiguous(ctx)
|
|
|
value = value.Permute(ctx, 1, 2, 0, 3).Contiguous(ctx)
|
|
|
|
|
|
- scores := key.MulmatFullPrec(ctx, query)
|
|
|
- scores = scores.Scale(ctx, 1.0/math.Sqrt(float64(headDim)))
|
|
|
- scores = scores.Add(ctx, mask)
|
|
|
- scores = scores.Softmax(ctx)
|
|
|
-
|
|
|
- attention := value.Mulmat(ctx, scores)
|
|
|
- attention = attention.Permute(ctx, 0, 2, 1, 3).Contiguous(ctx)
|
|
|
+ scaleFactor := 1.0 / math.Sqrt(float64(headDim))
|
|
|
+ attention := nn.Attention(ctx, query, key, value, mask, scaleFactor)
|
|
|
attention = attention.Reshape(ctx, opts.hiddenSize, batchSize)
|
|
|
|
|
|
return sa.Output.Forward(ctx, attention)
|
|
@@ -112,7 +107,7 @@ func (ca *TextCrossAttention) Forward(ctx ml.Context, hiddenState, crossAttentio
|
|
|
query = query.Reshape(ctx, headDim, opts.numHeads, batchSize)
|
|
|
query = ca.QueryNorm.Forward(ctx, query, opts.eps)
|
|
|
|
|
|
- var key, value ml.Tensor
|
|
|
+ var key, value, mask ml.Tensor
|
|
|
if crossAttentionStates != nil {
|
|
|
numVisionTokens, numTiles := crossAttentionStates.Dim(1), crossAttentionStates.Dim(2)
|
|
|
|
|
@@ -125,19 +120,15 @@ func (ca *TextCrossAttention) Forward(ctx ml.Context, hiddenState, crossAttentio
|
|
|
|
|
|
cache.Put(ctx, key, value)
|
|
|
} else {
|
|
|
- key, value, _ = cache.Get(ctx)
|
|
|
+ key, value, mask = cache.Get(ctx)
|
|
|
}
|
|
|
|
|
|
query = query.Permute(ctx, 0, 2, 1, 3).Contiguous(ctx)
|
|
|
key = key.Permute(ctx, 0, 2, 1, 3).Contiguous(ctx)
|
|
|
value = value.Permute(ctx, 1, 2, 0, 3).Contiguous(ctx)
|
|
|
|
|
|
- scores := key.Mulmat(ctx, query)
|
|
|
- scores = scores.Scale(ctx, 1.0/math.Sqrt(float64(headDim)))
|
|
|
- scores = scores.Softmax(ctx)
|
|
|
-
|
|
|
- attention := value.Mulmat(ctx, scores)
|
|
|
- attention = attention.Permute(ctx, 0, 2, 1, 3).Contiguous(ctx)
|
|
|
+ scaleFactor := 1.0 / math.Sqrt(float64(headDim))
|
|
|
+ attention := nn.Attention(ctx, query, key, value, mask, scaleFactor)
|
|
|
attention = attention.Reshape(ctx, opts.hiddenSize, batchSize)
|
|
|
|
|
|
return ca.Output.Forward(ctx, attention)
|