浏览代码

compat with upstream gguf

Michael Yang 1 月之前
父节点
当前提交
6b32a2d549
共有 3 个文件被更改,包括 14 次插入14 次删除
  1. 8 8
      convert/convert_gemma3.go
  2. 1 1
      model/models/gemma3/model.go
  3. 5 5
      model/models/gemma3/model_text.go

+ 8 - 8
convert/convert_gemma3.go

@@ -76,19 +76,19 @@ func (p *gemma3Model) KV(t *Tokenizer) ggml.KV {
 	switch p.Architecture {
 	case "Gemma3ForCausalLM":
 		kv["gemma3.context_length"] = p.MaxPositionEmbeddings
-		kv["gemma3.text.attention.layer_norm_rms_epsilon"] = p.RMSNormEPS
+		kv["gemma3.attention.layer_norm_rms_epsilon"] = p.RMSNormEPS
 		kv["gemma3.attention.key_length"] = p.HeadDim
 		kv["gemma3.attention.value_length"] = p.HeadDim
-		kv["gemma3.text.attention.sliding_window"] = p.SlidingWindow
-		kv["gemma3.text.final_logit_softcapping"] = p.FinalLogitSoftcap
-		kv["gemma3.text.rope.local.freq_base"] = p.RopeLocalTheta
-		kv["gemma3.text.rope.global.freq_base"] = p.RopeGlobalTheta
+		kv["gemma3.attention.sliding_window"] = p.SlidingWindow
+		kv["gemma3.final_logit_softcapping"] = p.FinalLogitSoftcap
+		kv["gemma3.rope.local.freq_base"] = p.RopeLocalTheta
+		kv["gemma3.rope.global.freq_base"] = p.RopeGlobalTheta
 		kv["gemma3.embedding_length"] = p.HiddenSize
-		kv["gemma3.text.feed_forward_length"] = p.IntermediateSize
+		kv["gemma3.feed_forward_length"] = p.IntermediateSize
 	default:
 		kv["gemma3.embedding_length"] = p.TextModel.HiddenSize
-		kv["gemma3.text.feed_forward_length"] = p.TextModel.IntermediateSize
-		kv["gemma3.text.attention.sliding_window"] = p.TextModel.SlidingWindow
+		kv["gemma3.feed_forward_length"] = p.TextModel.IntermediateSize
+		kv["gemma3.attention.sliding_window"] = p.TextModel.SlidingWindow
 		kv["gemma3.vision.block_count"] = p.VisionModel.NumHiddenLayers
 		kv["gemma3.vision.embedding_length"] = p.VisionModel.HiddenSize
 		kv["gemma3.vision.feed_forward_length"] = p.VisionModel.IntermediateSize

+ 1 - 1
model/models/gemma3/model.go

@@ -62,7 +62,7 @@ func New(c ml.Config) (model.Model, error) {
 		TextModel:      newTextModel(c),
 	}
 
-	slidingWindowLen := int32(c.Uint("text.attention.sliding_window"))
+	slidingWindowLen := int32(c.Uint("attention.sliding_window"))
 	m.Cache = kvcache.NewWrapperCache(kvcache.NewSWACache(slidingWindowLen, m.Shift), kvcache.NewCausalCache(m.Shift))
 
 	return &m, nil

+ 5 - 5
model/models/gemma3/model_text.go

@@ -62,11 +62,11 @@ func newTextModel(c ml.Config) *TextModel {
 			numKVHeads:        int(c.Uint("attention.head_count_kv")),
 			attnKeyLen:        int(c.Uint("attention.key_length", 256)),
 			attnValLen:        int(c.Uint("attention.value_length", 256)),
-			eps:               c.Float("text.attention.layer_norm_rms_epsilon", 1e-06),
-			ropeLocalBase:     c.Float("text.rope.local.freq_base", 10000.0),
-			ropeGlobalBase:    c.Float("text.rope.global.freq_base", 1000000.0),
-			ropeScale:         c.Float("text.rope.freq_scale", 1.0),
-			finalLogitSoftcap: c.Float("text.final_logit_softcapping", 30.0),
+			eps:               c.Float("attention.layer_norm_rms_epsilon", 1e-06),
+			ropeLocalBase:     c.Float("rope.local.freq_base", 10000.0),
+			ropeGlobalBase:    c.Float("rope.global.freq_base", 1000000.0),
+			ropeScale:         c.Float("rope.freq_scale", 1.0),
+			finalLogitSoftcap: c.Float("final_logit_softcapping", 30.0),
 		},
 	}