Browse Source

use 2d pooling

Michael Yang 1 month ago
parent
commit
63a394068c
4 changed files with 36 additions and 25 deletions
  1. 14 9
      convert/convert_gemma3.go
  2. 1 1
      ml/backend.go
  3. 3 3
      ml/backend/ggml/ggml.go
  4. 18 12
      model/models/gemma3/model.go

+ 14 - 9
convert/convert_gemma3.go

@@ -26,15 +26,16 @@ type gemma3Model struct {
 		NumChannels       uint32  `json:"num_channels"`        // num_channels 3
 		NumChannels       uint32  `json:"num_channels"`        // num_channels 3
 		PatchSize         uint32  `json:"patch_size"`          // patch_size 14
 		PatchSize         uint32  `json:"patch_size"`          // patch_size 14
 	} `json:"vision_config"`
 	} `json:"vision_config"`
-	MaxPositionEmbeddings uint32  `json:"max_position_embeddings"`
-	NumAttentionHeads     uint32  `json:"num_attention_heads"`
-	NumKeyValueHeads      uint32  `json:"num_key_value_heads"`
-	RMSNormEPS            float32 `json:"rms_norm_eps"`
-	HeadDim               uint32  `json:"head_dim"`
-	FinalLogitSoftcap     float32 `json:"final_logit_softcapping"`
-	RopeLocalTheta        float32 `json:"rope_local_base_freq"`
-	RopeGlobalTheta       float32 `json:"rope_global_base_freq"`
-	SlidingWindow         uint32  `json:"sliding_window"`
+	MaxPositionEmbeddings    uint32  `json:"max_position_embeddings"`
+	NumAttentionHeads        uint32  `json:"num_attention_heads"`
+	NumKeyValueHeads         uint32  `json:"num_key_value_heads"`
+	RMSNormEPS               float32 `json:"rms_norm_eps"`
+	HeadDim                  uint32  `json:"head_dim"`
+	FinalLogitSoftcap        float32 `json:"final_logit_softcapping"`
+	RopeLocalTheta           float32 `json:"rope_local_base_freq"`
+	RopeGlobalTheta          float32 `json:"rope_global_base_freq"`
+	SlidingWindow            uint32  `json:"sliding_window"`
+	MultiModalTokensPerImage uint32  `json:"mm_tokens_per_image"`
 }
 }
 
 
 const (
 const (
@@ -102,6 +103,10 @@ func (p *gemma3Model) KV(t *Tokenizer) ggml.KV {
 		kv["gemma3.attention.value_length"] = cmp.Or(p.TextModel.HeadDim, 256)
 		kv["gemma3.attention.value_length"] = cmp.Or(p.TextModel.HeadDim, 256)
 	}
 	}
 
 
+	if p.MultiModalTokensPerImage > 0 {
+		kv["gemma3.mm.tokens_per_image"] = p.MultiModalTokensPerImage
+	}
+
 	return kv
 	return kv
 }
 }
 
 

+ 1 - 1
ml/backend.go

@@ -135,7 +135,7 @@ type Tensor interface {
 	RMSNorm(ctx Context, weight Tensor, eps float32) Tensor
 	RMSNorm(ctx Context, weight Tensor, eps float32) Tensor
 	Scale(ctx Context, s float64) Tensor
 	Scale(ctx Context, s float64) Tensor
 
 
-	AvgPool1D(ctx Context, k, s, p int) Tensor
+	AvgPool2D(ctx Context, k, s int, p float32) Tensor
 	Conv2D(ctx Context, weight Tensor, s0, s1, p0, p1, d0, d1 int) Tensor
 	Conv2D(ctx Context, weight Tensor, s0, s1, p0, p1, d0, d1 int) Tensor
 
 
 	RoPE(ctx Context, positionIDs, ropeFactors Tensor, dim, ropeType uint32, base, scale float32) Tensor
 	RoPE(ctx Context, positionIDs, ropeFactors Tensor, dim, ropeType uint32, base, scale float32) Tensor

+ 3 - 3
ml/backend/ggml/ggml.go

@@ -247,7 +247,7 @@ func New(r *os.File, params ml.BackendParams) (ml.Backend, error) {
 			createTensor(tensor{source: t}, output.bts)
 			createTensor(tensor{source: t}, output.bts)
 		case strings.HasPrefix(t.Name, "v.") || strings.HasPrefix(t.Name, "mm."):
 		case strings.HasPrefix(t.Name, "v.") || strings.HasPrefix(t.Name, "mm."):
 			// TODO: assign vision tensors to the gpu if possible
 			// TODO: assign vision tensors to the gpu if possible
-			createTensor(tensor{source: t}, input.bts)
+			createTensor(tensor{source: t}, output.bts)
 		case contains(t.Name, "rope_freqs", "rope_factors_long", "rope_factors_short"):
 		case contains(t.Name, "rope_freqs", "rope_factors_long", "rope_factors_short"):
 			// these tensors should be repeated per layer
 			// these tensors should be repeated per layer
 			for i, layer := range layers {
 			for i, layer := range layers {
@@ -952,10 +952,10 @@ func (t *Tensor) Conv2D(ctx ml.Context, t2 ml.Tensor, s0, s1, p0, p1, d0, d1 int
 	}
 	}
 }
 }
 
 
-func (t *Tensor) AvgPool1D(ctx ml.Context, k, s, p int) ml.Tensor {
+func (t *Tensor) AvgPool2D(ctx ml.Context, k, s int, p float32) ml.Tensor {
 	return &Tensor{
 	return &Tensor{
 		b: t.b,
 		b: t.b,
-		t: C.ggml_pool_1d(ctx.(*Context).ctx, t.t, C.GGML_OP_POOL_AVG, C.int(k), C.int(s), C.int(p)),
+		t: C.ggml_pool_2d(ctx.(*Context).ctx, t.t, C.GGML_OP_POOL_AVG, C.int(k), C.int(k), C.int(s), C.int(s), C.float(p), C.float(p)),
 	}
 	}
 }
 }
 
 

+ 18 - 12
model/models/gemma3/model.go

@@ -5,6 +5,7 @@ import (
 	"encoding/binary"
 	"encoding/binary"
 	"hash/fnv"
 	"hash/fnv"
 	"image"
 	"image"
+	"math"
 
 
 	"github.com/ollama/ollama/kvcache"
 	"github.com/ollama/ollama/kvcache"
 	"github.com/ollama/ollama/ml"
 	"github.com/ollama/ollama/ml"
@@ -30,9 +31,21 @@ var _ model.MultimodalProcessor = (*Model)(nil)
 type MultiModalProjector struct {
 type MultiModalProjector struct {
 	SoftEmbNorm     *nn.RMSNorm `gguf:"mm_soft_emb_norm"`
 	SoftEmbNorm     *nn.RMSNorm `gguf:"mm_soft_emb_norm"`
 	InputProjection *nn.Linear  `gguf:"mm_input_projection"`
 	InputProjection *nn.Linear  `gguf:"mm_input_projection"`
+
+	tokensPerImage int
 }
 }
 
 
-func (p *MultiModalProjector) Forward(ctx ml.Context, visionOutputs ml.Tensor, eps float32) ml.Tensor {
+func (p *MultiModalProjector) Forward(ctx ml.Context, visionOutputs ml.Tensor, imageSize, patchSize int, eps float32) ml.Tensor {
+	l := visionOutputs.Dim(0)
+
+	visionOutputs = visionOutputs.Permute(ctx, 1, 0, 2, 3).Contiguous(ctx)
+	patchesPerImage := imageSize / patchSize
+	visionOutputs = visionOutputs.Reshape(ctx, patchesPerImage, patchesPerImage, l)
+
+	kernelSize := patchesPerImage / int(math.Sqrt(float64(p.tokensPerImage)))
+	visionOutputs = visionOutputs.AvgPool2D(ctx, kernelSize, kernelSize, 0)
+	visionOutputs = visionOutputs.Reshape(ctx, visionOutputs.Dim(0)*visionOutputs.Dim(1), l)
+	visionOutputs = visionOutputs.Permute(ctx, 1, 0, 2, 3).Contiguous(ctx)
 	visionOutputs = p.SoftEmbNorm.Forward(ctx, visionOutputs, eps)
 	visionOutputs = p.SoftEmbNorm.Forward(ctx, visionOutputs, eps)
 
 
 	// TODO: inputProjection must be transposed since they're incompatible with visionOutputs
 	// TODO: inputProjection must be transposed since they're incompatible with visionOutputs
@@ -59,6 +72,9 @@ func New(c ml.Config) (model.Model, error) {
 		ImageProcessor: newImageProcessor(c),
 		ImageProcessor: newImageProcessor(c),
 		VisionModel:    newVisionModel(c),
 		VisionModel:    newVisionModel(c),
 		TextModel:      newTextModel(c),
 		TextModel:      newTextModel(c),
+		MultiModalProjector: &MultiModalProjector{
+			tokensPerImage: int(c.Uint("mm_tokens_per_image", 256)),
+		},
 	}
 	}
 
 
 	slidingWindowLen := int32(c.Uint("attention.sliding_window"))
 	slidingWindowLen := int32(c.Uint("attention.sliding_window"))
@@ -88,17 +104,7 @@ func (m *Model) EncodeMultimodal(ctx ml.Context, multimodalData []byte) (any, er
 	}
 	}
 
 
 	visionOutputs := m.VisionModel.Forward(ctx, pixelValues)
 	visionOutputs := m.VisionModel.Forward(ctx, pixelValues)
-	visionOutputs = visionOutputs.Permute(ctx, 1, 0, 2, 3).Contiguous(ctx)
-	patchesPerImage := m.ImageProcessor.imageSize / m.ImageProcessor.patchSize
-
-	// TODO (jmorganca): read this from the model config
-	// it should instead be math.Sqrt(tokens per image)
-	tokensPerSide := 8
-	kernelSize := patchesPerImage / tokensPerSide
-	visionOutputs = visionOutputs.AvgPool1D(ctx, kernelSize, kernelSize, 0)
-
-	visionOutputs = visionOutputs.Permute(ctx, 1, 0, 2, 3).Contiguous(ctx)
-	visionOutputs = m.MultiModalProjector.Forward(ctx, visionOutputs, m.VisionModel.eps)
+	visionOutputs = m.MultiModalProjector.Forward(ctx, visionOutputs, m.imageSize, m.patchSize, m.VisionModel.eps)
 	return visionOutputs, nil
 	return visionOutputs, nil
 }
 }