소스 검색

split text model to its own file

Bruce MacDonald 1 개월 전
부모
커밋
ecc0ef468f
2개의 변경된 파일213개의 추가작업 그리고 144개의 파일을 삭제
  1. 42 144
      model/models/mistral3/model.go
  2. 171 0
      model/models/mistral3/model_text.go

+ 42 - 144
model/models/mistral3/model.go

@@ -1,157 +1,68 @@
 package mistral3
 
 import (
-	"fmt"
-	"math"
-	"strings"
-
 	"github.com/ollama/ollama/kvcache"
 	"github.com/ollama/ollama/ml"
-	"github.com/ollama/ollama/ml/nn"
 	"github.com/ollama/ollama/model"
 	"github.com/ollama/ollama/model/input"
 )
 
-type TextOptions struct {
-	hiddenSize, numHeads, numKVHeads, headDim int
-	eps, ropeBase, ropeScale                  float32
-	ropeDim                                   uint32
-}
-
 type Model struct {
 	model.Base
-	model.BytePairEncoding
+	*TextModel
 
-	TokenEmbedding *nn.Embedding `gguf:"token_embd"`
-	Layers         []Layer       `gguf:"blk"`
-	OutputNorm     *nn.RMSNorm   `gguf:"output_norm"`
-	Output         *nn.Linear    `gguf:"output,alt:token_embd"`
-
-	*TextOptions
-}
+	// TODO: Add VisionModel field
+	// *VisionModel `gguf:"v,vision"`
 
-func New(c ml.Config) (model.Model, error) {
-	if !strings.EqualFold(c.String("tokenizer.ggml.model"), "gpt2") {
-		return nil, fmt.Errorf("tokenizer %s not yet supported", c.String("tokenizer.ggml.model"))
-	}
-
-	m := Model{
-		BytePairEncoding: model.NewBytePairEncoding(
-			c.String("tokenizer.ggml.pretokenizer", `[^\r\n\p{L}\p{N}]?[\p{Lu}\p{Lt}\p{Lm}\p{Lo}\p{M}]*[\p{Ll}\p{Lm}\p{Lo}\p{M}]+|[^\r\n\p{L}\p{N}]?[\p{Lu}\p{Lt}\p{Lm}\p{Lo}\p{M}]+[\p{Ll}\p{Lm}\p{Lo}\p{M}]*|\p{N}| ?[^\s\p{L}\p{N}]+[\r\n/]*|\s*[\r\n]+|\s+(?!\S)|\s+`),
-			&model.Vocabulary{
-				Values: c.Strings("tokenizer.ggml.tokens"),
-				Types:  c.Uints("tokenizer.ggml.token_type"),
-				Merges: c.Strings("tokenizer.ggml.merges"),
-				BOS:    int32(c.Uint("tokenizer.ggml.bos_token_id", 1)),
-				AddBOS: c.Bool("tokenizer.ggml.add_bos_token", true),
-				EOS:    int32(c.Uint("tokenizer.ggml.eos_token_id", 2)),
-				AddEOS: c.Bool("tokenizer.ggml.add_eos_token", false),
-			},
-		),
-		Layers: make([]Layer, c.Uint("block_count")),
-		TextOptions: &TextOptions{
-			hiddenSize: int(c.Uint("embedding_length")),
-			numHeads:   int(c.Uint("attention.head_count")),
-			numKVHeads: int(c.Uint("attention.head_count_kv")),
-			headDim:    int(c.Uint("attention.key_length")),
-			eps:        c.Float("attention.layer_norm_rms_epsilon"),
-			ropeBase:   c.Float("rope.freq_base"),
-			ropeScale:  c.Float("rope.freq_scale", 1),
-			ropeDim:    c.Uint("rope.dimension_count"),
-		},
-	}
+	// TODO: Add MultiModalProjector field for combining vision and text features
+	// *MultiModalProjector `gguf:"mm"`
 
-	m.Cache = kvcache.NewCausalCache(m.Shift)
-
-	return &m, nil
+	// TODO: Add ImageProcessor field
+	// ImageProcessor
 }
 
-type SelfAttention struct {
-	Query       *nn.Linear `gguf:"attn_q"`
-	Key         *nn.Linear `gguf:"attn_k"`
-	Value       *nn.Linear `gguf:"attn_v"`
-	Output      *nn.Linear `gguf:"attn_output"`
-	RopeFactors ml.Tensor  `gguf:"rope_freqs.weight"`
-}
+// TODO: Implement MultimodalProcessor interface
+// var _ model.MultimodalProcessor = (*Model)(nil)
 
-func (sa *SelfAttention) Forward(ctx ml.Context, hiddenState, positionIDs ml.Tensor, cache kvcache.Cache, opts *TextOptions) ml.Tensor {
-	batchSize := hiddenState.Dim(1)
-	ropeType := uint32(0)
-	// Get head dimension - use explicit value if available, otherwise calculate
-	headDim := opts.headDim
-	if headDim == 0 {
-		headDim = opts.hiddenSize / opts.numHeads
+func New(c ml.Config) (model.Model, error) {
+	textModel, err := NewTextModel(c)
+	if err != nil {
+		return nil, err
 	}
 
-	// Query projection and reshape
-	q := sa.Query.Forward(ctx, hiddenState)
-	q = q.Reshape(ctx, headDim, opts.numHeads, batchSize)
-	q = q.RoPE(ctx, positionIDs, sa.RopeFactors, opts.ropeDim, ropeType, opts.ropeBase, opts.ropeScale)
-
-	// Key projection and reshape
-	k := sa.Key.Forward(ctx, hiddenState)
-	k = k.Reshape(ctx, headDim, opts.numKVHeads, batchSize)
-	k = k.RoPE(ctx, positionIDs, sa.RopeFactors, opts.ropeDim, ropeType, opts.ropeBase, opts.ropeScale)
-
-	// Value projection and reshape
-	v := sa.Value.Forward(ctx, hiddenState)
-	v = v.Reshape(ctx, headDim, opts.numKVHeads, batchSize)
-
-	// Attention computation
-	scaleFactor := 1.0 / math.Sqrt(float64(headDim))
-	kqv := nn.Attention(ctx, q, k, v, scaleFactor, cache)
+	m := &Model{
+		TextModel: textModel,
+		// TODO: Initialize VisionModel if present
+		// VisionModel: newVisionModel(c),
 
-	// Reshape attention output for final projection
-	outputDim := headDim * opts.numHeads
-	kqv = kqv.Reshape(ctx, outputDim, batchSize)
+		// TODO: Initialize ImageProcessor
+		// ImageProcessor: newImageProcessor(c),
 
-	// Apply output projection
-	return sa.Output.Forward(ctx, kqv)
-}
-
-func (m *Model) Shift(ctx ml.Context, layer int, key, shift ml.Tensor) (ml.Tensor, error) {
-	return key.RoPE(ctx, shift, m.Layers[layer].SelfAttention.RopeFactors, uint32(0), m.ropeDim, m.ropeBase, m.ropeScale), nil
-}
-
-type MLP struct {
-	Up   *nn.Linear `gguf:"ffn_up"`
-	Down *nn.Linear `gguf:"ffn_down"`
-	Gate *nn.Linear `gguf:"ffn_gate"`
-}
-
-func (mlp *MLP) Forward(ctx ml.Context, hiddenState ml.Tensor, opts *TextOptions) ml.Tensor {
-	hiddenState = mlp.Gate.Forward(ctx, hiddenState).SILU(ctx).Mul(ctx, mlp.Up.Forward(ctx, hiddenState))
-	return mlp.Down.Forward(ctx, hiddenState)
-}
-
-type Layer struct {
-	AttentionNorm *nn.RMSNorm `gguf:"attn_norm"`
-	SelfAttention *SelfAttention
-	MLPNorm       *nn.RMSNorm `gguf:"ffn_norm"`
-	MLP           *MLP
-}
-
-func (l *Layer) Forward(ctx ml.Context, hiddenState, positionIDs, outputs ml.Tensor, cache kvcache.Cache, opts *TextOptions) ml.Tensor {
-	residual := hiddenState
-
-	hiddenState = l.AttentionNorm.Forward(ctx, hiddenState, opts.eps)
-	hiddenState = l.SelfAttention.Forward(ctx, hiddenState, positionIDs, cache, opts)
-
-	// In the final layer (outputs != nil), optimize by pruning to just the token positions
-	// we need logits for.
-	if outputs != nil {
-		hiddenState = hiddenState.Rows(ctx, outputs)
-		residual = residual.Rows(ctx, outputs)
+		// TODO: Initialize MultiModalProjector
+		// MultiModalProjector: &MultiModalProjector{...},
 	}
 
-	hiddenState = hiddenState.Add(ctx, residual)
-	residual = hiddenState
+	m.Cache = kvcache.NewCausalCache(m.TextModel.Shift)
 
-	hiddenState = l.MLPNorm.Forward(ctx, hiddenState, opts.eps)
-	hiddenState = l.MLP.Forward(ctx, hiddenState, opts)
-	return hiddenState.Add(ctx, residual)
+	return m, nil
 }
 
+// TODO: Implement EncodeMultimodal method for processing images
+// func (m *Model) EncodeMultimodal(ctx ml.Context, multimodalData []byte) (any, error) {
+//     // Check if vision model is available
+//     // Decode image
+//     // Process the image
+//     // Pass through vision model
+//     // Project vision outputs to text embedding space
+//     // Return vision embeddings
+// }
+
+// TODO: Implement PostTokenize method to handle vision tokens
+// func (m *Model) PostTokenize(inputs []input.Input) ([]input.Input, error) {
+//     // Add special tokens around image data
+//     // Insert placeholders for image tokens
+// }
+
 func (m *Model) Forward(ctx ml.Context, opts input.Options) (ml.Tensor, error) {
 	inputs, err := ctx.Input().FromIntSlice(opts.Inputs, len(opts.Inputs))
 	if err != nil {
@@ -168,23 +79,10 @@ func (m *Model) Forward(ctx ml.Context, opts input.Options) (ml.Tensor, error) {
 		return nil, err
 	}
 
-	// Process text inputs
-	hiddenState := m.TokenEmbedding.Forward(ctx, inputs)
-
-	// Process through text transformer layers
-	for i, layer := range m.Layers {
-		m.Cache.SetLayer(i)
-
-		var lastLayerOutputs ml.Tensor
-		if i == len(m.Layers)-1 {
-			lastLayerOutputs = outputs
-		}
-
-		hiddenState = layer.Forward(ctx, hiddenState, positions, lastLayerOutputs, m.Cache, m.TextOptions)
-	}
+	// TODO: Add handling of multimodal inputs
+	// Set image embeddings into hidden state if present in opts.Multimodal
 
-	hiddenState = m.OutputNorm.Forward(ctx, hiddenState, m.eps)
-	return m.Output.Forward(ctx, hiddenState), nil
+	return m.TextModel.Forward(ctx, inputs, positions, outputs, opts, m.Cache), nil
 }
 
 func init() {

+ 171 - 0
model/models/mistral3/model_text.go

@@ -0,0 +1,171 @@
+package mistral3
+
+import (
+	"fmt"
+	"math"
+	"strings"
+
+	"github.com/ollama/ollama/kvcache"
+	"github.com/ollama/ollama/ml"
+	"github.com/ollama/ollama/ml/nn"
+	"github.com/ollama/ollama/model"
+	"github.com/ollama/ollama/model/input"
+)
+
+type TextOptions struct {
+	hiddenSize, numHeads, numKVHeads, headDim int
+	eps, ropeBase, ropeScale                  float32
+	ropeDim                                   uint32
+}
+
+type TextModel struct {
+	model.Base
+	model.BytePairEncoding
+
+	TokenEmbedding *nn.Embedding `gguf:"token_embd"`
+	Layers         []Layer       `gguf:"blk"`
+	OutputNorm     *nn.RMSNorm   `gguf:"output_norm"`
+	Output         *nn.Linear    `gguf:"output,alt:token_embd"`
+
+	*TextOptions
+}
+
+type SelfAttention struct {
+	Query       *nn.Linear `gguf:"attn_q"`
+	Key         *nn.Linear `gguf:"attn_k"`
+	Value       *nn.Linear `gguf:"attn_v"`
+	Output      *nn.Linear `gguf:"attn_output"`
+	RopeFactors ml.Tensor  `gguf:"rope_freqs.weight"`
+}
+
+func (sa *SelfAttention) Forward(ctx ml.Context, hiddenState, positionIDs ml.Tensor, cache kvcache.Cache, opts *TextOptions) ml.Tensor {
+	batchSize := hiddenState.Dim(1)
+	ropeType := uint32(0)
+	// Get head dimension - use explicit value if available, otherwise calculate
+	headDim := opts.headDim
+	if headDim == 0 {
+		headDim = opts.hiddenSize / opts.numHeads
+	}
+
+	// Query projection and reshape
+	q := sa.Query.Forward(ctx, hiddenState)
+	q = q.Reshape(ctx, headDim, opts.numHeads, batchSize)
+	q = q.RoPE(ctx, positionIDs, sa.RopeFactors, opts.ropeDim, ropeType, opts.ropeBase, opts.ropeScale)
+
+	// Key projection and reshape
+	k := sa.Key.Forward(ctx, hiddenState)
+	k = k.Reshape(ctx, headDim, opts.numKVHeads, batchSize)
+	k = k.RoPE(ctx, positionIDs, sa.RopeFactors, opts.ropeDim, ropeType, opts.ropeBase, opts.ropeScale)
+
+	// Value projection and reshape
+	v := sa.Value.Forward(ctx, hiddenState)
+	v = v.Reshape(ctx, headDim, opts.numKVHeads, batchSize)
+
+	// Attention computation
+	scaleFactor := 1.0 / math.Sqrt(float64(headDim))
+	kqv := nn.Attention(ctx, q, k, v, scaleFactor, cache)
+
+	// Reshape attention output for final projection
+	outputDim := headDim * opts.numHeads
+	kqv = kqv.Reshape(ctx, outputDim, batchSize)
+
+	// Apply output projection
+	return sa.Output.Forward(ctx, kqv)
+}
+
+func (m *TextModel) Shift(ctx ml.Context, layer int, key, shift ml.Tensor) (ml.Tensor, error) {
+	return key.RoPE(ctx, shift, m.Layers[layer].SelfAttention.RopeFactors, uint32(0), m.ropeDim, m.ropeBase, m.ropeScale), nil
+}
+
+type MLP struct {
+	Up   *nn.Linear `gguf:"ffn_up"`
+	Down *nn.Linear `gguf:"ffn_down"`
+	Gate *nn.Linear `gguf:"ffn_gate"`
+}
+
+func (mlp *MLP) Forward(ctx ml.Context, hiddenState ml.Tensor, opts *TextOptions) ml.Tensor {
+	hiddenState = mlp.Gate.Forward(ctx, hiddenState).SILU(ctx).Mul(ctx, mlp.Up.Forward(ctx, hiddenState))
+	return mlp.Down.Forward(ctx, hiddenState)
+}
+
+type Layer struct {
+	AttentionNorm *nn.RMSNorm `gguf:"attn_norm"`
+	SelfAttention *SelfAttention
+	MLPNorm       *nn.RMSNorm `gguf:"ffn_norm"`
+	MLP           *MLP
+}
+
+func (l *Layer) Forward(ctx ml.Context, hiddenState, positionIDs, outputs ml.Tensor, cache kvcache.Cache, opts *TextOptions) ml.Tensor {
+	residual := hiddenState
+
+	hiddenState = l.AttentionNorm.Forward(ctx, hiddenState, opts.eps)
+	hiddenState = l.SelfAttention.Forward(ctx, hiddenState, positionIDs, cache, opts)
+
+	// In the final layer (outputs != nil), optimize by pruning to just the token positions
+	// we need logits for.
+	if outputs != nil {
+		hiddenState = hiddenState.Rows(ctx, outputs)
+		residual = residual.Rows(ctx, outputs)
+	}
+
+	hiddenState = hiddenState.Add(ctx, residual)
+	residual = hiddenState
+
+	hiddenState = l.MLPNorm.Forward(ctx, hiddenState, opts.eps)
+	hiddenState = l.MLP.Forward(ctx, hiddenState, opts)
+	return hiddenState.Add(ctx, residual)
+}
+
+func (m *TextModel) Forward(ctx ml.Context, inputs, positions, outputs ml.Tensor, opts input.Options, cache kvcache.Cache) ml.Tensor {
+	// Process text inputs
+	hiddenState := m.TokenEmbedding.Forward(ctx, inputs)
+
+	// Process through text transformer layers
+	for i, layer := range m.Layers {
+		cache.SetLayer(i)
+
+		var lastLayerOutputs ml.Tensor
+		if i == len(m.Layers)-1 {
+			lastLayerOutputs = outputs
+		}
+
+		hiddenState = layer.Forward(ctx, hiddenState, positions, lastLayerOutputs, cache, m.TextOptions)
+	}
+
+	hiddenState = m.OutputNorm.Forward(ctx, hiddenState, m.eps)
+	return m.Output.Forward(ctx, hiddenState)
+}
+
+func NewTextModel(c ml.Config) (*TextModel, error) {
+	if !strings.EqualFold(c.String("tokenizer.ggml.model"), "gpt2") {
+		return nil, fmt.Errorf("tokenizer %s not yet supported", c.String("tokenizer.ggml.model"))
+	}
+
+	textModel := &TextModel{
+		BytePairEncoding: model.NewBytePairEncoding(
+			c.String("tokenizer.ggml.pretokenizer", `[^\r\n\p{L}\p{N}]?[\p{Lu}\p{Lt}\p{Lm}\p{Lo}\p{M}]*[\p{Ll}\p{Lm}\p{Lo}\p{M}]+|[^\r\n\p{L}\p{N}]?[\p{Lu}\p{Lt}\p{Lm}\p{Lo}\p{M}]+[\p{Ll}\p{Lm}\p{Lo}\p{M}]*|\p{N}| ?[^\s\p{L}\p{N}]+[\r\n/]*|\s*[\r\n]+|\s+(?!\S)|\s+`),
+			&model.Vocabulary{
+				Values: c.Strings("tokenizer.ggml.tokens"),
+				Types:  c.Uints("tokenizer.ggml.token_type"),
+				Merges: c.Strings("tokenizer.ggml.merges"),
+				BOS:    int32(c.Uint("tokenizer.ggml.bos_token_id", 1)),
+				AddBOS: c.Bool("tokenizer.ggml.add_bos_token", true),
+				EOS:    int32(c.Uint("tokenizer.ggml.eos_token_id", 2)),
+				AddEOS: c.Bool("tokenizer.ggml.add_eos_token", false),
+			},
+		),
+		Layers: make([]Layer, c.Uint("block_count")),
+		TextOptions: &TextOptions{
+			hiddenSize: int(c.Uint("embedding_length")),
+			numHeads:   int(c.Uint("attention.head_count")),
+			numKVHeads: int(c.Uint("attention.head_count_kv")),
+			headDim:    int(c.Uint("attention.key_length")),
+			eps:        c.Float("attention.layer_norm_rms_epsilon"),
+			ropeBase:   c.Float("rope.freq_base"),
+			ropeScale:  c.Float("rope.freq_scale", 1),
+			ropeDim:    c.Uint("rope.dimension_count"),
+		},
+	}
+
+	return textModel, nil
+}