Explorar o código

models: Prune unused outputs earlier in the forward pass

Currently Rows is called as the last step in a model computation
to get the values for the output tokens. However, if we move it
earlier in the process then we can trim out computations that
never get used. This is similar to how models are defined in
llama.cpp.

Changing the model definition in this way improves token generation
performance by approximately 8%.
Jesse Gross hai 2 meses
pai
achega
5c5535c064
Modificáronse 3 ficheiros con 43 adicións e 20 borrados
  1. 21 9
      model/models/llama/model.go
  2. 2 4
      model/models/mllama/model.go
  3. 20 7
      model/models/mllama/model_text.go

+ 21 - 9
model/models/llama/model.go

@@ -120,11 +120,19 @@ type Layer struct {
 	MLP           *MLP
 }
 
-func (l *Layer) Forward(ctx ml.Context, hiddenState, positionIDs ml.Tensor, cache kvcache.Cache, opts *Options) ml.Tensor {
+func (l *Layer) Forward(ctx ml.Context, hiddenState, positionIDs, outputs ml.Tensor, cache kvcache.Cache, opts *Options) ml.Tensor {
 	residual := hiddenState
 
 	hiddenState = l.AttentionNorm.Forward(ctx, hiddenState, opts.eps)
 	hiddenState = l.SelfAttention.Forward(ctx, hiddenState, positionIDs, cache, opts)
+
+	// In the final layer (outputs != nil), optimize by pruning to just the token positions
+	// we need logits for.
+	if outputs != nil {
+		hiddenState = hiddenState.Rows(ctx, outputs)
+		residual = residual.Rows(ctx, outputs)
+	}
+
 	hiddenState = hiddenState.Add(ctx, residual)
 	residual = hiddenState
 
@@ -144,22 +152,26 @@ func (m *Model) Forward(ctx ml.Context, opts model.Options) (ml.Tensor, error) {
 		return nil, err
 	}
 
+	outputs, err := ctx.FromIntSlice(opts.Outputs, len(opts.Outputs))
+	if err != nil {
+		return nil, err
+	}
+
 	hiddenState := m.TokenEmbedding.Forward(ctx, inputs)
 
 	for i, layer := range m.Layers {
 		m.Cache.SetLayer(i)
-		hiddenState = layer.Forward(ctx, hiddenState, positions, m.Cache, m.Options)
-	}
 
-	hiddenState = m.OutputNorm.Forward(ctx, hiddenState, m.eps)
-	hiddenState = m.Output.Forward(ctx, hiddenState)
+		var lastLayerOutputs ml.Tensor
+		if i == len(m.Layers)-1 {
+			lastLayerOutputs = outputs
+		}
 
-	outputs, err := ctx.FromIntSlice(opts.Outputs, len(opts.Outputs))
-	if err != nil {
-		return nil, err
+		hiddenState = layer.Forward(ctx, hiddenState, positions, lastLayerOutputs, m.Cache, m.Options)
 	}
 
-	return hiddenState.Rows(ctx, outputs), nil
+	hiddenState = m.OutputNorm.Forward(ctx, hiddenState, m.eps)
+	return m.Output.Forward(ctx, hiddenState), nil
 }
 
 func init() {

+ 2 - 4
model/models/mllama/model.go

@@ -93,15 +93,13 @@ func (m *Model) Forward(ctx ml.Context, opts model.Options) (ml.Tensor, error) {
 		return nil, err
 	}
 
-	// TODO: attention mask, cross attention mask
-	hiddenState := m.TextModel.Forward(ctx, inputs, positions, nil, crossAttentionStates, nil, m.Cache.(*kvcache.WrapperCache))
-
 	outputs, err := ctx.FromIntSlice(opts.Outputs, len(opts.Outputs))
 	if err != nil {
 		return nil, err
 	}
 
-	return hiddenState.Rows(ctx, outputs), nil
+	// TODO: attention mask, cross attention mask
+	return m.TextModel.Forward(ctx, inputs, positions, outputs, nil, crossAttentionStates, nil, m.Cache.(*kvcache.WrapperCache)), nil
 }
 
 func init() {

+ 20 - 7
model/models/mllama/model_text.go

@@ -74,11 +74,19 @@ type TextSelfAttentionDecoderLayer struct {
 	MLP     *TextMLP
 }
 
-func (d *TextSelfAttentionDecoderLayer) Forward(ctx ml.Context, hiddenState, positions, mask, _, _ ml.Tensor, cache *kvcache.WrapperCache, opts *TextModelOptions) ml.Tensor {
+func (d *TextSelfAttentionDecoderLayer) Forward(ctx ml.Context, hiddenState, positions, outputs, mask, _, _ ml.Tensor, cache *kvcache.WrapperCache, opts *TextModelOptions) ml.Tensor {
 	residual := hiddenState
 
 	hiddenState = d.AttentionNorm.Forward(ctx, hiddenState, opts.eps)
 	hiddenState = d.SelfAttention.Forward(ctx, hiddenState, positions, mask, cache, opts)
+
+	// In the final layer (outputs != nil), optimize by pruning to just the token positions
+	// we need logits for.
+	if outputs != nil {
+		hiddenState = hiddenState.Rows(ctx, outputs)
+		residual = residual.Rows(ctx, outputs)
+	}
+
 	hiddenState = hiddenState.Add(ctx, residual)
 	residual = hiddenState
 
@@ -145,7 +153,7 @@ type TextCrossAttentionDecoderLayer struct {
 	MLPGate ml.Tensor `gguf:"cross_attn_mlp_gate"`
 }
 
-func (d *TextCrossAttentionDecoderLayer) Forward(ctx ml.Context, hiddenState, _, _, crossAttentionStates, crossAttentionMask ml.Tensor, cache *kvcache.WrapperCache, opts *TextModelOptions) ml.Tensor {
+func (d *TextCrossAttentionDecoderLayer) Forward(ctx ml.Context, hiddenState, _, _, _, crossAttentionStates, crossAttentionMask ml.Tensor, cache *kvcache.WrapperCache, opts *TextModelOptions) ml.Tensor {
 	residual := hiddenState
 
 	hiddenState = d.AttentionNorm.Forward(ctx, hiddenState, opts.eps)
@@ -161,14 +169,14 @@ func (d *TextCrossAttentionDecoderLayer) Forward(ctx ml.Context, hiddenState, _,
 }
 
 type TextDecoderLayer interface {
-	Forward(ctx ml.Context, hiddenState, positionIDs, mask, crossAttentionStates, crossAttentionMask ml.Tensor, cache *kvcache.WrapperCache, opts *TextModelOptions) ml.Tensor
+	Forward(ctx ml.Context, hiddenState, positionIDs, outputs, mask, crossAttentionStates, crossAttentionMask ml.Tensor, cache *kvcache.WrapperCache, opts *TextModelOptions) ml.Tensor
 }
 
 type TextDecoder struct {
 	Layers []TextDecoderLayer
 }
 
-func (d *TextDecoder) Forward(ctx ml.Context, hiddenState, positionIDs, mask, crossAttentionStates, crossAttentionMask ml.Tensor, cache *kvcache.WrapperCache, opts *TextModelOptions) ml.Tensor {
+func (d *TextDecoder) Forward(ctx ml.Context, hiddenState, positionIDs, outputs, mask, crossAttentionStates, crossAttentionMask ml.Tensor, cache *kvcache.WrapperCache, opts *TextModelOptions) ml.Tensor {
 	for i, layer := range d.Layers {
 		layerType := selfAttentionLayer
 		if slices.Contains(opts.crossAttentionLayers, uint32(i)) {
@@ -179,7 +187,12 @@ func (d *TextDecoder) Forward(ctx ml.Context, hiddenState, positionIDs, mask, cr
 		cache.SetLayerType(layerType)
 
 		if layerType == selfAttentionLayer || crossAttentionStates != nil || cache.UnderlyingCache().(*kvcache.EncoderCache).EncoderCached() {
-			hiddenState = layer.Forward(ctx, hiddenState, positionIDs, mask, crossAttentionStates, crossAttentionMask, cache, opts)
+			var lastLayerOutputs ml.Tensor
+			if i == len(d.Layers)-1 {
+				lastLayerOutputs = outputs
+			}
+
+			hiddenState = layer.Forward(ctx, hiddenState, positionIDs, lastLayerOutputs, mask, crossAttentionStates, crossAttentionMask, cache, opts)
 		}
 	}
 
@@ -205,9 +218,9 @@ type TextModel struct {
 	*TextModelOptions
 }
 
-func (m *TextModel) Forward(ctx ml.Context, inputIDs, positionIDs, mask, crossAttentionStates, crossAttentionMask ml.Tensor, cache *kvcache.WrapperCache) ml.Tensor {
+func (m *TextModel) Forward(ctx ml.Context, inputIDs, positionIDs, outputs, mask, crossAttentionStates, crossAttentionMask ml.Tensor, cache *kvcache.WrapperCache) ml.Tensor {
 	hiddenState := m.TokenEmbedding.Forward(ctx, inputIDs)
-	hiddenState = m.Transformer.Forward(ctx, hiddenState, positionIDs, mask, crossAttentionStates, crossAttentionMask, cache, m.TextModelOptions)
+	hiddenState = m.Transformer.Forward(ctx, hiddenState, positionIDs, outputs, mask, crossAttentionStates, crossAttentionMask, cache, m.TextModelOptions)
 	hiddenState = m.OutputNorm.Forward(ctx, hiddenState, m.eps)
 	return m.Output.Forward(ctx, hiddenState)
 }