|
@@ -1,6 +1,7 @@
|
|
|
package mistral3
|
|
|
|
|
|
import (
|
|
|
+ "fmt"
|
|
|
"math"
|
|
|
|
|
|
"github.com/ollama/ollama/ml"
|
|
@@ -55,11 +56,9 @@ type MultiModalProjector struct {
|
|
|
|
|
|
func (p *MultiModalProjector) Forward(ctx ml.Context, visionOutputs ml.Tensor, eps float32) ml.Tensor {
|
|
|
visionOutputs = p.Norm.Forward(ctx, visionOutputs, eps)
|
|
|
- // fmt.Println("visionOutputs after norm", "shape", visionOutputs.Shape(), "data", ml.Dump(ctx, visionOutputs))
|
|
|
visionOutputs = p.PatchMerger.Forward(ctx, visionOutputs)
|
|
|
- // fmt.Println("visionOutputs after patch merger", "shape", visionOutputs.Shape(), "data", ml.Dump(ctx, visionOutputs))
|
|
|
- visionOutputs = p.Linear1.Forward(ctx, visionOutputs).GELU(ctx)
|
|
|
- // fmt.Println("visionOutputs after linear1 and gelu", "shape", visionOutputs.Shape(), "data", ml.Dump(ctx, visionOutputs))
|
|
|
+ visionOutputs = p.Linear1.Forward(ctx, visionOutputs)
|
|
|
+ visionOutputs = visionOutputs.GELU(ctx)
|
|
|
return p.Linear2.Forward(ctx, visionOutputs)
|
|
|
}
|
|
|
|
|
@@ -79,40 +78,20 @@ type VisionSelfAttention struct {
|
|
|
}
|
|
|
|
|
|
func (sa *VisionSelfAttention) Forward(ctx ml.Context, hiddenState, positionIDs ml.Tensor, opts *VisionModelOptions) ml.Tensor {
|
|
|
- headDim := opts.headDim
|
|
|
+ q := sa.Query.Forward(ctx, hiddenState)
|
|
|
+ k := sa.Key.Forward(ctx, hiddenState)
|
|
|
+ v := sa.Value.Forward(ctx, hiddenState)
|
|
|
|
|
|
- // fmt.Println("sa.Query", "shape", sa.Query.Weight.Shape(), "data", ml.Dump(ctx, sa.Query.Weight))
|
|
|
+ q = q.Reshape(ctx, opts.headDim, opts.numHeads, q.Dim(1), batchSize)
|
|
|
+ k = k.Reshape(ctx, opts.headDim, opts.numHeads, k.Dim(1), batchSize)
|
|
|
+ v = v.Reshape(ctx, opts.headDim, opts.numHeads, v.Dim(1), batchSize)
|
|
|
|
|
|
- query := sa.Query.Forward(ctx, hiddenState)
|
|
|
- key := sa.Key.Forward(ctx, hiddenState)
|
|
|
- value := sa.Value.Forward(ctx, hiddenState)
|
|
|
+ ropeType := uint32(24) // 2d vision rope
|
|
|
+ q = q.RoPEMulti(ctx, positionIDs, nil, uint32(opts.headDim/2), [4]int{0, opts.headDim / 2, opts.headDim / 2, 0}, ropeType, opts.ropeBase, opts.ropeScale)
|
|
|
+ k = k.RoPEMulti(ctx, positionIDs, nil, uint32(opts.headDim/2), [4]int{0, opts.headDim / 2, opts.headDim / 2, 0}, ropeType, opts.ropeBase, opts.ropeScale)
|
|
|
|
|
|
- // fmt.Println("query", "shape", query.Shape(), "data", ml.Dump(ctx, query))
|
|
|
- // fmt.Println("key", "shape", key.Shape(), "data", ml.Dump(ctx, key))
|
|
|
- // fmt.Println("value", "shape", value.Shape(), "data", ml.Dump(ctx, value))
|
|
|
-
|
|
|
- query = query.Reshape(ctx, headDim, opts.numHeads, query.Dim(1), batchSize)
|
|
|
- key = key.Reshape(ctx, headDim, opts.numHeads, key.Dim(1), batchSize)
|
|
|
- value = value.Reshape(ctx, headDim, opts.numHeads, value.Dim(1), batchSize)
|
|
|
-
|
|
|
- // fmt.Println("query permute", "shape", query.Shape(), "data", ml.Dump(ctx, query))
|
|
|
- // fmt.Println("key permute", "shape", key.Shape(), "data", ml.Dump(ctx, key))
|
|
|
- // fmt.Println("value permute", "shape", value.Shape(), "data", ml.Dump(ctx, value))
|
|
|
- // fmt.Println("positionIDs", "shape", positionIDs.Shape(), "data", ml.Dump(ctx, positionIDs))
|
|
|
-
|
|
|
- // Multimodal rope
|
|
|
- ropeType := uint32(24)
|
|
|
- query = query.RoPEMulti(ctx, positionIDs, nil, uint32(headDim/2), [4]int{0, headDim / 2, headDim / 2, 0}, ropeType, opts.ropeBase, opts.ropeScale)
|
|
|
- key = key.RoPEMulti(ctx, positionIDs, nil, uint32(headDim/2), [4]int{0, headDim / 2, headDim / 2, 0}, ropeType, opts.ropeBase, opts.ropeScale)
|
|
|
-
|
|
|
- // fmt.Println("query rope", "shape", query.Shape(), "data", ml.Dump(ctx, query))
|
|
|
- // fmt.Println("key rope", "shape", key.Shape(), "data", ml.Dump(ctx, key))
|
|
|
-
|
|
|
- attention := nn.Attention(ctx, query, key, value, 1.0/math.Sqrt(float64(headDim)), nil)
|
|
|
- // fmt.Println("attention", "shape", attention.Shape(), "data", ml.Dump(ctx, attention))
|
|
|
+ attention := nn.Attention(ctx, q, k, v, 1.0/math.Sqrt(float64(opts.headDim)), nil)
|
|
|
attention = attention.Reshape(ctx, opts.hiddenSize, attention.Dim(2), batchSize)
|
|
|
- // fmt.Println("attention reshape", "shape", attention.Shape(), "data", ml.Dump(ctx, attention))
|
|
|
-
|
|
|
return sa.Output.Forward(ctx, attention)
|
|
|
}
|
|
|
|
|
@@ -130,22 +109,19 @@ func (mlp *VisionMLP) Forward(ctx ml.Context, hiddenState ml.Tensor, opts *Visio
|
|
|
type VisionEncoderLayer struct {
|
|
|
AttentionNorm *nn.RMSNorm `gguf:"attn_norm"`
|
|
|
SelfAttention *VisionSelfAttention
|
|
|
-
|
|
|
- FFNNorm *nn.RMSNorm `gguf:"ffn_norm"`
|
|
|
- MLP *VisionMLP
|
|
|
+ FFNNorm *nn.RMSNorm `gguf:"ffn_norm"`
|
|
|
+ MLP *VisionMLP
|
|
|
}
|
|
|
|
|
|
func (e *VisionEncoderLayer) Forward(ctx ml.Context, hiddenState, positionIDs ml.Tensor, opts *VisionModelOptions) ml.Tensor {
|
|
|
residual := hiddenState
|
|
|
|
|
|
- // self attention
|
|
|
hiddenState = e.AttentionNorm.Forward(ctx, hiddenState, opts.eps)
|
|
|
- // fmt.Println("after attention norm", "eps", opts.eps, "shape", hiddenState.Shape(), "data", ml.Dump(ctx, hiddenState, ml.DumpOptions{Items: 3, Precision: 6}))
|
|
|
+ fmt.Println("after attention norm", "shape", hiddenState.Shape(), "data", ml.Dump(ctx, hiddenState, ml.DumpOptions{Items: 3, Precision: 6}))
|
|
|
hiddenState = e.SelfAttention.Forward(ctx, hiddenState, positionIDs, opts)
|
|
|
hiddenState = hiddenState.Add(ctx, residual)
|
|
|
residual = hiddenState
|
|
|
|
|
|
- // feed forward
|
|
|
hiddenState = e.FFNNorm.Forward(ctx, hiddenState, opts.eps)
|
|
|
hiddenState = e.MLP.Forward(ctx, hiddenState, opts)
|
|
|
return hiddenState.Add(ctx, residual)
|
|
@@ -177,24 +153,18 @@ func (m *VisionModel) Forward(ctx ml.Context, pixelValues ml.Tensor) ml.Tensor {
|
|
|
numPatchesW := pixelValues.Dim(0) / m.patchSize
|
|
|
numPatches := numPatchesH * numPatchesW
|
|
|
hiddenState := m.PatchEmbedding.Forward(ctx, pixelValues, m.patchSize, m.patchSize, 0, 0, 1, 1)
|
|
|
- // fmt.Println("after patch embedding", "shape", hiddenState.Shape(), "data", ml.Dump(ctx, hiddenState))
|
|
|
hiddenState = hiddenState.Reshape(ctx, numPatches, m.hiddenSize)
|
|
|
- // fmt.Println("after reshape", "shape", hiddenState.Shape(), "data", ml.Dump(ctx, hiddenState))
|
|
|
hiddenState = hiddenState.Permute(ctx, 1, 0, 2, 3).Contiguous(ctx)
|
|
|
- // fmt.Println("after permute", "shape", hiddenState.Shape(), "data", ml.Dump(ctx, hiddenState))
|
|
|
-
|
|
|
- // TODO: this seems to have incorrect output?
|
|
|
hiddenState = m.EncoderNorm.Forward(ctx, hiddenState, m.VisionModelOptions.eps)
|
|
|
- // fmt.Println("after norm", "eps", m.VisionModelOptions.eps, "shape", hiddenState.Shape(), "data", ml.Dump(ctx, hiddenState, ml.DumpOptions{Items: 3, Precision: 6}))
|
|
|
|
|
|
- // Generate 4D position IDs (time, height, width, extra) for MROPE
|
|
|
- var positions []int32
|
|
|
+ totalPositions := numPatchesH * numPatchesW
|
|
|
+ positions := make([]int32, totalPositions*4)
|
|
|
+
|
|
|
for h := 0; h < numPatchesH; h++ {
|
|
|
for w := 0; w < numPatchesW; w++ {
|
|
|
- positions = append(positions, 0) // unused
|
|
|
- positions = append(positions, int32(h)) // height
|
|
|
- positions = append(positions, int32(w)) // width
|
|
|
- positions = append(positions, 0) // unused
|
|
|
+ index := h*numPatchesW + w
|
|
|
+ positions[totalPositions+index] = int32(h)
|
|
|
+ positions[totalPositions*2+index] = int32(w)
|
|
|
}
|
|
|
}
|
|
|
|
|
@@ -203,8 +173,6 @@ func (m *VisionModel) Forward(ctx ml.Context, pixelValues ml.Tensor) ml.Tensor {
|
|
|
panic(err)
|
|
|
}
|
|
|
|
|
|
- // fmt.Println("positionIDs", "shape", positionIDs.Shape(), "data", ml.Dump(ctx, positionIDs))
|
|
|
-
|
|
|
for _, layer := range m.Layers {
|
|
|
hiddenState = layer.Forward(ctx, hiddenState, positionIDs, m.VisionModelOptions)
|
|
|
}
|