jmorganca 1 month ago
parent
commit
4530661799
3 changed files with 13 additions and 17 deletions
  1. 1 1
      convert/convert_mistral.go
  2. 1 4
      model/models/mistral3/model.go
  3. 11 12
      model/models/mistral3/model_vision.go

+ 1 - 1
convert/convert_mistral.go

@@ -138,10 +138,10 @@ func (p *mistral3Model) Replacements() []string {
 		"attention.v_proj", "attn_v",
 		"attention.o_proj", "attn_output",
 		"attention_norm", "attn_norm",
-		"feed_forward", "mlp",
 		"feed_forward.gate_proj", "ffn_gate",
 		"feed_forward.down_proj", "ffn_down",
 		"feed_forward.up_proj", "ffn_up",
+		"patch_merger.merging_layer", "merger",
 		"multi_modal_projector", "mm",
 		"ffn_norm", "ffn_norm",
 		"lm_head", "output",

+ 1 - 4
model/models/mistral3/model.go

@@ -2,7 +2,6 @@ package mistral3
 
 import (
 	"bytes"
-	"fmt"
 	"image"
 	"slices"
 
@@ -70,7 +69,7 @@ func (m *Model) EncodeMultimodal(ctx ml.Context, multimodalData []byte) (any, er
 		return nil, err
 	}
 
-	fmt.Println("pixelValues", "shape", pixelValues.Shape(), "data", ml.Dump(ctx, pixelValues))
+	// fmt.Println("pixelValues", "shape", pixelValues.Shape(), "data", ml.Dump(ctx, pixelValues))
 
 	// Forward pass through vision model
 	visionOutputs := m.VisionModel.Forward(ctx, pixelValues)
@@ -102,8 +101,6 @@ func (m *Model) PostTokenize(inputs []input.Input) ([]input.Input, error) {
 		}
 	}
 
-	fmt.Println("post tokenize", "result", result)
-
 	return result, nil
 }
 

+ 11 - 12
model/models/mistral3/model_vision.go

@@ -1,7 +1,6 @@
 package mistral3
 
 import (
-	"fmt"
 	"math"
 
 	"github.com/ollama/ollama/ml"
@@ -22,23 +21,23 @@ func (pm *PatchMerger) Forward(ctx ml.Context, visionOutputs ml.Tensor) ml.Tenso
 	d := visionOutputs.Dim(0)
 
 	// TODO: handle multiple images, this currently assumes one
-	fmt.Println("patchmerger visionOutputs", "shape", visionOutputs.Shape(), "data", ml.Dump(ctx, visionOutputs))
+	// fmt.Println("patchmerger visionOutputs", "shape", visionOutputs.Shape(), "data", ml.Dump(ctx, visionOutputs))
 
 	// Reshape to [h, w, hidden_size]
 	imageGrid := visionOutputs.Reshape(ctx, h, w, d)
-	fmt.Println("imageGrid", "shape", imageGrid.Shape(), "data", ml.Dump(ctx, imageGrid))
+	// fmt.Println("imageGrid", "shape", imageGrid.Shape(), "data", ml.Dump(ctx, imageGrid))
 
-	// TODO: load from ml.Config
+	// TODO: load from config
 	spatialMergeSize := 2
-	kernel := ctx.Output().Empty(ml.DTypeF32, spatialMergeSize, spatialMergeSize, d, 1)
-	fmt.Println("kernel", "shape", kernel.Shape(), "data", ml.Dump(ctx, kernel))
+	kernel := ctx.Input().Empty(ml.DTypeF32, spatialMergeSize, spatialMergeSize, d, 1)
+	// fmt.Println("kernel", "shape", kernel.Shape(), "data", ml.Dump(ctx, kernel))
 
 	patches := kernel.IM2Col(ctx, imageGrid, spatialMergeSize, spatialMergeSize, 0, 0, 1, 1)
-	fmt.Println("patches", "shape", patches.Shape(), "data", ml.Dump(ctx, patches))
+	// fmt.Println("patches", "shape", patches.Shape(), "data", ml.Dump(ctx, patches))
 
-	fmt.Println("creating reshaped", d*spatialMergeSize*spatialMergeSize, "x", patches.Dim(1)*patches.Dim(2))
+	// fmt.Println("creating reshaped", d*spatialMergeSize*spatialMergeSize, "x", patches.Dim(1)*patches.Dim(2))
 	reshaped := patches.Reshape(ctx, d*spatialMergeSize*spatialMergeSize, patches.Dim(1)*patches.Dim(2))
-	fmt.Println("reshaped", "shape", reshaped.Shape(), "data", ml.Dump(ctx, reshaped))
+	// fmt.Println("reshaped", "shape", reshaped.Shape(), "data", ml.Dump(ctx, reshaped))
 
 	return pm.MergingLayer.Forward(ctx, reshaped)
 }
@@ -56,11 +55,11 @@ type MultiModalProjector struct {
 
 func (p *MultiModalProjector) Forward(ctx ml.Context, visionOutputs ml.Tensor, eps float32) ml.Tensor {
 	visionOutputs = p.Norm.Forward(ctx, visionOutputs, eps)
-	fmt.Println("visionOutputs after norm", "shape", visionOutputs.Shape(), "data", ml.Dump(ctx, visionOutputs))
+	// fmt.Println("visionOutputs after norm", "shape", visionOutputs.Shape(), "data", ml.Dump(ctx, visionOutputs))
 	visionOutputs = p.PatchMerger.Forward(ctx, visionOutputs)
-	fmt.Println("visionOutputs after patch merger", "shape", visionOutputs.Shape(), "data", ml.Dump(ctx, visionOutputs))
+	// fmt.Println("visionOutputs after patch merger", "shape", visionOutputs.Shape(), "data", ml.Dump(ctx, visionOutputs))
 	visionOutputs = p.Linear1.Forward(ctx, visionOutputs).GELU(ctx)
-	fmt.Println("visionOutputs after linear1 and gelu", "shape", visionOutputs.Shape(), "data", ml.Dump(ctx, visionOutputs))
+	// fmt.Println("visionOutputs after linear1 and gelu", "shape", visionOutputs.Shape(), "data", ml.Dump(ctx, visionOutputs))
 	return p.Linear2.Forward(ctx, visionOutputs)
 }