فهرست منبع

fix drift from main

Jesse Gross 1 ماه پیش
والد
کامیت
4346c2409d
5فایلهای تغییر یافته به همراه49 افزوده شده و 22 حذف شده
  1. 4 0
      kvcache/causal_test.go
  2. 25 11
      model/models/gemma2/model.go
  3. 17 8
      model/models/gemma3/model_text.go
  4. 1 1
      model/models/gemma3/model_vision.go
  5. 2 2
      model/process_text_spm_test.go

+ 4 - 0
kvcache/causal_test.go

@@ -441,6 +441,10 @@ func (t *testTensor) Scale(ctx ml.Context, s float64) ml.Tensor {
 	panic("not implemented")
 }
 
+func (t *testTensor) AvgPool1D(ctx ml.Context, k, s, p int) ml.Tensor {
+	panic("not implemented")
+}
+
 func (t *testTensor) Conv2D(ctx ml.Context, weight ml.Tensor, s0, s1, p0, p1, d0, d1 int) ml.Tensor {
 	panic("not implemented")
 }

+ 25 - 11
model/models/gemma2/model.go

@@ -64,6 +64,7 @@ func New(c ml.Config) (model.Model, error) {
 
 	slidingWindowLen := int32(c.Uint("attention.sliding_window"))
 	m.Cache = kvcache.NewWrapperCache(kvcache.NewSWACache(slidingWindowLen, m.Shift), kvcache.NewCausalCache(m.Shift))
+	m.Cache.SetConfig(ml.CacheConfig{})
 
 	return &m, nil
 }
@@ -84,7 +85,7 @@ func (sa *SelfAttention) Forward(ctx ml.Context, hiddenState, positionIDs ml.Ten
 	q = q.RoPE(ctx, positionIDs, nil, uint32(opts.attnKeyLen), ropeType, opts.ropeBase, opts.ropeScale)
 
 	if opts.largeModelScaling {
-		q = q.Scale(ctx, 1.0/math.Sqrt(float64(opts.hiddenSize / opts.numHeads)))
+		q = q.Scale(ctx, 1.0/math.Sqrt(float64(opts.hiddenSize/opts.numHeads)))
 	} else {
 		q = q.Scale(ctx, 1.0/math.Sqrt(float64(opts.attnKeyLen)))
 	}
@@ -99,8 +100,8 @@ func (sa *SelfAttention) Forward(ctx ml.Context, hiddenState, positionIDs ml.Ten
 	cache.Put(ctx, k, v)
 	k, v, mask := cache.Get(ctx)
 
-	q = q.Permute(ctx, 0, 2, 1, 3).Contiguous(ctx)
-	k = k.Permute(ctx, 0, 2, 1, 3).Contiguous(ctx)
+	q = q.Permute(ctx, 0, 2, 1, 3)
+	k = k.Permute(ctx, 0, 2, 1, 3)
 	v = v.Permute(ctx, 1, 2, 0, 3).Contiguous(ctx)
 
 	kq := k.Mulmat(ctx, q)
@@ -144,12 +145,20 @@ type Layer struct {
 	PostMLPNorm       *nn.RMSNorm `gguf:"post_ffw_norm"`
 }
 
-func (l *Layer) Forward(ctx ml.Context, hiddenState, positionIDs ml.Tensor, cache kvcache.Cache, opts *Options) ml.Tensor {
+func (l *Layer) Forward(ctx ml.Context, hiddenState, positionIDs, outputs ml.Tensor, cache kvcache.Cache, opts *Options) ml.Tensor {
 	residual := hiddenState
 
 	hiddenState = l.AttentionNorm.Forward(ctx, hiddenState, opts.eps)
 	hiddenState = l.SelfAttention.Forward(ctx, hiddenState, positionIDs, cache, opts)
 	hiddenState = l.PostAttentionNorm.Forward(ctx, hiddenState, opts.eps)
+
+	// In the final layer (outputs != nil), optimize by pruning to just the token positions
+	// we need logits for.
+	if outputs != nil {
+		hiddenState = hiddenState.Rows(ctx, outputs)
+		residual = residual.Rows(ctx, outputs)
+	}
+
 	hiddenState = hiddenState.Add(ctx, residual)
 	residual = hiddenState
 
@@ -170,6 +179,11 @@ func (m *Model) Forward(ctx ml.Context, opts input.Options) (ml.Tensor, error) {
 		return nil, err
 	}
 
+	outputs, err := ctx.FromIntSlice(opts.Outputs, len(opts.Outputs))
+	if err != nil {
+		return nil, err
+	}
+
 	hiddenState := m.TokenEmbedding.Forward(ctx, inputs)
 	hiddenState = hiddenState.Scale(ctx, math.Sqrt(float64(m.Options.hiddenSize)))
 
@@ -182,7 +196,13 @@ func (m *Model) Forward(ctx ml.Context, opts input.Options) (ml.Tensor, error) {
 		m.Cache.SetLayer(i)
 		wc := m.Cache.(*kvcache.WrapperCache)
 		wc.SetLayerType(cacheType)
-		hiddenState = layer.Forward(ctx, hiddenState, positions, m.Cache, m.Options)
+
+		var lastLayerOutputs ml.Tensor
+		if i == len(m.Layers)-1 {
+			lastLayerOutputs = outputs
+		}
+
+		hiddenState = layer.Forward(ctx, hiddenState, positions, lastLayerOutputs, m.Cache, m.Options)
 	}
 
 	hiddenState = m.OutputNorm.Forward(ctx, hiddenState, m.eps)
@@ -192,12 +212,6 @@ func (m *Model) Forward(ctx ml.Context, opts input.Options) (ml.Tensor, error) {
 	hiddenState = hiddenState.Scale(ctx, 1.0/float64(m.Options.finalLogitSoftcap))
 	hiddenState = hiddenState.Tanh(ctx)
 	hiddenState = hiddenState.Scale(ctx, float64(m.Options.finalLogitSoftcap))
-
-	outputs, err := ctx.Output().FromIntSlice(opts.Outputs, len(opts.Outputs))
-	if err != nil {
-		return nil, err
-	}
-
 	return hiddenState.Rows(ctx, outputs), nil
 }
 

+ 17 - 8
model/models/gemma3/model_text.go

@@ -66,9 +66,6 @@ func newTextModel(c ml.Config) *TextModel {
 		},
 	}
 
-	slidingWindowLen := int32(c.Uint("text.attention.sliding_window"))
-	m.Cache = kvcache.NewWrapperCache(kvcache.NewSWACache(slidingWindowLen, m.Shift), kvcache.NewCausalCache(m.Shift))
-
 	return &m
 }
 
@@ -145,12 +142,20 @@ type TextLayer struct {
 	PostMLPNorm       *nn.RMSNorm `gguf:"post_ffw_norm"`
 }
 
-func (l *TextLayer) Forward(ctx ml.Context, layer int, hiddenState, positionIDs ml.Tensor, cache kvcache.Cache, opts *TextOptions) ml.Tensor {
+func (l *TextLayer) Forward(ctx ml.Context, layer int, hiddenState, positionIDs, outputs ml.Tensor, cache kvcache.Cache, opts *TextOptions) ml.Tensor {
 	residual := hiddenState
 
 	hiddenState = l.AttentionNorm.Forward(ctx, hiddenState, opts.eps)
 	hiddenState = l.SelfAttention.Forward(ctx, layer, hiddenState, positionIDs, cache, opts)
 	hiddenState = l.PostAttentionNorm.Forward(ctx, hiddenState, opts.eps)
+
+	// In the final layer (outputs != nil), optimize by pruning to just the token positions
+	// we need logits for.
+	if outputs != nil {
+		hiddenState = hiddenState.Rows(ctx, outputs)
+		residual = residual.Rows(ctx, outputs)
+	}
+
 	hiddenState = hiddenState.Add(ctx, residual)
 	residual = hiddenState
 
@@ -181,7 +186,13 @@ func (m *TextModel) Forward(ctx ml.Context, inputs, positions, embeddings, outpu
 		cache.SetLayer(i)
 		wc := cache.(*kvcache.WrapperCache)
 		wc.SetLayerType(cacheType)
-		hiddenState = layer.Forward(ctx, i, hiddenState, positions, cache, m.TextOptions)
+
+		var lastLayerOutputs ml.Tensor
+		if i == len(m.Layers)-1 {
+			lastLayerOutputs = outputs
+		}
+
+		hiddenState = layer.Forward(ctx, i, hiddenState, positions, lastLayerOutputs, cache, m.TextOptions)
 	}
 
 	hiddenState = m.OutputNorm.Forward(ctx, hiddenState, m.eps)
@@ -190,7 +201,5 @@ func (m *TextModel) Forward(ctx ml.Context, inputs, positions, embeddings, outpu
 	// final logit softcap
 	hiddenState = hiddenState.Scale(ctx, 1.0/float64(m.TextOptions.finalLogitSoftcap))
 	hiddenState = hiddenState.Tanh(ctx)
-	hiddenState = hiddenState.Scale(ctx, float64(m.TextOptions.finalLogitSoftcap))
-
-	return hiddenState.Rows(ctx, outputs)
+	return hiddenState.Scale(ctx, float64(m.TextOptions.finalLogitSoftcap))
 }

+ 1 - 1
model/models/gemma3/model_vision.go

@@ -53,7 +53,7 @@ func (mlp *VisionMLP) Forward(ctx ml.Context, hiddenState ml.Tensor, opts *Visio
 }
 
 type VisionEncoderLayer struct {
-	LayerNorm1    *nn.LayerNorm        `gguf:"layer_norm1"`
+	LayerNorm1    *nn.LayerNorm `gguf:"layer_norm1"`
 	SelfAttention *VisionSelfAttention
 
 	LayerNorm2 *nn.LayerNorm `gguf:"layer_norm2"`

+ 2 - 2
model/process_text_spm_test.go

@@ -73,7 +73,7 @@ func TestSentencePieceEncode(t *testing.T) {
 		}
 
 		for _, want := range cases {
-			ids, err := tokenizer.Encode(want)
+			ids, err := tokenizer.Encode(want, true)
 			if err != nil {
 				t.Fatal(err)
 			}
@@ -98,7 +98,7 @@ func TestSentencePieceEncode(t *testing.T) {
 		}
 
 		for _, want := range cases {
-			ids, err := tokenizer.Encode(want.token)
+			ids, err := tokenizer.Encode(want.token, true)
 			if err != nil {
 				t.Fatal(err)
 			}