فهرست منبع

mistral3 arch

Bruce MacDonald 1 ماه پیش
والد
کامیت
3b4ad00a4b
4فایلهای تغییر یافته به همراه87 افزوده شده و 44 حذف شده
  1. 1 1
      convert/convert.go
  2. 64 23
      convert/convert_mistral.go
  3. 11 11
      model/models/gemma3/model_text.go
  4. 11 9
      model/models/mistral/model.go

+ 1 - 1
convert/convert.go

@@ -185,7 +185,7 @@ func ConvertModel(fsys fs.FS, ws io.WriteSeeker) error {
 	case "LlamaForCausalLM":
 		conv = &llamaModel{}
 	case "Mistral3ForConditionalGeneration":
-		conv = &mistralModel{}
+		conv = &mistral3Model{}
 	case "MixtralForCausalLM":
 		conv = &mixtralModel{}
 	case "GemmaForCausalLM":

+ 64 - 23
convert/convert_mistral.go

@@ -11,8 +11,11 @@ import (
 	"github.com/ollama/ollama/fs/ggml"
 )
 
-type mistralModel struct {
+type mistral3Model struct {
 	ModelParameters
+	// ImageTokenIndex  uint32 `json:"image_token_index"`
+	// SpatialMergeSize uint32 `json:"spatial_merge_size"`
+	// VisionFeatureLayer int32  `json:"vision_feature_layer"`
 	TextModel struct {
 		NumHiddenLayers       uint32  `json:"num_hidden_layers"`
 		MaxPositionEmbeddings uint32  `json:"max_position_embeddings"`
@@ -23,30 +26,62 @@ type mistralModel struct {
 		RopeTheta             float32 `json:"rope_theta"`
 		RMSNormEPS            float32 `json:"rms_norm_eps"`
 		HeadDim               uint32  `json:"head_dim"`
+		SlidingWindow         *uint32 `json:"sliding_window"`
+		HiddenAct             string  `json:"hidden_act"`
+		VocabSize             uint32  `json:"vocab_size"`
 	} `json:"text_config"`
+	// VisionModel struct {
+	// 	NumAttentionHeads uint32  `json:"num_attention_heads"`
+	// 	NumHiddenLayers   uint32  `json:"num_hidden_layers"`
+	// 	HiddenSize        uint32  `json:"hidden_size"`
+	// 	IntermediateSize  uint32  `json:"intermediate_size"`
+	// 	ImageSize         uint32  `json:"image_size"`
+	// 	NumChannels       uint32  `json:"num_channels"`
+	// 	PatchSize         uint32  `json:"patch_size"`
+	// 	HeadDim           uint32  `json:"head_dim"`
+	// 	HiddenAct         string  `json:"hidden_act"`
+	// 	RopeTheta         float32 `json:"rope_theta"`
+	// } `json:"vision_config"`
+	// MultiModalProjectorBias bool   `json:"multimodal_projector_bias"`
+	// ProjectorHiddenAct      string `json:"projector_hidden_act"`
 }
 
-func (p *mistralModel) KV(t *Tokenizer) ggml.KV {
+func (p *mistral3Model) KV(t *Tokenizer) ggml.KV {
 	kv := p.ModelParameters.KV(t)
-	kv["general.architecture"] = "mistral"
-	kv["mistral.vocab_size"] = p.VocabSize
-
-	kv["mistral.block_count"] = p.TextModel.NumHiddenLayers
-	kv["mistral.context_length"] = p.TextModel.MaxPositionEmbeddings
-	kv["mistral.embedding_length"] = p.TextModel.HiddenSize
-	kv["mistral.feed_forward_length"] = p.TextModel.IntermediateSize
-	kv["mistral.attention.head_count"] = p.TextModel.NumAttentionHeads
-	kv["mistral.rope.dimension_count"] = p.TextModel.HiddenSize / p.TextModel.NumHiddenLayers
-	kv["mistral.rope.freq_base"] = p.TextModel.RopeTheta
-	kv["mistral.attention.head_count_kv"] = p.TextModel.NumKeyValueHeads
-	kv["mistral.attention.layer_norm_rms_epsilon"] = p.TextModel.RMSNormEPS
-	kv["mistral.attention.key_length"] = p.TextModel.HeadDim
-	kv["mistral.attention.value_length"] = p.TextModel.HeadDim
+	kv["general.architecture"] = "mistral3"
+	kv["mistral3.vocab_size"] = p.TextModel.VocabSize
+
+	// Text configuration
+	kv["mistral3.block_count"] = p.TextModel.NumHiddenLayers
+	kv["mistral3.context_length"] = p.TextModel.MaxPositionEmbeddings
+	kv["mistral3.embedding_length"] = p.TextModel.HiddenSize
+	kv["mistral3.feed_forward_length"] = p.TextModel.IntermediateSize
+	kv["mistral3.attention.head_count"] = p.TextModel.NumAttentionHeads
+	kv["mistral3.attention.head_count_kv"] = p.TextModel.NumKeyValueHeads
+	kv["mistral3.attention.layer_norm_rms_epsilon"] = p.TextModel.RMSNormEPS
+	kv["mistral3.attention.key_length"] = p.TextModel.HeadDim
+	kv["mistral3.attention.value_length"] = p.TextModel.HeadDim
+	kv["mistral3.rope.dimension_count"] = p.TextModel.HiddenSize / p.TextModel.NumHiddenLayers
+	kv["mistral3.rope.freq_base"] = p.TextModel.RopeTheta
+
+	// Multimodal configuration
+	// kv["mistral3.image_token_index"] = p.ImageTokenIndex
+	// kv["mistral3.spatial_merge_size"] = p.SpatialMergeSize
+
+	// if p.VisionFeatureLayer != 0 {
+	// 	kv["mistral3.vision_feature_layer"] = p.VisionFeatureLayer
+	// }
+
+	// kv["mistral3.mm.projector_bias"] = p.MultiModalProjectorBias
+
+	// if p.ProjectorHiddenAct != "" {
+	// 	kv["mistral3.mm.projector_hidden_act"] = p.ProjectorHiddenAct
+	// }
 
 	return kv
 }
 
-func (p *mistralModel) Tensors(ts []Tensor) []ggml.Tensor {
+func (p *mistral3Model) Tensors(ts []Tensor) []ggml.Tensor {
 	var out []ggml.Tensor
 
 	for _, t := range ts {
@@ -55,10 +90,8 @@ func (p *mistralModel) Tensors(ts []Tensor) []ggml.Tensor {
 			t.SetRepacker(p.repack)
 		}
 
-		if strings.HasPrefix(t.Name(), "patch_merger.") ||
-			strings.HasPrefix(t.Name(), "pre_mm_projector_output_norm.") ||
-			strings.HasPrefix(t.Name(), "vision_encoder.") ||
-			strings.HasPrefix(t.Name(), "vision_language_adapter.") {
+		// Skip certain vision model tensors that might need special handling
+		if strings.HasPrefix(t.Name(), "patch_merger.") || strings.HasPrefix(t.Name(), "pre_mm_projector_output_norm.") {
 			continue
 		}
 
@@ -73,8 +106,9 @@ func (p *mistralModel) Tensors(ts []Tensor) []ggml.Tensor {
 	return out
 }
 
-func (p *mistralModel) Replacements() []string {
+func (p *mistral3Model) Replacements() []string {
 	return []string{
+		// Text model replacements
 		"model.layers", "blk",
 		"input_layernorm", "attn_norm",
 		"post_attention_layernorm", "ffn_norm",
@@ -121,14 +155,21 @@ func (p *mistralModel) Replacements() []string {
 		"vision_tower.transformer.layers.*.ffn_norm", "v.ffn_norm",
 		"vision_tower.ln_pre", "v.encoder_norm",
 		"vision_tower.patch_conv", "v.patch_conv",
+		"vision_tower.embeddings", "v.embeddings",
+
+		// Alternative vision model paths
+		"vision_model.vision_model.embeddings", "v.embeddings",
+		"vision_model.vision_model", "v",
+		"vision_model.layers", "v.blk",
 
 		// Multimodal projector components
 		"multi_modal_projector.patch_merger", "mm.patch_merger",
 		"multi_modal_projector.norm", "mm.norm",
+		"multi_modal_projector.linear", "mm.projection",
 	}
 }
 
-func (p *mistralModel) repack(name string, data []float32, shape []uint64) ([]float32, error) {
+func (p *mistral3Model) repack(name string, data []float32, shape []uint64) ([]float32, error) {
 	var dims []int
 	for _, dim := range shape {
 		dims = append(dims, int(dim))

+ 11 - 11
model/models/gemma3/model_text.go

@@ -10,7 +10,7 @@ import (
 	"github.com/ollama/ollama/model/input"
 )
 
-type TextOptions struct {
+type TextConfig struct {
 	hiddenSize, numHeads, numKVHeads int
 	attnKeyLen, attnValLen           int
 	eps, ropeScale                   float32
@@ -27,7 +27,7 @@ type TextModel struct {
 	OutputNorm     *nn.RMSNorm   `gguf:"output_norm"`
 	Output         *nn.Linear    `gguf:"output,alt:token_embd"`
 
-	*TextOptions
+	*TextConfig
 }
 
 const (
@@ -55,7 +55,7 @@ func newTextModel(c ml.Config) *TextModel {
 			},
 		),
 		Layers: make([]TextLayer, numBlocks),
-		TextOptions: &TextOptions{
+		TextConfig: &TextConfig{
 			hiddenSize:     int(c.Uint("embedding_length")),
 			numHeads:       int(c.Uint("attention.head_count")),
 			numKVHeads:     int(c.Uint("attention.head_count_kv")),
@@ -84,7 +84,7 @@ type TextSelfAttention struct {
 	Output    *nn.Linear  `gguf:"attn_output"`
 }
 
-func (sa *TextSelfAttention) Forward(ctx ml.Context, layer int, hiddenState, positionIDs ml.Tensor, cache kvcache.Cache, opts *TextOptions) ml.Tensor {
+func (sa *TextSelfAttention) Forward(ctx ml.Context, layer int, hiddenState, positionIDs ml.Tensor, cache kvcache.Cache, opts *TextConfig) ml.Tensor {
 	batchSize := hiddenState.Dim(1)
 	ropeType := uint32(2)
 
@@ -120,12 +120,12 @@ func (sa *TextSelfAttention) Forward(ctx ml.Context, layer int, hiddenState, pos
 }
 
 func (m *TextModel) Shift(ctx ml.Context, layer int, key, shift ml.Tensor) (ml.Tensor, error) {
-	ropeBase := m.TextOptions.ropeLocalBase
+	ropeBase := m.TextConfig.ropeLocalBase
 	if (layer+1)%gemmaGlobalCacheCount == 0 {
-		ropeBase = m.TextOptions.ropeGlobalBase
+		ropeBase = m.TextConfig.ropeGlobalBase
 	}
 
-	return key.RoPE(ctx, shift, nil, uint32(m.TextOptions.attnKeyLen), uint32(2), ropeBase, m.TextOptions.ropeScale), nil
+	return key.RoPE(ctx, shift, nil, uint32(m.TextConfig.attnKeyLen), uint32(2), ropeBase, m.TextConfig.ropeScale), nil
 }
 
 type TextMLP struct {
@@ -134,7 +134,7 @@ type TextMLP struct {
 	Gate *nn.Linear `gguf:"ffn_gate"`
 }
 
-func (mlp *TextMLP) Forward(ctx ml.Context, hiddenState ml.Tensor, opts *TextOptions) ml.Tensor {
+func (mlp *TextMLP) Forward(ctx ml.Context, hiddenState ml.Tensor, opts *TextConfig) ml.Tensor {
 	hiddenState = mlp.Gate.Forward(ctx, hiddenState).GELU(ctx).Mul(ctx, mlp.Up.Forward(ctx, hiddenState))
 	return mlp.Down.Forward(ctx, hiddenState)
 }
@@ -148,7 +148,7 @@ type TextLayer struct {
 	PostMLPNorm       *nn.RMSNorm `gguf:"post_ffw_norm"`
 }
 
-func (l *TextLayer) Forward(ctx ml.Context, layer int, hiddenState, positionIDs, outputs ml.Tensor, cache kvcache.Cache, opts *TextOptions) ml.Tensor {
+func (l *TextLayer) Forward(ctx ml.Context, layer int, hiddenState, positionIDs, outputs ml.Tensor, cache kvcache.Cache, opts *TextConfig) ml.Tensor {
 	residual := hiddenState
 
 	hiddenState = l.AttentionNorm.Forward(ctx, hiddenState, opts.eps)
@@ -173,7 +173,7 @@ func (l *TextLayer) Forward(ctx ml.Context, layer int, hiddenState, positionIDs,
 
 func (m *TextModel) Forward(ctx ml.Context, inputs, positions, outputs ml.Tensor, batch input.Batch, cache kvcache.Cache) ml.Tensor {
 	hiddenState := m.TokenEmbedding.Forward(ctx, inputs)
-	hiddenState = hiddenState.Scale(ctx, math.Sqrt(float64(m.TextOptions.hiddenSize)))
+	hiddenState = hiddenState.Scale(ctx, math.Sqrt(float64(m.TextConfig.hiddenSize)))
 
 	// set image embeddings
 	var except []int
@@ -206,7 +206,7 @@ func (m *TextModel) Forward(ctx ml.Context, inputs, positions, outputs ml.Tensor
 			lastLayerOutputs = outputs
 		}
 
-		hiddenState = layer.Forward(ctx, i, hiddenState, positions, lastLayerOutputs, cache, m.TextOptions)
+		hiddenState = layer.Forward(ctx, i, hiddenState, positions, lastLayerOutputs, cache, m.TextConfig)
 	}
 
 	hiddenState = m.OutputNorm.Forward(ctx, hiddenState, m.eps)

+ 11 - 9
model/models/mistral/model.go

@@ -1,4 +1,4 @@
-package llama
+package mistral3
 
 import (
 	"fmt"
@@ -12,7 +12,7 @@ import (
 	"github.com/ollama/ollama/model/input"
 )
 
-type Options struct {
+type TextOptions struct {
 	hiddenSize, numHeads, numKVHeads, headDim int
 	eps, ropeBase, ropeScale                  float32
 	ropeDim                                   uint32
@@ -27,7 +27,7 @@ type Model struct {
 	OutputNorm     *nn.RMSNorm   `gguf:"output_norm"`
 	Output         *nn.Linear    `gguf:"output,alt:token_embd"`
 
-	*Options
+	*TextOptions
 }
 
 func New(c ml.Config) (model.Model, error) {
@@ -49,7 +49,7 @@ func New(c ml.Config) (model.Model, error) {
 			},
 		),
 		Layers: make([]Layer, c.Uint("block_count")),
-		Options: &Options{
+		TextOptions: &TextOptions{
 			hiddenSize: int(c.Uint("embedding_length")),
 			numHeads:   int(c.Uint("attention.head_count")),
 			numKVHeads: int(c.Uint("attention.head_count_kv")),
@@ -74,7 +74,7 @@ type SelfAttention struct {
 	RopeFactors ml.Tensor  `gguf:"rope_freqs.weight"`
 }
 
-func (sa *SelfAttention) Forward(ctx ml.Context, hiddenState, positionIDs ml.Tensor, cache kvcache.Cache, opts *Options) ml.Tensor {
+func (sa *SelfAttention) Forward(ctx ml.Context, hiddenState, positionIDs ml.Tensor, cache kvcache.Cache, opts *TextOptions) ml.Tensor {
 	batchSize := hiddenState.Dim(1)
 	ropeType := uint32(0)
 	// Get head dimension - use explicit value if available, otherwise calculate
@@ -119,7 +119,7 @@ type MLP struct {
 	Gate *nn.Linear `gguf:"ffn_gate"`
 }
 
-func (mlp *MLP) Forward(ctx ml.Context, hiddenState ml.Tensor, opts *Options) ml.Tensor {
+func (mlp *MLP) Forward(ctx ml.Context, hiddenState ml.Tensor, opts *TextOptions) ml.Tensor {
 	hiddenState = mlp.Gate.Forward(ctx, hiddenState).SILU(ctx).Mul(ctx, mlp.Up.Forward(ctx, hiddenState))
 	return mlp.Down.Forward(ctx, hiddenState)
 }
@@ -131,7 +131,7 @@ type Layer struct {
 	MLP           *MLP
 }
 
-func (l *Layer) Forward(ctx ml.Context, hiddenState, positionIDs, outputs ml.Tensor, cache kvcache.Cache, opts *Options) ml.Tensor {
+func (l *Layer) Forward(ctx ml.Context, hiddenState, positionIDs, outputs ml.Tensor, cache kvcache.Cache, opts *TextOptions) ml.Tensor {
 	residual := hiddenState
 
 	hiddenState = l.AttentionNorm.Forward(ctx, hiddenState, opts.eps)
@@ -168,8 +168,10 @@ func (m *Model) Forward(ctx ml.Context, opts input.Options) (ml.Tensor, error) {
 		return nil, err
 	}
 
+	// Process text inputs
 	hiddenState := m.TokenEmbedding.Forward(ctx, inputs)
 
+	// Process through text transformer layers
 	for i, layer := range m.Layers {
 		m.Cache.SetLayer(i)
 
@@ -178,7 +180,7 @@ func (m *Model) Forward(ctx ml.Context, opts input.Options) (ml.Tensor, error) {
 			lastLayerOutputs = outputs
 		}
 
-		hiddenState = layer.Forward(ctx, hiddenState, positions, lastLayerOutputs, m.Cache, m.Options)
+		hiddenState = layer.Forward(ctx, hiddenState, positions, lastLayerOutputs, m.Cache, m.TextOptions)
 	}
 
 	hiddenState = m.OutputNorm.Forward(ctx, hiddenState, m.eps)
@@ -186,5 +188,5 @@ func (m *Model) Forward(ctx ml.Context, opts input.Options) (ml.Tensor, error) {
 }
 
 func init() {
-	model.Register("mistral", New)
+	model.Register("mistral3", New)
 }