|
@@ -28,7 +28,7 @@ func (sa *TextSelfAttention) Forward(ctx ml.Context, hiddenState, positions, _ m
|
|
|
|
|
|
key := sa.Key.Forward(ctx, hiddenState)
|
|
key := sa.Key.Forward(ctx, hiddenState)
|
|
key = key.Reshape(ctx, headDim, opts.numKVHeads, batchSize)
|
|
key = key.Reshape(ctx, headDim, opts.numKVHeads, batchSize)
|
|
- key = key.RoPE(ctx, positions, sa.RopeFactors, opts.ropeDim, ropeType, opts.ropeBase, opts.ropeScale)
|
|
|
|
|
|
+ key = key.RoPE(ctx, positions, sa.RopeFactors, opts.ropeDim, ropeType, opts.ropeBase, opts.ropeScale)
|
|
|
|
|
|
value := sa.Value.Forward(ctx, hiddenState)
|
|
value := sa.Value.Forward(ctx, hiddenState)
|
|
value = value.Reshape(ctx, headDim, opts.numKVHeads, batchSize)
|
|
value = value.Reshape(ctx, headDim, opts.numKVHeads, batchSize)
|