Browse Source

Fix follow up images and images split across batches

Jesse Gross 1 month ago
parent
commit
2c40c4d35e
2 changed files with 73 additions and 47 deletions
  1. 25 32
      model/models/gemma3/model.go
  2. 48 15
      model/models/gemma3/model_text.go

+ 25 - 32
model/models/gemma3/model.go

@@ -5,7 +5,6 @@ import (
 	"encoding/binary"
 	"hash/fnv"
 	"image"
-	"slices"
 
 	"github.com/ollama/ollama/kvcache"
 	"github.com/ollama/ollama/ml"
@@ -99,49 +98,43 @@ func (m *Model) EncodeMultimodal(ctx ml.Context, multimodalData []byte) (any, er
 	return visionOutputs, nil
 }
 
+type imageToken struct {
+	embedding ml.Tensor
+	index     int
+}
+
 func (m *Model) PostTokenize(ctx ml.Context, inputs []input.Input) ([]input.Input, error) {
-	var images []input.Input
+	var result []input.Input
 	fnvHash := fnv.New64a()
 
-	for i := range inputs {
-		if inputs[i].Multimodal == nil {
-			for j := range images {
-				if j == 0 {
-					inputs[i].Multimodal = images[j].Multimodal
-					inputs[i].MultimodalHash = images[j].MultimodalHash
-				} else {
-					inputs[i].Multimodal = inputs[i].Multimodal.(ml.Tensor).Concat(ctx, images[j].Multimodal.(ml.Tensor), 3)
-					fnvHash.Reset()
-					binary.Write(fnvHash, binary.NativeEndian, inputs[i].MultimodalHash)
-					binary.Write(fnvHash, binary.NativeEndian, images[j].MultimodalHash)
-					inputs[i].MultimodalHash = fnvHash.Sum64()
-				}
-			}
-
-			images = nil
+	for _, inp := range inputs {
+		if inp.Multimodal == nil {
+			result = append(result, inp)
 		} else {
-			images = append(images, inputs[i])
-			inputs[i].Token = -1
-		}
-	}
-
-	for i := range inputs {
-		if inputs[i].Token == -1 {
 			imageInputs := []input.Input{
 				{Token: 108},    // "\n\n"
 				{Token: 255999}, // "<start_of_image>""
 			}
+			result = append(result, imageInputs...)
 
-			// pad inputs with placeholders for image embeddings
-			imageInputs = append(imageInputs, slices.Repeat([]input.Input{{Token: 0}}, 256)...)
-			// <end_of_image>
-			imageInputs = append(imageInputs, input.Input{Token: 256000})
+			// add image embeddings
+			inputMultimodal := inp.Multimodal.(ml.Tensor)
+
+			for i := range inputMultimodal.Dim(1) {
+				fnvHash.Reset()
+				binary.Write(fnvHash, binary.NativeEndian, inp.MultimodalHash)
+				fnvHash.Write([]byte{byte(i)})
 
-			inputs = append(inputs[:i], append(imageInputs, inputs[i+1:]...)...)
+				imageToken := imageToken{embedding: inputMultimodal, index: i}
+				result = append(result, input.Input{Multimodal: imageToken, MultimodalHash: fnvHash.Sum64()})
+			}
+
+			// <end_of_image>
+			result = append(result, input.Input{Token: 256000})
 		}
 	}
 
-	return inputs, nil
+	return result, nil
 }
 
 func (m *Model) Forward(ctx ml.Context, opts input.Options) (ml.Tensor, error) {
@@ -160,7 +153,7 @@ func (m *Model) Forward(ctx ml.Context, opts input.Options) (ml.Tensor, error) {
 		return nil, err
 	}
 
-	return m.TextModel.Forward(ctx, inputs, positions, outputs, opts.Multimodal, m.Cache), nil
+	return m.TextModel.Forward(ctx, inputs, positions, outputs, opts, m.Cache), nil
 }
 
 func init() {

+ 48 - 15
model/models/gemma3/model_text.go

@@ -173,25 +173,54 @@ func (l *TextLayer) Forward(ctx ml.Context, layer int, hiddenState, positionIDs,
 	return hiddenState.Add(ctx, residual)
 }
 
-func (m *TextModel) Forward(ctx ml.Context, inputs, positions, outputs ml.Tensor, multimodal []input.MultimodalIndex, cache kvcache.Cache) ml.Tensor {
-	hiddenState := m.TokenEmbedding.Forward(ctx, inputs)
-	hiddenState = hiddenState.Scale(ctx, math.Sqrt(float64(m.TextOptions.hiddenSize)))
-
-	if multimodal != nil {
-		visionOutputs := multimodal[0].Multimodal.(ml.Tensor)
-		offset := multimodal[0].Index - 1 - visionOutputs.Dim(1)
-		hiddenState = hiddenState.Set(ctx, visionOutputs, offset*hiddenState.Stride(1))
+func setImageEmbeddings(ctx ml.Context, hiddenState ml.Tensor, multimodal []input.MultimodalIndex, positions []int32) []int32 {
+	var embedding ml.Tensor
+	var src, dst, length int
+	var except []int32
+
+	for _, image := range multimodal {
+		imageToken := image.Multimodal.(imageToken)
+		imageSrc := imageToken.index
+		imageDst := image.Index
+
+		if embedding == nil {
+			embedding = imageToken.embedding
+			src = imageSrc
+			dst = imageDst
+			length = 1
+		} else if embedding == imageToken.embedding && imageSrc+1 == src && imageDst+1 == dst {
+			src = imageSrc
+			dst = imageDst
+			length++
+		} else if embedding == imageToken.embedding && src+length == imageSrc && dst+length == imageDst {
+			length++
+		} else {
+			visionOutputs := embedding.View(ctx, src*embedding.Stride(1), length*embedding.Dim(0))
+			ctx.Forward(visionOutputs.Copy(ctx, hiddenState.View(ctx, dst*hiddenState.Stride(1), length*hiddenState.Dim(0))))
+
+			embedding = imageToken.embedding
+			src = imageSrc
+			dst = imageDst
+			length = 1
+		}
 
-		if causal, ok := cache.(*kvcache.WrapperCache).UnderlyingCache().(*kvcache.Causal); ok {
-			except := make([]int32, visionOutputs.Dim(1))
-			for i := 0; i < visionOutputs.Dim(1); i++ {
-				except[i] = int32(offset + i)
-			}
+		except = append(except, positions[imageDst])
+	}
 
-			causal.SetCausal(ctx, kvcache.CausalOptions{Except: except})
-		}
+	if embedding != nil {
+		visionOutputs := embedding.View(ctx, src*embedding.Stride(1), length*embedding.Dim(0))
+		ctx.Forward(visionOutputs.Copy(ctx, hiddenState.View(ctx, dst*hiddenState.Stride(1), length*hiddenState.Dim(0))))
 	}
 
+	return except
+}
+
+func (m *TextModel) Forward(ctx ml.Context, inputs, positions, outputs ml.Tensor, opts input.Options, cache kvcache.Cache) ml.Tensor {
+	hiddenState := m.TokenEmbedding.Forward(ctx, inputs)
+	hiddenState = hiddenState.Scale(ctx, math.Sqrt(float64(m.TextOptions.hiddenSize)))
+
+	except := setImageEmbeddings(ctx, hiddenState, opts.Multimodal, opts.Positions)
+
 	for i, layer := range m.Layers {
 		// gemma alternates between the sliding window (local) and causal (global)
 		// kv cache every 6 layers
@@ -203,6 +232,10 @@ func (m *TextModel) Forward(ctx ml.Context, inputs, positions, outputs ml.Tensor
 		wc := cache.(*kvcache.WrapperCache)
 		wc.SetLayerType(cacheType)
 
+		if causal, ok := wc.UnderlyingCache().(*kvcache.Causal); ok {
+			causal.SetCausal(ctx, kvcache.CausalOptions{Except: except})
+		}
+
 		var lastLayerOutputs ml.Tensor
 		if i == len(m.Layers)-1 {
 			lastLayerOutputs = outputs