jmorganca 1 tháng trước cách đây
mục cha
commit
cfeca27133
2 tập tin đã thay đổi với 23 bổ sung58 xóa
  1. 1 4
      model/models/mistral3/model.go
  2. 22 54
      model/models/mistral3/model_vision.go

+ 1 - 4
model/models/mistral3/model.go

@@ -59,10 +59,7 @@ func (m *Model) EncodeMultimodal(ctx ml.Context, multimodalData []byte) (any, er
 	// Create tensor from image data
 	pixelValues, err := ctx.Input().FromFloatSlice(f32s,
 		m.ImageProcessor.imageSize,
-
-		// TODO (jmorganca): this should be returned from the
-		// image processor instead of hardcoded
-		1036,
+		1036, // TODO (jmorganca): this should be returned from ProcessImage
 		m.ImageProcessor.numChannels,
 	)
 	if err != nil {

+ 22 - 54
model/models/mistral3/model_vision.go

@@ -1,6 +1,7 @@
 package mistral3
 
 import (
+	"fmt"
 	"math"
 
 	"github.com/ollama/ollama/ml"
@@ -55,11 +56,9 @@ type MultiModalProjector struct {
 
 func (p *MultiModalProjector) Forward(ctx ml.Context, visionOutputs ml.Tensor, eps float32) ml.Tensor {
 	visionOutputs = p.Norm.Forward(ctx, visionOutputs, eps)
-	// fmt.Println("visionOutputs after norm", "shape", visionOutputs.Shape(), "data", ml.Dump(ctx, visionOutputs))
 	visionOutputs = p.PatchMerger.Forward(ctx, visionOutputs)
-	// fmt.Println("visionOutputs after patch merger", "shape", visionOutputs.Shape(), "data", ml.Dump(ctx, visionOutputs))
-	visionOutputs = p.Linear1.Forward(ctx, visionOutputs).GELU(ctx)
-	// fmt.Println("visionOutputs after linear1 and gelu", "shape", visionOutputs.Shape(), "data", ml.Dump(ctx, visionOutputs))
+	visionOutputs = p.Linear1.Forward(ctx, visionOutputs)
+	visionOutputs = visionOutputs.GELU(ctx)
 	return p.Linear2.Forward(ctx, visionOutputs)
 }
 
@@ -79,40 +78,20 @@ type VisionSelfAttention struct {
 }
 
 func (sa *VisionSelfAttention) Forward(ctx ml.Context, hiddenState, positionIDs ml.Tensor, opts *VisionModelOptions) ml.Tensor {
-	headDim := opts.headDim
+	q := sa.Query.Forward(ctx, hiddenState)
+	k := sa.Key.Forward(ctx, hiddenState)
+	v := sa.Value.Forward(ctx, hiddenState)
 
-	// fmt.Println("sa.Query", "shape", sa.Query.Weight.Shape(), "data", ml.Dump(ctx, sa.Query.Weight))
+	q = q.Reshape(ctx, opts.headDim, opts.numHeads, q.Dim(1), batchSize)
+	k = k.Reshape(ctx, opts.headDim, opts.numHeads, k.Dim(1), batchSize)
+	v = v.Reshape(ctx, opts.headDim, opts.numHeads, v.Dim(1), batchSize)
 
-	query := sa.Query.Forward(ctx, hiddenState)
-	key := sa.Key.Forward(ctx, hiddenState)
-	value := sa.Value.Forward(ctx, hiddenState)
+	ropeType := uint32(24) // 2d vision rope
+	q = q.RoPEMulti(ctx, positionIDs, nil, uint32(opts.headDim/2), [4]int{0, opts.headDim / 2, opts.headDim / 2, 0}, ropeType, opts.ropeBase, opts.ropeScale)
+	k = k.RoPEMulti(ctx, positionIDs, nil, uint32(opts.headDim/2), [4]int{0, opts.headDim / 2, opts.headDim / 2, 0}, ropeType, opts.ropeBase, opts.ropeScale)
 
-	// fmt.Println("query", "shape", query.Shape(), "data", ml.Dump(ctx, query))
-	// fmt.Println("key", "shape", key.Shape(), "data", ml.Dump(ctx, key))
-	// fmt.Println("value", "shape", value.Shape(), "data", ml.Dump(ctx, value))
-
-	query = query.Reshape(ctx, headDim, opts.numHeads, query.Dim(1), batchSize)
-	key = key.Reshape(ctx, headDim, opts.numHeads, key.Dim(1), batchSize)
-	value = value.Reshape(ctx, headDim, opts.numHeads, value.Dim(1), batchSize)
-
-	// fmt.Println("query permute", "shape", query.Shape(), "data", ml.Dump(ctx, query))
-	// fmt.Println("key permute", "shape", key.Shape(), "data", ml.Dump(ctx, key))
-	// fmt.Println("value permute", "shape", value.Shape(), "data", ml.Dump(ctx, value))
-	// fmt.Println("positionIDs", "shape", positionIDs.Shape(), "data", ml.Dump(ctx, positionIDs))
-
-	// Multimodal rope
-	ropeType := uint32(24)
-	query = query.RoPEMulti(ctx, positionIDs, nil, uint32(headDim/2), [4]int{0, headDim / 2, headDim / 2, 0}, ropeType, opts.ropeBase, opts.ropeScale)
-	key = key.RoPEMulti(ctx, positionIDs, nil, uint32(headDim/2), [4]int{0, headDim / 2, headDim / 2, 0}, ropeType, opts.ropeBase, opts.ropeScale)
-
-	// fmt.Println("query rope", "shape", query.Shape(), "data", ml.Dump(ctx, query))
-	// fmt.Println("key rope", "shape", key.Shape(), "data", ml.Dump(ctx, key))
-
-	attention := nn.Attention(ctx, query, key, value, 1.0/math.Sqrt(float64(headDim)), nil)
-	// fmt.Println("attention", "shape", attention.Shape(), "data", ml.Dump(ctx, attention))
+	attention := nn.Attention(ctx, q, k, v, 1.0/math.Sqrt(float64(opts.headDim)), nil)
 	attention = attention.Reshape(ctx, opts.hiddenSize, attention.Dim(2), batchSize)
-	// fmt.Println("attention reshape", "shape", attention.Shape(), "data", ml.Dump(ctx, attention))
-
 	return sa.Output.Forward(ctx, attention)
 }
 
@@ -130,22 +109,19 @@ func (mlp *VisionMLP) Forward(ctx ml.Context, hiddenState ml.Tensor, opts *Visio
 type VisionEncoderLayer struct {
 	AttentionNorm *nn.RMSNorm `gguf:"attn_norm"`
 	SelfAttention *VisionSelfAttention
-
-	FFNNorm *nn.RMSNorm `gguf:"ffn_norm"`
-	MLP     *VisionMLP
+	FFNNorm       *nn.RMSNorm `gguf:"ffn_norm"`
+	MLP           *VisionMLP
 }
 
 func (e *VisionEncoderLayer) Forward(ctx ml.Context, hiddenState, positionIDs ml.Tensor, opts *VisionModelOptions) ml.Tensor {
 	residual := hiddenState
 
-	// self attention
 	hiddenState = e.AttentionNorm.Forward(ctx, hiddenState, opts.eps)
-	// fmt.Println("after attention norm", "eps", opts.eps, "shape", hiddenState.Shape(), "data", ml.Dump(ctx, hiddenState, ml.DumpOptions{Items: 3, Precision: 6}))
+	fmt.Println("after attention norm", "shape", hiddenState.Shape(), "data", ml.Dump(ctx, hiddenState, ml.DumpOptions{Items: 3, Precision: 6}))
 	hiddenState = e.SelfAttention.Forward(ctx, hiddenState, positionIDs, opts)
 	hiddenState = hiddenState.Add(ctx, residual)
 	residual = hiddenState
 
-	// feed forward
 	hiddenState = e.FFNNorm.Forward(ctx, hiddenState, opts.eps)
 	hiddenState = e.MLP.Forward(ctx, hiddenState, opts)
 	return hiddenState.Add(ctx, residual)
@@ -177,24 +153,18 @@ func (m *VisionModel) Forward(ctx ml.Context, pixelValues ml.Tensor) ml.Tensor {
 	numPatchesW := pixelValues.Dim(0) / m.patchSize
 	numPatches := numPatchesH * numPatchesW
 	hiddenState := m.PatchEmbedding.Forward(ctx, pixelValues, m.patchSize, m.patchSize, 0, 0, 1, 1)
-	// fmt.Println("after patch embedding", "shape", hiddenState.Shape(), "data", ml.Dump(ctx, hiddenState))
 	hiddenState = hiddenState.Reshape(ctx, numPatches, m.hiddenSize)
-	// fmt.Println("after reshape", "shape", hiddenState.Shape(), "data", ml.Dump(ctx, hiddenState))
 	hiddenState = hiddenState.Permute(ctx, 1, 0, 2, 3).Contiguous(ctx)
-	// fmt.Println("after permute", "shape", hiddenState.Shape(), "data", ml.Dump(ctx, hiddenState))
-
-	// TODO: this seems to have incorrect output?
 	hiddenState = m.EncoderNorm.Forward(ctx, hiddenState, m.VisionModelOptions.eps)
-	// fmt.Println("after norm", "eps", m.VisionModelOptions.eps, "shape", hiddenState.Shape(), "data", ml.Dump(ctx, hiddenState, ml.DumpOptions{Items: 3, Precision: 6}))
 
-	// Generate 4D position IDs (time, height, width, extra) for MROPE
-	var positions []int32
+	totalPositions := numPatchesH * numPatchesW
+	positions := make([]int32, totalPositions*4)
+
 	for h := 0; h < numPatchesH; h++ {
 		for w := 0; w < numPatchesW; w++ {
-			positions = append(positions, 0)        // unused
-			positions = append(positions, int32(h)) // height
-			positions = append(positions, int32(w)) // width
-			positions = append(positions, 0)        // unused
+			index := h*numPatchesW + w
+			positions[totalPositions+index] = int32(h)
+			positions[totalPositions*2+index] = int32(w)
 		}
 	}
 
@@ -203,8 +173,6 @@ func (m *VisionModel) Forward(ctx ml.Context, pixelValues ml.Tensor) ml.Tensor {
 		panic(err)
 	}
 
-	// fmt.Println("positionIDs", "shape", positionIDs.Shape(), "data", ml.Dump(ctx, positionIDs))
-
 	for _, layer := range m.Layers {
 		hiddenState = layer.Forward(ctx, hiddenState, positionIDs, m.VisionModelOptions)
 	}