Browse Source

Merge pull request #9661 from ollama/gemma

engine: add gemma support
Michael Yang 1 month ago
parent
commit
aee28501b5

+ 16 - 2
convert/convert.go

@@ -13,8 +13,13 @@ import (
 )
 
 type ModelParameters struct {
-	Architectures []string `json:"architectures"`
-	VocabSize     uint32   `json:"vocab_size"`
+	Architectures []string       `json:"architectures"`
+	VocabSize     uint32         `json:"vocab_size"`
+	TextModel     TextParameters `json:"text_config"`
+}
+
+type TextParameters struct {
+	VocabSize uint32 `json:"vocab_size"`
 }
 
 type AdapterParameters struct {
@@ -185,6 +190,8 @@ func ConvertModel(fsys fs.FS, ws io.WriteSeeker) error {
 		conv = &gemmaModel{}
 	case "Gemma2ForCausalLM":
 		conv = &gemma2Model{}
+	case "Gemma3ForCausalLM", "Gemma3ForConditionalGeneration":
+		conv = &gemma3Model{Architecture: p.Architectures[0]}
 	case "Phi3ForCausalLM":
 		conv = &phi3Model{}
 	case "Qwen2ForCausalLM":
@@ -213,7 +220,14 @@ func ConvertModel(fsys fs.FS, ws io.WriteSeeker) error {
 	}
 
 	vocabSize := int(p.VocabSize)
+	if vocabSize == 0 {
+		tVocabSize := int(p.TextModel.VocabSize)
+		vocabSize = tVocabSize
+	}
+
 	switch {
+	case vocabSize == 0:
+		slog.Warn("vocabulary size was not explicitly set by the model", "default size", len(t.Vocabulary.Tokens))
 	case vocabSize > len(t.Vocabulary.Tokens):
 		slog.Warn("vocabulary is smaller than expected, padding with dummy tokens", "expect", vocabSize, "actual", len(t.Vocabulary.Tokens))
 		for i := range vocabSize - len(t.Vocabulary.Tokens) {

+ 1 - 1
convert/convert_gemma.go

@@ -45,7 +45,7 @@ func (p *gemmaModel) KV(t *Tokenizer) ggml.KV {
 func (p *gemmaModel) Tensors(ts []Tensor) []ggml.Tensor {
 	var out []ggml.Tensor
 	for _, t := range ts {
-		if strings.HasSuffix(t.Name(), "_norm.weight") {
+		if !strings.HasPrefix(t.Name(), "v.") && strings.HasSuffix(t.Name(), "_norm.weight") {
 			t.SetRepacker(p.addOne)
 		}
 

+ 142 - 0
convert/convert_gemma3.go

@@ -0,0 +1,142 @@
+package convert
+
+import (
+	"cmp"
+
+	"github.com/ollama/ollama/fs/ggml"
+)
+
+type gemma3Model struct {
+	gemmaModel
+	Architecture string
+	TextModel    struct {
+		HeadDim          uint32 `json:"head_dim"`
+		HiddenSize       uint32 `json:"hidden_size"`
+		HiddenLayers     uint32 `json:"num_hidden_layers"`
+		IntermediateSize uint32 `json:"intermediate_size"`
+		SlidingWindow    uint32 `json:"sliding_window"`
+	} `json:"text_config"`
+	VisionModel struct {
+		NumAttentionHeads uint32  `json:"num_attention_heads"` // attention.head_count 16
+		LayerNormEpsilon  float32 `json:"layer_norm_eps"`      // attention.layer_norm_epsilon 1e-05
+		NumHiddenLayers   uint32  `json:"num_hidden_layers"`   // block_count 32
+		HiddenSize        uint32  `json:"hidden_size"`         // embedding_length 1280
+		IntermediateSize  uint32  `json:"intermediate_size"`   // feed_forward_length 5120
+		ImageSize         uint32  `json:"image_size"`          // image_size 560
+		NumChannels       uint32  `json:"num_channels"`        // num_channels 3
+		PatchSize         uint32  `json:"patch_size"`          // patch_size 14
+	} `json:"vision_config"`
+	MaxPositionEmbeddings    uint32  `json:"max_position_embeddings"`
+	NumAttentionHeads        uint32  `json:"num_attention_heads"`
+	NumKeyValueHeads         uint32  `json:"num_key_value_heads"`
+	RMSNormEPS               float32 `json:"rms_norm_eps"`
+	HeadDim                  uint32  `json:"head_dim"`
+	FinalLogitSoftcap        float32 `json:"final_logit_softcapping"`
+	RopeLocalTheta           float32 `json:"rope_local_base_freq"`
+	RopeGlobalTheta          float32 `json:"rope_global_base_freq"`
+	SlidingWindow            uint32  `json:"sliding_window"`
+	MultiModalTokensPerImage uint32  `json:"mm_tokens_per_image"`
+}
+
+const (
+	gemma4BLayerCount  = 34
+	gemma12BLayerCount = 48
+	gemma27BLayerCount = 62
+)
+
+func (p *gemma3Model) KV(t *Tokenizer) ggml.KV {
+	kv := p.ModelParameters.KV(t)
+	kv["general.architecture"] = "gemma3"
+
+	numBlocks := cmp.Or(p.HiddenLayers, p.TextModel.HiddenLayers)
+	kv["gemma3.block_count"] = numBlocks
+
+	var (
+		numHeads   uint32
+		numKVHeads uint32
+	)
+
+	switch numBlocks {
+	case gemma4BLayerCount:
+		numHeads = 8
+		numKVHeads = 4
+	case gemma12BLayerCount:
+		numHeads = 16
+		numKVHeads = 8
+	case gemma27BLayerCount:
+		numHeads = 32
+		numKVHeads = 16
+	default:
+		numHeads = p.NumAttentionHeads
+		numKVHeads = p.NumKeyValueHeads
+	}
+
+	kv["gemma3.attention.head_count"] = numHeads
+	kv["gemma3.attention.head_count_kv"] = numKVHeads
+
+	switch p.Architecture {
+	case "Gemma3ForCausalLM":
+		kv["gemma3.context_length"] = p.MaxPositionEmbeddings
+		kv["gemma3.attention.layer_norm_rms_epsilon"] = p.RMSNormEPS
+		kv["gemma3.attention.key_length"] = p.HeadDim
+		kv["gemma3.attention.value_length"] = p.HeadDim
+		kv["gemma3.attention.sliding_window"] = p.SlidingWindow
+		kv["gemma3.final_logit_softcapping"] = cmp.Or(p.FinalLogitSoftcap, 30)
+		kv["gemma3.rope.local.freq_base"] = cmp.Or(p.RopeLocalTheta, 10000.0)
+		kv["gemma3.rope.global.freq_base"] = cmp.Or(p.RopeGlobalTheta, 1000000.0)
+		kv["gemma3.embedding_length"] = p.HiddenSize
+		kv["gemma3.feed_forward_length"] = p.IntermediateSize
+	default:
+		kv["gemma3.context_length"] = cmp.Or(p.MaxPositionEmbeddings, 8192)
+		kv["gemma3.embedding_length"] = p.TextModel.HiddenSize
+		kv["gemma3.feed_forward_length"] = p.TextModel.IntermediateSize
+		kv["gemma3.attention.sliding_window"] = p.TextModel.SlidingWindow
+		kv["gemma3.vision.block_count"] = p.VisionModel.NumHiddenLayers
+		kv["gemma3.vision.embedding_length"] = p.VisionModel.HiddenSize
+		kv["gemma3.vision.feed_forward_length"] = p.VisionModel.IntermediateSize
+		kv["gemma3.vision.image_size"] = p.VisionModel.ImageSize
+		kv["gemma3.vision.patch_size"] = p.VisionModel.PatchSize
+		kv["gemma3.vision.num_channels"] = cmp.Or(p.VisionModel.NumChannels, 3)
+		kv["gemma3.vision.attention.head_count"] = p.VisionModel.NumAttentionHeads
+		kv["gemma3.vision.attention.layer_norm_epsilon"] = cmp.Or(p.VisionModel.LayerNormEpsilon, 1e-6)
+		kv["gemma3.attention.key_length"] = cmp.Or(p.TextModel.HeadDim, 256)
+		kv["gemma3.attention.value_length"] = cmp.Or(p.TextModel.HeadDim, 256)
+	}
+
+	if p.MultiModalTokensPerImage > 0 {
+		kv["gemma3.mm.tokens_per_image"] = p.MultiModalTokensPerImage
+	}
+
+	return kv
+}
+
+func (p *gemma3Model) Replacements() []string {
+	return []string{
+		"lm_head", "output",
+		"model.embed_tokens", "token_embd",
+		"model.norm", "output_norm",
+		"vision_tower.vision_model.embeddings", "v",
+		"vision_tower.vision_model", "v",
+		"vision_model.vision_model.embeddings", "v",
+		"vision_model.vision_model", "v",
+		"language_model.", "",
+		"model.layers", "blk",
+		"encoder.layers", "blk",
+		"input_layernorm", "attn_norm",
+		"self_attn.q_proj", "attn_q",
+		"self_attn.q_norm", "attn_q_norm",
+		"self_attn.k_proj", "attn_k",
+		"self_attn.k_norm", "attn_k_norm",
+		"self_attn.v_proj", "attn_v",
+		"self_attn.o_proj", "attn_output",
+		"self_attn.out_proj", "attn_output",
+		"mlp.gate_proj", "ffn_gate",
+		"mlp.down_proj", "ffn_down",
+		"mlp.up_proj", "ffn_up",
+		"post_attention_layernorm", "post_attention_norm",
+		"pre_feedforward_layernorm", "ffn_norm",
+		"post_feedforward_layernorm", "post_ffw_norm",
+		"input_projection_weight", "input_projection.weight",
+		"multi_modal_projector", "mm",
+	}
+}

+ 66 - 8
convert/tokenizer_spm.go

@@ -6,7 +6,9 @@ import (
 	"errors"
 	"fmt"
 	"io/fs"
+	"log/slog"
 	"os"
+	"reflect"
 	"slices"
 
 	"google.golang.org/protobuf/proto"
@@ -15,6 +17,8 @@ import (
 )
 
 func parseSentencePiece(fsys fs.FS) (*Vocabulary, error) {
+	slog.Debug("using spm vocabulary")
+
 	ast, err := parseAdditionalSpecialTokens(fsys)
 	if err != nil {
 		return nil, err
@@ -43,10 +47,19 @@ func parseSentencePiece(fsys fs.FS) (*Vocabulary, error) {
 			v.Types = append(v.Types, int32(t))
 		default:
 			tt := int32(sentencepiece.ModelProto_SentencePiece_NORMAL)
-			if slices.Contains(ast, piece.GetPiece()) {
+
+			// temporary fix to handle gemma3 broken configs
+			if slices.Contains([]string{"<end_of_turn>", "<start_of_turn>"}, piece.GetPiece()) {
 				tt = int32(sentencepiece.ModelProto_SentencePiece_CONTROL)
 			}
 
+			for _, t := range ast {
+				if t.Content == piece.GetPiece() {
+					tt = int32(sentencepiece.ModelProto_SentencePiece_CONTROL)
+					break
+				}
+			}
+
 			v.Types = append(v.Types, tt)
 		}
 	}
@@ -78,10 +91,16 @@ func parseSentencePiece(fsys fs.FS) (*Vocabulary, error) {
 		return cmp.Compare(i.id, j.id)
 	})
 
-	n := len(v.Tokens)
-	for i, t := range ts {
-		if t.id != i+n {
-			return nil, fmt.Errorf("invalid token id: %d", t.id)
+	for _, t := range ts {
+		if t.id < len(v.Tokens) {
+			if v.Tokens[t.id] == t.content {
+				slog.Warn("tokenizer", "duplicate token", t.content, "id", t.id)
+				continue
+			}
+			return nil, fmt.Errorf("token mismatch: %s != %s at pos [%d]", t.content, v.Tokens[t.id], t.id)
+		}
+		if t.id != len(v.Tokens) {
+			return nil, fmt.Errorf("invalid token id: [%d] as pos [%d]", t.id, len(v.Tokens))
 		}
 
 		v.Tokens = append(v.Tokens, t.content)
@@ -92,7 +111,15 @@ func parseSentencePiece(fsys fs.FS) (*Vocabulary, error) {
 	return &v, nil
 }
 
-func parseAdditionalSpecialTokens(fsys fs.FS) ([]string, error) {
+type specialToken struct {
+	Content    string `json:"content"`
+	Lstrip     bool   `json:"lstrip"`
+	Normalized bool   `json:"normalized"`
+	Rstrip     bool   `json:"rstrip"`
+	SingleWord bool   `json:"single_word"`
+}
+
+func parseAdditionalSpecialTokens(fsys fs.FS) ([]specialToken, error) {
 	f, err := fsys.Open("special_tokens_map.json")
 	if errors.Is(err, os.ErrNotExist) {
 		return nil, nil
@@ -102,12 +129,43 @@ func parseAdditionalSpecialTokens(fsys fs.FS) ([]string, error) {
 	defer f.Close()
 
 	var m struct {
-		AdditionalSpecialTokens []string `json:"additional_special_tokens"`
+		AdditionalSpecialTokens any `json:"additional_special_tokens"`
 	}
 
 	if err := json.NewDecoder(f).Decode(&m); err != nil {
 		return nil, err
 	}
 
-	return m.AdditionalSpecialTokens, nil
+	var ast []specialToken
+
+	switch st := m.AdditionalSpecialTokens.(type) {
+	case []string:
+		for _, s := range st {
+			ast = append(ast, specialToken{Content: s})
+		}
+	case []any:
+		for _, s := range st {
+			// marshal and unmarshal the object to get the special token
+			tMap := s.(map[string]any)
+			data, err := json.Marshal(tMap)
+			if err != nil {
+				return nil, err
+			}
+
+			var token specialToken
+			err = json.Unmarshal(data, &token)
+			if err != nil {
+				return nil, err
+			}
+
+			ast = append(ast, token)
+		}
+
+	default:
+		slog.Warn("special token", "unknown token", reflect.TypeOf(st))
+	}
+
+	slog.Debug("spm tokenizer", "additional tokens", ast)
+
+	return ast, nil
 }

+ 14 - 1
fs/ggml/ggml.go

@@ -124,6 +124,19 @@ func (kv KV) Uints(key string, defaultValue ...[]uint32) []uint32 {
 	return s
 }
 
+func (kv KV) Floats(key string, defaultValue ...[]float32) []float32 {
+	r := keyValue(kv, key, &array{})
+	s := make([]float32, r.size)
+	for i := range r.size {
+		s[i] = float32(r.values[i].(float32))
+	}
+	return s
+}
+
+func (kv KV) OllamaEngineRequired() bool {
+	return kv.Architecture() == "gemma3"
+}
+
 func keyValue[T string | uint32 | uint64 | float32 | *array | bool](kv KV, key string, defaultValue ...T) T {
 	if !strings.HasPrefix(key, "tokenizer.") && !strings.HasPrefix(key, "general.") {
 		key = kv.Architecture() + "." + key
@@ -476,7 +489,7 @@ func (f GGML) GraphSize(context, batch uint64, kvCacheType string) (kv, partialO
 			// vocab graph
 			4*batch*(embedding+vocab)+embedding*vocab*105/128,
 		)
-	case "gemma", "gemma2":
+	case "gemma", "gemma2", "gemma3":
 		fullOffload = max(
 			4*batch*(embedding+vocab),
 			4*batch*(2+context+context*heads+2*embedding+2*embeddingHeadsK*heads),

+ 14 - 12
kvcache/causal.go

@@ -21,9 +21,10 @@ type shiftFn func(ctx ml.Context, layer int, key, shift ml.Tensor) (ml.Tensor, e
 type Causal struct {
 	DType      ml.DType
 	Capacity   int32
-	causal     bool
 	windowSize int32
 
+	opts CausalOptions
+
 	// config controls mostly backend-specific optimizations
 	config *ml.CacheConfig
 
@@ -79,7 +80,6 @@ type cellRange struct {
 
 func NewCausalCache(shift shiftFn) *Causal {
 	return &Causal{
-		causal:     true,
 		windowSize: math.MaxInt32,
 		shiftFn:    shift,
 		ctxs:       make(map[int]ml.Context),
@@ -90,7 +90,6 @@ func NewCausalCache(shift shiftFn) *Causal {
 
 func NewSWACache(windowSize int32, shift shiftFn) *Causal {
 	return &Causal{
-		causal:     true,
 		windowSize: windowSize,
 		shiftFn:    shift,
 		ctxs:       make(map[int]ml.Context),
@@ -145,6 +144,7 @@ func (c *Causal) StartForward(ctx ml.Context, opts input.Options) error {
 	c.curBatchSize = len(opts.Positions)
 	c.curSequences = opts.Sequences
 	c.curPositions = opts.Positions
+	c.opts.Except = nil
 
 	var err error
 	c.curLoc, err = c.findStartLoc()
@@ -235,9 +235,10 @@ func (c *Causal) buildMask(ctx ml.Context) (ml.Tensor, error) {
 	mask := make([]float32, batchSize*length)
 
 	for i := range c.curBatchSize {
+		enabled := !slices.Contains(c.opts.Except, i)
 		for j := c.curCellRange.min; j <= c.curCellRange.max; j++ {
 			if !slices.Contains(c.cells[j].sequences, c.curSequences[i]) ||
-				(c.causal && c.cells[j].pos > c.curPositions[i]) ||
+				(enabled && c.cells[j].pos > c.curPositions[i]) ||
 				c.cells[j].pos < c.curPositions[i]-c.windowSize {
 				mask[i*length+(j-c.curCellRange.min)] = float32(math.Inf(-1))
 			}
@@ -404,15 +405,16 @@ func (c *Causal) SetLayer(layer int) {
 	c.curLayer = layer
 }
 
-// SetCausal enables or disables causal mask generation for subsequent calls to Get.
-// This state carries over to future forward passes. The default value is true.
-//
-// ctx may be set to nil if this is called from outside of a forward pass, for
-// example, when initializing the cache.
-func (c *Causal) SetCausal(ctx ml.Context, causal bool) {
-	if c.causal != causal {
-		c.causal = causal
+type CausalOptions struct {
+	// Enabled controls whether the causal mask is generated for a particular index in a batch
+	Except []int
+}
 
+// SetCausal disables causal mask generation for a particular range of indicies in
+// the current batch for subsequent calls to Get. The state resets for the next forward pass.
+func (c *Causal) SetCausal(ctx ml.Context, opts CausalOptions) {
+	if !slices.Equal(c.opts.Except, opts.Except) {
+		c.opts = opts
 		if ctx != nil {
 			var err error
 			c.curMask, err = c.buildMask(ctx)

+ 13 - 1
kvcache/causal_test.go

@@ -441,11 +441,19 @@ func (t *testTensor) Scale(ctx ml.Context, s float64) ml.Tensor {
 	panic("not implemented")
 }
 
+func (t *testTensor) AvgPool1D(ctx ml.Context, k, s, p int) ml.Tensor {
+	panic("not implemented")
+}
+
+func (t *testTensor) AvgPool2D(ctx ml.Context, k, s int, p float32) ml.Tensor {
+	panic("not implemented")
+}
+
 func (t *testTensor) Conv2D(ctx ml.Context, weight ml.Tensor, s0, s1, p0, p1, d0, d1 int) ml.Tensor {
 	panic("not implemented")
 }
 
-func (t *testTensor) RoPE(ctx ml.Context, positionIDs, ropeFactors ml.Tensor, dim uint32, base, scale float32) ml.Tensor {
+func (t *testTensor) RoPE(ctx ml.Context, positionIDs, ropeFactors ml.Tensor, dim, ropeType uint32, base, scale float32) ml.Tensor {
 	panic("not implemented")
 }
 
@@ -495,6 +503,10 @@ func (t *testTensor) Contiguous(ctx ml.Context) ml.Tensor {
 	panic("not implemented")
 }
 
+func (t *testTensor) Set(ctx ml.Context, t2 ml.Tensor, offset int, strides ...int) ml.Tensor {
+	panic("not implemented")
+}
+
 func (t *testTensor) Pad(ctx ml.Context, shape ...int) ml.Tensor {
 	panic("not implemented")
 }

+ 33 - 0
llama/patches/0020-ollama-debug-tensor.patch

@@ -0,0 +1,33 @@
+From 0000000000000000000000000000000000000000 Mon Sep 17 00:00:00 2001
+From: Michael Yang <mxyng@pm.me>
+Date: Sun, 9 Mar 2025 14:44:16 -0700
+Subject: [PATCH] ollama debug tensor
+
+---
+ ggml/src/ggml-cpu/ggml-cpu.c | 6 ++++++
+ 1 file changed, 6 insertions(+)
+
+diff --git a/ggml/src/ggml-cpu/ggml-cpu.c b/ggml/src/ggml-cpu/ggml-cpu.c
+index 2f606d82..ec60e8fc 100644
+--- a/ggml/src/ggml-cpu/ggml-cpu.c
++++ b/ggml/src/ggml-cpu/ggml-cpu.c
+@@ -11,6 +11,8 @@
+ #include "ggml-threading.h"
+ #include "ggml.h"
+ 
++#include "ollama-debug.h"
++
+ #if defined(_MSC_VER) || defined(__MINGW32__)
+ #include <malloc.h> // using malloc.h with MSC/MINGW
+ #elif !defined(__FreeBSD__) && !defined(__NetBSD__) && !defined(__OpenBSD__)
+@@ -14103,6 +14105,10 @@ static thread_ret_t ggml_graph_compute_thread(void * data) {
+ 
+         ggml_compute_forward(&params, node);
+ 
++#ifdef OLLAMA_DEBUG
++        ollama_debug(node, true);
++#endif
++
+         if (state->ith == 0 && cplan->abort_callback &&
+                 cplan->abort_callback(cplan->abort_callback_data)) {
+             atomic_store_explicit(&tp->abort, node_n + 1, memory_order_relaxed);

+ 1 - 1
llm/server.go

@@ -271,7 +271,7 @@ func NewLlamaServer(gpus discover.GpuInfoList, modelPath string, f *ggml.GGML, a
 
 	var llamaModel *llama.Model
 	var textProcessor model.TextProcessor
-	if envconfig.NewEngine() {
+	if envconfig.NewEngine() || f.KV().OllamaEngineRequired() {
 		textProcessor, err = model.NewTextProcessor(modelPath)
 		if err != nil {
 			// To prepare for opt-out mode, instead of treating this as an error, we fallback to the old runner

+ 5 - 1
ml/backend.go

@@ -19,6 +19,7 @@ type Config interface {
 
 	Strings(string, ...[]string) []string
 	Uints(string, ...[]uint32) []uint32
+	Floats(string, ...[]float32) []float32
 }
 
 type Backend interface {
@@ -134,8 +135,10 @@ type Tensor interface {
 	RMSNorm(ctx Context, weight Tensor, eps float32) Tensor
 	Scale(ctx Context, s float64) Tensor
 
+	AvgPool2D(ctx Context, k, s int, p float32) Tensor
 	Conv2D(ctx Context, weight Tensor, s0, s1, p0, p1, d0, d1 int) Tensor
-	RoPE(ctx Context, positionIDs, ropeFactors Tensor, dim uint32, base, scale float32) Tensor
+
+	RoPE(ctx Context, positionIDs, ropeFactors Tensor, dim, ropeType uint32, base, scale float32) Tensor
 
 	Tanh(ctx Context) Tensor
 	GELU(ctx Context) Tensor
@@ -145,6 +148,7 @@ type Tensor interface {
 	View(ctx Context, offset int, shape ...int) Tensor
 	Permute(ctx Context, shape ...int) Tensor
 	Contiguous(ctx Context) Tensor
+	Set(ctx Context, t2 Tensor, offset int, strides ...int) Tensor
 
 	Pad(ctx Context, shape ...int) Tensor
 	Unpad(ctx Context, shape ...int) Tensor

+ 43 - 14
ml/backend/ggml/ggml.go

@@ -240,11 +240,22 @@ func New(r *os.File, params ml.BackendParams) (ml.Backend, error) {
 		switch {
 		case contains(t.Name, "position_embd", "token_embd", "token_norm_embd", "token_types"):
 			createTensor(tensor{source: t}, input.bts)
+			if _, ok := meta.Tensors().GroupLayers()["output"]; !ok && t.Name == "token_embd.weight" {
+				createTensor(tensor{source: t, target: "output.weight"}, output.bts)
+			}
 		case contains(t.Name, "cls", "output", "output_norm"):
 			createTensor(tensor{source: t}, output.bts)
 		case strings.HasPrefix(t.Name, "v.") || strings.HasPrefix(t.Name, "mm."):
 			// TODO: assign vision tensors to the gpu if possible
-			createTensor(tensor{source: t}, input.bts)
+			createTensor(tensor{source: t}, output.bts)
+		case contains(t.Name, "rope_freqs", "rope_factors_long", "rope_factors_short"):
+			// these tensors should be repeated per layer
+			for i, layer := range layers {
+				createTensor(tensor{
+					source: t,
+					target: "blk." + strconv.Itoa(i) + "." + t.Name,
+				}, layer.bts)
+			}
 		default:
 			layerIndex := -1
 			if fields := strings.FieldsFunc(t.Name, func(r rune) bool { return !unicode.IsNumber(r) }); len(fields) > 0 {
@@ -256,14 +267,8 @@ func New(r *os.File, params ml.BackendParams) (ml.Backend, error) {
 			if layerIndex >= 0 {
 				createTensor(tensor{source: t}, layers[layerIndex].bts)
 			} else {
-				// this is a repeating tensor that doesn't explicitly associated with a layer so
-				// duplicate it for each layer
-				for i, layer := range layers {
-					createTensor(tensor{
-						source: t,
-						target: "blk." + strconv.Itoa(i) + "." + t.Name,
-					}, layer.bts)
-				}
+				// load all other tensors on the cpu
+				createTensor(tensor{source: t}, input.bts)
 			}
 		}
 	}
@@ -352,7 +357,7 @@ func New(r *os.File, params ml.BackendParams) (ml.Backend, error) {
 
 		if C.ggml_backend_is_cpu(b) {
 			// set number of threads for cpu backend
-			C.ggml_backend_cpu_set_n_threads(b, C.int(params.NumThreads))
+			C.ggml_backend_cpu_set_n_threads(b, C.int(Threads(params.NumThreads)))
 		}
 	}
 
@@ -893,10 +898,13 @@ func (t *Tensor) View(ctx ml.Context, offset int, shape ...int) ml.Tensor {
 }
 
 const (
-	ropeTypeNorm C.int = iota
+	ropeTypeNorm   C.int = 0
+	ropeTypeNeox   C.int = 2
+	ropeTypeMrope  C.int = 8
+	ropeTypeVision C.int = 24
 )
 
-func (t *Tensor) RoPE(ctx ml.Context, positionIDs, ropeFactors ml.Tensor, ropeDim uint32, ropeBase, ropeScale float32) ml.Tensor {
+func (t *Tensor) RoPE(ctx ml.Context, positionIDs, ropeFactors ml.Tensor, ropeDim, ropeType uint32, ropeBase, ropeScale float32) ml.Tensor {
 	if ropeFactors == nil {
 		ropeFactors = &Tensor{b: t.b}
 	}
@@ -911,8 +919,8 @@ func (t *Tensor) RoPE(ctx ml.Context, positionIDs, ropeFactors ml.Tensor, ropeDi
 		t: C.ggml_rope_ext(
 			ctx.(*Context).ctx, dequant, positionIDs.(*Tensor).t, ropeFactors.(*Tensor).t,
 			C.int(ropeDim),
-			131072,       // YaRN n_ctx_train
-			ropeTypeNorm, // ROPE_TYPE_NORM
+			C.int(ropeType),
+			131072, // YaRN n_ctx_train
 			C.float(ropeBase),
 			C.float(ropeScale),
 			0.,  // YaRN ext_factor
@@ -944,6 +952,27 @@ func (t *Tensor) Conv2D(ctx ml.Context, t2 ml.Tensor, s0, s1, p0, p1, d0, d1 int
 	}
 }
 
+func (t *Tensor) AvgPool2D(ctx ml.Context, k, s int, p float32) ml.Tensor {
+	return &Tensor{
+		b: t.b,
+		t: C.ggml_pool_2d(ctx.(*Context).ctx, t.t, C.GGML_OP_POOL_AVG, C.int(k), C.int(k), C.int(s), C.int(s), C.float(p), C.float(p)),
+	}
+}
+
+func (t *Tensor) Set(ctx ml.Context, t2 ml.Tensor, offset int, strides ...int) ml.Tensor {
+	var tt *C.struct_ggml_tensor
+	switch len(strides) {
+	case 0:
+		tt = C.ggml_set_1d(ctx.(*Context).ctx, t.t, t2.(*Tensor).t, C.size_t(offset))
+	case 1:
+		tt = C.ggml_set_2d(ctx.(*Context).ctx, t.t, t2.(*Tensor).t, C.size_t(offset), C.size_t(strides[0]))
+	default:
+		panic("unsupported number of dimensions")
+	}
+
+	return &Tensor{b: t.b, t: tt}
+}
+
 func (t *Tensor) ScaledDotProductAttention(ctx ml.Context, key, value, mask ml.Tensor, scale float64) ml.Tensor {
 	var kqMask *C.struct_ggml_tensor
 	if mask != nil {

+ 11 - 0
ml/backend/ggml/ggml/include/ollama-debug.h

@@ -0,0 +1,11 @@
+#include "ggml.h"
+
+#ifdef __cplusplus
+extern "C" {
+#endif
+
+void ollama_debug(const struct ggml_tensor *tensor, bool verbose);
+
+#ifdef __cplusplus
+}
+#endif

+ 6 - 0
ml/backend/ggml/ggml/src/ggml-cpu/cpu_debug.go

@@ -0,0 +1,6 @@
+//go:build debug
+
+package cpu
+
+// #cgo CPPFLAGS: -DOLLAMA_DEBUG
+import "C"

+ 6 - 0
ml/backend/ggml/ggml/src/ggml-cpu/ggml-cpu.c

@@ -11,6 +11,8 @@
 #include "ggml-threading.h"
 #include "ggml.h"
 
+#include "ollama-debug.h"
+
 #if defined(_MSC_VER) || defined(__MINGW32__)
 #include <malloc.h> // using malloc.h with MSC/MINGW
 #elif !defined(__FreeBSD__) && !defined(__NetBSD__) && !defined(__OpenBSD__)
@@ -14103,6 +14105,10 @@ static thread_ret_t ggml_graph_compute_thread(void * data) {
 
         ggml_compute_forward(&params, node);
 
+#ifdef OLLAMA_DEBUG
+        ollama_debug(node, true);
+#endif
+
         if (state->ith == 0 && cplan->abort_callback &&
                 cplan->abort_callback(cplan->abort_callback_data)) {
             atomic_store_explicit(&tp->abort, node_n + 1, memory_order_relaxed);

+ 115 - 0
ml/backend/ggml/ggml/src/ollama-debug.c

@@ -0,0 +1,115 @@
+#include <string.h>
+
+#include "ollama-debug.h"
+
+static int mul(int64_t *dims, int ndims) {
+    int result = 1;
+    for (int i = 0; i < ndims; i++) {
+        result *= dims[i];
+    }
+
+    return result;
+}
+
+static void repeat(char c, int n) {
+    for (int i = 0; i < n; i++) {
+        fprintf(stderr, "%c", c);
+    }
+}
+
+static void print_tensor(const void *tensor, void (*cb)(const void *, int),
+                         int shape,
+                         int64_t *dims, int ndims, int stride,
+                         int nitems, int pad) {
+    fprintf(stderr, "[");
+    for (int i = 0; i < dims[0]; i++) {
+        if (i >= nitems && i < dims[0] - nitems) {
+            fprintf(stderr, "... (%lld more), ", dims[0] - 2 * nitems);
+            int skip = dims[0] - 2 * nitems;
+            if (ndims > 1) {
+                stride += mul(dims + 1, ndims - 1) * skip;
+                repeat('\n', ndims - 1);
+                repeat(' ', shape - ndims + 1 + pad);
+            }
+            i += skip - 1;
+        } else if (ndims > 1) {
+            print_tensor(tensor, cb, shape, dims + 1, ndims - 1, stride,
+                         nitems, pad);
+            stride += mul(dims + 1, ndims - 1);
+            if (i < dims[0] - 1) {
+                fprintf(stderr, ", ");
+                repeat('\n', ndims - 1);
+                repeat(' ', shape - ndims + 1 + pad);
+            }
+        } else {
+            cb(tensor, stride + i);
+            if (i < dims[0] - 1) {
+                fprintf(stderr, ", ");
+            }
+        }
+    }
+    fprintf(stderr, "]");
+}
+
+static void print_tensor_f16(const void *tensor, int i) {
+    float value = ggml_fp16_to_fp32(((const ggml_fp16_t *)tensor)[i]);
+    fprintf(stderr, "%s%f", value < 0 ? "" : " ", value);
+}
+
+static void print_tensor_f32(const void *tensor, int i) {
+    float value = ((const float *)tensor)[i];
+    fprintf(stderr, "%s%f", value < 0 ? "" : " ", value);
+}
+
+static void print_tensor_i32(const void *tensor, int i) {
+    int32_t value = ((const int32_t *)tensor)[i];
+    fprintf(stderr, "%s%d", value < 0 ? "" : " ", value);
+}
+
+static void ollama_debug_tensor(const struct ggml_tensor *tensor, bool verbose, const char *prefix, int indent) {
+    fprintf(stderr, "%s%s %s (%s): [%lld %lld %lld %lld]\n", prefix, tensor->name,
+            ggml_op_name(tensor->op), ggml_type_name(tensor->type), tensor->ne[0],
+            tensor->ne[1], tensor->ne[2], tensor->ne[3]);
+
+    if (!verbose) {
+        return;
+    }
+
+    for (int i = 0; i < indent; i++) {
+        fprintf(stderr, " ");
+    }
+
+    switch (tensor->type) {
+    case GGML_TYPE_F16:
+        print_tensor(ggml_get_data(tensor), print_tensor_f16, ggml_n_dims(tensor),
+                     (int64_t *)tensor->ne, ggml_n_dims(tensor), 0, 3, indent);
+        break;
+    case GGML_TYPE_F32:
+        print_tensor(ggml_get_data(tensor), print_tensor_f32, ggml_n_dims(tensor),
+                     (int64_t *)tensor->ne, ggml_n_dims(tensor), 0, 3, indent);
+        break;
+    case GGML_TYPE_I32:
+        print_tensor(ggml_get_data(tensor), print_tensor_i32, ggml_n_dims(tensor),
+                     (int64_t *)tensor->ne, ggml_n_dims(tensor), 0, 3, indent);
+        break;
+    default:
+        fprintf(stderr, "<unsupported type>\n");
+        return;
+    }
+
+    fprintf(stderr, "\n");
+}
+
+void ollama_debug(const struct ggml_tensor *tensor, bool verbose) {
+    ollama_debug_tensor(tensor, verbose, ">>> ", 4);
+
+    for (int i = 0; i < GGML_MAX_SRC && tensor->src[i] != NULL; ++i) {
+        char src[8];
+        const int n = snprintf(src, sizeof(src), " src%d ", i);
+        if (n >= sizeof(src)) {
+            src[sizeof(src) - 1] = '\0';
+        }
+
+        ollama_debug_tensor(tensor->src[i], verbose, src, 4);
+    }
+}

+ 7 - 0
ml/backend/ggml/threads.go

@@ -0,0 +1,7 @@
+//go:build !debug
+
+package ggml
+
+func Threads(n int) int {
+	return n
+}

+ 7 - 0
ml/backend/ggml/threads_debug.go

@@ -0,0 +1,7 @@
+//go:build debug
+
+package ggml
+
+func Threads(_ int) int {
+	return 1
+}

+ 220 - 0
model/models/gemma2/model.go

@@ -0,0 +1,220 @@
+package gemma2
+
+import (
+	"math"
+
+	"github.com/ollama/ollama/kvcache"
+	"github.com/ollama/ollama/ml"
+	"github.com/ollama/ollama/ml/nn"
+	"github.com/ollama/ollama/model"
+	"github.com/ollama/ollama/model/input"
+)
+
+type Options struct {
+	hiddenSize, numHeads, numKVHeads int
+	attnKeyLen, attnValLen           int
+	eps, ropeBase, ropeScale         float32
+	attnLogitSoftcap                 float32
+	finalLogitSoftcap                float32
+	largeModelScaling                bool
+}
+
+type Model struct {
+	model.Base
+	model.SentencePieceModel
+
+	TokenEmbedding *nn.Embedding `gguf:"token_embd"`
+	Layers         []Layer       `gguf:"blk"`
+	OutputNorm     *nn.RMSNorm   `gguf:"output_norm"`
+	Output         *nn.Linear    `gguf:"output,alt:token_embd"` // just set to token_embd?
+
+	*Options
+}
+
+const (
+	gemma27BLayerCount = 46
+)
+
+func New(c ml.Config) (model.Model, error) {
+	m := Model{
+		SentencePieceModel: model.NewSentencePieceModel(
+			c.String("tokenizer.ggml.pretokenizer", `(?i:'s|'t|'re|'ve|'m|'ll|'d)|[^\r\n\p{L}\p{N}]?\p{L}+|\p{N}{1,3}| ?[^\s\p{L}\p{N}]+[\r\n]*|\s*[\r\n]+|\s+(?!\S)|\s+`),
+			&model.Vocabulary{
+				Values: c.Strings("tokenizer.ggml.tokens"),
+				Scores: c.Floats("tokenizer.ggml.scores"),
+				Types:  c.Uints("tokenizer.ggml.token_type"),
+				BOS:    int32(c.Uint("tokenizer.ggml.bos_token_id")),
+				EOS:    int32(c.Uint("tokenizer.ggml.eos_token_id")),
+			},
+		),
+		Layers: make([]Layer, c.Uint("block_count")),
+		Options: &Options{
+			hiddenSize:        int(c.Uint("embedding_length")),
+			numHeads:          int(c.Uint("attention.head_count")),
+			numKVHeads:        int(c.Uint("attention.head_count_kv")),
+			attnKeyLen:        int(c.Uint("attention.key_length")),
+			attnValLen:        int(c.Uint("attention.value_length")),
+			eps:               c.Float("attention.layer_norm_rms_epsilon"),
+			ropeBase:          c.Float("rope.freq_base", 10000.0),
+			ropeScale:         c.Float("rope.freq_scale", 1.0),
+			attnLogitSoftcap:  c.Float("attn_logit_softcapping"),
+			finalLogitSoftcap: c.Float("final_logit_softcapping"),
+		},
+	}
+
+	slidingWindowLen := int32(c.Uint("attention.sliding_window"))
+	m.Cache = kvcache.NewWrapperCache(kvcache.NewSWACache(slidingWindowLen, m.Shift), kvcache.NewCausalCache(m.Shift))
+	m.Cache.SetConfig(ml.CacheConfig{})
+
+	return &m, nil
+}
+
+type SelfAttention struct {
+	Query  *nn.Linear `gguf:"attn_q"`
+	Key    *nn.Linear `gguf:"attn_k"`
+	Value  *nn.Linear `gguf:"attn_v"`
+	Output *nn.Linear `gguf:"attn_output"`
+}
+
+func (sa *SelfAttention) Forward(ctx ml.Context, hiddenState, positionIDs ml.Tensor, cache kvcache.Cache, opts *Options) ml.Tensor {
+	batchSize := hiddenState.Dim(1)
+	ropeType := uint32(2)
+
+	q := sa.Query.Forward(ctx, hiddenState)
+	q = q.Reshape(ctx, opts.attnKeyLen, opts.numHeads, batchSize)
+	q = q.RoPE(ctx, positionIDs, nil, uint32(opts.attnKeyLen), ropeType, opts.ropeBase, opts.ropeScale)
+
+	if opts.largeModelScaling {
+		q = q.Scale(ctx, 1.0/math.Sqrt(float64(opts.hiddenSize/opts.numHeads)))
+	} else {
+		q = q.Scale(ctx, 1.0/math.Sqrt(float64(opts.attnKeyLen)))
+	}
+
+	k := sa.Key.Forward(ctx, hiddenState)
+	k = k.Reshape(ctx, opts.attnKeyLen, opts.numKVHeads, batchSize)
+	k = k.RoPE(ctx, positionIDs, nil, uint32(opts.attnKeyLen), ropeType, opts.ropeBase, opts.ropeScale)
+
+	v := sa.Value.Forward(ctx, hiddenState)
+	v = v.Reshape(ctx, opts.attnValLen, opts.numKVHeads, batchSize)
+
+	cache.Put(ctx, k, v)
+	k, v, mask := cache.Get(ctx)
+
+	q = q.Permute(ctx, 0, 2, 1, 3)
+	k = k.Permute(ctx, 0, 2, 1, 3)
+	v = v.Permute(ctx, 1, 2, 0, 3).Contiguous(ctx)
+
+	kq := k.Mulmat(ctx, q)
+
+	// logit softcap
+	kq = kq.Scale(ctx, 1.0/float64(opts.attnLogitSoftcap))
+	kq = kq.Tanh(ctx)
+	kq = kq.Scale(ctx, float64(opts.attnLogitSoftcap))
+
+	kq = kq.Add(ctx, mask)
+	kq = kq.Softmax(ctx)
+
+	kqv := v.Mulmat(ctx, kq)
+	kqv = kqv.Permute(ctx, 0, 2, 1, 3).Contiguous(ctx)
+	kqv = kqv.Reshape(ctx, opts.attnValLen*opts.numHeads, batchSize)
+
+	return sa.Output.Forward(ctx, kqv)
+}
+
+func (m *Model) Shift(ctx ml.Context, layer int, key, shift ml.Tensor) (ml.Tensor, error) {
+	return key.RoPE(ctx, shift, nil, uint32(m.Options.attnKeyLen), uint32(2), m.Options.ropeBase, m.Options.ropeScale), nil
+}
+
+type MLP struct {
+	Up   *nn.Linear `gguf:"ffn_up"`
+	Down *nn.Linear `gguf:"ffn_down"`
+	Gate *nn.Linear `gguf:"ffn_gate"`
+}
+
+func (mlp *MLP) Forward(ctx ml.Context, hiddenState ml.Tensor, opts *Options) ml.Tensor {
+	hiddenState = mlp.Gate.Forward(ctx, hiddenState).GELU(ctx).Mul(ctx, mlp.Up.Forward(ctx, hiddenState))
+	return mlp.Down.Forward(ctx, hiddenState)
+}
+
+type Layer struct {
+	AttentionNorm     *nn.RMSNorm `gguf:"attn_norm"`
+	SelfAttention     *SelfAttention
+	PostAttentionNorm *nn.RMSNorm `gguf:"post_attention_norm"`
+	MLPNorm           *nn.RMSNorm `gguf:"ffn_norm"`
+	MLP               *MLP
+	PostMLPNorm       *nn.RMSNorm `gguf:"post_ffw_norm"`
+}
+
+func (l *Layer) Forward(ctx ml.Context, hiddenState, positionIDs, outputs ml.Tensor, cache kvcache.Cache, opts *Options) ml.Tensor {
+	residual := hiddenState
+
+	hiddenState = l.AttentionNorm.Forward(ctx, hiddenState, opts.eps)
+	hiddenState = l.SelfAttention.Forward(ctx, hiddenState, positionIDs, cache, opts)
+	hiddenState = l.PostAttentionNorm.Forward(ctx, hiddenState, opts.eps)
+
+	// In the final layer (outputs != nil), optimize by pruning to just the token positions
+	// we need logits for.
+	if outputs != nil {
+		hiddenState = hiddenState.Rows(ctx, outputs)
+		residual = residual.Rows(ctx, outputs)
+	}
+
+	hiddenState = hiddenState.Add(ctx, residual)
+	residual = hiddenState
+
+	hiddenState = l.MLPNorm.Forward(ctx, hiddenState, opts.eps)
+	hiddenState = l.MLP.Forward(ctx, hiddenState, opts)
+	hiddenState = l.PostMLPNorm.Forward(ctx, hiddenState, opts.eps)
+	return hiddenState.Add(ctx, residual)
+}
+
+func (m *Model) Forward(ctx ml.Context, opts input.Options) (ml.Tensor, error) {
+	inputs, err := ctx.Input().FromIntSlice(opts.Inputs, len(opts.Inputs))
+	if err != nil {
+		return nil, err
+	}
+
+	positions, err := ctx.Input().FromIntSlice(opts.Positions, len(opts.Positions))
+	if err != nil {
+		return nil, err
+	}
+
+	outputs, err := ctx.Output().FromIntSlice(opts.Outputs, len(opts.Outputs))
+	if err != nil {
+		return nil, err
+	}
+
+	hiddenState := m.TokenEmbedding.Forward(ctx, inputs)
+	hiddenState = hiddenState.Scale(ctx, math.Sqrt(float64(m.Options.hiddenSize)))
+
+	if len(m.Layers) == gemma27BLayerCount {
+		m.Options.largeModelScaling = true
+	}
+
+	for i, layer := range m.Layers {
+		cacheType := i % 2
+		m.Cache.SetLayer(i)
+		wc := m.Cache.(*kvcache.WrapperCache)
+		wc.SetLayerType(cacheType)
+
+		var lastLayerOutputs ml.Tensor
+		if i == len(m.Layers)-1 {
+			lastLayerOutputs = outputs
+		}
+
+		hiddenState = layer.Forward(ctx, hiddenState, positions, lastLayerOutputs, m.Cache, m.Options)
+	}
+
+	hiddenState = m.OutputNorm.Forward(ctx, hiddenState, m.eps)
+	hiddenState = m.Output.Forward(ctx, hiddenState)
+
+	// final logit softcap
+	hiddenState = hiddenState.Scale(ctx, 1.0/float64(m.Options.finalLogitSoftcap))
+	hiddenState = hiddenState.Tanh(ctx)
+	hiddenState = hiddenState.Scale(ctx, float64(m.Options.finalLogitSoftcap))
+	return hiddenState.Rows(ctx, outputs), nil
+}
+
+func init() {
+	model.Register("gemma2", New)
+}

+ 173 - 0
model/models/gemma3/model.go

@@ -0,0 +1,173 @@
+package gemma3
+
+import (
+	"bytes"
+	"encoding/binary"
+	"hash/fnv"
+	"image"
+	"math"
+
+	"github.com/ollama/ollama/kvcache"
+	"github.com/ollama/ollama/ml"
+	"github.com/ollama/ollama/ml/nn"
+	"github.com/ollama/ollama/model"
+	"github.com/ollama/ollama/model/input"
+)
+
+type Model struct {
+	model.Base
+	model.SentencePieceModel
+
+	*VisionModel `gguf:"v,vision"`
+	*TextModel
+
+	*MultiModalProjector `gguf:"mm"`
+
+	ImageProcessor
+}
+
+var _ model.MultimodalProcessor = (*Model)(nil)
+
+type MultiModalProjector struct {
+	SoftEmbNorm     *nn.RMSNorm `gguf:"mm_soft_emb_norm"`
+	InputProjection *nn.Linear  `gguf:"mm_input_projection"`
+
+	tokensPerImage int
+}
+
+func (p *MultiModalProjector) Forward(ctx ml.Context, visionOutputs ml.Tensor, imageSize, patchSize int, eps float32) ml.Tensor {
+	l := visionOutputs.Dim(0)
+
+	visionOutputs = visionOutputs.Permute(ctx, 1, 0, 2, 3).Contiguous(ctx)
+	patchesPerImage := imageSize / patchSize
+	visionOutputs = visionOutputs.Reshape(ctx, patchesPerImage, patchesPerImage, l)
+
+	kernelSize := patchesPerImage / int(math.Sqrt(float64(p.tokensPerImage)))
+	visionOutputs = visionOutputs.AvgPool2D(ctx, kernelSize, kernelSize, 0)
+	visionOutputs = visionOutputs.Reshape(ctx, visionOutputs.Dim(0)*visionOutputs.Dim(1), l)
+	visionOutputs = visionOutputs.Permute(ctx, 1, 0, 2, 3).Contiguous(ctx)
+	visionOutputs = p.SoftEmbNorm.Forward(ctx, visionOutputs, eps)
+
+	// TODO: inputProjection must be transposed since they're incompatible with visionOutputs
+	visionOutputs = p.InputProjection.Weight.Permute(ctx, 1, 0, 2, 3).Contiguous(ctx).Mulmat(ctx, visionOutputs)
+	return visionOutputs
+}
+
+func New(c ml.Config) (model.Model, error) {
+	m := Model{
+		SentencePieceModel: model.NewSentencePieceModel(
+			c.String("tokenizer.ggml.pretokenizer", `(?i:'s|'t|'re|'ve|'m|'ll|'d)|[^\r\n\p{L}\p{N}]?\p{L}+|\p{N}{1,3}| ?[^\s\p{L}\p{N}]+[\r\n]*|\s*[\r\n]+|\s+(?!\S)|\s+`),
+			&model.Vocabulary{
+				Values: c.Strings("tokenizer.ggml.tokens"),
+				Scores: c.Floats("tokenizer.ggml.scores"),
+				Types:  c.Uints("tokenizer.ggml.token_type"),
+				BOS:    int32(c.Uint("tokenizer.ggml.bos_token_id")),
+				AddBOS: c.Bool("tokenizer.ggml.add_bos_token", true),
+				EOS:    int32(1),
+				AddEOS: c.Bool("tokenizer.ggml.add_eos_token", false),
+				EOT:    int32(106),
+				AddEOT: c.Bool("tokenizer.ggml.add_eot_token", false),
+			},
+		),
+		ImageProcessor: newImageProcessor(c),
+		VisionModel:    newVisionModel(c),
+		TextModel:      newTextModel(c),
+		MultiModalProjector: &MultiModalProjector{
+			tokensPerImage: int(c.Uint("mm_tokens_per_image", 256)),
+		},
+	}
+
+	slidingWindowLen := int32(c.Uint("attention.sliding_window"))
+	m.Cache = kvcache.NewWrapperCache(kvcache.NewSWACache(slidingWindowLen, m.Shift), kvcache.NewCausalCache(m.Shift))
+
+	return &m, nil
+}
+
+func (m *Model) EncodeMultimodal(ctx ml.Context, multimodalData []byte) (any, error) {
+	image, _, err := image.Decode(bytes.NewReader(multimodalData))
+	if err != nil {
+		return nil, err
+	}
+
+	f32s, err := m.ImageProcessor.ProcessImage(image)
+	if err != nil {
+		return nil, err
+	}
+
+	pixelValues, err := ctx.Input().FromFloatSlice(f32s,
+		m.ImageProcessor.imageSize,
+		m.ImageProcessor.imageSize,
+		m.ImageProcessor.numChannels,
+	)
+	if err != nil {
+		return nil, err
+	}
+
+	visionOutputs := m.VisionModel.Forward(ctx, pixelValues)
+	visionOutputs = m.MultiModalProjector.Forward(ctx, visionOutputs, m.imageSize, m.patchSize, m.VisionModel.eps)
+	return visionOutputs, nil
+}
+
+type imageToken struct {
+	embedding ml.Tensor
+	index     int
+}
+
+func (m *Model) PostTokenize(ctx ml.Context, inputs []input.Input) ([]input.Input, error) {
+	var result []input.Input
+	fnvHash := fnv.New64a()
+
+	for _, inp := range inputs {
+		if inp.Multimodal == nil {
+			result = append(result, inp)
+		} else {
+			imageInputs := []input.Input{
+				{Token: 108},    // "\n\n"
+				{Token: 255999}, // "<start_of_image>""
+			}
+			result = append(result, imageInputs...)
+
+			// add image embeddings
+			inputMultimodal := inp.Multimodal.(ml.Tensor)
+
+			for i := range inputMultimodal.Dim(1) {
+				fnvHash.Reset()
+				binary.Write(fnvHash, binary.NativeEndian, inp.MultimodalHash)
+				fnvHash.Write([]byte{byte(i)})
+
+				imageToken := imageToken{embedding: inputMultimodal, index: i}
+				result = append(result, input.Input{Multimodal: imageToken, MultimodalHash: fnvHash.Sum64()})
+			}
+
+			result = append(result,
+				input.Input{Token: 256000}, // <end_of_image>
+				input.Input{Token: 108},    // "\n\n"
+			)
+		}
+	}
+
+	return result, nil
+}
+
+func (m *Model) Forward(ctx ml.Context, opts input.Options) (ml.Tensor, error) {
+	inputs, err := ctx.Input().FromIntSlice(opts.Inputs, len(opts.Inputs))
+	if err != nil {
+		return nil, err
+	}
+
+	positions, err := ctx.Input().FromIntSlice(opts.Positions, len(opts.Positions))
+	if err != nil {
+		return nil, err
+	}
+
+	outputs, err := ctx.Output().FromIntSlice(opts.Outputs, len(opts.Outputs))
+	if err != nil {
+		return nil, err
+	}
+
+	return m.TextModel.Forward(ctx, inputs, positions, outputs, opts, m.Cache), nil
+}
+
+func init() {
+	model.Register("gemma3", New)
+}

+ 254 - 0
model/models/gemma3/model_text.go

@@ -0,0 +1,254 @@
+package gemma3
+
+import (
+	"math"
+
+	"github.com/ollama/ollama/kvcache"
+	"github.com/ollama/ollama/ml"
+	"github.com/ollama/ollama/ml/nn"
+	"github.com/ollama/ollama/model"
+	"github.com/ollama/ollama/model/input"
+)
+
+type TextOptions struct {
+	hiddenSize, numHeads, numKVHeads int
+	attnKeyLen, attnValLen           int
+	eps, ropeScale                   float32
+	ropeLocalBase, ropeGlobalBase    float32
+	finalLogitSoftcap                float32
+	largeModelScaling                bool
+}
+
+type TextModel struct {
+	model.Base
+	model.SentencePieceModel
+
+	TokenEmbedding *nn.Embedding `gguf:"token_embd"`
+	Layers         []TextLayer   `gguf:"blk"`
+	OutputNorm     *nn.RMSNorm   `gguf:"output_norm"`
+	Output         *nn.Linear    `gguf:"output,alt:token_embd"`
+
+	*TextOptions
+}
+
+const (
+	gemmaGlobalCacheCount = 6
+	gemma27BLayerCount    = 62
+)
+
+const (
+	cacheTypeSWA = iota
+	cacheTypeCausal
+)
+
+func newTextModel(c ml.Config) *TextModel {
+	numBlocks := int(c.Uint("block_count"))
+
+	m := TextModel{
+		SentencePieceModel: model.NewSentencePieceModel(
+			c.String("tokenizer.ggml.pretokenizer", `(?i:'s|'t|'re|'ve|'m|'ll|'d)|[^\r\n\p{L}\p{N}]?\p{L}+|\p{N}{1,3}| ?[^\s\p{L}\p{N}]+[\r\n]*|\s*[\r\n]+|\s+(?!\S)|\s+`),
+			&model.Vocabulary{
+				Values: c.Strings("tokenizer.ggml.tokens"),
+				Scores: c.Floats("tokenizer.ggml.scores"),
+				Types:  c.Uints("tokenizer.ggml.token_type"),
+				BOS:    int32(c.Uint("tokenizer.ggml.bos_token_id")),
+				EOS:    int32(c.Uint("tokenizer.ggml.eos_token_id")),
+			},
+		),
+		Layers: make([]TextLayer, numBlocks),
+		TextOptions: &TextOptions{
+			hiddenSize:        int(c.Uint("embedding_length")),
+			numHeads:          int(c.Uint("attention.head_count")),
+			numKVHeads:        int(c.Uint("attention.head_count_kv")),
+			attnKeyLen:        int(c.Uint("attention.key_length", 256)),
+			attnValLen:        int(c.Uint("attention.value_length", 256)),
+			eps:               c.Float("attention.layer_norm_rms_epsilon", 1e-06),
+			ropeLocalBase:     c.Float("rope.local.freq_base", 10000.0),
+			ropeGlobalBase:    c.Float("rope.global.freq_base", 1000000.0),
+			ropeScale:         c.Float("rope.freq_scale", 1.0),
+			finalLogitSoftcap: c.Float("final_logit_softcapping", 30.0),
+		},
+	}
+
+	if numBlocks == gemma27BLayerCount {
+		m.largeModelScaling = true
+	}
+
+	return &m
+}
+
+type TextSelfAttention struct {
+	Query     *nn.Linear  `gguf:"attn_q"`
+	QueryNorm *nn.RMSNorm `gguf:"attn_q_norm"`
+	Key       *nn.Linear  `gguf:"attn_k"`
+	KeyNorm   *nn.RMSNorm `gguf:"attn_k_norm"`
+	Value     *nn.Linear  `gguf:"attn_v"`
+	Output    *nn.Linear  `gguf:"attn_output"`
+}
+
+func (sa *TextSelfAttention) Forward(ctx ml.Context, layer int, hiddenState, positionIDs ml.Tensor, cache kvcache.Cache, opts *TextOptions) ml.Tensor {
+	batchSize := hiddenState.Dim(1)
+	ropeType := uint32(2)
+
+	ropeBase := opts.ropeLocalBase
+	if (layer+1)%gemmaGlobalCacheCount == 0 {
+		ropeBase = opts.ropeGlobalBase
+	}
+
+	q := sa.Query.Forward(ctx, hiddenState)
+	q = q.Reshape(ctx, opts.attnKeyLen, opts.numHeads, batchSize)
+	q = sa.QueryNorm.Forward(ctx, q, opts.eps)
+	q = q.RoPE(ctx, positionIDs, nil, uint32(opts.attnKeyLen), ropeType, ropeBase, opts.ropeScale)
+
+	if opts.largeModelScaling {
+		q = q.Scale(ctx, 1.0/math.Sqrt(float64(opts.hiddenSize/opts.numHeads)))
+	} else {
+		q = q.Scale(ctx, 1.0/math.Sqrt(float64(opts.attnKeyLen)))
+	}
+
+	k := sa.Key.Forward(ctx, hiddenState)
+	k = k.Reshape(ctx, opts.attnKeyLen, opts.numKVHeads, batchSize)
+	k = sa.KeyNorm.Forward(ctx, k, opts.eps)
+	k = k.RoPE(ctx, positionIDs, nil, uint32(opts.attnKeyLen), ropeType, ropeBase, opts.ropeScale)
+
+	v := sa.Value.Forward(ctx, hiddenState)
+	v = v.Reshape(ctx, opts.attnValLen, opts.numKVHeads, batchSize)
+
+	scaleFactor := 1.0
+	kqv := nn.Attention(ctx, q, k, v, scaleFactor, cache)
+	kqv = kqv.Reshape(ctx, opts.attnValLen*opts.numHeads, batchSize)
+
+	return sa.Output.Forward(ctx, kqv)
+}
+
+func (m *TextModel) Shift(ctx ml.Context, layer int, key, shift ml.Tensor) (ml.Tensor, error) {
+	ropeBase := m.TextOptions.ropeLocalBase
+	if (layer+1)%gemmaGlobalCacheCount == 0 {
+		ropeBase = m.TextOptions.ropeGlobalBase
+	}
+
+	return key.RoPE(ctx, shift, nil, uint32(m.TextOptions.attnKeyLen), uint32(2), ropeBase, m.TextOptions.ropeScale), nil
+}
+
+type TextMLP struct {
+	Up   *nn.Linear `gguf:"ffn_up"`
+	Down *nn.Linear `gguf:"ffn_down"`
+	Gate *nn.Linear `gguf:"ffn_gate"`
+}
+
+func (mlp *TextMLP) Forward(ctx ml.Context, hiddenState ml.Tensor, opts *TextOptions) ml.Tensor {
+	hiddenState = mlp.Gate.Forward(ctx, hiddenState).GELU(ctx).Mul(ctx, mlp.Up.Forward(ctx, hiddenState))
+	return mlp.Down.Forward(ctx, hiddenState)
+}
+
+type TextLayer struct {
+	AttentionNorm     *nn.RMSNorm `gguf:"attn_norm"`
+	SelfAttention     *TextSelfAttention
+	PostAttentionNorm *nn.RMSNorm `gguf:"post_attention_norm"`
+	MLPNorm           *nn.RMSNorm `gguf:"ffn_norm"`
+	MLP               *TextMLP
+	PostMLPNorm       *nn.RMSNorm `gguf:"post_ffw_norm"`
+}
+
+func (l *TextLayer) Forward(ctx ml.Context, layer int, hiddenState, positionIDs, outputs ml.Tensor, cache kvcache.Cache, opts *TextOptions) ml.Tensor {
+	residual := hiddenState
+
+	hiddenState = l.AttentionNorm.Forward(ctx, hiddenState, opts.eps)
+	hiddenState = l.SelfAttention.Forward(ctx, layer, hiddenState, positionIDs, cache, opts)
+	hiddenState = l.PostAttentionNorm.Forward(ctx, hiddenState, opts.eps)
+
+	// In the final layer (outputs != nil), optimize by pruning to just the token positions
+	// we need logits for.
+	if outputs != nil {
+		hiddenState = hiddenState.Rows(ctx, outputs)
+		residual = residual.Rows(ctx, outputs)
+	}
+
+	hiddenState = hiddenState.Add(ctx, residual)
+	residual = hiddenState
+
+	hiddenState = l.MLPNorm.Forward(ctx, hiddenState, opts.eps)
+	hiddenState = l.MLP.Forward(ctx, hiddenState, opts)
+	hiddenState = l.PostMLPNorm.Forward(ctx, hiddenState, opts.eps)
+	return hiddenState.Add(ctx, residual)
+}
+
+func setImageEmbeddings(ctx ml.Context, hiddenState ml.Tensor, multimodal []input.MultimodalIndex) []int {
+	var embedding ml.Tensor
+	var src, dst, length int
+	var except []int
+
+	for _, image := range multimodal {
+		imageToken := image.Multimodal.(imageToken)
+		imageSrc := imageToken.index
+		imageDst := image.Index
+
+		if embedding == nil {
+			embedding = imageToken.embedding
+			src = imageSrc
+			dst = imageDst
+			length = 1
+		} else if embedding == imageToken.embedding && imageSrc+1 == src && imageDst+1 == dst {
+			src = imageSrc
+			dst = imageDst
+			length++
+		} else if embedding == imageToken.embedding && src+length == imageSrc && dst+length == imageDst {
+			length++
+		} else {
+			visionOutputs := embedding.View(ctx, src*embedding.Stride(1), length*embedding.Dim(0))
+			ctx.Forward(visionOutputs.Copy(ctx, hiddenState.View(ctx, dst*hiddenState.Stride(1), length*hiddenState.Dim(0))))
+
+			embedding = imageToken.embedding
+			src = imageSrc
+			dst = imageDst
+			length = 1
+		}
+
+		except = append(except, imageDst)
+	}
+
+	if embedding != nil {
+		visionOutputs := embedding.View(ctx, src*embedding.Stride(1), length*embedding.Dim(0))
+		ctx.Forward(visionOutputs.Copy(ctx, hiddenState.View(ctx, dst*hiddenState.Stride(1), length*hiddenState.Dim(0))))
+	}
+
+	return except
+}
+
+func (m *TextModel) Forward(ctx ml.Context, inputs, positions, outputs ml.Tensor, opts input.Options, cache kvcache.Cache) ml.Tensor {
+	hiddenState := m.TokenEmbedding.Forward(ctx, inputs)
+	hiddenState = hiddenState.Scale(ctx, math.Sqrt(float64(m.TextOptions.hiddenSize)))
+
+	except := setImageEmbeddings(ctx, hiddenState, opts.Multimodal)
+
+	for i, layer := range m.Layers {
+		// gemma alternates between the sliding window (local) and causal (global)
+		// kv cache every 6 layers
+		cacheType := cacheTypeSWA
+		if (i+1)%gemmaGlobalCacheCount == 0 {
+			cacheType = cacheTypeCausal
+		}
+		cache.SetLayer(i)
+		wc := cache.(*kvcache.WrapperCache)
+		wc.SetLayerType(cacheType)
+
+		if causal, ok := wc.UnderlyingCache().(*kvcache.Causal); ok {
+			causal.SetCausal(ctx, kvcache.CausalOptions{Except: except})
+		}
+
+		var lastLayerOutputs ml.Tensor
+		if i == len(m.Layers)-1 {
+			lastLayerOutputs = outputs
+		}
+
+		hiddenState = layer.Forward(ctx, i, hiddenState, positions, lastLayerOutputs, cache, m.TextOptions)
+	}
+
+	hiddenState = m.OutputNorm.Forward(ctx, hiddenState, m.eps)
+	hiddenState = m.Output.Forward(ctx, hiddenState)
+
+	// final logit softcap
+	hiddenState = hiddenState.Scale(ctx, 1.0/float64(m.TextOptions.finalLogitSoftcap))
+	hiddenState = hiddenState.Tanh(ctx)
+	return hiddenState.Scale(ctx, float64(m.TextOptions.finalLogitSoftcap))
+}

+ 127 - 0
model/models/gemma3/model_vision.go

@@ -0,0 +1,127 @@
+package gemma3
+
+import (
+	"math"
+
+	"github.com/ollama/ollama/ml"
+	"github.com/ollama/ollama/ml/nn"
+)
+
+var batchSize int = 1
+
+type VisionSelfAttention struct {
+	Query  *nn.Linear `gguf:"attn_q"`
+	Key    *nn.Linear `gguf:"attn_k"`
+	Value  *nn.Linear `gguf:"attn_v"`
+	Output *nn.Linear `gguf:"attn_output"`
+}
+
+func (sa *VisionSelfAttention) Forward(ctx ml.Context, hiddenState ml.Tensor, opts *VisionModelOptions) ml.Tensor {
+	headDim := opts.hiddenSize / opts.numHeads
+
+	query := sa.Query.Forward(ctx, hiddenState)
+	key := sa.Key.Forward(ctx, hiddenState)
+	value := sa.Value.Forward(ctx, hiddenState)
+
+	query = query.Reshape(ctx, headDim, opts.numHeads, query.Dim(1), batchSize)
+	key = key.Reshape(ctx, headDim, opts.numHeads, key.Dim(1), batchSize)
+	value = value.Reshape(ctx, headDim, opts.numHeads, value.Dim(1), batchSize)
+
+	attention := nn.Attention(ctx, query, key, value, 1.0/math.Sqrt(float64(headDim)), nil)
+	attention = attention.Reshape(ctx, opts.hiddenSize, attention.Dim(2), batchSize)
+
+	hiddenState = sa.Output.Forward(ctx, attention)
+	return hiddenState
+}
+
+type VisionMLP struct {
+	FC1 *nn.Linear `gguf:"fc1"`
+	FC2 *nn.Linear `gguf:"fc2"`
+}
+
+func (mlp *VisionMLP) Forward(ctx ml.Context, hiddenState ml.Tensor, opts *VisionModelOptions) ml.Tensor {
+	hiddenState = mlp.FC1.Forward(ctx, hiddenState).GELU(ctx)
+	hiddenState = mlp.FC2.Forward(ctx, hiddenState)
+	return hiddenState
+}
+
+type VisionEncoderLayer struct {
+	LayerNorm1    *nn.LayerNorm `gguf:"layer_norm1"`
+	SelfAttention *VisionSelfAttention
+
+	LayerNorm2 *nn.LayerNorm `gguf:"layer_norm2"`
+	MLP        *VisionMLP    `gguf:"mlp"`
+}
+
+func (e *VisionEncoderLayer) Forward(ctx ml.Context, hiddenState ml.Tensor, opts *VisionModelOptions) ml.Tensor {
+	residual := hiddenState
+
+	// self attention
+	hiddenState = e.LayerNorm1.Forward(ctx, hiddenState, opts.eps)
+	hiddenState = e.SelfAttention.Forward(ctx, hiddenState, opts)
+	hiddenState = hiddenState.Add(ctx, residual)
+	residual = hiddenState
+
+	// feed forward
+	hiddenState = e.LayerNorm2.Forward(ctx, hiddenState, opts.eps)
+	hiddenState = e.MLP.Forward(ctx, hiddenState, opts)
+	return hiddenState.Add(ctx, residual)
+}
+
+type VisionModelOptions struct {
+	hiddenSize, numHeads int
+	imageSize, patchSize int
+	eps                  float32
+}
+
+type VisionModel struct {
+	PatchEmbedding    *nn.Conv2D    `gguf:"patch_embedding"`
+	PositionEmbedding *nn.Embedding `gguf:"position_embedding"`
+	PostLayerNorm     *nn.LayerNorm `gguf:"post_layernorm"`
+
+	Layers []VisionEncoderLayer `gguf:"blk"`
+
+	*VisionModelOptions
+}
+
+func (m *VisionModel) Forward(ctx ml.Context, pixelValues ml.Tensor) ml.Tensor {
+	numPatches := (m.imageSize / m.patchSize) * (m.imageSize / m.patchSize)
+
+	hiddenState := m.PatchEmbedding.Forward(ctx, pixelValues, m.patchSize, m.patchSize, 0, 0, 1, 1)
+	hiddenState = hiddenState.Reshape(ctx, numPatches, m.hiddenSize)
+	hiddenState = hiddenState.Permute(ctx, 1, 0, 2, 3).Contiguous(ctx)
+
+	positions := make([]int32, numPatches)
+	for i := range positions {
+		positions[i] = int32(i)
+	}
+
+	positionIDs, err := ctx.Input().FromIntSlice(positions, len(positions))
+	if err != nil {
+		panic(err)
+	}
+
+	hiddenState = hiddenState.Add(ctx, m.PositionEmbedding.Forward(ctx, positionIDs))
+
+	for _, layer := range m.Layers {
+		hiddenState = layer.Forward(ctx, hiddenState, m.VisionModelOptions)
+	}
+
+	hiddenState = m.PostLayerNorm.Forward(ctx, hiddenState, m.eps)
+	return hiddenState
+}
+
+func newVisionModel(c ml.Config) *VisionModel {
+	return &VisionModel{
+		Layers: make([]VisionEncoderLayer, c.Uint("vision.block_count")),
+		VisionModelOptions: &VisionModelOptions{
+			hiddenSize: int(c.Uint("vision.embedding_length")),
+			numHeads:   int(c.Uint("vision.attention.head_count")),
+
+			imageSize: int(c.Uint("vision.image_size")),
+			patchSize: int(c.Uint("vision.patch_size")),
+
+			eps: c.Float("vision.attention.layer_norm_epsilon"),
+		},
+	}
+}

+ 58 - 0
model/models/gemma3/process_image.go

@@ -0,0 +1,58 @@
+package gemma3
+
+import (
+	"image"
+
+	"github.com/ollama/ollama/ml"
+	"github.com/ollama/ollama/model/imageproc"
+)
+
+type ImageProcessor struct {
+	imageSize, patchSize, numChannels int
+}
+
+func newImageProcessor(c ml.Config) ImageProcessor {
+	return ImageProcessor{
+		imageSize:   int(c.Uint("vision.image_size")),
+		patchSize:   int(c.Uint("vision.patch_size")),
+		numChannels: int(c.Uint("vision.num_channels")),
+	}
+}
+
+func (p *ImageProcessor) pack(img image.Image, mean, std [3]float32) []float32 {
+	var pixelVals, rVals, gVals, bVals []float32
+
+	bounds := img.Bounds()
+	for y := bounds.Min.Y; y < bounds.Max.Y; y++ {
+		for x := bounds.Min.X; x < bounds.Max.X; x++ {
+			c := img.At(x, y)
+			r, g, b, _ := c.RGBA()
+			rVal := float32(r>>8) / 255.0
+			gVal := float32(g>>8) / 255.0
+			bVal := float32(b>>8) / 255.0
+
+			rVal = (rVal - mean[0]) / std[0]
+			gVal = (gVal - mean[1]) / std[1]
+			bVal = (bVal - mean[2]) / std[2]
+
+			rVals = append(rVals, rVal)
+			gVals = append(gVals, gVal)
+			bVals = append(bVals, bVal)
+		}
+	}
+
+	pixelVals = append(pixelVals, rVals...)
+	pixelVals = append(pixelVals, gVals...)
+	pixelVals = append(pixelVals, bVals...)
+
+	return pixelVals
+}
+
+func (p ImageProcessor) ProcessImage(img image.Image) ([]float32, error) {
+	outputSize := image.Point{p.imageSize, p.imageSize}
+	newImage := imageproc.Composite(img)
+	newImage = imageproc.Resize(newImage, outputSize, imageproc.ResizeBilinear)
+
+	data := p.pack(newImage, imageproc.ImageNetStandardMean, imageproc.ImageNetStandardSTD)
+	return data, nil
+}

+ 4 - 3
model/models/llama/model.go

@@ -76,14 +76,15 @@ type SelfAttention struct {
 func (sa *SelfAttention) Forward(ctx ml.Context, hiddenState, positionIDs ml.Tensor, cache kvcache.Cache, opts *Options) ml.Tensor {
 	batchSize := hiddenState.Dim(1)
 	headDim := opts.hiddenSize / opts.numHeads
+	ropeType := uint32(0)
 
 	q := sa.Query.Forward(ctx, hiddenState)
 	q = q.Reshape(ctx, headDim, opts.numHeads, batchSize)
-	q = q.RoPE(ctx, positionIDs, sa.RopeFactors, opts.ropeDim, opts.ropeBase, opts.ropeScale)
+	q = q.RoPE(ctx, positionIDs, sa.RopeFactors, opts.ropeDim, ropeType, opts.ropeBase, opts.ropeScale)
 
 	k := sa.Key.Forward(ctx, hiddenState)
 	k = k.Reshape(ctx, headDim, opts.numKVHeads, batchSize)
-	k = k.RoPE(ctx, positionIDs, sa.RopeFactors, opts.ropeDim, opts.ropeBase, opts.ropeScale)
+	k = k.RoPE(ctx, positionIDs, sa.RopeFactors, opts.ropeDim, ropeType, opts.ropeBase, opts.ropeScale)
 
 	v := sa.Value.Forward(ctx, hiddenState)
 	v = v.Reshape(ctx, headDim, opts.numKVHeads, batchSize)
@@ -96,7 +97,7 @@ func (sa *SelfAttention) Forward(ctx ml.Context, hiddenState, positionIDs ml.Ten
 }
 
 func (m *Model) Shift(ctx ml.Context, layer int, key, shift ml.Tensor) (ml.Tensor, error) {
-	return key.RoPE(ctx, shift, m.Layers[layer].SelfAttention.RopeFactors, m.ropeDim, m.ropeBase, m.ropeScale), nil
+	return key.RoPE(ctx, shift, m.Layers[layer].SelfAttention.RopeFactors, uint32(0), m.ropeDim, m.ropeBase, m.ropeScale), nil
 }
 
 type MLP struct {

+ 5 - 3
model/models/mllama/model_text.go

@@ -20,14 +20,15 @@ type TextSelfAttention struct {
 func (sa *TextSelfAttention) Forward(ctx ml.Context, hiddenState, positions, _ ml.Tensor, cache *kvcache.WrapperCache, opts *TextModelOptions) ml.Tensor {
 	batchSize := hiddenState.Dim(1)
 	headDim := opts.hiddenSize / opts.numHeads
+	ropeType := uint32(0)
 
 	query := sa.Query.Forward(ctx, hiddenState)
 	query = query.Reshape(ctx, headDim, opts.numHeads, batchSize)
-	query = query.RoPE(ctx, positions, sa.RopeFactors, opts.ropeDim, opts.ropeBase, opts.ropeScale)
+	query = query.RoPE(ctx, positions, sa.RopeFactors, opts.ropeDim, ropeType, opts.ropeBase, opts.ropeScale)
 
 	key := sa.Key.Forward(ctx, hiddenState)
 	key = key.Reshape(ctx, headDim, opts.numKVHeads, batchSize)
-	key = key.RoPE(ctx, positions, sa.RopeFactors, opts.ropeDim, opts.ropeBase, opts.ropeScale)
+	key = key.RoPE(ctx, positions, sa.RopeFactors, opts.ropeDim, ropeType, opts.ropeBase, opts.ropeScale)
 
 	value := sa.Value.Forward(ctx, hiddenState)
 	value = value.Reshape(ctx, headDim, opts.numKVHeads, batchSize)
@@ -40,8 +41,9 @@ func (sa *TextSelfAttention) Forward(ctx ml.Context, hiddenState, positions, _ m
 }
 
 func (m *TextModel) Shift(ctx ml.Context, layer int, key, shift ml.Tensor) (ml.Tensor, error) {
+	// This will only get called for layers in the cache, which are just the self attention layers
 	if sa, ok := m.Transformer.Layers[layer].(*TextSelfAttentionDecoderLayer); ok {
-		return key.RoPE(ctx, shift, sa.SelfAttention.RopeFactors, m.ropeDim, m.ropeBase, m.ropeScale), nil
+		return key.RoPE(ctx, shift, sa.SelfAttention.RopeFactors, m.ropeDim, uint32(0), m.ropeBase, m.ropeScale), nil
 	}
 
 	return key, nil

+ 0 - 2
model/models/mllama/process_image.go

@@ -144,8 +144,6 @@ func (p *ImageProcessor) splitToTiles(img image.Image, numTilesSize image.Point)
 	return images
 }
 
-// remove the "alpha" channel by drawing over a prefilled image
-//
 // remove the "alpha" channel by drawing over a prefilled image
 //
 //nolint:unused

+ 2 - 0
model/models/models.go

@@ -1,6 +1,8 @@
 package models
 
 import (
+	_ "github.com/ollama/ollama/model/models/gemma2"
+	_ "github.com/ollama/ollama/model/models/gemma3"
 	_ "github.com/ollama/ollama/model/models/llama"
 	_ "github.com/ollama/ollama/model/models/mllama"
 )

+ 17 - 5
model/process_text.go

@@ -4,6 +4,7 @@ import (
 	"cmp"
 	"iter"
 	"log/slog"
+	"slices"
 	"strings"
 	"sync"
 
@@ -18,6 +19,15 @@ const (
 	SpecialEOS
 )
 
+const (
+	TOKEN_TYPE_NORMAL = iota + 1
+	TOKEN_TYPE_UNKNOWN
+	TOKEN_TYPE_CONTROL
+	TOKEN_TYPE_USER_DEFINED
+	TOKEN_TYPE_UNUSED
+	TOKEN_TYPE_BYTE
+)
+
 type TextProcessor interface {
 	Encode(s string, addSpecial bool) ([]int32, error)
 	Decode([]int32) (string, error)
@@ -27,11 +37,11 @@ type TextProcessor interface {
 type Vocabulary struct {
 	Values []string
 	Types  []uint32
-	Scores []uint32
+	Scores []float32
 	Merges []string
 
-	BOS, EOS       int32
-	AddBOS, AddEOS bool
+	BOS, EOS, EOT          int32
+	AddBOS, AddEOS, AddEOT bool
 
 	specialOnce sync.Once
 	special     []string
@@ -48,7 +58,7 @@ func (v *Vocabulary) Is(id int32, special Special) bool {
 	case SpecialBOS:
 		return id == v.BOS
 	case SpecialEOS:
-		return id == v.EOS
+		return id == v.EOS || id == v.EOT
 	default:
 		return false
 	}
@@ -76,7 +86,9 @@ func (v *Vocabulary) Decode(id int32) string {
 func (v *Vocabulary) SpecialVocabulary() []string {
 	v.specialOnce.Do(func() {
 		for i := range v.Values {
-			if v.Types[i] == 3 {
+			if slices.Contains([]int{105, 106}, i) {
+				v.special = append(v.special, v.Values[i])
+			} else if v.Types[i] == TOKEN_TYPE_CONTROL {
 				v.special = append(v.special, v.Values[i])
 			}
 		}

+ 246 - 0
model/process_text_spm.go

@@ -0,0 +1,246 @@
+package model
+
+import (
+	"iter"
+	"log/slog"
+	"strings"
+
+	"github.com/dlclark/regexp2"
+	queue "github.com/emirpasic/gods/v2/queues/priorityqueue"
+)
+
+const spmWhitespaceSep = "▁"
+
+func replaceWhitespaceBySeperator(s string) string {
+	return strings.ReplaceAll(s, " ", spmWhitespaceSep)
+}
+
+type SentencePieceModel struct {
+	maxTokenLen int
+	pre         *regexp2.Regexp
+	vocab       *Vocabulary
+}
+
+var _ TextProcessor = (*SentencePieceModel)(nil)
+
+func NewSentencePieceModel(pre string, vocab *Vocabulary) SentencePieceModel {
+	slog.Debug("Tokens", "num tokens", len(vocab.Values), "vals", vocab.Values[:5], "scores", vocab.Scores[:5], "types", vocab.Types[:5])
+
+	counter := map[int]int{}
+	var maxTokenLen int
+	for cnt := range vocab.Types {
+		switch vocab.Types[cnt] {
+		case TOKEN_TYPE_NORMAL, TOKEN_TYPE_USER_DEFINED, TOKEN_TYPE_UNUSED:
+			maxTokenLen = max(maxTokenLen, len(vocab.Values[cnt]))
+			fallthrough
+		default:
+			counter[int(vocab.Types[cnt])] += 1
+		}
+	}
+
+	slog.Debug("Token counts", "normal", counter[TOKEN_TYPE_NORMAL], "unknown", counter[TOKEN_TYPE_UNKNOWN], "control", counter[TOKEN_TYPE_CONTROL],
+		"user defined", counter[TOKEN_TYPE_USER_DEFINED], "unused", counter[TOKEN_TYPE_UNUSED], "byte", counter[TOKEN_TYPE_BYTE],
+		"max token len", maxTokenLen)
+
+	return SentencePieceModel{
+		maxTokenLen: maxTokenLen,
+		pre:         regexp2.MustCompile(pre, regexp2.Unicode|regexp2.RE2),
+		vocab:       vocab,
+	}
+}
+
+func (spm SentencePieceModel) Is(id int32, special Special) bool {
+	return spm.vocab.Is(id, special)
+}
+
+func (spm *SentencePieceModel) split(s string) iter.Seq[string] {
+	return func(yield func(string) bool) {
+		for m, _ := spm.pre.FindStringMatch(s); m != nil; m, _ = spm.pre.FindNextMatch(m) {
+			if !yield(m.String()) {
+				break
+			}
+		}
+	}
+}
+
+func (spm SentencePieceModel) Encode(s string, addSpecial bool) ([]int32, error) {
+	fragments := []fragment{{value: s}}
+	for _, special := range spm.vocab.SpecialVocabulary() {
+		// TODO: process special tokens concurrently
+		id := spm.vocab.Encode(special)
+		for i := 0; i < len(fragments); i++ {
+			frag := fragments[i]
+			if len(frag.ids) > 0 {
+				continue
+			}
+
+			var middle []fragment
+			switch i := strings.Index(frag.value, special); {
+			case i < 0:
+				middle = append(middle, frag)
+			case i > 0:
+				middle = append(middle, fragment{value: frag.value[:i]})
+				fallthrough
+			default:
+				middle = append(middle, fragment{value: special, ids: []int32{id}})
+				if rest := frag.value[i+len(special):]; rest != "" {
+					middle = append(middle, fragment{value: rest})
+				}
+			}
+
+			fragments = append(fragments[:i], append(middle, fragments[i+1:]...)...)
+		}
+	}
+	slog.Debug("fragments", "frags", fragments)
+
+	var ids []int32
+	for _, frag := range fragments {
+		if len(frag.ids) > 0 {
+			ids = append(ids, frag.ids...)
+			continue
+		}
+
+		for split := range spm.split(frag.value) {
+			split = replaceWhitespaceBySeperator(split)
+
+			var sb strings.Builder
+			sb.Write([]byte(split))
+			if id := spm.vocab.Encode(sb.String()); id >= 0 {
+				ids = append(ids, id)
+				continue
+			}
+
+			runes := []rune(sb.String())
+			pq := queue.NewWith(func(a, b any) int {
+				priA := a.(*candidate)
+				priB := b.(*candidate)
+				if priA.score > priB.score || (priA.score == priB.score && priA.a < priB.a) {
+					return -1
+				}
+				return 1
+			})
+
+			merges := make([]merge, len(runes))
+			for r := range runes {
+				merges[r] = merge{
+					p:     r - 1,
+					n:     r + 1,
+					runes: []rune{runes[r]},
+				}
+			}
+
+			slog.Debug("tokenizer", "merges", merges)
+
+			pairwise := func(a, b int) *candidate {
+				if a < 0 || b >= len(runes) {
+					return nil
+				}
+
+				left, right := string(merges[a].runes), string(merges[b].runes)
+				if id := spm.vocab.Encode(left + right); id >= 0 {
+					return &candidate{
+						a:     a,
+						b:     b,
+						score: spm.vocab.Scores[id],
+					}
+				}
+				return nil
+			}
+
+			for i := range len(runes) - 1 {
+				if pair := pairwise(i, i+1); pair != nil {
+					pq.Enqueue(pair)
+				}
+			}
+
+			pqv := pq.Values()
+			for _, v := range pqv {
+				e := v.(*candidate)
+				slog.Debug("candidate", "candidate", e)
+			}
+
+			for !pq.Empty() {
+				v, _ := pq.Dequeue()
+				pair := v.(*candidate)
+				left, right := merges[pair.a], merges[pair.b]
+
+				slog.Debug("pair", "left", left, "right", right)
+				if len(left.runes) == 0 || len(right.runes) == 0 {
+					continue
+				}
+
+				if id := spm.vocab.Encode(string(left.runes) + string(right.runes)); id < 0 {
+					continue
+				}
+
+				merges[pair.a].runes = append(left.runes, right.runes...)
+				merges[pair.b].runes = nil
+				merges[pair.a].n = right.n
+				if right.n < len(merges) {
+					merges[right.n].p = pair.a
+				}
+
+				if pair := pairwise(merges[pair.a].p, pair.a); pair != nil {
+					pq.Enqueue(pair)
+				}
+
+				if pair := pairwise(pair.a, merges[pair.a].n); pair != nil {
+					pq.Enqueue(pair)
+				}
+			}
+
+			slog.Debug("merges", "merges", merges)
+
+			for _, merge := range merges {
+				if len(merge.runes) > 0 {
+					if id := spm.vocab.Encode(string(merge.runes)); id >= 0 {
+						ids = append(ids, id)
+					} else {
+						slog.Debug("missing token", "token", string(merge.runes))
+					}
+				}
+			}
+		}
+	}
+
+	if addSpecial && len(ids) > 0 {
+		if spm.vocab.AddBOS {
+			if ids[0] == spm.vocab.BOS {
+				slog.Warn("adding bos token to prompt which already has it", "id", spm.vocab.BOS)
+			}
+
+			slog.Debug("adding bos token to prompt", "id", spm.vocab.BOS)
+			ids = append([]int32{spm.vocab.BOS}, ids...)
+		}
+
+		if spm.vocab.AddEOS {
+			if ids[len(ids)-1] == spm.vocab.EOS {
+				slog.Warn("adding eos token to prompt which already has it", "id", spm.vocab.EOS)
+			}
+
+			slog.Debug("adding eos token to prompt", "id", spm.vocab.EOS)
+			ids = append(ids, spm.vocab.EOS)
+		}
+	}
+
+	return ids, nil
+}
+
+type candidate struct {
+	a, b  int
+	score float32
+}
+
+func (spm SentencePieceModel) Decode(ids []int32) (string, error) {
+	var sb strings.Builder
+	for _, id := range ids {
+		data := spm.vocab.Decode(id)
+		data = strings.ReplaceAll(data, spmWhitespaceSep, " ")
+		if _, err := sb.WriteString(data); err != nil {
+			return "", err
+		}
+	}
+
+	slog.Debug("decoded", "ids", ids, "text", sb.String())
+	return sb.String(), nil
+}

+ 118 - 0
model/process_text_spm_test.go

@@ -0,0 +1,118 @@
+package model
+
+import (
+	"log/slog"
+	"os"
+	"path/filepath"
+	"slices"
+	"testing"
+
+	"google.golang.org/protobuf/proto"
+
+	"github.com/ollama/ollama/convert/sentencepiece"
+)
+
+func loadSentencePieceVocab(t *testing.T) SentencePieceModel {
+	t.Helper()
+
+	bts, err := os.ReadFile(filepath.Join("testdata", "gemma2", "tokenizer.model"))
+	if err != nil {
+		t.Fatal(err)
+	}
+
+	var spm sentencepiece.ModelProto
+	if err := proto.Unmarshal(bts, &spm); err != nil {
+		t.Fatal(err)
+	}
+
+	preTokenizer := `(?i:'s|'t|'re|'ve|'m|'ll|'d)|[^\r\n\p{L}\p{N}]?\p{L}+|\p{N}{1,3}| ?[^\s\p{L}\p{N}]+[\r\n]*|\s*[\r\n]+|\s+(?!\S)|\s+`
+
+	var v Vocabulary
+
+	for _, piece := range spm.GetPieces() {
+		v.Values = append(v.Values, piece.GetPiece())
+		v.Scores = append(v.Scores, piece.GetScore())
+		switch t := piece.GetType(); t {
+		case sentencepiece.ModelProto_SentencePiece_UNKNOWN,
+			sentencepiece.ModelProto_SentencePiece_CONTROL,
+			sentencepiece.ModelProto_SentencePiece_UNUSED,
+			sentencepiece.ModelProto_SentencePiece_BYTE:
+			v.Types = append(v.Types, uint32(t))
+		default:
+			tt := uint32(sentencepiece.ModelProto_SentencePiece_NORMAL)
+			// todo parse the special tokens file
+			//   - this will roundtrip correctly but the <start_of_turn> and
+			//     <end_of_turn> tokens aren't processed
+			v.Types = append(v.Types, tt)
+		}
+	}
+
+	return NewSentencePieceModel(preTokenizer, &v)
+}
+
+func TestSentencePieceEncode(t *testing.T) {
+	logger := slog.New(slog.NewTextHandler(os.Stdout, &slog.HandlerOptions{Level: slog.LevelDebug}))
+	slog.SetDefault(logger)
+
+	tokenizer := loadSentencePieceVocab(t)
+
+	t.Run("basic roundtrip", func(t *testing.T) {
+		t.Parallel()
+
+		cases := []string{
+			"hello",
+			"hello ",
+			"hello  ",
+			" hello",
+			" hello ",
+			" hello  ",
+			"hello world",
+			"请考试我的软件!12345",
+			"你好",
+			"Hello 你好 world!",
+			"Special characters: !@#$%^&*()_+-=[]{}|;':\",./<>?",
+			"Multilingual: 你好 こんにちは Привет Hola مرحبا",
+			"Numbers and symbols: 123456789 +- */",
+			"Special tokens: <bos> text <eos>",
+			"Code snippets: func main() { fmt.Println(\"Hello World\") }",
+			"Long text: " + "Lorem ipsum dolor sit amet, consectetur adipiscing elit. " +
+				"Sed do eiusmod tempor incididunt ut labore et dolore magna aliqua. " +
+				"Ut enim ad minim veniam, quis nostrud exercitation ullamco laboris.",
+		}
+
+		for _, want := range cases {
+			ids, err := tokenizer.Encode(want, true)
+			if err != nil {
+				t.Fatal(err)
+			}
+
+			if got, err := tokenizer.Decode(ids); err != nil {
+				t.Fatal(err)
+			} else if got != want {
+				t.Errorf("got %q, want %q [%#v]", got, want, ids)
+			}
+		}
+	})
+
+	t.Run("special tokens", func(t *testing.T) {
+		type candidate struct {
+			token string
+			ids   []int32
+		}
+
+		cases := []candidate{
+			{"<bos>", []int32{2}},
+			{"<eos>", []int32{1}},
+		}
+
+		for _, want := range cases {
+			ids, err := tokenizer.Encode(want.token, true)
+			if err != nil {
+				t.Fatal(err)
+			}
+			if !slices.Equal(ids, want.ids) {
+				t.Errorf("got %#v, want %#v", ids, want.ids)
+			}
+		}
+	})
+}

BIN
model/testdata/gemma2/tokenizer.model


+ 11 - 1
server/prompt.go

@@ -26,6 +26,7 @@ func chatPrompt(ctx context.Context, m *Model, tokenize tokenizeFunc, opts *api.
 	var system []api.Message
 
 	isMllama := checkMllamaModelFamily(m)
+	isGemma3 := checkGemma3ModelFamily(m)
 
 	var imageNumTokens int
 	// TODO: Ideally we would compute this from the projector metadata but some pieces are implementation dependent
@@ -40,7 +41,7 @@ func chatPrompt(ctx context.Context, m *Model, tokenize tokenizeFunc, opts *api.
 	n := len(msgs) - 1
 	// in reverse, find all messages that fit into context window
 	for i := n; i >= 0; i-- {
-		if isMllama && len(msgs[i].Images) > 1 {
+		if (isMllama || isGemma3) && len(msgs[i].Images) > 1 {
 			return "", nil, errTooManyImages
 		}
 
@@ -157,3 +158,12 @@ func checkMllamaModelFamily(m *Model) bool {
 	}
 	return false
 }
+
+func checkGemma3ModelFamily(m *Model) bool {
+	for _, arch := range m.Config.ModelFamilies {
+		if arch == "gemma3" {
+			return true
+		}
+	}
+	return false
+}