|
@@ -10,10 +10,11 @@ import (
|
|
|
)
|
|
|
|
|
|
type TextSelfAttention struct {
|
|
|
- Query *nn.Linear `gguf:"attn_q"`
|
|
|
- Key *nn.Linear `gguf:"attn_k"`
|
|
|
- Value *nn.Linear `gguf:"attn_v"`
|
|
|
- Output *nn.Linear `gguf:"attn_output"`
|
|
|
+ Query *nn.Linear `gguf:"attn_q"`
|
|
|
+ Key *nn.Linear `gguf:"attn_k"`
|
|
|
+ Value *nn.Linear `gguf:"attn_v"`
|
|
|
+ Output *nn.Linear `gguf:"attn_output"`
|
|
|
+ RopeFactors ml.Tensor `gguf:"rope_freqs.weight"`
|
|
|
}
|
|
|
|
|
|
func (sa *TextSelfAttention) Forward(ctx ml.Context, hiddenState, positions, _ ml.Tensor, cache *kvcache.WrapperCache, opts *TextModelOptions) ml.Tensor {
|
|
@@ -22,11 +23,11 @@ func (sa *TextSelfAttention) Forward(ctx ml.Context, hiddenState, positions, _ m
|
|
|
|
|
|
query := sa.Query.Forward(ctx, hiddenState)
|
|
|
query = query.Reshape(ctx, headDim, opts.numHeads, batchSize)
|
|
|
- query = query.RoPE(ctx, positions, opts.RopeFactors, opts.ropeDim, opts.ropeBase, opts.ropeScale)
|
|
|
+ query = query.RoPE(ctx, positions, sa.RopeFactors, opts.ropeDim, opts.ropeBase, opts.ropeScale)
|
|
|
|
|
|
key := sa.Key.Forward(ctx, hiddenState)
|
|
|
key = key.Reshape(ctx, headDim, opts.numKVHeads, batchSize)
|
|
|
- key = key.RoPE(ctx, positions, opts.RopeFactors, opts.ropeDim, opts.ropeBase, opts.ropeScale)
|
|
|
+ key = key.RoPE(ctx, positions, sa.RopeFactors, opts.ropeDim, opts.ropeBase, opts.ropeScale)
|
|
|
|
|
|
value := sa.Value.Forward(ctx, hiddenState)
|
|
|
value = value.Reshape(ctx, headDim, opts.numKVHeads, batchSize)
|
|
@@ -39,8 +40,11 @@ func (sa *TextSelfAttention) Forward(ctx ml.Context, hiddenState, positions, _ m
|
|
|
}
|
|
|
|
|
|
func (m *TextModel) Shift(ctx ml.Context, layer int, key, shift ml.Tensor) (ml.Tensor, error) {
|
|
|
- // This will only get called for layers in the causal cache, which are just the self attention layers
|
|
|
- return key.RoPE(ctx, shift, m.RopeFactors, m.ropeDim, m.ropeBase, m.ropeScale), nil
|
|
|
+ if sa, ok := m.Transformer.Layers[layer].(*TextSelfAttentionDecoderLayer); ok {
|
|
|
+ return key.RoPE(ctx, shift, sa.SelfAttention.RopeFactors, m.ropeDim, m.ropeBase, m.ropeScale), nil
|
|
|
+ }
|
|
|
+
|
|
|
+ return key, nil
|
|
|
}
|
|
|
|
|
|
type TextMLP struct {
|
|
@@ -191,8 +195,6 @@ func (d *TextDecoder) Forward(ctx ml.Context, hiddenState, positionIDs, outputs,
|
|
|
}
|
|
|
|
|
|
type TextModelOptions struct {
|
|
|
- RopeFactors ml.Tensor `gguf:"rope_freqs.weight"`
|
|
|
-
|
|
|
hiddenSize, numHeads, numKVHeads int
|
|
|
eps, ropeBase, ropeScale float32
|
|
|
ropeDim uint32
|