|
@@ -24,17 +24,11 @@ func (sa *VisionSelfAttention) Forward(ctx ml.Context, hiddenState ml.Tensor, op
|
|
|
key := sa.Key.Forward(ctx, hiddenState)
|
|
|
value := sa.Value.Forward(ctx, hiddenState)
|
|
|
|
|
|
- query = query.Reshape(ctx, headDim, opts.numHeads, query.Dim(1), batchSize).Permute(ctx, 0, 2, 1, 3)
|
|
|
- key = key.Reshape(ctx, headDim, opts.numHeads, key.Dim(1), batchSize).Permute(ctx, 0, 2, 1, 3)
|
|
|
- value = value.Reshape(ctx, headDim, opts.numHeads, value.Dim(1), batchSize).Permute(ctx, 1, 2, 0, 3).Contiguous(ctx)
|
|
|
+ query = query.Reshape(ctx, headDim, opts.numHeads, query.Dim(1), batchSize)
|
|
|
+ key = key.Reshape(ctx, headDim, opts.numHeads, key.Dim(1), batchSize)
|
|
|
+ value = value.Reshape(ctx, headDim, opts.numHeads, value.Dim(1), batchSize)
|
|
|
|
|
|
- scores := key.Mulmat(ctx, query)
|
|
|
- scores = scores.Scale(ctx, 1.0/math.Sqrt(float64(headDim)))
|
|
|
- scores = scores.Softmax(ctx)
|
|
|
-
|
|
|
- attention := value.Mulmat(ctx, scores)
|
|
|
- attention = attention.Reshape(ctx, headDim, attention.Dim(1), opts.numHeads, batchSize)
|
|
|
- attention = attention.Permute(ctx, 0, 2, 1, 3).Contiguous(ctx)
|
|
|
+ attention := nn.Attention(ctx, query, key, value, 1.0/math.Sqrt(float64(headDim)), nil)
|
|
|
attention = attention.Reshape(ctx, opts.hiddenSize, attention.Dim(2), batchSize)
|
|
|
|
|
|
hiddenState = sa.Output.Forward(ctx, attention)
|