Kaynağa Gözat

add gemma vision encoder

Michael Yang 1 ay önce
ebeveyn
işleme
4b037a97dc

+ 3 - 3
convert/convert.go

@@ -13,13 +13,13 @@ import (
 )
 
 type ModelParameters struct {
-	Architectures []string `json:"architectures"`
-	VocabSize     uint32   `json:"vocab_size"`
+	Architectures []string       `json:"architectures"`
+	VocabSize     uint32         `json:"vocab_size"`
 	TextModel     TextParameters `json:"text_config"`
 }
 
 type TextParameters struct {
-	VocabSize     uint32   `json:"vocab_size"`
+	VocabSize uint32 `json:"vocab_size"`
 }
 
 type AdapterParameters struct {

+ 26 - 13
convert/convert_gemma3.go

@@ -4,8 +4,17 @@ import "github.com/ollama/ollama/fs/ggml"
 
 type gemma3Model struct {
 	gemmaModel
-	TextModel   gemma3TextModel   `json:"text_config"`
-	VisionModel gemma3VisionModel `json:"vision_config"`
+	TextModel   gemma3TextModel `json:"text_config"`
+	VisionModel struct {
+		NumAttentionHeads uint32  `json:"num_attention_heads"` // attention.head_count 16
+		LayerNormEpsilon  float32 `json:"layer_norm_eps"`      // attention.layer_norm_epsilon 1e-05
+		NumHiddenLayers   uint32  `json:"num_hidden_layers"`   // block_count 32
+		HiddenSize        uint32  `json:"hidden_size"`         // embedding_length 1280
+		IntermediateSize  uint32  `json:"intermediate_size"`   // feed_forward_length 5120
+		ImageSize         uint32  `json:"image_size"`          // image_size 560
+		NumChannels       uint32  `json:"num_channels"`        // num_channels 3
+		PatchSize         uint32  `json:"patch_size"`          // patch_size 14
+	} `json:"vision_config"`
 }
 
 type gemma3TextModel struct {
@@ -24,12 +33,6 @@ type gemma3TextModel struct {
 	RopeGlobalTheta       float32 `json:"rope_global_base_freq"`
 }
 
-type gemma3VisionModel struct {
-	ImageSize    uint32 `json:"image_size"`
-	NumChannels  uint32 `json:"num_channels"`
-	HiddenLayers uint32 `json:"num_hidden_layers"`
-}
-
 func (p *gemma3Model) KV(t *Tokenizer) ggml.KV {
 	kv := p.ModelParameters.KV(t)
 	kv["general.architecture"] = "gemma3"
@@ -46,11 +49,18 @@ func (p *gemma3Model) KV(t *Tokenizer) ggml.KV {
 	kv["gemma3.text.final_logit_softcapping"] = p.TextModel.FinalLogitSoftcap
 	kv["gemma3.text.rope.local.freq_base"] = p.TextModel.RopeLocalTheta
 	kv["gemma3.text.rope.global.freq_base"] = p.TextModel.RopeGlobalTheta
-	kv["tokenizer.ggml.bos_token_id"] = uint32(2)
-	kv["tokenizer.ggml.eot_token_id"] = uint32(1)
+
+	kv["gemma3.vision.block_count"] = p.VisionModel.NumHiddenLayers
+	kv["gemma3.vision.embedding_length"] = p.VisionModel.HiddenSize
+	kv["gemma3.vision.feed_forward_length"] = p.VisionModel.IntermediateSize
 	kv["gemma3.vision.image_size"] = p.VisionModel.ImageSize
+	kv["gemma3.vision.patch_size"] = p.VisionModel.PatchSize
 	kv["gemma3.vision.num_channels"] = p.VisionModel.NumChannels
-	kv["gemma3.vision.block_count"] = p.VisionModel.HiddenLayers
+	kv["gemma3.vision.attention.head_count"] = p.VisionModel.NumAttentionHeads
+	kv["gemma3.vision.attention.layer_norm_epsilon"] = p.VisionModel.LayerNormEpsilon
+
+	kv["tokenizer.ggml.bos_token_id"] = uint32(2)
+	kv["tokenizer.ggml.eot_token_id"] = uint32(1)
 	return kv
 }
 
@@ -59,11 +69,11 @@ func (p *gemma3Model) Replacements() []string {
 		"lm_head", "output",
 		"model.embed_tokens", "token_embd",
 		"model.norm", "output_norm",
-		"vision_model.vision_model", "v",
+		"vision_tower.vision_model.embeddings", "v",
+		"vision_tower.vision_model", "v",
 		"language_model.", "",
 		"model.layers", "blk",
 		"encoder.layers", "blk",
-		"vision_tower.vision_model.embeddings", "v",
 		"input_layernorm", "attn_norm",
 		"self_attn.q_proj", "attn_q",
 		"self_attn.q_norm", "attn_q_norm",
@@ -71,11 +81,14 @@ func (p *gemma3Model) Replacements() []string {
 		"self_attn.k_norm", "attn_k_norm",
 		"self_attn.v_proj", "attn_v",
 		"self_attn.o_proj", "attn_output",
+		"self_attn.out_proj", "attn_output",
 		"mlp.gate_proj", "ffn_gate",
 		"mlp.down_proj", "ffn_down",
 		"mlp.up_proj", "ffn_up",
 		"post_attention_layernorm", "post_attention_norm",
 		"pre_feedforward_layernorm", "ffn_norm",
 		"post_feedforward_layernorm", "post_ffw_norm",
+		"input_projection_weight", "input_projection.weight",
+		"multi_modal_projector", "mm",
 	}
 }

+ 2 - 0
ml/backend.go

@@ -135,7 +135,9 @@ type Tensor interface {
 	RMSNorm(ctx Context, weight Tensor, eps float32) Tensor
 	Scale(ctx Context, s float64) Tensor
 
+	AvgPool1D(ctx Context, k, s, p int) Tensor
 	Conv2D(ctx Context, weight Tensor, s0, s1, p0, p1, d0, d1 int) Tensor
+
 	RoPE(ctx Context, positionIDs, ropeFactors Tensor, dim, ropeType uint32, base, scale float32) Tensor
 
 	Tanh(ctx Context) Tensor

+ 7 - 0
ml/backend/ggml/ggml.go

@@ -947,6 +947,13 @@ func (t *Tensor) Conv2D(ctx ml.Context, t2 ml.Tensor, s0, s1, p0, p1, d0, d1 int
 	}
 }
 
+func (t *Tensor) AvgPool1D(ctx ml.Context, k, s, p int) ml.Tensor {
+	return &Tensor{
+		b: t.b,
+		t: C.ggml_pool_1d(ctx.(*Context).ctx, t.t, C.GGML_OP_POOL_AVG, C.int(k), C.int(s), C.int(p)),
+	}
+}
+
 func (t *Tensor) ScaledDotProductAttention(ctx ml.Context, key, value, mask ml.Tensor, scale float64) ml.Tensor {
 	var kqMask *C.struct_ggml_tensor
 	if mask != nil {

+ 97 - 10
model/models/gemma3/model.go

@@ -1,10 +1,15 @@
 package gemma3
 
 import (
-	"fmt"
+	"bytes"
+	"encoding/binary"
+	"hash/fnv"
+	"image"
+	"slices"
 
 	"github.com/ollama/ollama/kvcache"
 	"github.com/ollama/ollama/ml"
+	"github.com/ollama/ollama/ml/nn"
 	"github.com/ollama/ollama/model"
 	"github.com/ollama/ollama/model/input"
 )
@@ -13,19 +18,30 @@ type Model struct {
 	model.Base
 	model.SentencePieceModel
 
-	//*VisionModel `gguf:"v,vision"`
+	*VisionModel `gguf:"v,vision"`
 	*TextModel
 
-	//Projector *nn.Linear `gguf:"mm.0"`
+	*MultiModalProjector `gguf:"mm"`
 
 	ImageProcessor
 }
 
+var _ model.MultimodalProcessor = (*Model)(nil)
+
+type MultiModalProjector struct {
+	SoftEmbNorm     *nn.RMSNorm `gguf:"mm_soft_emb_norm"`
+	InputProjection *nn.Linear  `gguf:"mm_input_projection"`
+}
+
+func (p *MultiModalProjector) Forward(ctx ml.Context, visionOutputs ml.Tensor, eps float32) ml.Tensor {
+	visionOutputs = p.SoftEmbNorm.Forward(ctx, visionOutputs, eps)
+
+	// TODO: inputProjection must be transposed since they're incompatible with visionOutputs
+	visionOutputs = p.InputProjection.Weight.Permute(ctx, 1, 0, 2, 3).Contiguous(ctx).Mulmat(ctx, visionOutputs)
+	return visionOutputs
+}
+
 func New(c ml.Config) (model.Model, error) {
-	// Verify unified config
-	if c.Uint("vision.block_count") == 0 {
-		return nil, fmt.Errorf("non-unified vision model not supported")
-	}
 	m := Model{
 		SentencePieceModel: model.NewSentencePieceModel(
 			c.String("tokenizer.ggml.pretokenizer", `(?i:'s|'t|'re|'ve|'m|'ll|'d)|[^\r\n\p{L}\p{N}]?\p{L}+|\p{N}{1,3}| ?[^\s\p{L}\p{N}]+[\r\n]*|\s*[\r\n]+|\s+(?!\S)|\s+`),
@@ -40,8 +56,8 @@ func New(c ml.Config) (model.Model, error) {
 			},
 		),
 		ImageProcessor: newImageProcessor(c),
-		//VisionModel:    newVisionModel(c),
-		TextModel: newTextModel(c),
+		VisionModel:    newVisionModel(c),
+		TextModel:      newTextModel(c),
 	}
 
 	slidingWindowLen := int32(c.Uint("text.attention.sliding_window"))
@@ -50,7 +66,78 @@ func New(c ml.Config) (model.Model, error) {
 	return &m, nil
 }
 
+func (m *Model) EncodeMultimodal(ctx ml.Context, multimodalData []byte) (any, error) {
+	image, _, err := image.Decode(bytes.NewReader(multimodalData))
+	if err != nil {
+		return nil, err
+	}
+
+	f32s, err := m.ImageProcessor.ProcessImage(image)
+	if err != nil {
+		return nil, err
+	}
+
+	pixelValues, err := ctx.Input().FromFloatSlice(f32s,
+		m.ImageProcessor.imageSize,
+		m.ImageProcessor.imageSize,
+		m.ImageProcessor.numChannels,
+	)
+	if err != nil {
+		return nil, err
+	}
+
+	positionIDs, err := ctx.FromIntSlice([]int32{0}, 1)
+	if err != nil {
+		return nil, err
+	}
+
+	visionOutputs := m.VisionModel.Forward(ctx, pixelValues, positionIDs)
+
+	visionOutputs = visionOutputs.Permute(ctx, 1, 0, 2, 3).Contiguous(ctx)
+	patchesPerImage := m.ImageProcessor.imageSize / m.ImageProcessor.patchSize
+	kernelSize := patchesPerImage * patchesPerImage / 256
+	visionOutputs = visionOutputs.AvgPool1D(ctx, kernelSize, kernelSize, 0)
+
+	visionOutputs = visionOutputs.Permute(ctx, 1, 0, 2, 3).Contiguous(ctx)
+	visionOutputs = m.MultiModalProjector.Forward(ctx, visionOutputs, m.VisionModel.eps)
+	return visionOutputs, nil
+}
+
+func (m *Model) PostTokenize(ctx ml.Context, inputs []input.Input) ([]input.Input, error) {
+	var images []input.Input
+	fnvHash := fnv.New64a()
+
+	for i := range inputs {
+		if inputs[i].Multimodal == nil {
+			if len(images) > 0 {
+				inputs[i].Multimodal = images[0].Multimodal
+				inputs[i].MultimodalHash = images[0].MultimodalHash
+				for j := 1; j < len(images); j++ {
+					inputs[i].Multimodal = inputs[i].Multimodal.(ml.Tensor).Concat(ctx, images[j].Multimodal.(ml.Tensor), 3)
+					fnvHash.Reset()
+					binary.Write(fnvHash, binary.NativeEndian, inputs[i].MultimodalHash)
+					binary.Write(fnvHash, binary.NativeEndian, inputs[j].MultimodalHash)
+					inputs[i].MultimodalHash = fnvHash.Sum64()
+				}
+				images = nil
+			}
+		} else {
+			images = append(images, inputs[i])
+			inputs[i].Token = -1
+		}
+	}
+
+	inputs = slices.DeleteFunc(inputs, func(input input.Input) bool { return input.Token == -1 })
+
+	return inputs, nil
+}
+
 func (m *Model) Forward(ctx ml.Context, opts input.Options) (ml.Tensor, error) {
+	var embeddings ml.Tensor
+	if opts.Multimodal != nil {
+		embeddings = opts.Multimodal[0].Multimodal.(ml.Tensor)
+	}
+
 	inputs, err := ctx.Input().FromIntSlice(opts.Inputs, len(opts.Inputs))
 	if err != nil {
 		return nil, err
@@ -66,7 +153,7 @@ func (m *Model) Forward(ctx ml.Context, opts input.Options) (ml.Tensor, error) {
 		return nil, err
 	}
 
-	return m.TextModel.Forward(ctx, inputs, positions, outputs, m.Cache), nil
+	return m.TextModel.Forward(ctx, inputs, positions, embeddings, outputs, m.Cache), nil
 }
 
 func init() {

+ 6 - 3
model/models/gemma3/model_text.go

@@ -160,9 +160,12 @@ func (l *TextLayer) Forward(ctx ml.Context, layer int, hiddenState, positionIDs
 	return hiddenState.Add(ctx, residual)
 }
 
-func (m *TextModel) Forward(ctx ml.Context, inputs, positions, outputs ml.Tensor, cache kvcache.Cache) ml.Tensor {
-	hiddenState := m.TokenEmbedding.Forward(ctx, inputs)
-	hiddenState = hiddenState.Scale(ctx, math.Sqrt(float64(m.TextOptions.hiddenSize)))
+func (m *TextModel) Forward(ctx ml.Context, inputs, positions, embeddings, outputs ml.Tensor, cache kvcache.Cache) ml.Tensor {
+	if embeddings == nil {
+		embeddings = m.TokenEmbedding.Forward(ctx, inputs)
+	}
+
+	hiddenState := embeddings.Scale(ctx, math.Sqrt(float64(m.TextOptions.hiddenSize)))
 
 	if len(m.Layers) == gemma27BLayerCount {
 		m.TextOptions.largeModelScaling = true

+ 171 - 0
model/models/gemma3/model_vision.go

@@ -0,0 +1,171 @@
+package gemma3
+
+import (
+	"math"
+	"slices"
+
+	"github.com/ollama/ollama/ml"
+	"github.com/ollama/ollama/ml/nn"
+)
+
+var batchSize int = 1
+
+type VisionSelfAttention struct {
+	Query  *nn.Linear `gguf:"attn_q"`
+	Key    *nn.Linear `gguf:"attn_k"`
+	Value  *nn.Linear `gguf:"attn_v"`
+	Output *nn.Linear `gguf:"attn_output"`
+}
+
+func (sa *VisionSelfAttention) Forward(ctx ml.Context, hiddenState ml.Tensor, opts *VisionModelOptions) ml.Tensor {
+	headDim := opts.hiddenSize / opts.numHeads
+
+	query := sa.Query.Forward(ctx, hiddenState)
+	key := sa.Key.Forward(ctx, hiddenState)
+	value := sa.Value.Forward(ctx, hiddenState)
+
+	query = query.Reshape(ctx, headDim, opts.numHeads, query.Dim(1), batchSize).Permute(ctx, 0, 2, 1, 3)
+	key = key.Reshape(ctx, headDim, opts.numHeads, key.Dim(1), batchSize).Permute(ctx, 0, 2, 1, 3)
+	value = value.Reshape(ctx, headDim, opts.numHeads, value.Dim(1), batchSize).Permute(ctx, 1, 2, 0, 3).Contiguous(ctx)
+
+	scores := key.Mulmat(ctx, query)
+	scores = scores.Scale(ctx, 1.0/math.Sqrt(float64(headDim)))
+	scores = scores.Softmax(ctx)
+
+	attention := value.Mulmat(ctx, scores)
+	attention = attention.Reshape(ctx, headDim, attention.Dim(1), opts.numHeads, batchSize)
+	attention = attention.Permute(ctx, 0, 2, 1, 3).Contiguous(ctx)
+	attention = attention.Reshape(ctx, opts.hiddenSize, attention.Dim(2), batchSize)
+
+	hiddenState = sa.Output.Forward(ctx, attention)
+	return hiddenState
+}
+
+type VisionMLP struct {
+	FC1 *nn.Linear `gguf:"fc1"`
+	FC2 *nn.Linear `gguf:"fc2"`
+}
+
+func (mlp *VisionMLP) Forward(ctx ml.Context, hiddenState ml.Tensor, opts *VisionModelOptions) ml.Tensor {
+	hiddenState = mlp.FC1.Forward(ctx, hiddenState).GELU(ctx)
+	hiddenState = mlp.FC2.Forward(ctx, hiddenState)
+	return hiddenState
+}
+
+type VisionEncoderLayer struct {
+	LayerNorm1    *nn.LayerNorm        `gguf:"layer_norm1"`
+	SelfAttention *VisionSelfAttention
+
+	LayerNorm2 *nn.LayerNorm `gguf:"layer_norm2"`
+	MLP        *VisionMLP    `gguf:"mlp"`
+}
+
+func (e *VisionEncoderLayer) Forward(ctx ml.Context, hiddenState ml.Tensor, opts *VisionModelOptions) ml.Tensor {
+	residual := hiddenState
+
+	// self attention
+	hiddenState = e.LayerNorm1.Forward(ctx, hiddenState, opts.eps)
+	hiddenState = e.SelfAttention.Forward(ctx, hiddenState, opts)
+	hiddenState = hiddenState.Add(ctx, residual)
+	residual = hiddenState
+
+	// feed forward
+	hiddenState = e.LayerNorm2.Forward(ctx, hiddenState, opts.eps)
+	hiddenState = e.MLP.Forward(ctx, hiddenState, opts)
+	return hiddenState.Add(ctx, residual)
+}
+
+type VisionEncoder struct {
+	Layers []VisionEncoderLayer
+}
+
+func (e *VisionEncoder) Forward(ctx ml.Context, hiddenState ml.Tensor, intermediateLayersIndices []uint32, opts *VisionModelOptions) (ml.Tensor, []ml.Tensor) {
+	var intermediateHiddenStates []ml.Tensor
+	for i, layer := range e.Layers {
+		if slices.Contains(intermediateLayersIndices, uint32(i)) {
+			intermediateHiddenStates = append(intermediateHiddenStates, hiddenState.Reshape(ctx, append([]int{1}, hiddenState.Shape()...)...))
+		}
+
+		hiddenState = layer.Forward(ctx, hiddenState, opts)
+	}
+
+	return hiddenState, intermediateHiddenStates
+}
+
+type PrecomputedAspectRatioEmbedding struct {
+	Embedding *nn.Embedding
+	Gate      ml.Tensor `gguf:"gate"`
+}
+
+func (e *PrecomputedAspectRatioEmbedding) Forward(ctx ml.Context, hiddenState ml.Tensor, aspectRatioIDs ml.Tensor, opts *VisionModelOptions) ml.Tensor {
+	embeddings := e.Embedding.Forward(ctx, aspectRatioIDs)
+	embeddings = embeddings.Reshape(ctx, opts.hiddenSize, 1, opts.numTiles)
+	if e.Gate != nil {
+		embeddings = embeddings.Mul(ctx, e.Gate)
+	}
+
+	return hiddenState.Add(ctx, embeddings)
+}
+
+type PrecomputedPositionEmbedding struct {
+	PositionEmbedding     *nn.Embedding `gguf:"position_embd"`
+	PositionEmbeddingGate ml.Tensor     `gguf:"position_embd.gate"`
+}
+
+func (e *PrecomputedPositionEmbedding) Forward(ctx ml.Context, hiddenState, positionIDs ml.Tensor, numPositions int, opts *VisionModelOptions) ml.Tensor {
+	positionEmbedding := e.PositionEmbedding.Forward(ctx, positionIDs)
+	if e.PositionEmbeddingGate != nil {
+		positionEmbedding = positionEmbedding.Mul(ctx, e.PositionEmbeddingGate)
+	}
+
+	return hiddenState.Add(ctx, positionEmbedding)
+}
+
+type VisionModelOptions struct {
+	hiddenSize, numHeads, numTiles int
+	imageSize, patchSize           int
+	eps                            float32
+}
+
+type VisionModel struct {
+	PatchEmbedding    *nn.Conv2D    `gguf:"patch_embedding"`
+	PositionEmbedding *nn.Embedding `gguf:"position_embedding"`
+	PostLayerNorm     *nn.LayerNorm `gguf:"post_layernorm"`
+
+	Encoder *VisionEncoder `gguf:"blk"`
+
+	*VisionModelOptions
+}
+
+func (m *VisionModel) Forward(ctx ml.Context, pixelValues, positionIDs ml.Tensor) ml.Tensor {
+	numPatches := (m.imageSize / m.patchSize) * (m.imageSize / m.patchSize)
+
+	hiddenState := m.PatchEmbedding.Forward(ctx, pixelValues, m.patchSize, m.patchSize, 0, 0, 1, 1)
+	hiddenState = hiddenState.Reshape(ctx, numPatches, m.hiddenSize)
+	hiddenState = hiddenState.Permute(ctx, 1, 0, 2, 3).Contiguous(ctx)
+
+	positions := m.PositionEmbedding.Forward(ctx, positionIDs)
+	hiddenState = hiddenState.Add(ctx, positions)
+
+	for _, layer := range m.Encoder.Layers {
+		hiddenState = layer.Forward(ctx, hiddenState, m.VisionModelOptions)
+	}
+
+	hiddenState = m.PostLayerNorm.Forward(ctx, hiddenState, m.eps)
+	return hiddenState
+}
+
+func newVisionModel(c ml.Config) *VisionModel {
+	return &VisionModel{
+		Encoder: &VisionEncoder{Layers: make([]VisionEncoderLayer, c.Uint("vision.block_count"))},
+		VisionModelOptions: &VisionModelOptions{
+			hiddenSize: int(c.Uint("vision.embedding_length")),
+			numHeads:   int(c.Uint("vision.attention.head_count")),
+
+			imageSize: int(c.Uint("vision.image_size")),
+			patchSize: int(c.Uint("vision.patch_size")),
+
+			eps: c.Float("vision.attention.layer_norm_epsilon"),
+		},
+	}
+}

+ 2 - 1
model/models/gemma3/process_image.go

@@ -8,12 +8,13 @@ import (
 )
 
 type ImageProcessor struct {
-	imageSize, numChannels int
+	imageSize, patchSize, numChannels int
 }
 
 func newImageProcessor(c ml.Config) ImageProcessor {
 	return ImageProcessor{
 		imageSize:   int(c.Uint("vision.image_size")),
+		patchSize:   int(c.Uint("vision.patch_size")),
 		numChannels: int(c.Uint("vision.num_channels")),
 	}
 }

+ 0 - 2
model/models/mllama/process_image.go

@@ -144,8 +144,6 @@ func (p *ImageProcessor) splitToTiles(img image.Image, numTilesSize image.Point)
 	return images
 }
 
-// remove the "alpha" channel by drawing over a prefilled image
-//
 // remove the "alpha" channel by drawing over a prefilled image
 //
 //nolint:unused

+ 23 - 2
model/process_text_spm.go

@@ -21,6 +21,8 @@ type SentencePieceModel struct {
 	vocab       *Vocabulary
 }
 
+var _ TextProcessor = (*SentencePieceModel)(nil)
+
 func NewSentencePieceModel(pre string, vocab *Vocabulary) SentencePieceModel {
 	slog.Debug("Tokens", "num tokens", len(vocab.Values), "vals", vocab.Values[:5], "scores", vocab.Scores[:5], "types", vocab.Types[:5])
 
@@ -61,7 +63,7 @@ func (spm *SentencePieceModel) split(s string) iter.Seq[string] {
 	}
 }
 
-func (spm SentencePieceModel) Encode(s string) ([]int32, error) {
+func (spm SentencePieceModel) Encode(s string, addSpecial bool) ([]int32, error) {
 	fragments := []fragment{{value: s}}
 	for _, special := range spm.vocab.SpecialVocabulary() {
 		// TODO: process special tokens concurrently
@@ -196,7 +198,26 @@ func (spm SentencePieceModel) Encode(s string) ([]int32, error) {
 			}
 		}
 	}
-	slog.Debug("encoded", "ids", ids)
+
+	if addSpecial && len(ids) > 0 {
+		if spm.vocab.AddBOS {
+			if ids[0] == spm.vocab.BOS {
+				slog.Warn("adding bos token to prompt which already has it", "id", spm.vocab.BOS)
+			}
+
+			slog.Debug("adding bos token to prompt", "id", spm.vocab.BOS)
+			ids = append([]int32{spm.vocab.BOS}, ids...)
+		}
+
+		if spm.vocab.AddEOS {
+			if ids[len(ids)-1] == spm.vocab.EOS {
+				slog.Warn("adding eos token to prompt which already has it", "id", spm.vocab.EOS)
+			}
+
+			slog.Debug("adding eos token to prompt", "id", spm.vocab.EOS)
+			ids = append(ids, spm.vocab.EOS)
+		}
+	}
 
 	return ids, nil
 }