Przeglądaj źródła

models: Prune unused outputs earlier in the forward pass

Currently Rows is called as the last step in a model computation
to get the values for the output tokens. However, if we move it
earlier in the process then we can trim out computations that
never get used. This is similar to how models are defined in
llama.cpp.

Changing the model definition in this way improves token generation
performance by approximately 8%.
Jesse Gross 2 miesięcy temu
rodzic
commit
5c5535c064

+ 21 - 9
model/models/llama/model.go

@@ -120,11 +120,19 @@ type Layer struct {
 	MLP           *MLP
 	MLP           *MLP
 }
 }
 
 
-func (l *Layer) Forward(ctx ml.Context, hiddenState, positionIDs ml.Tensor, cache kvcache.Cache, opts *Options) ml.Tensor {
+func (l *Layer) Forward(ctx ml.Context, hiddenState, positionIDs, outputs ml.Tensor, cache kvcache.Cache, opts *Options) ml.Tensor {
 	residual := hiddenState
 	residual := hiddenState
 
 
 	hiddenState = l.AttentionNorm.Forward(ctx, hiddenState, opts.eps)
 	hiddenState = l.AttentionNorm.Forward(ctx, hiddenState, opts.eps)
 	hiddenState = l.SelfAttention.Forward(ctx, hiddenState, positionIDs, cache, opts)
 	hiddenState = l.SelfAttention.Forward(ctx, hiddenState, positionIDs, cache, opts)
+
+	// In the final layer (outputs != nil), optimize by pruning to just the token positions
+	// we need logits for.
+	if outputs != nil {
+		hiddenState = hiddenState.Rows(ctx, outputs)
+		residual = residual.Rows(ctx, outputs)
+	}
+
 	hiddenState = hiddenState.Add(ctx, residual)
 	hiddenState = hiddenState.Add(ctx, residual)
 	residual = hiddenState
 	residual = hiddenState
 
 
@@ -144,22 +152,26 @@ func (m *Model) Forward(ctx ml.Context, opts model.Options) (ml.Tensor, error) {
 		return nil, err
 		return nil, err
 	}
 	}
 
 
+	outputs, err := ctx.FromIntSlice(opts.Outputs, len(opts.Outputs))
+	if err != nil {
+		return nil, err
+	}
+
 	hiddenState := m.TokenEmbedding.Forward(ctx, inputs)
 	hiddenState := m.TokenEmbedding.Forward(ctx, inputs)
 
 
 	for i, layer := range m.Layers {
 	for i, layer := range m.Layers {
 		m.Cache.SetLayer(i)
 		m.Cache.SetLayer(i)
-		hiddenState = layer.Forward(ctx, hiddenState, positions, m.Cache, m.Options)
-	}
 
 
-	hiddenState = m.OutputNorm.Forward(ctx, hiddenState, m.eps)
-	hiddenState = m.Output.Forward(ctx, hiddenState)
+		var lastLayerOutputs ml.Tensor
+		if i == len(m.Layers)-1 {
+			lastLayerOutputs = outputs
+		}
 
 
-	outputs, err := ctx.FromIntSlice(opts.Outputs, len(opts.Outputs))
-	if err != nil {
-		return nil, err
+		hiddenState = layer.Forward(ctx, hiddenState, positions, lastLayerOutputs, m.Cache, m.Options)
 	}
 	}
 
 
-	return hiddenState.Rows(ctx, outputs), nil
+	hiddenState = m.OutputNorm.Forward(ctx, hiddenState, m.eps)
+	return m.Output.Forward(ctx, hiddenState), nil
 }
 }
 
 
 func init() {
 func init() {

+ 2 - 4
model/models/mllama/model.go

@@ -93,15 +93,13 @@ func (m *Model) Forward(ctx ml.Context, opts model.Options) (ml.Tensor, error) {
 		return nil, err
 		return nil, err
 	}
 	}
 
 
-	// TODO: attention mask, cross attention mask
-	hiddenState := m.TextModel.Forward(ctx, inputs, positions, nil, crossAttentionStates, nil, m.Cache.(*kvcache.WrapperCache))
-
 	outputs, err := ctx.FromIntSlice(opts.Outputs, len(opts.Outputs))
 	outputs, err := ctx.FromIntSlice(opts.Outputs, len(opts.Outputs))
 	if err != nil {
 	if err != nil {
 		return nil, err
 		return nil, err
 	}
 	}
 
 
-	return hiddenState.Rows(ctx, outputs), nil
+	// TODO: attention mask, cross attention mask
+	return m.TextModel.Forward(ctx, inputs, positions, outputs, nil, crossAttentionStates, nil, m.Cache.(*kvcache.WrapperCache)), nil
 }
 }
 
 
 func init() {
 func init() {

+ 20 - 7
model/models/mllama/model_text.go

@@ -74,11 +74,19 @@ type TextSelfAttentionDecoderLayer struct {
 	MLP     *TextMLP
 	MLP     *TextMLP
 }
 }
 
 
-func (d *TextSelfAttentionDecoderLayer) Forward(ctx ml.Context, hiddenState, positions, mask, _, _ ml.Tensor, cache *kvcache.WrapperCache, opts *TextModelOptions) ml.Tensor {
+func (d *TextSelfAttentionDecoderLayer) Forward(ctx ml.Context, hiddenState, positions, outputs, mask, _, _ ml.Tensor, cache *kvcache.WrapperCache, opts *TextModelOptions) ml.Tensor {
 	residual := hiddenState
 	residual := hiddenState
 
 
 	hiddenState = d.AttentionNorm.Forward(ctx, hiddenState, opts.eps)
 	hiddenState = d.AttentionNorm.Forward(ctx, hiddenState, opts.eps)
 	hiddenState = d.SelfAttention.Forward(ctx, hiddenState, positions, mask, cache, opts)
 	hiddenState = d.SelfAttention.Forward(ctx, hiddenState, positions, mask, cache, opts)
+
+	// In the final layer (outputs != nil), optimize by pruning to just the token positions
+	// we need logits for.
+	if outputs != nil {
+		hiddenState = hiddenState.Rows(ctx, outputs)
+		residual = residual.Rows(ctx, outputs)
+	}
+
 	hiddenState = hiddenState.Add(ctx, residual)
 	hiddenState = hiddenState.Add(ctx, residual)
 	residual = hiddenState
 	residual = hiddenState
 
 
@@ -145,7 +153,7 @@ type TextCrossAttentionDecoderLayer struct {
 	MLPGate ml.Tensor `gguf:"cross_attn_mlp_gate"`
 	MLPGate ml.Tensor `gguf:"cross_attn_mlp_gate"`
 }
 }
 
 
-func (d *TextCrossAttentionDecoderLayer) Forward(ctx ml.Context, hiddenState, _, _, crossAttentionStates, crossAttentionMask ml.Tensor, cache *kvcache.WrapperCache, opts *TextModelOptions) ml.Tensor {
+func (d *TextCrossAttentionDecoderLayer) Forward(ctx ml.Context, hiddenState, _, _, _, crossAttentionStates, crossAttentionMask ml.Tensor, cache *kvcache.WrapperCache, opts *TextModelOptions) ml.Tensor {
 	residual := hiddenState
 	residual := hiddenState
 
 
 	hiddenState = d.AttentionNorm.Forward(ctx, hiddenState, opts.eps)
 	hiddenState = d.AttentionNorm.Forward(ctx, hiddenState, opts.eps)
@@ -161,14 +169,14 @@ func (d *TextCrossAttentionDecoderLayer) Forward(ctx ml.Context, hiddenState, _,
 }
 }
 
 
 type TextDecoderLayer interface {
 type TextDecoderLayer interface {
-	Forward(ctx ml.Context, hiddenState, positionIDs, mask, crossAttentionStates, crossAttentionMask ml.Tensor, cache *kvcache.WrapperCache, opts *TextModelOptions) ml.Tensor
+	Forward(ctx ml.Context, hiddenState, positionIDs, outputs, mask, crossAttentionStates, crossAttentionMask ml.Tensor, cache *kvcache.WrapperCache, opts *TextModelOptions) ml.Tensor
 }
 }
 
 
 type TextDecoder struct {
 type TextDecoder struct {
 	Layers []TextDecoderLayer
 	Layers []TextDecoderLayer
 }
 }
 
 
-func (d *TextDecoder) Forward(ctx ml.Context, hiddenState, positionIDs, mask, crossAttentionStates, crossAttentionMask ml.Tensor, cache *kvcache.WrapperCache, opts *TextModelOptions) ml.Tensor {
+func (d *TextDecoder) Forward(ctx ml.Context, hiddenState, positionIDs, outputs, mask, crossAttentionStates, crossAttentionMask ml.Tensor, cache *kvcache.WrapperCache, opts *TextModelOptions) ml.Tensor {
 	for i, layer := range d.Layers {
 	for i, layer := range d.Layers {
 		layerType := selfAttentionLayer
 		layerType := selfAttentionLayer
 		if slices.Contains(opts.crossAttentionLayers, uint32(i)) {
 		if slices.Contains(opts.crossAttentionLayers, uint32(i)) {
@@ -179,7 +187,12 @@ func (d *TextDecoder) Forward(ctx ml.Context, hiddenState, positionIDs, mask, cr
 		cache.SetLayerType(layerType)
 		cache.SetLayerType(layerType)
 
 
 		if layerType == selfAttentionLayer || crossAttentionStates != nil || cache.UnderlyingCache().(*kvcache.EncoderCache).EncoderCached() {
 		if layerType == selfAttentionLayer || crossAttentionStates != nil || cache.UnderlyingCache().(*kvcache.EncoderCache).EncoderCached() {
-			hiddenState = layer.Forward(ctx, hiddenState, positionIDs, mask, crossAttentionStates, crossAttentionMask, cache, opts)
+			var lastLayerOutputs ml.Tensor
+			if i == len(d.Layers)-1 {
+				lastLayerOutputs = outputs
+			}
+
+			hiddenState = layer.Forward(ctx, hiddenState, positionIDs, lastLayerOutputs, mask, crossAttentionStates, crossAttentionMask, cache, opts)
 		}
 		}
 	}
 	}
 
 
@@ -205,9 +218,9 @@ type TextModel struct {
 	*TextModelOptions
 	*TextModelOptions
 }
 }
 
 
-func (m *TextModel) Forward(ctx ml.Context, inputIDs, positionIDs, mask, crossAttentionStates, crossAttentionMask ml.Tensor, cache *kvcache.WrapperCache) ml.Tensor {
+func (m *TextModel) Forward(ctx ml.Context, inputIDs, positionIDs, outputs, mask, crossAttentionStates, crossAttentionMask ml.Tensor, cache *kvcache.WrapperCache) ml.Tensor {
 	hiddenState := m.TokenEmbedding.Forward(ctx, inputIDs)
 	hiddenState := m.TokenEmbedding.Forward(ctx, inputIDs)
-	hiddenState = m.Transformer.Forward(ctx, hiddenState, positionIDs, mask, crossAttentionStates, crossAttentionMask, cache, m.TextModelOptions)
+	hiddenState = m.Transformer.Forward(ctx, hiddenState, positionIDs, outputs, mask, crossAttentionStates, crossAttentionMask, cache, m.TextModelOptions)
 	hiddenState = m.OutputNorm.Forward(ctx, hiddenState, m.eps)
 	hiddenState = m.OutputNorm.Forward(ctx, hiddenState, m.eps)
 	return m.Output.Forward(ctx, hiddenState)
 	return m.Output.Forward(ctx, hiddenState)
 }
 }