ソースを参照

ml: Allow models to constrain inputs to a single batch

Models may require that a set of inputs all be processed as part
of the same batch. For example, if an image has multiple patches
with fully connected attention between them, we should not split
the batch in the middle of an image.

Fixes #9697
Jesse Gross 1 ヶ月 前
コミット
9679f40146

+ 29 - 0
integration/llm_image_test.go

@@ -66,6 +66,35 @@ func TestIntegrationMllama(t *testing.T) {
 	DoGenerate(ctx, t, client, req, []string{resp}, 240*time.Second, 30*time.Second)
 }
 
+func TestIntegrationSplitBatch(t *testing.T) {
+	image, err := base64.StdEncoding.DecodeString(imageEncoding)
+	require.NoError(t, err)
+	req := api.GenerateRequest{
+		Model: "gemma3:4b",
+		// Fill up a chunk of the batch so the image will partially spill over into the next one
+		System: "Lorem ipsum dolor sit amet, consectetur adipiscing elit. Sed aliquet, justo in malesuada lobortis, odio ligula volutpat quam, quis faucibus ipsum magna quis sapien. Aliquam in venenatis diam, eu viverra magna. Phasellus imperdiet hendrerit volutpat. Vivamus sem ex, facilisis placerat felis non, dictum elementum est. Phasellus aliquam imperdiet lacus, eget placerat ligula sodales vel. Pellentesque nec auctor mi. Curabitur arcu nisi, faucibus eget nunc id, viverra interdum mi. Curabitur ornare ipsum ex, ac euismod ex aliquam in. Vestibulum id magna at purus accumsan fermentum. Proin scelerisque posuere nunc quis interdum. Maecenas sed mollis nisl. Etiam vitae ipsum interdum, placerat est quis, tincidunt velit. Nullam tempor nibh non lorem volutpat efficitur. Cras laoreet diam imperdiet ipsum auctor bibendum. Suspendisse ultrices urna sed metus sagittis suscipit. Quisque ullamcorper aliquam nibh ut mollis. Aenean dapibus mauris pharetra, venenatis elit ac, hendrerit odio. Cras vestibulum erat tempor, lobortis justo eu, lobortis ipsum. Nam laoreet dapibus sem. Proin vel diam ultrices, elementum ante et, ornare lectus. Proin eu accumsan nisl. Praesent ac ex vitae ipsum vulputate tristique facilisis sit amet lacus. Nullam faucibus magna a pellentesque pretium. Nunc lacinia ullamcorper sollicitudin. Donec vitae accumsan turpis, sed porttitor est. Donec porttitor mi vitae augue faucibus, vel mollis diam tincidunt.",
+		Prompt: "what does the text in this image say?",
+		Stream: &stream,
+		Options: map[string]interface{}{
+			"seed":        42,
+			"temperature": 0.0,
+		},
+		Images: []api.ImageData{
+			image,
+		},
+	}
+
+	// Note: sometimes it returns "the ollamas" sometimes "the ollams"
+	resp := "the ollam"
+	ctx, cancel := context.WithTimeout(context.Background(), 3*time.Minute)
+	defer cancel()
+	client, _, cleanup := InitServerConnection(ctx, t)
+	defer cleanup()
+	require.NoError(t, PullIfMissing(ctx, client, req.Model))
+	// llava models on CPU can be quite slow to start,
+	DoGenerate(ctx, t, client, req, []string{resp}, 120*time.Second, 30*time.Second)
+}
+
 const imageEncoding = `iVBORw0KGgoAAAANSUhEUgAAANIAAAB4CAYAAACHHqzKAAAAAXNSR0IArs4c6QAAAIRlWElmTU0AKgAAAAgABQESAAMAAAABAAEAAAEaAAUAAAABAAAASgEb
 AAUAAAABAAAAUgEoAAMAAAABAAIAAIdpAAQAAAABAAAAWgAAAAAAAABIAAAAAQAAAEgAAAABAAOgAQADAAAAAQABAACgAgAEAAAAAQAAANKgAwAEAAAAAQAA
 AHgAAAAAXdsepgAAAAlwSFlzAAALEwAACxMBAJqcGAAAAVlpVFh0WE1MOmNvbS5hZG9iZS54bXAAAAAAADx4OnhtcG1ldGEgeG1sbnM6eD0iYWRvYmU6bnM6

+ 6 - 0
model/input/input.go

@@ -15,6 +15,12 @@ type Input struct {
 	// stored in Multimodal, used for caching and comparing
 	// equality.
 	MultimodalHash uint64
+
+	// SameBatch forces the following number of tokens to be processed
+	// in a single batch, breaking and extending batches as needed.
+	// Useful for things like images that must be processed in one
+	// shot.
+	SameBatch int
 }
 
 // MultimodalIndex is a multimodal element (such as an image)

+ 8 - 22
model/models/gemma3/model.go

@@ -2,10 +2,9 @@ package gemma3
 
 import (
 	"bytes"
-	"encoding/binary"
-	"hash/fnv"
 	"image"
 	"math"
+	"slices"
 
 	"github.com/ollama/ollama/kvcache"
 	"github.com/ollama/ollama/ml"
@@ -112,36 +111,23 @@ func (m *Model) EncodeMultimodal(ctx ml.Context, multimodalData []byte) (any, er
 	return visionOutputs, nil
 }
 
-type imageToken struct {
-	embedding ml.Tensor
-	index     int
-}
-
 func (m *Model) PostTokenize(ctx ml.Context, inputs []input.Input) ([]input.Input, error) {
 	var result []input.Input
-	fnvHash := fnv.New64a()
 
 	for _, inp := range inputs {
 		if inp.Multimodal == nil {
 			result = append(result, inp)
 		} else {
-			imageInputs := []input.Input{
-				{Token: 108},    // "\n\n"
-				{Token: 255999}, // "<start_of_image>""
-			}
-			result = append(result, imageInputs...)
-
-			// add image embeddings
 			inputMultimodal := inp.Multimodal.(ml.Tensor)
 
-			for i := range inputMultimodal.Dim(1) {
-				fnvHash.Reset()
-				binary.Write(fnvHash, binary.NativeEndian, inp.MultimodalHash)
-				fnvHash.Write([]byte{byte(i)})
+			result = append(result,
+				input.Input{Token: 108, SameBatch: inputMultimodal.Dim(1) + 3},               // "\n\n"
+				input.Input{Token: 255999},                                                   // "<start_of_image>""
+				input.Input{Multimodal: inputMultimodal, MultimodalHash: inp.MultimodalHash}, // image data is on the first placeholder
+			)
 
-				imageToken := imageToken{embedding: inputMultimodal, index: i}
-				result = append(result, input.Input{Multimodal: imageToken, MultimodalHash: fnvHash.Sum64()})
-			}
+			// add image token placeholders
+			result = append(result, slices.Repeat([]input.Input{{Token: 0}}, inputMultimodal.Dim(1)-1)...)
 
 			result = append(result,
 				input.Input{Token: 256000}, // <end_of_image>

+ 10 - 43
model/models/gemma3/model_text.go

@@ -171,53 +171,20 @@ func (l *TextLayer) Forward(ctx ml.Context, layer int, hiddenState, positionIDs,
 	return hiddenState.Add(ctx, residual)
 }
 
-func setImageEmbeddings(ctx ml.Context, hiddenState ml.Tensor, multimodal []input.MultimodalIndex) []int {
-	var embedding ml.Tensor
-	var src, dst, length int
-	var except []int
-
-	for _, image := range multimodal {
-		imageToken := image.Multimodal.(imageToken)
-		imageSrc := imageToken.index
-		imageDst := image.Index
-
-		if embedding == nil {
-			embedding = imageToken.embedding
-			src = imageSrc
-			dst = imageDst
-			length = 1
-		} else if embedding == imageToken.embedding && imageSrc+1 == src && imageDst+1 == dst {
-			src = imageSrc
-			dst = imageDst
-			length++
-		} else if embedding == imageToken.embedding && src+length == imageSrc && dst+length == imageDst {
-			length++
-		} else {
-			visionOutputs := embedding.View(ctx, src*embedding.Stride(1), length*embedding.Dim(0))
-			ctx.Forward(visionOutputs.Copy(ctx, hiddenState.View(ctx, dst*hiddenState.Stride(1), length*hiddenState.Dim(0))))
-
-			embedding = imageToken.embedding
-			src = imageSrc
-			dst = imageDst
-			length = 1
-		}
-
-		except = append(except, imageDst)
-	}
-
-	if embedding != nil {
-		visionOutputs := embedding.View(ctx, src*embedding.Stride(1), length*embedding.Dim(0))
-		ctx.Forward(visionOutputs.Copy(ctx, hiddenState.View(ctx, dst*hiddenState.Stride(1), length*hiddenState.Dim(0))))
-	}
-
-	return except
-}
-
 func (m *TextModel) Forward(ctx ml.Context, inputs, positions, outputs ml.Tensor, opts input.Options, cache kvcache.Cache) ml.Tensor {
 	hiddenState := m.TokenEmbedding.Forward(ctx, inputs)
 	hiddenState = hiddenState.Scale(ctx, math.Sqrt(float64(m.TextOptions.hiddenSize)))
 
-	except := setImageEmbeddings(ctx, hiddenState, opts.Multimodal)
+	// set image embeddings
+	var except []int
+	for _, image := range opts.Multimodal {
+		visionOutputs := image.Multimodal.(ml.Tensor)
+		ctx.Forward(visionOutputs.Copy(ctx, hiddenState.View(ctx, image.Index*hiddenState.Stride(1), visionOutputs.Dim(0)*visionOutputs.Dim(1))))
+
+		for i := range visionOutputs.Dim(1) {
+			except = append(except, image.Index+i)
+		}
+	}
 
 	for i, layer := range m.Layers {
 		// gemma alternates between the sliding window (local) and causal (global)

+ 11 - 1
runner/ollamarunner/runner.go

@@ -352,6 +352,8 @@ func (s *Server) processBatch() error {
 			seq.cache.Inputs = []input.Input{}
 		}
 
+		batchSize := s.batchSize
+
 		for j, inp := range seq.inputs {
 			if int32(len(seq.cache.Inputs)+len(seq.pendingInputs)+1) > s.cache.numCtx {
 				if len(seq.pendingInputs) == 0 {
@@ -364,7 +366,15 @@ func (s *Server) processBatch() error {
 				}
 			}
 
-			if j >= s.batchSize {
+			// If we are required to put following inputs into a single batch then extend the
+			// batch size. Since we are only extending the size the minimum amount possible, this
+			// will cause a break if we have pending inputs.
+			minBatch := 1 + inp.SameBatch
+			if minBatch > batchSize {
+				batchSize = minBatch
+			}
+
+			if len(seq.pendingInputs)+minBatch > batchSize {
 				break
 			}