|
@@ -74,11 +74,19 @@ type TextSelfAttentionDecoderLayer struct {
|
|
MLP *TextMLP
|
|
MLP *TextMLP
|
|
}
|
|
}
|
|
|
|
|
|
-func (d *TextSelfAttentionDecoderLayer) Forward(ctx ml.Context, hiddenState, positions, mask, _, _ ml.Tensor, cache *kvcache.WrapperCache, opts *TextModelOptions) ml.Tensor {
|
|
|
|
|
|
+func (d *TextSelfAttentionDecoderLayer) Forward(ctx ml.Context, hiddenState, positions, outputs, mask, _, _ ml.Tensor, cache *kvcache.WrapperCache, opts *TextModelOptions) ml.Tensor {
|
|
residual := hiddenState
|
|
residual := hiddenState
|
|
|
|
|
|
hiddenState = d.AttentionNorm.Forward(ctx, hiddenState, opts.eps)
|
|
hiddenState = d.AttentionNorm.Forward(ctx, hiddenState, opts.eps)
|
|
hiddenState = d.SelfAttention.Forward(ctx, hiddenState, positions, mask, cache, opts)
|
|
hiddenState = d.SelfAttention.Forward(ctx, hiddenState, positions, mask, cache, opts)
|
|
|
|
+
|
|
|
|
+ // In the final layer (outputs != nil), optimize by pruning to just the token positions
|
|
|
|
+ // we need logits for.
|
|
|
|
+ if outputs != nil {
|
|
|
|
+ hiddenState = hiddenState.Rows(ctx, outputs)
|
|
|
|
+ residual = residual.Rows(ctx, outputs)
|
|
|
|
+ }
|
|
|
|
+
|
|
hiddenState = hiddenState.Add(ctx, residual)
|
|
hiddenState = hiddenState.Add(ctx, residual)
|
|
residual = hiddenState
|
|
residual = hiddenState
|
|
|
|
|
|
@@ -145,7 +153,7 @@ type TextCrossAttentionDecoderLayer struct {
|
|
MLPGate ml.Tensor `gguf:"cross_attn_mlp_gate"`
|
|
MLPGate ml.Tensor `gguf:"cross_attn_mlp_gate"`
|
|
}
|
|
}
|
|
|
|
|
|
-func (d *TextCrossAttentionDecoderLayer) Forward(ctx ml.Context, hiddenState, _, _, crossAttentionStates, crossAttentionMask ml.Tensor, cache *kvcache.WrapperCache, opts *TextModelOptions) ml.Tensor {
|
|
|
|
|
|
+func (d *TextCrossAttentionDecoderLayer) Forward(ctx ml.Context, hiddenState, _, _, _, crossAttentionStates, crossAttentionMask ml.Tensor, cache *kvcache.WrapperCache, opts *TextModelOptions) ml.Tensor {
|
|
residual := hiddenState
|
|
residual := hiddenState
|
|
|
|
|
|
hiddenState = d.AttentionNorm.Forward(ctx, hiddenState, opts.eps)
|
|
hiddenState = d.AttentionNorm.Forward(ctx, hiddenState, opts.eps)
|
|
@@ -161,14 +169,14 @@ func (d *TextCrossAttentionDecoderLayer) Forward(ctx ml.Context, hiddenState, _,
|
|
}
|
|
}
|
|
|
|
|
|
type TextDecoderLayer interface {
|
|
type TextDecoderLayer interface {
|
|
- Forward(ctx ml.Context, hiddenState, positionIDs, mask, crossAttentionStates, crossAttentionMask ml.Tensor, cache *kvcache.WrapperCache, opts *TextModelOptions) ml.Tensor
|
|
|
|
|
|
+ Forward(ctx ml.Context, hiddenState, positionIDs, outputs, mask, crossAttentionStates, crossAttentionMask ml.Tensor, cache *kvcache.WrapperCache, opts *TextModelOptions) ml.Tensor
|
|
}
|
|
}
|
|
|
|
|
|
type TextDecoder struct {
|
|
type TextDecoder struct {
|
|
Layers []TextDecoderLayer
|
|
Layers []TextDecoderLayer
|
|
}
|
|
}
|
|
|
|
|
|
-func (d *TextDecoder) Forward(ctx ml.Context, hiddenState, positionIDs, mask, crossAttentionStates, crossAttentionMask ml.Tensor, cache *kvcache.WrapperCache, opts *TextModelOptions) ml.Tensor {
|
|
|
|
|
|
+func (d *TextDecoder) Forward(ctx ml.Context, hiddenState, positionIDs, outputs, mask, crossAttentionStates, crossAttentionMask ml.Tensor, cache *kvcache.WrapperCache, opts *TextModelOptions) ml.Tensor {
|
|
for i, layer := range d.Layers {
|
|
for i, layer := range d.Layers {
|
|
layerType := selfAttentionLayer
|
|
layerType := selfAttentionLayer
|
|
if slices.Contains(opts.crossAttentionLayers, uint32(i)) {
|
|
if slices.Contains(opts.crossAttentionLayers, uint32(i)) {
|
|
@@ -179,7 +187,12 @@ func (d *TextDecoder) Forward(ctx ml.Context, hiddenState, positionIDs, mask, cr
|
|
cache.SetLayerType(layerType)
|
|
cache.SetLayerType(layerType)
|
|
|
|
|
|
if layerType == selfAttentionLayer || crossAttentionStates != nil || cache.UnderlyingCache().(*kvcache.EncoderCache).EncoderCached() {
|
|
if layerType == selfAttentionLayer || crossAttentionStates != nil || cache.UnderlyingCache().(*kvcache.EncoderCache).EncoderCached() {
|
|
- hiddenState = layer.Forward(ctx, hiddenState, positionIDs, mask, crossAttentionStates, crossAttentionMask, cache, opts)
|
|
|
|
|
|
+ var lastLayerOutputs ml.Tensor
|
|
|
|
+ if i == len(d.Layers)-1 {
|
|
|
|
+ lastLayerOutputs = outputs
|
|
|
|
+ }
|
|
|
|
+
|
|
|
|
+ hiddenState = layer.Forward(ctx, hiddenState, positionIDs, lastLayerOutputs, mask, crossAttentionStates, crossAttentionMask, cache, opts)
|
|
}
|
|
}
|
|
}
|
|
}
|
|
|
|
|
|
@@ -205,9 +218,9 @@ type TextModel struct {
|
|
*TextModelOptions
|
|
*TextModelOptions
|
|
}
|
|
}
|
|
|
|
|
|
-func (m *TextModel) Forward(ctx ml.Context, inputIDs, positionIDs, mask, crossAttentionStates, crossAttentionMask ml.Tensor, cache *kvcache.WrapperCache) ml.Tensor {
|
|
|
|
|
|
+func (m *TextModel) Forward(ctx ml.Context, inputIDs, positionIDs, outputs, mask, crossAttentionStates, crossAttentionMask ml.Tensor, cache *kvcache.WrapperCache) ml.Tensor {
|
|
hiddenState := m.TokenEmbedding.Forward(ctx, inputIDs)
|
|
hiddenState := m.TokenEmbedding.Forward(ctx, inputIDs)
|
|
- hiddenState = m.Transformer.Forward(ctx, hiddenState, positionIDs, mask, crossAttentionStates, crossAttentionMask, cache, m.TextModelOptions)
|
|
|
|
|
|
+ hiddenState = m.Transformer.Forward(ctx, hiddenState, positionIDs, outputs, mask, crossAttentionStates, crossAttentionMask, cache, m.TextModelOptions)
|
|
hiddenState = m.OutputNorm.Forward(ctx, hiddenState, m.eps)
|
|
hiddenState = m.OutputNorm.Forward(ctx, hiddenState, m.eps)
|
|
return m.Output.Forward(ctx, hiddenState)
|
|
return m.Output.Forward(ctx, hiddenState)
|
|
}
|
|
}
|