瀏覽代碼

convert: mistral-3.1-2503 text component

Bruce MacDonald 2 月之前
父節點
當前提交
edac05387f
共有 3 個文件被更改,包括 64 次插入25 次删除
  1. 1 1
      convert/convert.go
  2. 61 22
      convert/convert_mistral.go
  3. 2 2
      model/models/mistral/model.go

+ 1 - 1
convert/convert.go

@@ -184,7 +184,7 @@ func ConvertModel(fsys fs.FS, ws io.WriteSeeker) error {
 	switch p.Architectures[0] {
 	case "LlamaForCausalLM":
 		conv = &llamaModel{}
-	case "MistralForCausalLM":
+	case "Mistral3ForConditionalGeneration":
 		conv = &mistralModel{}
 	case "MixtralForCausalLM":
 		conv = &mixtralModel{}

+ 61 - 22
convert/convert_mistral.go

@@ -13,15 +13,17 @@ import (
 
 type mistralModel struct {
 	ModelParameters
-	NumHiddenLayers       uint32  `json:"num_hidden_layers"`
-	MaxPositionEmbeddings uint32  `json:"max_position_embeddings"`
-	HiddenSize            uint32  `json:"hidden_size"`
-	IntermediateSize      uint32  `json:"intermediate_size"`
-	NumAttentionHeads     uint32  `json:"num_attention_heads"`
-	NumKeyValueHeads      uint32  `json:"num_key_value_heads"`
-	RopeTheta             float32 `json:"rope_theta"`
-	RMSNormEPS            float32 `json:"rms_norm_eps"`
-	HeadDim               uint32  `json:"head_dim"`
+	TextModel struct {
+		NumHiddenLayers       uint32  `json:"num_hidden_layers"`
+		MaxPositionEmbeddings uint32  `json:"max_position_embeddings"`
+		HiddenSize            uint32  `json:"hidden_size"`
+		IntermediateSize      uint32  `json:"intermediate_size"`
+		NumAttentionHeads     uint32  `json:"num_attention_heads"`
+		NumKeyValueHeads      uint32  `json:"num_key_value_heads"`
+		RopeTheta             float32 `json:"rope_theta"`
+		RMSNormEPS            float32 `json:"rms_norm_eps"`
+		HeadDim               uint32  `json:"head_dim"`
+	} `json:"text_config"`
 }
 
 func (p *mistralModel) KV(t *Tokenizer) ggml.KV {
@@ -29,17 +31,17 @@ func (p *mistralModel) KV(t *Tokenizer) ggml.KV {
 	kv["general.architecture"] = "mistral"
 	kv["mistral.vocab_size"] = p.VocabSize
 
-	kv["mistral.block_count"] = p.NumHiddenLayers
-	kv["mistral.context_length"] = p.MaxPositionEmbeddings
-	kv["mistral.embedding_length"] = cmp.Or(p.HiddenSize)
-	kv["mistral.feed_forward_length"] = cmp.Or(p.IntermediateSize)
-	kv["mistral.attention.head_count"] = cmp.Or(p.NumAttentionHeads)
-	kv["mistral.rope.dimension_count"] = p.HiddenSize / p.NumHiddenLayers
-	kv["mistral.rope.freq_base"] = p.RopeTheta
-	kv["mistral.attention.head_count_kv"] = p.NumKeyValueHeads
-	kv["mistral.attention.layer_norm_rms_epsilon"] = p.RMSNormEPS
-	kv["mistral.attention.key_length"] = p.HeadDim
-	kv["mistral.attention.value_length"] = p.HeadDim
+	kv["mistral.block_count"] = p.TextModel.NumHiddenLayers
+	kv["mistral.context_length"] = p.TextModel.MaxPositionEmbeddings
+	kv["mistral.embedding_length"] = p.TextModel.HiddenSize
+	kv["mistral.feed_forward_length"] = p.TextModel.IntermediateSize
+	kv["mistral.attention.head_count"] = p.TextModel.NumAttentionHeads
+	kv["mistral.rope.dimension_count"] = p.TextModel.HiddenSize / p.TextModel.NumHiddenLayers
+	kv["mistral.rope.freq_base"] = p.TextModel.RopeTheta
+	kv["mistral.attention.head_count_kv"] = p.TextModel.NumKeyValueHeads
+	kv["mistral.attention.layer_norm_rms_epsilon"] = p.TextModel.RMSNormEPS
+	kv["mistral.attention.key_length"] = p.TextModel.HeadDim
+	kv["mistral.attention.value_length"] = p.TextModel.HeadDim
 
 	return kv
 }
@@ -86,6 +88,43 @@ func (p *mistralModel) Replacements() []string {
 		"mlp.down_proj", "ffn_down",
 		"mlp.gate_proj", "ffn_gate",
 		"mlp.up_proj", "ffn_up",
+
+		// Language model replacements
+		"language_model.model.embed_tokens", "token_embd",
+		"language_model.model.layers", "blk",
+		"language_model.model.layers.*.input_layernorm", "attn_norm",
+		"language_model.model.layers.*.self_attn.q_proj", "attn_q",
+		"language_model.model.layers.*.self_attn.k_proj", "attn_k",
+		"language_model.model.layers.*.self_attn.v_proj", "attn_v",
+		"language_model.model.layers.*.self_attn.o_proj", "attn_output",
+		"language_model.model.layers.*.mlp.gate_proj", "ffn_gate",
+		"language_model.model.layers.*.mlp.down_proj", "ffn_down",
+		"language_model.model.layers.*.mlp.up_proj", "ffn_up",
+		"language_model.model.layers.*.post_attention_layernorm", "ffn_norm",
+		"language_model.lm_head", "output",
+		"language_model.model.norm", "output_norm",
+
+		// Vision model replacements - map to shorter prefixes
+		"vision_tower", "v",
+		"multi_modal_projector", "mm",
+
+		// Vision transformer blocks - these should be updated accordingly
+		"vision_tower.transformer.layers", "v.blk",
+		"vision_tower.transformer.layers.*.attention_norm", "v.attn_norm",
+		"vision_tower.transformer.layers.*.attention.q_proj", "v.attn_q",
+		"vision_tower.transformer.layers.*.attention.k_proj", "v.attn_k",
+		"vision_tower.transformer.layers.*.attention.v_proj", "v.attn_v",
+		"vision_tower.transformer.layers.*.attention.o_proj", "v.attn_output",
+		"vision_tower.transformer.layers.*.feed_forward.gate_proj", "v.ffn_gate",
+		"vision_tower.transformer.layers.*.feed_forward.down_proj", "v.ffn_down",
+		"vision_tower.transformer.layers.*.feed_forward.up_proj", "v.ffn_up",
+		"vision_tower.transformer.layers.*.ffn_norm", "v.ffn_norm",
+		"vision_tower.ln_pre", "v.encoder_norm",
+		"vision_tower.patch_conv", "v.patch_conv",
+
+		// Multimodal projector components
+		"multi_modal_projector.patch_merger", "mm.patch_merger",
+		"multi_modal_projector.norm", "mm.norm",
 	}
 }
 
@@ -97,9 +136,9 @@ func (p *mistralModel) repack(name string, data []float32, shape []uint64) ([]fl
 
 	var heads uint32
 	if strings.HasSuffix(name, "attn_q.weight") {
-		heads = p.NumAttentionHeads
+		heads = p.TextModel.NumAttentionHeads
 	} else if strings.HasSuffix(name, "attn_k.weight") {
-		heads = cmp.Or(p.NumKeyValueHeads, p.NumAttentionHeads)
+		heads = cmp.Or(p.TextModel.NumKeyValueHeads, p.TextModel.NumAttentionHeads)
 	} else {
 		return nil, fmt.Errorf("unknown tensor for repack: %s", name)
 	}

+ 2 - 2
model/models/mistral/model.go

@@ -42,9 +42,9 @@ func New(c ml.Config) (model.Model, error) {
 				Values: c.Strings("tokenizer.ggml.tokens"),
 				Types:  c.Uints("tokenizer.ggml.token_type"),
 				Merges: c.Strings("tokenizer.ggml.merges"),
-				BOS:    int32(c.Uint("tokenizer.ggml.bos_token_id")),
+				BOS:    int32(c.Uint("tokenizer.ggml.bos_token_id", 1)),
 				AddBOS: c.Bool("tokenizer.ggml.add_bos_token", true),
-				EOS:    int32(c.Uint("tokenizer.ggml.eos_token_id")),
+				EOS:    int32(c.Uint("tokenizer.ggml.eos_token_id", 2)),
 				AddEOS: c.Bool("tokenizer.ggml.add_eos_token", false),
 			},
 		),