Ver Fonte

ollamarunner: Improve multimodal input handling

Various vision models have different requirements for how they
receive their inputs. For example:
 - Mllama wants images together with text and the image embeddings
   don't themselves have positions or get stored in the main KV cache
 - Llava-style models feed in embeddings similar to tokens and
   images correspond to a varying number of tokens in the cache.

In addition, the strategy for providing inputs must support batching
and multiple sequences, which are managed by the runner. At the same
time, we want to keep data handling fully in the model so that new
architectures are not bottlenecked by runner code which does not
understand their particular requirements.

This provides a method for models to edit the input stream so that
it meets their needs while still being in a format that the runner
understands. This allows the runner to avoid special processing
for different models.

In addition, this fixes a regression where non-vision models may
try to incorrectly interpret images.
Jesse Gross há 1 mês atrás
pai
commit
a7e63b82be

+ 63 - 7
model/model.go

@@ -3,7 +3,6 @@ package model
 import (
 import (
 	"errors"
 	"errors"
 	"fmt"
 	"fmt"
-	"image"
 	_ "image/jpeg"
 	_ "image/jpeg"
 	_ "image/png"
 	_ "image/png"
 	"log/slog"
 	"log/slog"
@@ -22,14 +21,40 @@ import (
 	_ "github.com/ollama/ollama/ml/backend"
 	_ "github.com/ollama/ollama/ml/backend"
 )
 )
 
 
+// Input represents one token in the input stream
+type Input struct {
+	// Token is a single element of text.
+	Token int32
+
+	// Multimodal is opaque data representing a non-text
+	// element such as an image (or part of one if the image
+	// can be processed in pieces). It may be either together
+	// with Token or on its own.
+	Multimodal any
+
+	// MultimodalHash is a unique representation of the data
+	// stored in Multimodal, used for caching and comparing
+	// equality.
+	MultimodalHash uint64
+}
+
+// MultimodalIndex is a multimodal element (such as an image)
+// together with an index into the slice of Inputs with the
+// corresponding token. Note that the index is not the same
+// as the position - to find that use the index with the
+// Positions slice.
+type MultimodalIndex struct {
+	Index      int
+	Multimodal any
+}
+
 // Options contains the inputs for a model forward pass
 // Options contains the inputs for a model forward pass
 type Options struct {
 type Options struct {
-	Inputs    []int32
-	Positions []int32
-	Sequences []int
-	Outputs   []int32
-
-	Images []image.Image
+	Inputs     []int32
+	Multimodal []MultimodalIndex
+	Positions  []int32
+	Sequences  []int
+	Outputs    []int32
 }
 }
 
 
 type config struct {
 type config struct {
@@ -59,6 +84,37 @@ type Model interface {
 	Config() config
 	Config() config
 }
 }
 
 
+// MultimodalProcessor must be implemented by multimodal models.
+type MultimodalProcessor interface {
+	// EncodeMultimodal processes a single input (such as an image) and
+	// generates an output (typically an embedding) that can be used by the model.
+	//
+	// The return value is most typically an ml.Tensor, however, different
+	// type are possible, such as an object containing a tensor plus
+	// additional metadata, a slice of tensors or even just the original input.
+	//
+	// The result may be cached by the runner.
+	EncodeMultimodal(ml.Context, []byte) (any, error)
+
+	// PostTokenize is called after tokenization to allow the model to edit the
+	// input stream to correctly arrange multimodal elements.
+	//
+	// The input is a slice of tokens with the results of EncodeMultimodal interleaved
+	// in the order that the user provided them. Each element of the slice will be
+	// either a single token or single multimodal object.
+	//
+	// The model must ensure that inputs are stored according to how they will be
+	// processed and stored in the cache. For example, Llava-style models should insert
+	// placeholder tokens equal to the feature size of the corresponding image with
+	// the image itself attached to and split across these tokens. When Forward is called
+	// a partial subset of these tokens may be submitted according to the batch size.
+	//
+	// This function is also responsible for updating MultimodalHash for any Multimodal
+	// that is modified to ensure that there is a unique hash value that accurately
+	// represents the contents.
+	PostTokenize(ml.Context, []Input) ([]Input, error)
+}
+
 var models = make(map[string]func(ml.Config) (Model, error))
 var models = make(map[string]func(ml.Config) (Model, error))
 
 
 // Register registers a model constructor for the given architecture
 // Register registers a model constructor for the given architecture

+ 72 - 29
model/models/mllama/model.go

@@ -1,7 +1,12 @@
 package mllama
 package mllama
 
 
 import (
 import (
+	"bytes"
+	"encoding/binary"
 	"fmt"
 	"fmt"
+	"hash/fnv"
+	"image"
+	"slices"
 
 
 	"github.com/ollama/ollama/kvcache"
 	"github.com/ollama/ollama/kvcache"
 	"github.com/ollama/ollama/ml"
 	"github.com/ollama/ollama/ml"
@@ -56,41 +61,79 @@ func New(c ml.Config) (model.Model, error) {
 	return &m, nil
 	return &m, nil
 }
 }
 
 
-func (m *Model) Forward(ctx ml.Context, opts model.Options) (ml.Tensor, error) {
-	var crossAttentionStates ml.Tensor
-	if opts.Images != nil {
-		f32s, aspectRatioID, err := m.ImageProcessor.ProcessImage(opts.Images[0])
-		if err != nil {
-			return nil, err
-		}
+func (m *Model) EncodeMultimodal(ctx ml.Context, multimodalData []byte) (any, error) {
+	image, _, err := image.Decode(bytes.NewReader(multimodalData))
+	if err != nil {
+		return nil, err
+	}
 
 
-		pixelValues, err := ctx.FromFloatSlice(f32s,
-			m.ImageProcessor.imageSize,
-			m.ImageProcessor.imageSize,
-			m.ImageProcessor.numChannels,
-			m.ImageProcessor.maxNumTiles,
-		)
-		if err != nil {
-			return nil, err
-		}
+	f32s, aspectRatioID, err := m.ImageProcessor.ProcessImage(image)
+	if err != nil {
+		return nil, err
+	}
 
 
-		aspectRatio, err := ctx.FromIntSlice([]int32{int32(aspectRatioID)}, 1)
-		if err != nil {
-			return nil, err
-		}
+	pixelValues, err := ctx.FromFloatSlice(f32s,
+		m.ImageProcessor.imageSize,
+		m.ImageProcessor.imageSize,
+		m.ImageProcessor.numChannels,
+		m.ImageProcessor.maxNumTiles,
+	)
+	if err != nil {
+		return nil, err
+	}
 
 
-		positions := make([]int32, 1601)
-		for i := range positions {
-			positions[i] = int32(i)
-		}
+	aspectRatio, err := ctx.FromIntSlice([]int32{int32(aspectRatioID)}, 1)
+	if err != nil {
+		return nil, err
+	}
+
+	positions := make([]int32, 1601)
+	for i := range positions {
+		positions[i] = int32(i)
+	}
+
+	positionIDs, err := ctx.FromIntSlice(positions, len(positions))
+	if err != nil {
+		return nil, err
+	}
 
 
-		positionIDs, err := ctx.FromIntSlice(positions, len(positions))
-		if err != nil {
-			return nil, err
+	crossAttentionStates := m.VisionModel.Forward(ctx, pixelValues, positionIDs, aspectRatio)
+	return m.Projector.Forward(ctx, crossAttentionStates), nil
+}
+
+func (m *Model) PostTokenize(ctx ml.Context, inputs []model.Input) ([]model.Input, error) {
+	var images []model.Input
+	fnvHash := fnv.New64a()
+
+	for i := range inputs {
+		if inputs[i].Multimodal == nil {
+			if len(images) > 0 {
+				inputs[i].Multimodal = images[0].Multimodal
+				inputs[i].MultimodalHash = images[0].MultimodalHash
+				for j := 1; j < len(images); j++ {
+					inputs[i].Multimodal = inputs[i].Multimodal.(ml.Tensor).Concat(ctx, images[j].Multimodal.(ml.Tensor), 3)
+					fnvHash.Reset()
+					binary.Write(fnvHash, binary.NativeEndian, inputs[i].MultimodalHash)
+					binary.Write(fnvHash, binary.NativeEndian, inputs[j].MultimodalHash)
+					inputs[i].MultimodalHash = fnvHash.Sum64()
+				}
+				images = nil
+			}
+		} else {
+			images = append(images, inputs[i])
+			inputs[i].Token = -1
 		}
 		}
+	}
+
+	inputs = slices.DeleteFunc(inputs, func(input model.Input) bool { return input.Token == -1 })
 
 
-		crossAttentionStates = m.VisionModel.Forward(ctx, pixelValues, positionIDs, aspectRatio)
-		crossAttentionStates = m.Projector.Forward(ctx, crossAttentionStates)
+	return inputs, nil
+}
+
+func (m *Model) Forward(ctx ml.Context, opts model.Options) (ml.Tensor, error) {
+	var crossAttentionStates ml.Tensor
+	if opts.Multimodal != nil {
+		crossAttentionStates = opts.Multimodal[0].Multimodal.(ml.Tensor)
 	}
 	}
 
 
 	inputs, err := ctx.FromIntSlice(opts.Inputs, len(opts.Inputs))
 	inputs, err := ctx.FromIntSlice(opts.Inputs, len(opts.Inputs))

+ 8 - 12
runner/ollamarunner/cache.go

@@ -5,7 +5,6 @@ import (
 	"fmt"
 	"fmt"
 	"log/slog"
 	"log/slog"
 	"math"
 	"math"
-	"reflect"
 	"time"
 	"time"
 
 
 	"github.com/ollama/ollama/kvcache"
 	"github.com/ollama/ollama/kvcache"
@@ -39,10 +38,7 @@ func NewInputCache(model model.Model, kvCacheType string, kvSize int32, numSlots
 	slots := make([]InputCacheSlot, numSlots)
 	slots := make([]InputCacheSlot, numSlots)
 
 
 	for i := range slots {
 	for i := range slots {
-		slots[i] = InputCacheSlot{
-			Id:     i,
-			Inputs: make([]input, 0),
-		}
+		slots[i] = InputCacheSlot{Id: i}
 	}
 	}
 
 
 	cache := model.Config().Cache
 	cache := model.Config().Cache
@@ -83,7 +79,7 @@ type InputCacheSlot struct {
 	Id int
 	Id int
 
 
 	// Inputs that are stored in the KV cache
 	// Inputs that are stored in the KV cache
-	Inputs []input
+	Inputs []model.Input
 
 
 	// is this cache actively being processed as part of a sequence?
 	// is this cache actively being processed as part of a sequence?
 	InUse bool
 	InUse bool
@@ -92,7 +88,7 @@ type InputCacheSlot struct {
 	lastUsed time.Time
 	lastUsed time.Time
 }
 }
 
 
-func (c *InputCache) LoadCacheSlot(prompt []input, cachePrompt bool) (*InputCacheSlot, []input, error) {
+func (c *InputCache) LoadCacheSlot(prompt []model.Input, cachePrompt bool) (*InputCacheSlot, []model.Input, error) {
 	var slot *InputCacheSlot
 	var slot *InputCacheSlot
 	var numPast int32
 	var numPast int32
 	var err error
 	var err error
@@ -143,7 +139,7 @@ func (c *InputCache) LoadCacheSlot(prompt []input, cachePrompt bool) (*InputCach
 	return slot, prompt, nil
 	return slot, prompt, nil
 }
 }
 
 
-func (c *InputCache) findLongestCacheSlot(prompt []input) (*InputCacheSlot, int32, error) {
+func (c *InputCache) findLongestCacheSlot(prompt []model.Input) (*InputCacheSlot, int32, error) {
 	longest := int32(-1)
 	longest := int32(-1)
 	var longestSlot *InputCacheSlot
 	var longestSlot *InputCacheSlot
 
 
@@ -166,7 +162,7 @@ func (c *InputCache) findLongestCacheSlot(prompt []input) (*InputCacheSlot, int3
 	return longestSlot, longest, nil
 	return longestSlot, longest, nil
 }
 }
 
 
-func (c *InputCache) findBestCacheSlot(prompt []input) (*InputCacheSlot, int32, error) {
+func (c *InputCache) findBestCacheSlot(prompt []model.Input) (*InputCacheSlot, int32, error) {
 	oldest := time.Now()
 	oldest := time.Now()
 	var oldestSlot *InputCacheSlot
 	var oldestSlot *InputCacheSlot
 
 
@@ -202,7 +198,7 @@ func (c *InputCache) findBestCacheSlot(prompt []input) (*InputCacheSlot, int32,
 	if longest > 0 && longestSlot != oldestSlot {
 	if longest > 0 && longestSlot != oldestSlot {
 		slog.Debug("forking cache slot", "src", longestSlot.Id, "dst", oldestSlot.Id, "inputs", longest, "total",
 		slog.Debug("forking cache slot", "src", longestSlot.Id, "dst", oldestSlot.Id, "inputs", longest, "total",
 			len(longestSlot.Inputs))
 			len(longestSlot.Inputs))
-		oldestSlot.Inputs = make([]input, longest)
+		oldestSlot.Inputs = make([]model.Input, longest)
 		copy(oldestSlot.Inputs, longestSlot.Inputs[:longest])
 		copy(oldestSlot.Inputs, longestSlot.Inputs[:longest])
 		if c.cache != nil {
 		if c.cache != nil {
 			c.cache.CopyPrefix(longestSlot.Id, oldestSlot.Id, longest)
 			c.cache.CopyPrefix(longestSlot.Id, oldestSlot.Id, longest)
@@ -212,7 +208,7 @@ func (c *InputCache) findBestCacheSlot(prompt []input) (*InputCacheSlot, int32,
 	return oldestSlot, longest, nil
 	return oldestSlot, longest, nil
 }
 }
 
 
-func countCommonPrefix(a []input, b []input) int32 {
+func countCommonPrefix(a []model.Input, b []model.Input) int32 {
 	var count int32
 	var count int32
 
 
 	for i := range a {
 	for i := range a {
@@ -220,7 +216,7 @@ func countCommonPrefix(a []input, b []input) int32 {
 			break
 			break
 		}
 		}
 
 
-		if !reflect.DeepEqual(a[i], b[i]) {
+		if a[i].Token != b[i].Token || a[i].MultimodalHash != b[i].MultimodalHash {
 			break
 			break
 		}
 		}
 
 

+ 41 - 33
runner/ollamarunner/cache_test.go

@@ -4,6 +4,8 @@ import (
 	"image"
 	"image"
 	"testing"
 	"testing"
 	"time"
 	"time"
+
+	"github.com/ollama/ollama/model"
 )
 )
 
 
 func TestCountCommon(t *testing.T) {
 func TestCountCommon(t *testing.T) {
@@ -13,44 +15,50 @@ func TestCountCommon(t *testing.T) {
 
 
 	tests := []struct {
 	tests := []struct {
 		name     string
 		name     string
-		t1       []input
-		t2       []input
+		t1       []model.Input
+		t2       []model.Input
 		expected int32
 		expected int32
 	}{
 	}{
 		{
 		{
 			name:     "Equal",
 			name:     "Equal",
-			t1:       []input{{token: 1}, {token: 2}, {token: 3}},
-			t2:       []input{{token: 1}, {token: 2}, {token: 3}},
+			t1:       []model.Input{{Token: 1}, {Token: 2}, {Token: 3}},
+			t2:       []model.Input{{Token: 1}, {Token: 2}, {Token: 3}},
 			expected: 3,
 			expected: 3,
 		},
 		},
 		{
 		{
 			name:     "Prefix",
 			name:     "Prefix",
-			t1:       []input{{token: 1}},
-			t2:       []input{{token: 1}, {token: 2}, {token: 3}},
+			t1:       []model.Input{{Token: 1}},
+			t2:       []model.Input{{Token: 1}, {Token: 2}, {Token: 3}},
 			expected: 1,
 			expected: 1,
 		},
 		},
 		{
 		{
 			name:     "Image Prefix",
 			name:     "Image Prefix",
-			t1:       []input{{image: imgA}},
-			t2:       []input{{image: imgA}, {image: imgB}, {image: imgC}},
+			t1:       []model.Input{{Multimodal: imgA, MultimodalHash: 1}},
+			t2:       []model.Input{{Multimodal: imgA, MultimodalHash: 1}, {Multimodal: imgB, MultimodalHash: 2}, {Multimodal: imgC, MultimodalHash: 3}},
 			expected: 1,
 			expected: 1,
 		},
 		},
 		{
 		{
 			name:     "Mixed",
 			name:     "Mixed",
-			t1:       []input{{token: 1}, {image: imgA}},
-			t2:       []input{{token: 1}, {image: imgA}, {token: 5}},
+			t1:       []model.Input{{Token: 1}, {Multimodal: imgA, MultimodalHash: 1}},
+			t2:       []model.Input{{Token: 1}, {Multimodal: imgA, MultimodalHash: 1}, {Token: 5}},
 			expected: 2,
 			expected: 2,
 		},
 		},
+		{
+			name:     "Mixed, Same Length",
+			t1:       []model.Input{{Token: 1}, {Multimodal: imgA, MultimodalHash: 1}},
+			t2:       []model.Input{{Token: 1}, {Multimodal: imgB, MultimodalHash: 2}},
+			expected: 1,
+		},
 		{
 		{
 			name:     "Empty",
 			name:     "Empty",
-			t1:       []input{},
-			t2:       []input{{token: 1}, {token: 2}, {token: 3}},
+			t1:       []model.Input{},
+			t2:       []model.Input{{Token: 1}, {Token: 2}, {Token: 3}},
 			expected: 0,
 			expected: 0,
 		},
 		},
 		{
 		{
 			name:     "Both Empty",
 			name:     "Both Empty",
-			t1:       []input{},
-			t2:       []input{},
+			t1:       []model.Input{},
+			t2:       []model.Input{},
 			expected: 0,
 			expected: 0,
 		},
 		},
 	}
 	}
@@ -74,7 +82,7 @@ func TestFindCacheSlot(t *testing.T) {
 	tests := []struct {
 	tests := []struct {
 		name    string
 		name    string
 		cache   InputCache
 		cache   InputCache
-		prompt  []input
+		prompt  []model.Input
 		longest expected
 		longest expected
 		best    expected
 		best    expected
 	}{
 	}{
@@ -83,18 +91,18 @@ func TestFindCacheSlot(t *testing.T) {
 			cache: InputCache{slots: []InputCacheSlot{
 			cache: InputCache{slots: []InputCacheSlot{
 				{
 				{
 					Id:       0,
 					Id:       0,
-					Inputs:   []input{},
+					Inputs:   []model.Input{},
 					InUse:    false,
 					InUse:    false,
 					lastUsed: time.Time{},
 					lastUsed: time.Time{},
 				},
 				},
 				{
 				{
 					Id:       1,
 					Id:       1,
-					Inputs:   []input{},
+					Inputs:   []model.Input{},
 					InUse:    false,
 					InUse:    false,
 					lastUsed: time.Time{},
 					lastUsed: time.Time{},
 				},
 				},
 			}},
 			}},
-			prompt:  []input{{token: 1}},
+			prompt:  []model.Input{{Token: 1}},
 			longest: expected{result: 0, len: 0},
 			longest: expected{result: 0, len: 0},
 			best:    expected{result: 0, len: 0},
 			best:    expected{result: 0, len: 0},
 		},
 		},
@@ -103,18 +111,18 @@ func TestFindCacheSlot(t *testing.T) {
 			cache: InputCache{slots: []InputCacheSlot{
 			cache: InputCache{slots: []InputCacheSlot{
 				{
 				{
 					Id:       0,
 					Id:       0,
-					Inputs:   []input{{token: 1}},
+					Inputs:   []model.Input{{Token: 1}},
 					InUse:    false,
 					InUse:    false,
 					lastUsed: time.Now().Add(-time.Second),
 					lastUsed: time.Now().Add(-time.Second),
 				},
 				},
 				{
 				{
 					Id:       1,
 					Id:       1,
-					Inputs:   []input{{token: 1}, {token: 2}},
+					Inputs:   []model.Input{{Token: 1}, {Token: 2}},
 					InUse:    false,
 					InUse:    false,
 					lastUsed: time.Now().Add(-2 * time.Second),
 					lastUsed: time.Now().Add(-2 * time.Second),
 				},
 				},
 			}},
 			}},
-			prompt:  []input{{token: 1}, {token: 2}},
+			prompt:  []model.Input{{Token: 1}, {Token: 2}},
 			longest: expected{result: 1, len: 2},
 			longest: expected{result: 1, len: 2},
 			best:    expected{result: 1, len: 2},
 			best:    expected{result: 1, len: 2},
 		},
 		},
@@ -123,18 +131,18 @@ func TestFindCacheSlot(t *testing.T) {
 			cache: InputCache{slots: []InputCacheSlot{
 			cache: InputCache{slots: []InputCacheSlot{
 				{
 				{
 					Id:       0,
 					Id:       0,
-					Inputs:   []input{{token: 1}, {token: 2}},
+					Inputs:   []model.Input{{Token: 1}, {Token: 2}},
 					InUse:    false,
 					InUse:    false,
 					lastUsed: time.Now().Add(-time.Second),
 					lastUsed: time.Now().Add(-time.Second),
 				},
 				},
 				{
 				{
 					Id:       1,
 					Id:       1,
-					Inputs:   []input{},
+					Inputs:   []model.Input{},
 					InUse:    false,
 					InUse:    false,
 					lastUsed: time.Time{},
 					lastUsed: time.Time{},
 				},
 				},
 			}},
 			}},
-			prompt:  []input{{token: 2}},
+			prompt:  []model.Input{{Token: 2}},
 			longest: expected{result: 0, len: 0},
 			longest: expected{result: 0, len: 0},
 			best:    expected{result: 1, len: 0},
 			best:    expected{result: 1, len: 0},
 		},
 		},
@@ -144,19 +152,19 @@ func TestFindCacheSlot(t *testing.T) {
 				slots: []InputCacheSlot{
 				slots: []InputCacheSlot{
 					{
 					{
 						Id:       0,
 						Id:       0,
-						Inputs:   []input{{token: 1}, {token: 2}},
+						Inputs:   []model.Input{{Token: 1}, {Token: 2}},
 						InUse:    false,
 						InUse:    false,
 						lastUsed: time.Now().Add(-time.Second),
 						lastUsed: time.Now().Add(-time.Second),
 					},
 					},
 					{
 					{
 						Id:       1,
 						Id:       1,
-						Inputs:   []input{},
+						Inputs:   []model.Input{},
 						InUse:    false,
 						InUse:    false,
 						lastUsed: time.Time{},
 						lastUsed: time.Time{},
 					},
 					},
 				},
 				},
 			},
 			},
-			prompt:  []input{{token: 1}},
+			prompt:  []model.Input{{Token: 1}},
 			longest: expected{result: 0, len: 1},
 			longest: expected{result: 0, len: 1},
 			best:    expected{result: 1, len: 1},
 			best:    expected{result: 1, len: 1},
 		},
 		},
@@ -165,18 +173,18 @@ func TestFindCacheSlot(t *testing.T) {
 			cache: InputCache{slots: []InputCacheSlot{
 			cache: InputCache{slots: []InputCacheSlot{
 				{
 				{
 					Id:       0,
 					Id:       0,
-					Inputs:   []input{{token: 1}},
+					Inputs:   []model.Input{{Token: 1}},
 					InUse:    false,
 					InUse:    false,
 					lastUsed: time.Now().Add(-time.Second),
 					lastUsed: time.Now().Add(-time.Second),
 				},
 				},
 				{
 				{
 					Id:       1,
 					Id:       1,
-					Inputs:   []input{{token: 1}, {token: 2}},
+					Inputs:   []model.Input{{Token: 1}, {Token: 2}},
 					InUse:    false,
 					InUse:    false,
 					lastUsed: time.Now().Add(-2 * time.Second),
 					lastUsed: time.Now().Add(-2 * time.Second),
 				},
 				},
 			}},
 			}},
-			prompt:  []input{{token: 2}, {token: 3}},
+			prompt:  []model.Input{{Token: 2}, {Token: 3}},
 			longest: expected{result: 0, len: 0},
 			longest: expected{result: 0, len: 0},
 			best:    expected{result: 1, len: 0},
 			best:    expected{result: 1, len: 0},
 		},
 		},
@@ -185,18 +193,18 @@ func TestFindCacheSlot(t *testing.T) {
 			cache: InputCache{slots: []InputCacheSlot{
 			cache: InputCache{slots: []InputCacheSlot{
 				{
 				{
 					Id:       0,
 					Id:       0,
-					Inputs:   []input{{token: 1}, {token: 2}},
+					Inputs:   []model.Input{{Token: 1}, {Token: 2}},
 					InUse:    true,
 					InUse:    true,
 					lastUsed: time.Now().Add(-time.Second),
 					lastUsed: time.Now().Add(-time.Second),
 				},
 				},
 				{
 				{
 					Id:       1,
 					Id:       1,
-					Inputs:   []input{{token: 1}},
+					Inputs:   []model.Input{{Token: 1}},
 					InUse:    false,
 					InUse:    false,
 					lastUsed: time.Now().Add(-2 * time.Second),
 					lastUsed: time.Now().Add(-2 * time.Second),
 				},
 				},
 			}},
 			}},
-			prompt:  []input{{token: 1}, {token: 2}},
+			prompt:  []model.Input{{Token: 1}, {Token: 2}},
 			longest: expected{result: 1, len: 1},
 			longest: expected{result: 1, len: 1},
 			best:    expected{result: 1, len: 2},
 			best:    expected{result: 1, len: 2},
 		},
 		},

+ 58 - 44
runner/ollamarunner/runner.go

@@ -1,13 +1,12 @@
 package ollamarunner
 package ollamarunner
 
 
 import (
 import (
-	"bytes"
 	"context"
 	"context"
 	"encoding/json"
 	"encoding/json"
 	"errors"
 	"errors"
 	"flag"
 	"flag"
 	"fmt"
 	"fmt"
-	"image"
+	"hash/maphash"
 	"log"
 	"log"
 	"log/slog"
 	"log/slog"
 	"net"
 	"net"
@@ -33,22 +32,19 @@ import (
 	_ "github.com/ollama/ollama/model/models"
 	_ "github.com/ollama/ollama/model/models"
 )
 )
 
 
-// input is an element of the prompt to process, either a token or an image
-type input struct {
-	token int32
-
-	image image.Image
-}
-
 type Sequence struct {
 type Sequence struct {
+	// ctx for allocating tensors that last the lifetime of the sequence, such as
+	// multimodal embeddings
+	ctx ml.Context
+
 	// batch index
 	// batch index
 	iBatch int
 	iBatch int
 
 
 	// prompt inputs left to evaluate
 	// prompt inputs left to evaluate
-	inputs []input
+	inputs []model.Input
 
 
 	// inputs that have been added to a batch but not yet submitted to Forward
 	// inputs that have been added to a batch but not yet submitted to Forward
-	pendingInputs []input
+	pendingInputs []model.Input
 
 
 	// tokens that have been generated but not returned yet (e.g. for stop sequences)
 	// tokens that have been generated but not returned yet (e.g. for stop sequences)
 	pendingResponses []string
 	pendingResponses []string
@@ -101,8 +97,9 @@ func (s *Server) NewSequence(prompt string, images []ImageData, params NewSequen
 	s.ready.Wait()
 	s.ready.Wait()
 
 
 	startTime := time.Now()
 	startTime := time.Now()
+	ctx := s.model.Backend().NewContext()
 
 
-	inputs, err := s.inputs(prompt, images)
+	inputs, err := s.inputs(ctx, prompt, images)
 	if err != nil {
 	if err != nil {
 		return nil, fmt.Errorf("failed to process inputs: %w", err)
 		return nil, fmt.Errorf("failed to process inputs: %w", err)
 	} else if len(inputs) == 0 {
 	} else if len(inputs) == 0 {
@@ -128,6 +125,7 @@ func (s *Server) NewSequence(prompt string, images []ImageData, params NewSequen
 	// TODO(jessegross): Ingest cached history for grammar
 	// TODO(jessegross): Ingest cached history for grammar
 
 
 	return &Sequence{
 	return &Sequence{
+		ctx:                 ctx,
 		inputs:              inputs,
 		inputs:              inputs,
 		numPromptInputs:     len(inputs),
 		numPromptInputs:     len(inputs),
 		startProcessingTime: startTime,
 		startProcessingTime: startTime,
@@ -146,19 +144,22 @@ func (s *Server) NewSequence(prompt string, images []ImageData, params NewSequen
 // inputs processes the prompt and images into a list of inputs
 // inputs processes the prompt and images into a list of inputs
 // by splitting the prompt on [img-<n>] tags, tokenizing text and
 // by splitting the prompt on [img-<n>] tags, tokenizing text and
 // decoding images
 // decoding images
-func (s *Server) inputs(prompt string, images []ImageData) ([]input, error) {
-	var inputs []input
+func (s *Server) inputs(ctx ml.Context, prompt string, images []ImageData) ([]model.Input, error) {
+	var inputs []model.Input
 	var parts []string
 	var parts []string
 	var matches [][]string
 	var matches [][]string
 
 
-	// TODO(jessegross): This can sometimes trigger for matching text in the
-	// user's prompt. We previously tried to avoid it by only looking for images
-	// on image models. We don't have a clear indication now but it would be better
-	// to properly escape it in any case.
-	re := regexp.MustCompile(`\[img-(\d+)\]`)
-	parts = re.Split(prompt, -1)
-	matches = re.FindAllStringSubmatch(prompt, -1)
+	multimodalProcessor, visionModel := s.model.(model.MultimodalProcessor)
 
 
+	if visionModel {
+		re := regexp.MustCompile(`\[img-(\d+)\]`)
+		parts = re.Split(prompt, -1)
+		matches = re.FindAllStringSubmatch(prompt, -1)
+	} else {
+		parts = []string{prompt}
+	}
+
+	postTokenize := false
 	for i, part := range parts {
 	for i, part := range parts {
 		// text - tokenize
 		// text - tokenize
 		tokens, err := s.model.(model.TextProcessor).Encode(part, i == 0)
 		tokens, err := s.model.(model.TextProcessor).Encode(part, i == 0)
@@ -167,7 +168,7 @@ func (s *Server) inputs(prompt string, images []ImageData) ([]input, error) {
 		}
 		}
 
 
 		for _, t := range tokens {
 		for _, t := range tokens {
-			inputs = append(inputs, input{token: t})
+			inputs = append(inputs, model.Input{Token: t})
 		}
 		}
 
 
 		// image - decode and store
 		// image - decode and store
@@ -186,12 +187,25 @@ func (s *Server) inputs(prompt string, images []ImageData) ([]input, error) {
 				return nil, fmt.Errorf("invalid image index: %d", n)
 				return nil, fmt.Errorf("invalid image index: %d", n)
 			}
 			}
 
 
-			image, _, err := image.Decode(bytes.NewReader(images[imageIndex].Data))
+			imageEmbeddings, err := multimodalProcessor.EncodeMultimodal(ctx, images[imageIndex].Data)
 			if err != nil {
 			if err != nil {
 				return nil, err
 				return nil, err
 			}
 			}
 
 
-			inputs = append(inputs, input{image: image})
+			s.multimodalHash.Reset()
+			_, _ = s.multimodalHash.Write(images[imageIndex].Data)
+			imageHash := s.multimodalHash.Sum64()
+
+			inputs = append(inputs, model.Input{Multimodal: imageEmbeddings, MultimodalHash: imageHash})
+			postTokenize = true
+		}
+	}
+
+	if visionModel && postTokenize {
+		var err error
+		inputs, err = multimodalProcessor.PostTokenize(ctx, inputs)
+		if err != nil {
+			return nil, err
 		}
 		}
 	}
 	}
 
 
@@ -238,6 +252,10 @@ type Server struct {
 
 
 	// next sequence for prompt processing to avoid starvation
 	// next sequence for prompt processing to avoid starvation
 	nextSeq int
 	nextSeq int
+
+	// multimodalHash generates hashes for comparing equality
+	// of non-text data
+	multimodalHash maphash.Hash
 }
 }
 
 
 func (s *Server) allNil() bool {
 func (s *Server) allNil() bool {
@@ -283,6 +301,7 @@ func (s *Server) removeSequence(seqIndex int, reason string) {
 	close(seq.responses)
 	close(seq.responses)
 	close(seq.embedding)
 	close(seq.embedding)
 	seq.cache.InUse = false
 	seq.cache.InUse = false
+	seq.ctx.Close()
 	s.seqs[seqIndex] = nil
 	s.seqs[seqIndex] = nil
 	s.seqsSem.Release(1)
 	s.seqsSem.Release(1)
 }
 }
@@ -311,7 +330,6 @@ func (s *Server) processBatch() error {
 	defer s.mu.Unlock()
 	defer s.mu.Unlock()
 
 
 	var options model.Options
 	var options model.Options
-	imgSeq := -1
 
 
 	seqIdx := s.nextSeq - 1
 	seqIdx := s.nextSeq - 1
 	for range s.seqs {
 	for range s.seqs {
@@ -330,7 +348,7 @@ func (s *Server) processBatch() error {
 
 
 		if !s.cache.enabled {
 		if !s.cache.enabled {
 			seq.inputs = append(seq.cache.Inputs, seq.inputs...)
 			seq.inputs = append(seq.cache.Inputs, seq.inputs...)
-			seq.cache.Inputs = []input{}
+			seq.cache.Inputs = []model.Input{}
 		}
 		}
 
 
 		for i, input := range seq.inputs {
 		for i, input := range seq.inputs {
@@ -349,25 +367,21 @@ func (s *Server) processBatch() error {
 				break
 				break
 			}
 			}
 
 
-			// TODO(jessegross): Image inputs need to be rethought - it's
-			// it doesn't work well for different types of models or multiple sequences
-			if input.image != nil {
-				if len(seq.pendingInputs) != len(options.Images) {
-					break
-				}
-
-				if imgSeq != seqIdx && imgSeq != -1 {
-					s.nextSeq = seqIdx
-					break
-				}
+			// TODO(jessegross): This is a workaround for generating an attention mask and also providing a hint
+			// to the encoder cache.
+			//
+			// Break the batch when switching from text to images so that images are always at the beginning.
+			if input.Multimodal != nil && !(len(seq.pendingInputs) == 0 ||
+				(len(options.Multimodal) > 0 && options.Multimodal[len(options.Multimodal)-1].Index == len(options.Inputs)-1)) {
+				s.nextSeq = seqIdx
+				break
+			}
 
 
-				imgSeq = seqIdx
-				options.Images = append(options.Images, input.image)
-				seq.pendingInputs = append(seq.pendingInputs, input)
-				continue
+			options.Inputs = append(options.Inputs, input.Token)
+			if input.Multimodal != nil {
+				options.Multimodal = append(options.Multimodal, model.MultimodalIndex{Index: len(options.Inputs) - 1, Multimodal: input.Multimodal})
 			}
 			}
 
 
-			options.Inputs = append(options.Inputs, input.token)
 			options.Positions = append(options.Positions, int32(len(seq.cache.Inputs)+len(seq.pendingInputs)))
 			options.Positions = append(options.Positions, int32(len(seq.cache.Inputs)+len(seq.pendingInputs)))
 			options.Sequences = append(options.Sequences, seq.cache.Id)
 			options.Sequences = append(options.Sequences, seq.cache.Id)
 
 
@@ -403,7 +417,7 @@ func (s *Server) processBatch() error {
 		// After calling Forward, pending inputs are now in the cache
 		// After calling Forward, pending inputs are now in the cache
 		if len(seq.pendingInputs) > 0 {
 		if len(seq.pendingInputs) > 0 {
 			seq.cache.Inputs = append(seq.cache.Inputs, seq.pendingInputs...)
 			seq.cache.Inputs = append(seq.cache.Inputs, seq.pendingInputs...)
-			seq.pendingInputs = []input{}
+			seq.pendingInputs = []model.Input{}
 		}
 		}
 
 
 		// don't sample prompt processing
 		// don't sample prompt processing
@@ -449,7 +463,7 @@ func (s *Server) processBatch() error {
 			return err
 			return err
 		}
 		}
 
 
-		seq.inputs = []input{{token: token}}
+		seq.inputs = []model.Input{{Token: token}}
 
 
 		seq.pendingResponses = append(seq.pendingResponses, piece)
 		seq.pendingResponses = append(seq.pendingResponses, piece)
 		sequence := strings.Join(seq.pendingResponses, "")
 		sequence := strings.Join(seq.pendingResponses, "")