Browse Source

fix vision encoder

Michael Yang 1 month ago
parent
commit
f888912870
2 changed files with 11 additions and 5 deletions
  1. 1 1
      model/models/gemma3/model_text.go
  2. 10 4
      model/models/gemma3/process_image.go

+ 1 - 1
model/models/gemma3/model_text.go

@@ -180,7 +180,7 @@ func (m *TextModel) Forward(ctx ml.Context, inputs, positions, outputs ml.Tensor
 	if multimodal != nil {
 		visionOutputs := multimodal[0].Multimodal.(ml.Tensor)
 		offset := multimodal[0].Index - 1 - visionOutputs.Dim(1)
-		hiddenState = hiddenState.Set(ctx, visionOutputs, offset*hiddenState.Stride(0))
+		hiddenState = hiddenState.Set(ctx, visionOutputs, offset*hiddenState.Stride(1))
 	}
 
 	for i, layer := range m.Layers {

+ 10 - 4
model/models/gemma3/process_image.go

@@ -20,11 +20,11 @@ func newImageProcessor(c ml.Config) ImageProcessor {
 }
 
 func (p *ImageProcessor) pack(img image.Image, mean, std [3]float32) []float32 {
-	var pixelVals []float32
+	var pixelVals, rVals, gVals, bVals []float32
 
 	bounds := img.Bounds()
-	for x := bounds.Min.X; x < bounds.Max.X; x++ {
-		for y := bounds.Min.Y; y < bounds.Max.Y; y++ {
+	for y := bounds.Min.Y; y < bounds.Max.Y; y++ {
+		for x := bounds.Min.X; x < bounds.Max.X; x++ {
 			c := img.At(x, y)
 			r, g, b, _ := c.RGBA()
 			rVal := float32(r>>8) / 255.0
@@ -35,10 +35,16 @@ func (p *ImageProcessor) pack(img image.Image, mean, std [3]float32) []float32 {
 			gVal = (gVal - mean[1]) / std[1]
 			bVal = (bVal - mean[2]) / std[2]
 
-			pixelVals = append(pixelVals, rVal, gVal, bVal)
+			rVals = append(rVals, rVal)
+			gVals = append(gVals, gVal)
+			bVals = append(bVals, bVal)
 		}
 	}
 
+	pixelVals = append(pixelVals, rVals...)
+	pixelVals = append(pixelVals, gVals...)
+	pixelVals = append(pixelVals, bVals...)
+
 	return pixelVals
 }