Преглед на файлове

model: document qwen2 forward pass

Bruce MacDonald преди 2 месеца
родител
ревизия
96510b9353
променени са 1 файла, в които са добавени 98 реда и са изтрити 72 реда
  1. 98 72
      model/models/qwen2/model.go

+ 98 - 72
model/models/qwen2/model.go

@@ -10,10 +10,15 @@ import (
 )
 
 type Options struct {
-	RopeFactors                              ml.Tensor `gguf:"rope_freqs.weight"`
-	ctxLen, hiddenSize, numHeads, numKVHeads int
-	eps, ropeBase, ropeScale                 float32
-	ropeDim                                  uint32
+	RopeFactors    ml.Tensor `gguf:"rope_freqs.weight"`
+	contextLength  int
+	hiddenSize     int
+	numAttnHeads   int
+	numKVHeads     int
+	modelEpsilon   float32
+	ropeBaseFreq   float32
+	ropeFreqScale  float32
+	ropeDimensions uint32
 }
 
 type Model struct {
@@ -42,14 +47,14 @@ func New(c ml.Config) (model.Model, error) {
 		),
 		Layers: make([]Layer, c.Uint("block_count")),
 		Options: &Options{
-			hiddenSize: int(c.Uint("embedding_length")),
-			numHeads:   int(c.Uint("attention.head_count")),
-			numKVHeads: int(c.Uint("attention.head_count_kv")),
-			eps:        c.Float("attention.layer_norm_rms_epsilon"),
-			ctxLen:     int(c.Uint("context_length")),
-			ropeBase:   c.Float("rope.freq_base"),
-			ropeScale:  c.Float("rope.freq_scale", 1),
-			ropeDim:    c.Uint("rope.dimension_count", 64),
+			hiddenSize:     int(c.Uint("embedding_length")),
+			numAttnHeads:   int(c.Uint("attention.head_count")),
+			numKVHeads:     int(c.Uint("attention.head_count_kv")),
+			modelEpsilon:   c.Float("attention.layer_norm_rms_epsilon"),
+			contextLength:  int(c.Uint("context_length")),
+			ropeBaseFreq:   c.Float("rope.freq_base"),
+			ropeFreqScale:  c.Float("rope.freq_scale", 1),
+			ropeDimensions: c.Uint("rope.dimension_count", 64),
 		},
 	}
 
@@ -58,21 +63,24 @@ func New(c ml.Config) (model.Model, error) {
 	return m, nil
 }
 
+// Shift applies rotary position embeddings to the key tensor for causal attention caching
 func (m *Model) Shift(ctx ml.Context, layer int, key, shift ml.Tensor) (ml.Tensor, error) {
 	return key.RoPE(
 		ctx,
 		ml.RopeConfig{
 			PositionIDs: shift,
 			RopeFactors: m.Options.RopeFactors,
-			RopeDim:     m.Options.ropeDim,
+			RopeDim:     m.Options.ropeDimensions,
 			RopeType:    ml.RopeTypeNeoX,
-			OrigCtxLen:  m.Options.ctxLen,
-			RopeBase:    m.Options.ropeBase,
-			RopeScale:   m.Options.ropeScale,
+			OrigCtxLen:  m.Options.contextLength,
+			RopeBase:    m.Options.ropeBaseFreq,
+			RopeScale:   m.Options.ropeFreqScale,
 		},
 	), nil
 }
 
+// SelfAttention implements the multi-head self-attention mechanism
+// with separate projections for query, key, value and output transformations
 type SelfAttention struct {
 	Query  *nn.Linear `gguf:"attn_q"`
 	Key    *nn.Linear `gguf:"attn_k"`
@@ -81,49 +89,59 @@ type SelfAttention struct {
 }
 
 func (sa *SelfAttention) Forward(ctx ml.Context, hiddenState, inputPositions ml.Tensor, cache kvcache.Cache, opts *Options) ml.Tensor {
+	// Initialize dimensions and configuration
 	batchSize := hiddenState.Dim(1)
-	headDim := opts.hiddenSize / opts.numHeads
-	rc := ml.RopeConfig{
+	headDimension := opts.hiddenSize / opts.numAttnHeads
+	ropeConfig := ml.RopeConfig{
 		PositionIDs: inputPositions,
 		RopeFactors: nil,
-		RopeDim:     opts.ropeDim,
+		RopeDim:     opts.ropeDimensions,
 		RopeType:    ml.RopeTypeNeoX,
-		OrigCtxLen:  opts.ctxLen,
-		RopeBase:    opts.ropeBase,
-		RopeScale:   opts.ropeScale,
+		OrigCtxLen:  opts.contextLength,
+		RopeBase:    opts.ropeBaseFreq,
+		RopeScale:   opts.ropeFreqScale,
 	}
 
-	q := sa.Query.Forward(ctx, hiddenState)
-
-	q = q.Reshape(ctx, headDim, opts.numHeads, batchSize)
-	q = q.RoPE(ctx, rc)
-
-	k := sa.Key.Forward(ctx, hiddenState)
-	k = k.Reshape(ctx, headDim, opts.numKVHeads, batchSize)
-	k = k.RoPE(ctx, rc)
-
-	v := sa.Value.Forward(ctx, hiddenState)
-	v = v.Reshape(ctx, headDim, opts.numKVHeads, batchSize)
-
-	cache.Put(ctx, k, v)
-	k, v, mask := cache.Get(ctx)
-
-	q = q.Permute(ctx, 0, 2, 1, 3).Contiguous(ctx)
-	k = k.Permute(ctx, 0, 2, 1, 3).Contiguous(ctx)
-	v = v.Permute(ctx, 1, 2, 0, 3).Contiguous(ctx)
-
-	kq := k.MulmatFullPrec(ctx, q)
-	kq = kq.Scale(ctx, 1.0/math.Sqrt(float64(headDim)))
-	kq = kq.Add(ctx, mask)
-	kq = kq.Softmax(ctx)
-
-	kqv := v.Mulmat(ctx, kq)
-	kqv = kqv.Permute(ctx, 0, 2, 1, 3).Contiguous(ctx)
-	kqv = kqv.Reshape(ctx, opts.hiddenSize, batchSize)
-
-	return sa.Output.Forward(ctx, kqv)
+	// Project and reshape query states with rotary embeddings
+	queryStates := sa.Query.Forward(ctx, hiddenState)
+	queryStates = queryStates.Reshape(ctx, headDimension, opts.numAttnHeads, batchSize)
+	queryStates = queryStates.RoPE(ctx, ropeConfig)
+
+	// Project and reshape key states with rotary embeddings
+	keyStates := sa.Key.Forward(ctx, hiddenState)
+	keyStates = keyStates.Reshape(ctx, headDimension, opts.numKVHeads, batchSize)
+	keyStates = keyStates.RoPE(ctx, ropeConfig)
+
+	// Project and reshape value states
+	valueStates := sa.Value.Forward(ctx, hiddenState)
+	valueStates = valueStates.Reshape(ctx, headDimension, opts.numKVHeads, batchSize)
+
+	// Update and retrieve from KV cache
+	cache.Put(ctx, keyStates, valueStates)
+	keyStates, valueStates, attentionMask := cache.Get(ctx)
+
+	// Prepare tensors for attention computation
+	queryStates = queryStates.Permute(ctx, 0, 2, 1, 3).Contiguous(ctx)
+	keyStates = keyStates.Permute(ctx, 0, 2, 1, 3).Contiguous(ctx)
+	valueStates = valueStates.Permute(ctx, 1, 2, 0, 3).Contiguous(ctx)
+
+	// Apply scaling and attention mask to scores
+	attentionScores := keyStates.MulmatFullPrec(ctx, queryStates)
+	attentionScores = attentionScores.Scale(ctx, 1.0/math.Sqrt(float64(headDimension)))
+	attentionScores = attentionScores.Add(ctx, attentionMask)
+	// Compute scaled dot-product attention
+	attentionProbs := attentionScores.Softmax(ctx)
+
+	// Apply attention weights and reshape
+	weightedStates := valueStates.Mulmat(ctx, attentionProbs)
+	weightedStates = weightedStates.Permute(ctx, 0, 2, 1, 3).Contiguous(ctx)
+	weightedStates = weightedStates.Reshape(ctx, opts.hiddenSize, batchSize)
+
+	// Project to output dimension
+	return sa.Output.Forward(ctx, weightedStates)
 }
 
+// MLP implements the feed-forward network component with SwiGLU activation
 type MLP struct {
 	Up   *nn.Linear `gguf:"ffn_up"`
 	Down *nn.Linear `gguf:"ffn_down"`
@@ -131,10 +149,16 @@ type MLP struct {
 }
 
 func (mlp *MLP) Forward(ctx ml.Context, hiddenState ml.Tensor, opts *Options) ml.Tensor {
-	hiddenState = mlp.Gate.Forward(ctx, hiddenState).SILU(ctx).Mul(ctx, mlp.Up.Forward(ctx, hiddenState))
-	return mlp.Down.Forward(ctx, hiddenState)
+	// Apply SwiGLU activation gating
+	gateActivation := mlp.Gate.Forward(ctx, hiddenState).SILU(ctx)
+	upProjection := mlp.Up.Forward(ctx, hiddenState)
+	intermediateStates := gateActivation.Mul(ctx, upProjection)
+
+	// Project back to hidden dimension
+	return mlp.Down.Forward(ctx, intermediateStates)
 }
 
+// Layer represents a single transformer layer combining self-attention and feed-forward components
 type Layer struct {
 	AttentionNorm *nn.RMSNorm `gguf:"attn_norm"`
 	SelfAttention *SelfAttention
@@ -143,52 +167,54 @@ type Layer struct {
 }
 
 func (l *Layer) Forward(ctx ml.Context, hiddenState, positionIDs ml.Tensor, cache kvcache.Cache, opts *Options) ml.Tensor {
+	// Self-attention branch with residual connection
 	residual := hiddenState
 
-	hiddenState = l.AttentionNorm.Forward(ctx, hiddenState, opts.eps)
+	normalizedAttention := l.AttentionNorm.Forward(ctx, hiddenState, opts.modelEpsilon)
+	attentionOutput := l.SelfAttention.Forward(ctx, normalizedAttention, positionIDs, cache, opts)
+	hiddenState = attentionOutput.Add(ctx, residual)
 
-	hiddenState = l.SelfAttention.Forward(ctx, hiddenState, positionIDs, cache, opts)
-
-	hiddenState = hiddenState.Add(ctx, residual)
+	// Feed-forward branch with residual connection
 	residual = hiddenState
-
-	hiddenState = l.MLPNorm.Forward(ctx, hiddenState, opts.eps)
-
-	hiddenState = l.MLP.Forward(ctx, hiddenState, opts)
-
-	output := hiddenState.Add(ctx, residual)
+	normalizedMLP := l.MLPNorm.Forward(ctx, hiddenState, opts.modelEpsilon)
+	mlpOutput := l.MLP.Forward(ctx, normalizedMLP, opts)
+	output := mlpOutput.Add(ctx, residual)
 
 	return output
 }
 
 func (m *Model) Forward(ctx ml.Context, opts model.Options) (ml.Tensor, error) {
-	inputs, err := ctx.FromIntSlice(opts.Inputs, len(opts.Inputs))
+	// Convert input tokens and positions to tensors
+	inputTensor, err := ctx.FromIntSlice(opts.Inputs, len(opts.Inputs))
 	if err != nil {
 		return nil, err
 	}
 
-	positions, err := ctx.FromIntSlice(opts.Positions, len(opts.Positions))
+	positionsTensor, err := ctx.FromIntSlice(opts.Positions, len(opts.Positions))
 	if err != nil {
 		return nil, err
 	}
 
-	hiddenState := m.TokenEmbedding.Forward(ctx, inputs)
+	// Initial token embedding
+	hiddenStates := m.TokenEmbedding.Forward(ctx, inputTensor)
 
+	// Process through transformer layers
 	for i, layer := range m.Layers {
 		m.Cache.SetLayer(i)
-		hiddenState = layer.Forward(ctx, hiddenState, positions, m.Cache, m.Options)
+		hiddenStates = layer.Forward(ctx, hiddenStates, positionsTensor, m.Cache, m.Options)
 	}
 
-	hiddenState = m.OutputNorm.Forward(ctx, hiddenState, m.eps)
-
-	hiddenState = m.Output.Forward(ctx, hiddenState)
+	// Final layer normalization and output projection
+	normalizedOutput := m.OutputNorm.Forward(ctx, hiddenStates, m.modelEpsilon)
+	logits := m.Output.Forward(ctx, normalizedOutput)
 
-	outputs, err := ctx.FromIntSlice(opts.Outputs, len(opts.Outputs))
+	// Extract requested output token positions
+	outputsTensor, err := ctx.FromIntSlice(opts.Outputs, len(opts.Outputs))
 	if err != nil {
 		return nil, err
 	}
 
-	return hiddenState.Rows(ctx, outputs), nil
+	return logits.Rows(ctx, outputsTensor), nil
 }
 
 func init() {