|
@@ -1,7 +1,6 @@
|
|
|
package mistral3
|
|
|
|
|
|
import (
|
|
|
- "fmt"
|
|
|
"math"
|
|
|
|
|
|
"github.com/ollama/ollama/ml"
|
|
@@ -22,23 +21,23 @@ func (pm *PatchMerger) Forward(ctx ml.Context, visionOutputs ml.Tensor) ml.Tenso
|
|
|
d := visionOutputs.Dim(0)
|
|
|
|
|
|
// TODO: handle multiple images, this currently assumes one
|
|
|
- fmt.Println("patchmerger visionOutputs", "shape", visionOutputs.Shape(), "data", ml.Dump(ctx, visionOutputs))
|
|
|
+ // fmt.Println("patchmerger visionOutputs", "shape", visionOutputs.Shape(), "data", ml.Dump(ctx, visionOutputs))
|
|
|
|
|
|
// Reshape to [h, w, hidden_size]
|
|
|
imageGrid := visionOutputs.Reshape(ctx, h, w, d)
|
|
|
- fmt.Println("imageGrid", "shape", imageGrid.Shape(), "data", ml.Dump(ctx, imageGrid))
|
|
|
+ // fmt.Println("imageGrid", "shape", imageGrid.Shape(), "data", ml.Dump(ctx, imageGrid))
|
|
|
|
|
|
- // TODO: load from ml.Config
|
|
|
+ // TODO: load from config
|
|
|
spatialMergeSize := 2
|
|
|
- kernel := ctx.Output().Empty(ml.DTypeF32, spatialMergeSize, spatialMergeSize, d, 1)
|
|
|
- fmt.Println("kernel", "shape", kernel.Shape(), "data", ml.Dump(ctx, kernel))
|
|
|
+ kernel := ctx.Input().Empty(ml.DTypeF32, spatialMergeSize, spatialMergeSize, d, 1)
|
|
|
+ // fmt.Println("kernel", "shape", kernel.Shape(), "data", ml.Dump(ctx, kernel))
|
|
|
|
|
|
patches := kernel.IM2Col(ctx, imageGrid, spatialMergeSize, spatialMergeSize, 0, 0, 1, 1)
|
|
|
- fmt.Println("patches", "shape", patches.Shape(), "data", ml.Dump(ctx, patches))
|
|
|
+ // fmt.Println("patches", "shape", patches.Shape(), "data", ml.Dump(ctx, patches))
|
|
|
|
|
|
- fmt.Println("creating reshaped", d*spatialMergeSize*spatialMergeSize, "x", patches.Dim(1)*patches.Dim(2))
|
|
|
+ // fmt.Println("creating reshaped", d*spatialMergeSize*spatialMergeSize, "x", patches.Dim(1)*patches.Dim(2))
|
|
|
reshaped := patches.Reshape(ctx, d*spatialMergeSize*spatialMergeSize, patches.Dim(1)*patches.Dim(2))
|
|
|
- fmt.Println("reshaped", "shape", reshaped.Shape(), "data", ml.Dump(ctx, reshaped))
|
|
|
+ // fmt.Println("reshaped", "shape", reshaped.Shape(), "data", ml.Dump(ctx, reshaped))
|
|
|
|
|
|
return pm.MergingLayer.Forward(ctx, reshaped)
|
|
|
}
|
|
@@ -56,11 +55,11 @@ type MultiModalProjector struct {
|
|
|
|
|
|
func (p *MultiModalProjector) Forward(ctx ml.Context, visionOutputs ml.Tensor, eps float32) ml.Tensor {
|
|
|
visionOutputs = p.Norm.Forward(ctx, visionOutputs, eps)
|
|
|
- fmt.Println("visionOutputs after norm", "shape", visionOutputs.Shape(), "data", ml.Dump(ctx, visionOutputs))
|
|
|
+ // fmt.Println("visionOutputs after norm", "shape", visionOutputs.Shape(), "data", ml.Dump(ctx, visionOutputs))
|
|
|
visionOutputs = p.PatchMerger.Forward(ctx, visionOutputs)
|
|
|
- fmt.Println("visionOutputs after patch merger", "shape", visionOutputs.Shape(), "data", ml.Dump(ctx, visionOutputs))
|
|
|
+ // fmt.Println("visionOutputs after patch merger", "shape", visionOutputs.Shape(), "data", ml.Dump(ctx, visionOutputs))
|
|
|
visionOutputs = p.Linear1.Forward(ctx, visionOutputs).GELU(ctx)
|
|
|
- fmt.Println("visionOutputs after linear1 and gelu", "shape", visionOutputs.Shape(), "data", ml.Dump(ctx, visionOutputs))
|
|
|
+ // fmt.Println("visionOutputs after linear1 and gelu", "shape", visionOutputs.Shape(), "data", ml.Dump(ctx, visionOutputs))
|
|
|
return p.Linear2.Forward(ctx, visionOutputs)
|
|
|
}
|
|
|
|