Ver código fonte

models/gemma3: remove final logit softcap (#9692)

Softcap isn't in the whitepaper/implementation for the language model so we should remove it. There is no discernible difference in output with it removed.
Bruce MacDonald 1 mês atrás
pai
commit
a70820daa0
1 arquivos alterados com 10 adições e 17 exclusões
  1. 10 17
      model/models/gemma3/model_text.go

+ 10 - 17
model/models/gemma3/model_text.go

@@ -15,7 +15,6 @@ type TextOptions struct {
 	attnKeyLen, attnValLen           int
 	eps, ropeScale                   float32
 	ropeLocalBase, ropeGlobalBase    float32
-	finalLogitSoftcap                float32
 	largeModelScaling                bool
 }
 
@@ -57,16 +56,15 @@ func newTextModel(c ml.Config) *TextModel {
 		),
 		Layers: make([]TextLayer, numBlocks),
 		TextOptions: &TextOptions{
-			hiddenSize:        int(c.Uint("embedding_length")),
-			numHeads:          int(c.Uint("attention.head_count")),
-			numKVHeads:        int(c.Uint("attention.head_count_kv")),
-			attnKeyLen:        int(c.Uint("attention.key_length", 256)),
-			attnValLen:        int(c.Uint("attention.value_length", 256)),
-			eps:               c.Float("attention.layer_norm_rms_epsilon", 1e-06),
-			ropeLocalBase:     c.Float("rope.local.freq_base", 10000.0),
-			ropeGlobalBase:    c.Float("rope.global.freq_base", 1000000.0),
-			ropeScale:         c.Float("rope.freq_scale", 1.0),
-			finalLogitSoftcap: c.Float("final_logit_softcapping", 30.0),
+			hiddenSize:     int(c.Uint("embedding_length")),
+			numHeads:       int(c.Uint("attention.head_count")),
+			numKVHeads:     int(c.Uint("attention.head_count_kv")),
+			attnKeyLen:     int(c.Uint("attention.key_length", 256)),
+			attnValLen:     int(c.Uint("attention.value_length", 256)),
+			eps:            c.Float("attention.layer_norm_rms_epsilon", 1e-06),
+			ropeLocalBase:  c.Float("rope.local.freq_base", 10000.0),
+			ropeGlobalBase: c.Float("rope.global.freq_base", 1000000.0),
+			ropeScale:      c.Float("rope.freq_scale", 1.0),
 		},
 	}
 
@@ -245,10 +243,5 @@ func (m *TextModel) Forward(ctx ml.Context, inputs, positions, outputs ml.Tensor
 	}
 
 	hiddenState = m.OutputNorm.Forward(ctx, hiddenState, m.eps)
-	hiddenState = m.Output.Forward(ctx, hiddenState)
-
-	// final logit softcap
-	hiddenState = hiddenState.Scale(ctx, 1.0/float64(m.TextOptions.finalLogitSoftcap))
-	hiddenState = hiddenState.Tanh(ctx)
-	return hiddenState.Scale(ctx, float64(m.TextOptions.finalLogitSoftcap))
+	return m.Output.Forward(ctx, hiddenState)
 }