소스 검색

gemma2 ftw

Patrick Devine 2 달 전
부모
커밋
10e06d0a45
6개의 변경된 파일59개의 추가작업 그리고 38개의 파일을 삭제
  1. 1 1
      ml/backend.go
  2. 6 3
      ml/backend/ggml/ggml.go
  3. 38 23
      model/models/gemma2/model.go
  4. 4 3
      model/models/llama/model.go
  5. 4 3
      model/models/mllama/model_text.go
  6. 6 5
      model/process_text_spm.go

+ 1 - 1
ml/backend.go

@@ -77,7 +77,7 @@ type Tensor interface {
 	Scale(ctx Context, s float64) Tensor
 
 	Conv2D(ctx Context, weight Tensor, s0, s1, p0, p1, d0, d1 int) Tensor
-	RoPE(ctx Context, positionIDs, ropeFactors Tensor, dim uint32, base, scale float32) Tensor
+	RoPE(ctx Context, positionIDs, ropeFactors Tensor, dim, ropeType uint32, base, scale float32) Tensor
 
 	Tanh(ctx Context) Tensor
 	GELU(ctx Context) Tensor

+ 6 - 3
ml/backend/ggml/ggml.go

@@ -596,10 +596,13 @@ func (t *Tensor) View(ctx ml.Context, offset int, shape ...int) ml.Tensor {
 }
 
 const (
-	ropeTypeNorm C.int = iota
+	ropeTypeNorm C.int = 0
+	ropeTypeNeox C.int = 2
+	ropeTypeMrope C.int = 8
+	ropeTypeVision C.int = 24
 )
 
-func (t *Tensor) RoPE(ctx ml.Context, positionIDs, ropeFactors ml.Tensor, ropeDim uint32, ropeBase, ropeScale float32) ml.Tensor {
+func (t *Tensor) RoPE(ctx ml.Context, positionIDs, ropeFactors ml.Tensor, ropeDim, ropeType uint32, ropeBase, ropeScale float32) ml.Tensor {
 	if ropeFactors == nil {
 		ropeFactors = &Tensor{}
 	}
@@ -613,8 +616,8 @@ func (t *Tensor) RoPE(ctx ml.Context, positionIDs, ropeFactors ml.Tensor, ropeDi
 		t: C.ggml_rope_ext(
 			ctx.(*Context).ctx, dequant, positionIDs.(*Tensor).t, ropeFactors.(*Tensor).t,
 			C.int(ropeDim),
+			C.int(ropeType),
 			131072,       // YaRN n_ctx_train
-			ropeTypeNorm, // ROPE_TYPE_NORM
 			C.float(ropeBase),
 			C.float(ropeScale),
 			0.,  // YaRN ext_factor

+ 38 - 23
model/models/gemma2/model.go

@@ -10,11 +10,11 @@ import (
 )
 
 type Options struct {
-	RopeFactors                      ml.Tensor `gguf:"rope_freqs.weight"`
 	hiddenSize, numHeads, numKVHeads int
 	attnKeyLen, attnValLen           int
 	eps, ropeBase, ropeScale         float32
-	ropeDim                          uint32
+	attnLogitSoftcap                 float32
+	finalLogitSoftcap                float32
 }
 
 type Model struct {
@@ -43,14 +43,16 @@ func New(c ml.Config) (model.Model, error) {
 		),
 		Layers: make([]Layer, c.Uint("block_count")),
 		Options: &Options{
-			hiddenSize: int(c.Uint("embedding_length")),
-			numHeads:   int(c.Uint("attention.head_count")),
-			numKVHeads: int(c.Uint("attention.head_count_kv")),
-			attnKeyLen: int(c.Uint("attention.key_length")),
-			attnValLen: int(c.Uint("attention.value_length")),
-			eps:        c.Float("attention.layer_norm_rms_epsilon"),
-			ropeBase:   c.Float("rope.freq_base", 10000.0),
-			ropeScale:  c.Float("rope.freq_scale", 1.0),
+			hiddenSize:        int(c.Uint("embedding_length")),
+			numHeads:          int(c.Uint("attention.head_count")),
+			numKVHeads:        int(c.Uint("attention.head_count_kv")),
+			attnKeyLen:        int(c.Uint("attention.key_length")),
+			attnValLen:        int(c.Uint("attention.value_length")),
+			eps:               c.Float("attention.layer_norm_rms_epsilon"),
+			ropeBase:          c.Float("rope.freq_base", 10000.0),
+			ropeScale:         c.Float("rope.freq_scale", 1.0),
+			attnLogitSoftcap:  c.Float("attn_logit_softcapping"),
+			finalLogitSoftcap: c.Float("final_logit_softcapping"),
 		},
 	}
 
@@ -69,18 +71,18 @@ type SelfAttention struct {
 
 func (sa *SelfAttention) Forward(ctx ml.Context, hiddenState, positionIDs ml.Tensor, cache kvcache.Cache, opts *Options) ml.Tensor {
 	batchSize := hiddenState.Dim(1)
-	headDim := opts.hiddenSize / opts.numHeads
+	ropeType := uint32(2)
 
 	q := sa.Query.Forward(ctx, hiddenState)
 	q = q.Reshape(ctx, opts.attnKeyLen, opts.numHeads, batchSize)
-	q = q.RoPE(ctx, positionIDs, opts.RopeFactors, uint32(headDim), opts.ropeBase, opts.ropeScale)
+	q = q.RoPE(ctx, positionIDs, nil, uint32(opts.attnKeyLen), ropeType, opts.ropeBase, opts.ropeScale)
 
 	// todo: this should be 1.0/math.Sqrt(float64(headDim)) for 27B models
-	//q = q.Scale(ctx, 1.0/math.Sqrt(float64(opts.attnKeyLen)))
+	q = q.Scale(ctx, 1.0/math.Sqrt(float64(opts.attnKeyLen)))
 
 	k := sa.Key.Forward(ctx, hiddenState)
 	k = k.Reshape(ctx, opts.attnKeyLen, opts.numKVHeads, batchSize)
-	k = k.RoPE(ctx, positionIDs, opts.RopeFactors, opts.ropeDim, opts.ropeBase, opts.ropeScale)
+	k = k.RoPE(ctx, positionIDs, nil, uint32(opts.attnKeyLen), ropeType, opts.ropeBase, opts.ropeScale)
 
 	v := sa.Value.Forward(ctx, hiddenState)
 	v = v.Reshape(ctx, opts.attnValLen, opts.numKVHeads, batchSize)
@@ -93,7 +95,12 @@ func (sa *SelfAttention) Forward(ctx ml.Context, hiddenState, positionIDs ml.Ten
 	v = v.Permute(ctx, 1, 2, 0, 3).Contiguous(ctx)
 
 	kq := k.Mulmat(ctx, q)
-	kq = kq.Scale(ctx, 1.0/math.Sqrt(float64(headDim)))
+
+	// logit softcap
+	kq = kq.Scale(ctx, 1.0/float64(opts.attnLogitSoftcap))
+	kq = kq.Tanh(ctx)
+	kq = kq.Scale(ctx, float64(opts.attnLogitSoftcap))
+
 	kq = kq.Add(ctx, mask)
 	kq = kq.Softmax(ctx)
 
@@ -105,7 +112,7 @@ func (sa *SelfAttention) Forward(ctx ml.Context, hiddenState, positionIDs ml.Ten
 }
 
 func (m *Model) Shift(ctx ml.Context, layer int, key, shift ml.Tensor) (ml.Tensor, error) {
-	return key.RoPE(ctx, shift, m.Options.RopeFactors, m.Options.ropeDim, m.Options.ropeBase, m.Options.ropeScale), nil
+	return key.RoPE(ctx, shift, nil, uint32(m.Options.attnKeyLen), uint32(2), m.Options.ropeBase, m.Options.ropeScale), nil
 }
 
 type MLP struct {
@@ -115,15 +122,17 @@ type MLP struct {
 }
 
 func (mlp *MLP) Forward(ctx ml.Context, hiddenState ml.Tensor, opts *Options) ml.Tensor {
-	hiddenState = mlp.Gate.Forward(ctx, hiddenState).Tanh(ctx).Mul(ctx, mlp.Up.Forward(ctx, hiddenState))
+	hiddenState = mlp.Gate.Forward(ctx, hiddenState).GELU(ctx).Mul(ctx, mlp.Up.Forward(ctx, hiddenState))
 	return mlp.Down.Forward(ctx, hiddenState)
 }
 
 type Layer struct {
-	AttentionNorm *nn.RMSNorm `gguf:"attn_norm"`
-	SelfAttention *SelfAttention
-	MLPNorm       *nn.RMSNorm `gguf:"ffn_norm"`
-	MLP           *MLP
+	AttentionNorm     *nn.RMSNorm `gguf:"attn_norm"`
+	SelfAttention     *SelfAttention
+	PostAttentionNorm *nn.RMSNorm `gguf:"post_attention_norm"`
+	MLPNorm           *nn.RMSNorm `gguf:"ffn_norm"`
+	MLP               *MLP
+	PostMLPNorm       *nn.RMSNorm `gguf:"post_ffw_norm"`
 }
 
 func (l *Layer) Forward(ctx ml.Context, hiddenState, positionIDs ml.Tensor, cache kvcache.Cache, opts *Options) ml.Tensor {
@@ -131,11 +140,13 @@ func (l *Layer) Forward(ctx ml.Context, hiddenState, positionIDs ml.Tensor, cach
 
 	hiddenState = l.AttentionNorm.Forward(ctx, hiddenState, opts.eps)
 	hiddenState = l.SelfAttention.Forward(ctx, hiddenState, positionIDs, cache, opts)
+	hiddenState = l.PostAttentionNorm.Forward(ctx, hiddenState, opts.eps)
 	hiddenState = hiddenState.Add(ctx, residual)
 	residual = hiddenState
 
 	hiddenState = l.MLPNorm.Forward(ctx, hiddenState, opts.eps)
 	hiddenState = l.MLP.Forward(ctx, hiddenState, opts)
+	hiddenState = l.PostMLPNorm.Forward(ctx, hiddenState, opts.eps)
 	return hiddenState.Add(ctx, residual)
 }
 
@@ -144,7 +155,6 @@ func (m *Model) Forward(ctx ml.Context, opts model.Options) (ml.Tensor, error) {
 	if err != nil {
 		return nil, err
 	}
-	inputs = inputs.Scale(ctx, math.Sqrt(float64(m.Options.hiddenSize)))
 
 	positions, err := ctx.FromIntSlice(opts.Positions, len(opts.Positions))
 	if err != nil {
@@ -152,7 +162,7 @@ func (m *Model) Forward(ctx ml.Context, opts model.Options) (ml.Tensor, error) {
 	}
 
 	hiddenState := m.TokenEmbedding.Forward(ctx, inputs)
-	ctx.Forward(hiddenState)
+	hiddenState = hiddenState.Scale(ctx, math.Sqrt(float64(m.Options.hiddenSize)))
 
 	for i, layer := range m.Layers {
 		cacheType := i % 2
@@ -165,6 +175,11 @@ func (m *Model) Forward(ctx ml.Context, opts model.Options) (ml.Tensor, error) {
 	hiddenState = m.OutputNorm.Forward(ctx, hiddenState, m.eps)
 	hiddenState = m.Output.Forward(ctx, hiddenState)
 
+	// final logit softcap
+	hiddenState = hiddenState.Scale(ctx, 1.0/float64(m.Options.finalLogitSoftcap))
+	hiddenState = hiddenState.Tanh(ctx)
+	hiddenState = hiddenState.Scale(ctx, float64(m.Options.finalLogitSoftcap))
+
 	outputs, err := ctx.FromIntSlice(opts.Outputs, len(opts.Outputs))
 	if err != nil {
 		return nil, err

+ 4 - 3
model/models/llama/model.go

@@ -67,14 +67,15 @@ type SelfAttention struct {
 func (sa *SelfAttention) Forward(ctx ml.Context, hiddenState, positionIDs ml.Tensor, cache kvcache.Cache, opts *Options) ml.Tensor {
 	batchSize := hiddenState.Dim(1)
 	headDim := opts.hiddenSize / opts.numHeads
+	ropeType := uint32(0)
 
 	q := sa.Query.Forward(ctx, hiddenState)
 	q = q.Reshape(ctx, headDim, opts.numHeads, batchSize)
-	q = q.RoPE(ctx, positionIDs, opts.RopeFactors, opts.ropeDim, opts.ropeBase, opts.ropeScale)
+	q = q.RoPE(ctx, positionIDs, opts.RopeFactors, opts.ropeDim, ropeType, opts.ropeBase, opts.ropeScale)
 
 	k := sa.Key.Forward(ctx, hiddenState)
 	k = k.Reshape(ctx, headDim, opts.numKVHeads, batchSize)
-	k = k.RoPE(ctx, positionIDs, opts.RopeFactors, opts.ropeDim, opts.ropeBase, opts.ropeScale)
+	k = k.RoPE(ctx, positionIDs, opts.RopeFactors, opts.ropeDim, ropeType, opts.ropeBase, opts.ropeScale)
 
 	v := sa.Value.Forward(ctx, hiddenState)
 	v = v.Reshape(ctx, headDim, opts.numKVHeads, batchSize)
@@ -99,7 +100,7 @@ func (sa *SelfAttention) Forward(ctx ml.Context, hiddenState, positionIDs ml.Ten
 }
 
 func (m *Model) Shift(ctx ml.Context, layer int, key, shift ml.Tensor) (ml.Tensor, error) {
-	return key.RoPE(ctx, shift, m.Options.RopeFactors, m.Options.ropeDim, m.Options.ropeBase, m.Options.ropeScale), nil
+	return key.RoPE(ctx, shift, m.Options.RopeFactors, m.Options.ropeDim, uint32(0), m.Options.ropeBase, m.Options.ropeScale), nil
 }
 
 type MLP struct {

+ 4 - 3
model/models/mllama/model_text.go

@@ -19,14 +19,15 @@ type TextSelfAttention struct {
 func (sa *TextSelfAttention) Forward(ctx ml.Context, hiddenState, positions, _ ml.Tensor, cache *kvcache.WrapperCache, opts *TextModelOptions) ml.Tensor {
 	batchSize := hiddenState.Dim(1)
 	headDim := opts.hiddenSize / opts.numHeads
+	ropeType := uint32(0)
 
 	query := sa.Query.Forward(ctx, hiddenState)
 	query = query.Reshape(ctx, headDim, opts.numHeads, batchSize)
-	query = query.RoPE(ctx, positions, opts.RopeFactors, opts.ropeDim, opts.ropeBase, opts.ropeScale)
+	query = query.RoPE(ctx, positions, opts.RopeFactors, opts.ropeDim, ropeType, opts.ropeBase, opts.ropeScale)
 
 	key := sa.Key.Forward(ctx, hiddenState)
 	key = key.Reshape(ctx, headDim, opts.numKVHeads, batchSize)
-	key = key.RoPE(ctx, positions, opts.RopeFactors, opts.ropeDim, opts.ropeBase, opts.ropeScale)
+	key = key.RoPE(ctx, positions, opts.RopeFactors, opts.ropeDim, ropeType, opts.ropeBase, opts.ropeScale)
 
 	value := sa.Value.Forward(ctx, hiddenState)
 	value = value.Reshape(ctx, headDim, opts.numKVHeads, batchSize)
@@ -52,7 +53,7 @@ func (sa *TextSelfAttention) Forward(ctx ml.Context, hiddenState, positions, _ m
 
 func (m *TextModel) Shift(ctx ml.Context, layer int, key, shift ml.Tensor) (ml.Tensor, error) {
 	// This will only get called for layers in the cache, which are just the self attention layers
-	return key.RoPE(ctx, shift, m.RopeFactors, m.ropeDim, m.ropeBase, m.ropeScale), nil
+	return key.RoPE(ctx, shift, m.RopeFactors, m.ropeDim, uint32(0), m.ropeBase, m.ropeScale), nil
 }
 
 type TextMLP struct {

+ 6 - 5
model/process_text_spm.go

@@ -3,6 +3,7 @@ package model
 import (
 	"fmt"
 	"iter"
+	"log/slog"
 	"strings"
 	//"unicode/utf8"
 
@@ -220,13 +221,13 @@ type candidate struct {
 func (spm SentencePieceModel) Decode(ids []int32) (string, error) {
 	var sb strings.Builder
 	for _, id := range ids {
-		for _, r := range spm.vocab.Decode(id) {
-			// todo - do we need to introspect the chars here?
-			if err := sb.WriteByte(byte(r)); err != nil {
-				return "", err
-			}
+		data := spm.vocab.Decode(id)
+		data = strings.ReplaceAll(data, spmWhitespaceSep, " ")
+		if _, err := sb.WriteString(data); err != nil {
+			return "", err
 		}
 	}
 
+	slog.Debug("decoded", "ids", ids, "text", sb.String())
 	return sb.String(), nil
 }