Browse Source

model: Update encoder cache to use multimodal input processing handler

The encoder cache needs to know the position of images in the input
stream so that it knows when to delete them. Previously images didn't
have a position, so we implied one by breaking batches before an
image and then assuming the image was in the first position. However,
multimodal objects are now given explicit positions in the input
stream, so we can use that instead.

Breaking batches was also a way to simulate a cross attention mask
for mllama. However, given that it only supports a single sequence
and a single image, this mask doesn't serve any real purpose.
Removing the batch break does not appear to affect the quality of
the output.

Most of this is simply moving the input data structures to a new
package to avoid import cycles.
Jesse Gross 1 month ago
parent
commit
a1cda80bcb

+ 2 - 1
kvcache/cache.go

@@ -4,6 +4,7 @@ import (
 	"errors"
 	"errors"
 
 
 	"github.com/ollama/ollama/ml"
 	"github.com/ollama/ollama/ml"
+	"github.com/ollama/ollama/model/input"
 )
 )
 
 
 var (
 var (
@@ -51,7 +52,7 @@ type Cache interface {
 	// StartForward is called before the start of the model's forward pass.
 	// StartForward is called before the start of the model's forward pass.
 	// For each token in the coming batch, there must be a corresponding
 	// For each token in the coming batch, there must be a corresponding
 	// entry in positions and seqs.
 	// entry in positions and seqs.
-	StartForward(ctx ml.Context, positions []int32, seqs []int) error
+	StartForward(ctx ml.Context, opts input.Options) error
 
 
 	// CopyPrefix copies tokens in the range [0, len) from srcSeq to dstSeq
 	// CopyPrefix copies tokens in the range [0, len) from srcSeq to dstSeq
 	CopyPrefix(srcSeq, dstSeq int, len int32)
 	CopyPrefix(srcSeq, dstSeq int, len int32)

+ 7 - 6
kvcache/causal.go

@@ -8,6 +8,7 @@ import (
 	"slices"
 	"slices"
 
 
 	"github.com/ollama/ollama/ml"
 	"github.com/ollama/ollama/ml"
+	"github.com/ollama/ollama/model/input"
 )
 )
 
 
 type shiftFn func(ctx ml.Context, layer int, key, shift ml.Tensor) (ml.Tensor, error)
 type shiftFn func(ctx ml.Context, layer int, key, shift ml.Tensor) (ml.Tensor, error)
@@ -140,10 +141,10 @@ func (c *Causal) Close() {
 	}
 	}
 }
 }
 
 
-func (c *Causal) StartForward(ctx ml.Context, positions []int32, seqs []int) error {
-	c.curBatchSize = len(positions)
-	c.curSequences = seqs
-	c.curPositions = positions
+func (c *Causal) StartForward(ctx ml.Context, opts input.Options) error {
+	c.curBatchSize = len(opts.Positions)
+	c.curSequences = opts.Sequences
+	c.curPositions = opts.Positions
 
 
 	var err error
 	var err error
 	c.curLoc, err = c.findStartLoc()
 	c.curLoc, err = c.findStartLoc()
@@ -156,8 +157,8 @@ func (c *Causal) StartForward(ctx ml.Context, positions []int32, seqs []int) err
 	}
 	}
 
 
 	c.curCellRange = newRange()
 	c.curCellRange = newRange()
-	for i, pos := range positions {
-		seq := seqs[i]
+	for i, pos := range opts.Positions {
+		seq := opts.Sequences[i]
 
 
 		c.cells[c.curLoc+i] = cacheCell{pos: pos, sequences: []int{seq}}
 		c.cells[c.curLoc+i] = cacheCell{pos: pos, sequences: []int{seq}}
 
 

+ 2 - 1
kvcache/causal_test.go

@@ -6,6 +6,7 @@ import (
 	"testing"
 	"testing"
 
 
 	"github.com/ollama/ollama/ml"
 	"github.com/ollama/ollama/ml"
+	"github.com/ollama/ollama/model/input"
 )
 )
 
 
 type testCase struct {
 type testCase struct {
@@ -269,7 +270,7 @@ func testCache(t *testing.T, backend ml.Backend, cache Cache, tests []testCase)
 			context := backend.NewContext()
 			context := backend.NewContext()
 			defer context.Close()
 			defer context.Close()
 
 
-			err := cache.StartForward(context, test.pos, test.seqs)
+			err := cache.StartForward(context, input.Options{Positions: test.pos, Sequences: test.seqs})
 			if err != nil {
 			if err != nil {
 				panic(err)
 				panic(err)
 			}
 			}

+ 6 - 3
kvcache/encoder.go

@@ -4,6 +4,7 @@ import (
 	"fmt"
 	"fmt"
 
 
 	"github.com/ollama/ollama/ml"
 	"github.com/ollama/ollama/ml"
+	"github.com/ollama/ollama/model/input"
 )
 )
 
 
 // Encoder cache stores K and V tensors that are position independent
 // Encoder cache stores K and V tensors that are position independent
@@ -78,9 +79,11 @@ func (c *EncoderCache) Close() {
 	}
 	}
 }
 }
 
 
-func (c *EncoderCache) StartForward(ctx ml.Context, positions []int32, seqs []int) error {
-	// The image is always in the first position
-	c.curPos = positions[0]
+func (c *EncoderCache) StartForward(ctx ml.Context, opts input.Options) error {
+	// We work with the most recent image
+	if len(opts.Multimodal) > 0 {
+		c.curPos = opts.Positions[opts.Multimodal[len(opts.Multimodal)-1].Index]
+	}
 
 
 	return nil
 	return nil
 }
 }

+ 5 - 4
kvcache/wrapper.go

@@ -4,6 +4,7 @@ import (
 	"math"
 	"math"
 
 
 	"github.com/ollama/ollama/ml"
 	"github.com/ollama/ollama/ml"
+	"github.com/ollama/ollama/model/input"
 )
 )
 
 
 // Wrapper cache is a container for multiple types of caches,
 // Wrapper cache is a container for multiple types of caches,
@@ -40,14 +41,14 @@ func (c *WrapperCache) Close() {
 	}
 	}
 }
 }
 
 
-func (c *WrapperCache) StartForward(ctx ml.Context, positions []int32, seqs []int) error {
+func (c *WrapperCache) StartForward(ctx ml.Context, opts input.Options) error {
 	for i, cache := range c.caches {
 	for i, cache := range c.caches {
-		err := cache.StartForward(ctx, positions, seqs)
+		err := cache.StartForward(ctx, opts)
 		if err != nil {
 		if err != nil {
 			// unwind on error - Remove with endIndex set to math.MaxInt32 does not fail
 			// unwind on error - Remove with endIndex set to math.MaxInt32 does not fail
 			for j := i - 1; j >= 0; j-- {
 			for j := i - 1; j >= 0; j-- {
-				for k := range positions {
-					_ = c.caches[j].Remove(seqs[k], positions[k], math.MaxInt32)
+				for k := range opts.Positions {
+					_ = c.caches[j].Remove(opts.Sequences[k], opts.Positions[k], math.MaxInt32)
 				}
 				}
 			}
 			}
 			return err
 			return err

+ 37 - 0
model/input/input.go

@@ -0,0 +1,37 @@
+package input
+
+// Input represents one token in the input stream
+type Input struct {
+	// Token is a single element of text.
+	Token int32
+
+	// Multimodal is opaque data representing a non-text
+	// element such as an image (or part of one if the image
+	// can be processed in pieces). It may be either together
+	// with Token or on its own.
+	Multimodal any
+
+	// MultimodalHash is a unique representation of the data
+	// stored in Multimodal, used for caching and comparing
+	// equality.
+	MultimodalHash uint64
+}
+
+// MultimodalIndex is a multimodal element (such as an image)
+// together with an index into the slice of Inputs with the
+// corresponding token. Note that the index is not the same
+// as the position - to find that use the index with the
+// Positions slice.
+type MultimodalIndex struct {
+	Index      int
+	Multimodal any
+}
+
+// Options contains the inputs for a model forward pass
+type Options struct {
+	Inputs     []int32
+	Multimodal []MultimodalIndex
+	Positions  []int32
+	Sequences  []int
+	Outputs    []int32
+}

+ 24 - 59
model/model.go

@@ -19,66 +19,12 @@ import (
 	"github.com/ollama/ollama/kvcache"
 	"github.com/ollama/ollama/kvcache"
 	"github.com/ollama/ollama/ml"
 	"github.com/ollama/ollama/ml"
 	_ "github.com/ollama/ollama/ml/backend"
 	_ "github.com/ollama/ollama/ml/backend"
+	"github.com/ollama/ollama/model/input"
 )
 )
 
 
-// Input represents one token in the input stream
-type Input struct {
-	// Token is a single element of text.
-	Token int32
-
-	// Multimodal is opaque data representing a non-text
-	// element such as an image (or part of one if the image
-	// can be processed in pieces). It may be either together
-	// with Token or on its own.
-	Multimodal any
-
-	// MultimodalHash is a unique representation of the data
-	// stored in Multimodal, used for caching and comparing
-	// equality.
-	MultimodalHash uint64
-}
-
-// MultimodalIndex is a multimodal element (such as an image)
-// together with an index into the slice of Inputs with the
-// corresponding token. Note that the index is not the same
-// as the position - to find that use the index with the
-// Positions slice.
-type MultimodalIndex struct {
-	Index      int
-	Multimodal any
-}
-
-// Options contains the inputs for a model forward pass
-type Options struct {
-	Inputs     []int32
-	Multimodal []MultimodalIndex
-	Positions  []int32
-	Sequences  []int
-	Outputs    []int32
-}
-
-type config struct {
-	Cache kvcache.Cache
-}
-
-// Base implements the common fields and methods for all models
-type Base struct {
-	b ml.Backend
-	config
-}
-
-// Backend returns the underlying backend that will run the model
-func (m *Base) Backend() ml.Backend {
-	return m.b
-}
-
-func (m *Base) Config() config {
-	return m.config
-}
-
 // Model implements a specific model architecture, defining the forward pass and any model-specific configuration
 // Model implements a specific model architecture, defining the forward pass and any model-specific configuration
 type Model interface {
 type Model interface {
-	Forward(ml.Context, Options) (ml.Tensor, error)
+	Forward(ml.Context, input.Options) (ml.Tensor, error)
 
 
 	Backend() ml.Backend
 	Backend() ml.Backend
 	Config() config
 	Config() config
@@ -112,7 +58,26 @@ type MultimodalProcessor interface {
 	// This function is also responsible for updating MultimodalHash for any Multimodal
 	// This function is also responsible for updating MultimodalHash for any Multimodal
 	// that is modified to ensure that there is a unique hash value that accurately
 	// that is modified to ensure that there is a unique hash value that accurately
 	// represents the contents.
 	// represents the contents.
-	PostTokenize(ml.Context, []Input) ([]Input, error)
+	PostTokenize(ml.Context, []input.Input) ([]input.Input, error)
+}
+
+// Base implements the common fields and methods for all models
+type Base struct {
+	b ml.Backend
+	config
+}
+
+type config struct {
+	Cache kvcache.Cache
+}
+
+// Backend returns the underlying backend that will run the model
+func (m *Base) Backend() ml.Backend {
+	return m.b
+}
+
+func (m *Base) Config() config {
+	return m.config
 }
 }
 
 
 var models = make(map[string]func(ml.Config) (Model, error))
 var models = make(map[string]func(ml.Config) (Model, error))
@@ -313,7 +278,7 @@ func canNil(t reflect.Type) bool {
 		t.Kind() == reflect.Slice
 		t.Kind() == reflect.Slice
 }
 }
 
 
-func Forward(ctx ml.Context, m Model, opts Options) (ml.Tensor, error) {
+func Forward(ctx ml.Context, m Model, opts input.Options) (ml.Tensor, error) {
 	if len(opts.Positions) != len(opts.Sequences) {
 	if len(opts.Positions) != len(opts.Sequences) {
 		return nil, fmt.Errorf("length of positions (%v) must match length of seqs (%v)", len(opts.Positions), len(opts.Sequences))
 		return nil, fmt.Errorf("length of positions (%v) must match length of seqs (%v)", len(opts.Positions), len(opts.Sequences))
 	}
 	}
@@ -324,7 +289,7 @@ func Forward(ctx ml.Context, m Model, opts Options) (ml.Tensor, error) {
 
 
 	cache := m.Config().Cache
 	cache := m.Config().Cache
 	if cache != nil {
 	if cache != nil {
-		err := cache.StartForward(ctx, opts.Positions, opts.Sequences)
+		err := cache.StartForward(ctx, opts)
 		if err != nil {
 		if err != nil {
 			return nil, err
 			return nil, err
 		}
 		}

+ 2 - 1
model/model_test.go

@@ -11,6 +11,7 @@ import (
 	"github.com/ollama/ollama/ml"
 	"github.com/ollama/ollama/ml"
 	"github.com/ollama/ollama/ml/backend/ggml"
 	"github.com/ollama/ollama/ml/backend/ggml"
 	"github.com/ollama/ollama/ml/nn"
 	"github.com/ollama/ollama/ml/nn"
+	"github.com/ollama/ollama/model/input"
 )
 )
 
 
 func TestParseTags(t *testing.T) {
 func TestParseTags(t *testing.T) {
@@ -162,7 +163,7 @@ func TestGetTextProcessor(t *testing.T) {
 
 
 type notTextProcessorModel struct{}
 type notTextProcessorModel struct{}
 
 
-func (notTextProcessorModel) Forward(ml.Context, Options) (ml.Tensor, error) {
+func (notTextProcessorModel) Forward(ml.Context, input.Options) (ml.Tensor, error) {
 	panic("unimplemented")
 	panic("unimplemented")
 }
 }
 
 

+ 2 - 1
model/models/llama/model.go

@@ -9,6 +9,7 @@ import (
 	"github.com/ollama/ollama/ml"
 	"github.com/ollama/ollama/ml"
 	"github.com/ollama/ollama/ml/nn"
 	"github.com/ollama/ollama/ml/nn"
 	"github.com/ollama/ollama/model"
 	"github.com/ollama/ollama/model"
+	"github.com/ollama/ollama/model/input"
 )
 )
 
 
 type Options struct {
 type Options struct {
@@ -137,7 +138,7 @@ func (l *Layer) Forward(ctx ml.Context, hiddenState, positionIDs, outputs ml.Ten
 	return hiddenState.Add(ctx, residual)
 	return hiddenState.Add(ctx, residual)
 }
 }
 
 
-func (m *Model) Forward(ctx ml.Context, opts model.Options) (ml.Tensor, error) {
+func (m *Model) Forward(ctx ml.Context, opts input.Options) (ml.Tensor, error) {
 	inputs, err := ctx.Input().FromIntSlice(opts.Inputs, len(opts.Inputs))
 	inputs, err := ctx.Input().FromIntSlice(opts.Inputs, len(opts.Inputs))
 	if err != nil {
 	if err != nil {
 		return nil, err
 		return nil, err

+ 7 - 6
model/models/mllama/model.go

@@ -12,6 +12,7 @@ import (
 	"github.com/ollama/ollama/ml"
 	"github.com/ollama/ollama/ml"
 	"github.com/ollama/ollama/ml/nn"
 	"github.com/ollama/ollama/ml/nn"
 	"github.com/ollama/ollama/model"
 	"github.com/ollama/ollama/model"
+	"github.com/ollama/ollama/model/input"
 )
 )
 
 
 type Model struct {
 type Model struct {
@@ -101,8 +102,8 @@ func (m *Model) EncodeMultimodal(ctx ml.Context, multimodalData []byte) (any, er
 	return m.Projector.Forward(ctx, crossAttentionStates), nil
 	return m.Projector.Forward(ctx, crossAttentionStates), nil
 }
 }
 
 
-func (m *Model) PostTokenize(ctx ml.Context, inputs []model.Input) ([]model.Input, error) {
-	var images []model.Input
+func (m *Model) PostTokenize(ctx ml.Context, inputs []input.Input) ([]input.Input, error) {
+	var images []input.Input
 	fnvHash := fnv.New64a()
 	fnvHash := fnv.New64a()
 
 
 	for i := range inputs {
 	for i := range inputs {
@@ -125,15 +126,15 @@ func (m *Model) PostTokenize(ctx ml.Context, inputs []model.Input) ([]model.Inpu
 		}
 		}
 	}
 	}
 
 
-	inputs = slices.DeleteFunc(inputs, func(input model.Input) bool { return input.Token == -1 })
+	inputs = slices.DeleteFunc(inputs, func(input input.Input) bool { return input.Token == -1 })
 
 
 	return inputs, nil
 	return inputs, nil
 }
 }
 
 
-func (m *Model) Forward(ctx ml.Context, opts model.Options) (ml.Tensor, error) {
+func (m *Model) Forward(ctx ml.Context, opts input.Options) (ml.Tensor, error) {
 	var crossAttentionStates ml.Tensor
 	var crossAttentionStates ml.Tensor
-	if opts.Multimodal != nil {
-		crossAttentionStates = opts.Multimodal[0].Multimodal.(ml.Tensor)
+	if len(opts.Multimodal) > 0 {
+		crossAttentionStates = opts.Multimodal[len(opts.Multimodal)-1].Multimodal.(ml.Tensor)
 	}
 	}
 
 
 	inputs, err := ctx.Input().FromIntSlice(opts.Inputs, len(opts.Inputs))
 	inputs, err := ctx.Input().FromIntSlice(opts.Inputs, len(opts.Inputs))

+ 7 - 6
runner/ollamarunner/cache.go

@@ -10,6 +10,7 @@ import (
 	"github.com/ollama/ollama/kvcache"
 	"github.com/ollama/ollama/kvcache"
 	"github.com/ollama/ollama/ml"
 	"github.com/ollama/ollama/ml"
 	"github.com/ollama/ollama/model"
 	"github.com/ollama/ollama/model"
+	"github.com/ollama/ollama/model/input"
 )
 )
 
 
 type InputCache struct {
 type InputCache struct {
@@ -79,7 +80,7 @@ type InputCacheSlot struct {
 	Id int
 	Id int
 
 
 	// Inputs that are stored in the KV cache
 	// Inputs that are stored in the KV cache
-	Inputs []model.Input
+	Inputs []input.Input
 
 
 	// is this cache actively being processed as part of a sequence?
 	// is this cache actively being processed as part of a sequence?
 	InUse bool
 	InUse bool
@@ -88,7 +89,7 @@ type InputCacheSlot struct {
 	lastUsed time.Time
 	lastUsed time.Time
 }
 }
 
 
-func (c *InputCache) LoadCacheSlot(prompt []model.Input, cachePrompt bool) (*InputCacheSlot, []model.Input, error) {
+func (c *InputCache) LoadCacheSlot(prompt []input.Input, cachePrompt bool) (*InputCacheSlot, []input.Input, error) {
 	var slot *InputCacheSlot
 	var slot *InputCacheSlot
 	var numPast int32
 	var numPast int32
 	var err error
 	var err error
@@ -139,7 +140,7 @@ func (c *InputCache) LoadCacheSlot(prompt []model.Input, cachePrompt bool) (*Inp
 	return slot, prompt, nil
 	return slot, prompt, nil
 }
 }
 
 
-func (c *InputCache) findLongestCacheSlot(prompt []model.Input) (*InputCacheSlot, int32, error) {
+func (c *InputCache) findLongestCacheSlot(prompt []input.Input) (*InputCacheSlot, int32, error) {
 	longest := int32(-1)
 	longest := int32(-1)
 	var longestSlot *InputCacheSlot
 	var longestSlot *InputCacheSlot
 
 
@@ -162,7 +163,7 @@ func (c *InputCache) findLongestCacheSlot(prompt []model.Input) (*InputCacheSlot
 	return longestSlot, longest, nil
 	return longestSlot, longest, nil
 }
 }
 
 
-func (c *InputCache) findBestCacheSlot(prompt []model.Input) (*InputCacheSlot, int32, error) {
+func (c *InputCache) findBestCacheSlot(prompt []input.Input) (*InputCacheSlot, int32, error) {
 	oldest := time.Now()
 	oldest := time.Now()
 	var oldestSlot *InputCacheSlot
 	var oldestSlot *InputCacheSlot
 
 
@@ -198,7 +199,7 @@ func (c *InputCache) findBestCacheSlot(prompt []model.Input) (*InputCacheSlot, i
 	if longest > 0 && longestSlot != oldestSlot {
 	if longest > 0 && longestSlot != oldestSlot {
 		slog.Debug("forking cache slot", "src", longestSlot.Id, "dst", oldestSlot.Id, "inputs", longest, "total",
 		slog.Debug("forking cache slot", "src", longestSlot.Id, "dst", oldestSlot.Id, "inputs", longest, "total",
 			len(longestSlot.Inputs))
 			len(longestSlot.Inputs))
-		oldestSlot.Inputs = make([]model.Input, longest)
+		oldestSlot.Inputs = make([]input.Input, longest)
 		copy(oldestSlot.Inputs, longestSlot.Inputs[:longest])
 		copy(oldestSlot.Inputs, longestSlot.Inputs[:longest])
 		if c.cache != nil {
 		if c.cache != nil {
 			c.cache.CopyPrefix(longestSlot.Id, oldestSlot.Id, longest)
 			c.cache.CopyPrefix(longestSlot.Id, oldestSlot.Id, longest)
@@ -208,7 +209,7 @@ func (c *InputCache) findBestCacheSlot(prompt []model.Input) (*InputCacheSlot, i
 	return oldestSlot, longest, nil
 	return oldestSlot, longest, nil
 }
 }
 
 
-func countCommonPrefix(a []model.Input, b []model.Input) int32 {
+func countCommonPrefix(a []input.Input, b []input.Input) int32 {
 	var count int32
 	var count int32
 
 
 	for i := range a {
 	for i := range a {

+ 36 - 36
runner/ollamarunner/cache_test.go

@@ -5,7 +5,7 @@ import (
 	"testing"
 	"testing"
 	"time"
 	"time"
 
 
-	"github.com/ollama/ollama/model"
+	"github.com/ollama/ollama/model/input"
 )
 )
 
 
 func TestCountCommon(t *testing.T) {
 func TestCountCommon(t *testing.T) {
@@ -15,50 +15,50 @@ func TestCountCommon(t *testing.T) {
 
 
 	tests := []struct {
 	tests := []struct {
 		name     string
 		name     string
-		t1       []model.Input
-		t2       []model.Input
+		t1       []input.Input
+		t2       []input.Input
 		expected int32
 		expected int32
 	}{
 	}{
 		{
 		{
 			name:     "Equal",
 			name:     "Equal",
-			t1:       []model.Input{{Token: 1}, {Token: 2}, {Token: 3}},
-			t2:       []model.Input{{Token: 1}, {Token: 2}, {Token: 3}},
+			t1:       []input.Input{{Token: 1}, {Token: 2}, {Token: 3}},
+			t2:       []input.Input{{Token: 1}, {Token: 2}, {Token: 3}},
 			expected: 3,
 			expected: 3,
 		},
 		},
 		{
 		{
 			name:     "Prefix",
 			name:     "Prefix",
-			t1:       []model.Input{{Token: 1}},
-			t2:       []model.Input{{Token: 1}, {Token: 2}, {Token: 3}},
+			t1:       []input.Input{{Token: 1}},
+			t2:       []input.Input{{Token: 1}, {Token: 2}, {Token: 3}},
 			expected: 1,
 			expected: 1,
 		},
 		},
 		{
 		{
 			name:     "Image Prefix",
 			name:     "Image Prefix",
-			t1:       []model.Input{{Multimodal: imgA, MultimodalHash: 1}},
-			t2:       []model.Input{{Multimodal: imgA, MultimodalHash: 1}, {Multimodal: imgB, MultimodalHash: 2}, {Multimodal: imgC, MultimodalHash: 3}},
+			t1:       []input.Input{{Multimodal: imgA, MultimodalHash: 1}},
+			t2:       []input.Input{{Multimodal: imgA, MultimodalHash: 1}, {Multimodal: imgB, MultimodalHash: 2}, {Multimodal: imgC, MultimodalHash: 3}},
 			expected: 1,
 			expected: 1,
 		},
 		},
 		{
 		{
 			name:     "Mixed",
 			name:     "Mixed",
-			t1:       []model.Input{{Token: 1}, {Multimodal: imgA, MultimodalHash: 1}},
-			t2:       []model.Input{{Token: 1}, {Multimodal: imgA, MultimodalHash: 1}, {Token: 5}},
+			t1:       []input.Input{{Token: 1}, {Multimodal: imgA, MultimodalHash: 1}},
+			t2:       []input.Input{{Token: 1}, {Multimodal: imgA, MultimodalHash: 1}, {Token: 5}},
 			expected: 2,
 			expected: 2,
 		},
 		},
 		{
 		{
 			name:     "Mixed, Same Length",
 			name:     "Mixed, Same Length",
-			t1:       []model.Input{{Token: 1}, {Multimodal: imgA, MultimodalHash: 1}},
-			t2:       []model.Input{{Token: 1}, {Multimodal: imgB, MultimodalHash: 2}},
+			t1:       []input.Input{{Token: 1}, {Multimodal: imgA, MultimodalHash: 1}},
+			t2:       []input.Input{{Token: 1}, {Multimodal: imgB, MultimodalHash: 2}},
 			expected: 1,
 			expected: 1,
 		},
 		},
 		{
 		{
 			name:     "Empty",
 			name:     "Empty",
-			t1:       []model.Input{},
-			t2:       []model.Input{{Token: 1}, {Token: 2}, {Token: 3}},
+			t1:       []input.Input{},
+			t2:       []input.Input{{Token: 1}, {Token: 2}, {Token: 3}},
 			expected: 0,
 			expected: 0,
 		},
 		},
 		{
 		{
 			name:     "Both Empty",
 			name:     "Both Empty",
-			t1:       []model.Input{},
-			t2:       []model.Input{},
+			t1:       []input.Input{},
+			t2:       []input.Input{},
 			expected: 0,
 			expected: 0,
 		},
 		},
 	}
 	}
@@ -82,7 +82,7 @@ func TestFindCacheSlot(t *testing.T) {
 	tests := []struct {
 	tests := []struct {
 		name    string
 		name    string
 		cache   InputCache
 		cache   InputCache
-		prompt  []model.Input
+		prompt  []input.Input
 		longest expected
 		longest expected
 		best    expected
 		best    expected
 	}{
 	}{
@@ -91,18 +91,18 @@ func TestFindCacheSlot(t *testing.T) {
 			cache: InputCache{slots: []InputCacheSlot{
 			cache: InputCache{slots: []InputCacheSlot{
 				{
 				{
 					Id:       0,
 					Id:       0,
-					Inputs:   []model.Input{},
+					Inputs:   []input.Input{},
 					InUse:    false,
 					InUse:    false,
 					lastUsed: time.Time{},
 					lastUsed: time.Time{},
 				},
 				},
 				{
 				{
 					Id:       1,
 					Id:       1,
-					Inputs:   []model.Input{},
+					Inputs:   []input.Input{},
 					InUse:    false,
 					InUse:    false,
 					lastUsed: time.Time{},
 					lastUsed: time.Time{},
 				},
 				},
 			}},
 			}},
-			prompt:  []model.Input{{Token: 1}},
+			prompt:  []input.Input{{Token: 1}},
 			longest: expected{result: 0, len: 0},
 			longest: expected{result: 0, len: 0},
 			best:    expected{result: 0, len: 0},
 			best:    expected{result: 0, len: 0},
 		},
 		},
@@ -111,18 +111,18 @@ func TestFindCacheSlot(t *testing.T) {
 			cache: InputCache{slots: []InputCacheSlot{
 			cache: InputCache{slots: []InputCacheSlot{
 				{
 				{
 					Id:       0,
 					Id:       0,
-					Inputs:   []model.Input{{Token: 1}},
+					Inputs:   []input.Input{{Token: 1}},
 					InUse:    false,
 					InUse:    false,
 					lastUsed: time.Now().Add(-time.Second),
 					lastUsed: time.Now().Add(-time.Second),
 				},
 				},
 				{
 				{
 					Id:       1,
 					Id:       1,
-					Inputs:   []model.Input{{Token: 1}, {Token: 2}},
+					Inputs:   []input.Input{{Token: 1}, {Token: 2}},
 					InUse:    false,
 					InUse:    false,
 					lastUsed: time.Now().Add(-2 * time.Second),
 					lastUsed: time.Now().Add(-2 * time.Second),
 				},
 				},
 			}},
 			}},
-			prompt:  []model.Input{{Token: 1}, {Token: 2}},
+			prompt:  []input.Input{{Token: 1}, {Token: 2}},
 			longest: expected{result: 1, len: 2},
 			longest: expected{result: 1, len: 2},
 			best:    expected{result: 1, len: 2},
 			best:    expected{result: 1, len: 2},
 		},
 		},
@@ -131,18 +131,18 @@ func TestFindCacheSlot(t *testing.T) {
 			cache: InputCache{slots: []InputCacheSlot{
 			cache: InputCache{slots: []InputCacheSlot{
 				{
 				{
 					Id:       0,
 					Id:       0,
-					Inputs:   []model.Input{{Token: 1}, {Token: 2}},
+					Inputs:   []input.Input{{Token: 1}, {Token: 2}},
 					InUse:    false,
 					InUse:    false,
 					lastUsed: time.Now().Add(-time.Second),
 					lastUsed: time.Now().Add(-time.Second),
 				},
 				},
 				{
 				{
 					Id:       1,
 					Id:       1,
-					Inputs:   []model.Input{},
+					Inputs:   []input.Input{},
 					InUse:    false,
 					InUse:    false,
 					lastUsed: time.Time{},
 					lastUsed: time.Time{},
 				},
 				},
 			}},
 			}},
-			prompt:  []model.Input{{Token: 2}},
+			prompt:  []input.Input{{Token: 2}},
 			longest: expected{result: 0, len: 0},
 			longest: expected{result: 0, len: 0},
 			best:    expected{result: 1, len: 0},
 			best:    expected{result: 1, len: 0},
 		},
 		},
@@ -152,19 +152,19 @@ func TestFindCacheSlot(t *testing.T) {
 				slots: []InputCacheSlot{
 				slots: []InputCacheSlot{
 					{
 					{
 						Id:       0,
 						Id:       0,
-						Inputs:   []model.Input{{Token: 1}, {Token: 2}},
+						Inputs:   []input.Input{{Token: 1}, {Token: 2}},
 						InUse:    false,
 						InUse:    false,
 						lastUsed: time.Now().Add(-time.Second),
 						lastUsed: time.Now().Add(-time.Second),
 					},
 					},
 					{
 					{
 						Id:       1,
 						Id:       1,
-						Inputs:   []model.Input{},
+						Inputs:   []input.Input{},
 						InUse:    false,
 						InUse:    false,
 						lastUsed: time.Time{},
 						lastUsed: time.Time{},
 					},
 					},
 				},
 				},
 			},
 			},
-			prompt:  []model.Input{{Token: 1}},
+			prompt:  []input.Input{{Token: 1}},
 			longest: expected{result: 0, len: 1},
 			longest: expected{result: 0, len: 1},
 			best:    expected{result: 1, len: 1},
 			best:    expected{result: 1, len: 1},
 		},
 		},
@@ -173,18 +173,18 @@ func TestFindCacheSlot(t *testing.T) {
 			cache: InputCache{slots: []InputCacheSlot{
 			cache: InputCache{slots: []InputCacheSlot{
 				{
 				{
 					Id:       0,
 					Id:       0,
-					Inputs:   []model.Input{{Token: 1}},
+					Inputs:   []input.Input{{Token: 1}},
 					InUse:    false,
 					InUse:    false,
 					lastUsed: time.Now().Add(-time.Second),
 					lastUsed: time.Now().Add(-time.Second),
 				},
 				},
 				{
 				{
 					Id:       1,
 					Id:       1,
-					Inputs:   []model.Input{{Token: 1}, {Token: 2}},
+					Inputs:   []input.Input{{Token: 1}, {Token: 2}},
 					InUse:    false,
 					InUse:    false,
 					lastUsed: time.Now().Add(-2 * time.Second),
 					lastUsed: time.Now().Add(-2 * time.Second),
 				},
 				},
 			}},
 			}},
-			prompt:  []model.Input{{Token: 2}, {Token: 3}},
+			prompt:  []input.Input{{Token: 2}, {Token: 3}},
 			longest: expected{result: 0, len: 0},
 			longest: expected{result: 0, len: 0},
 			best:    expected{result: 1, len: 0},
 			best:    expected{result: 1, len: 0},
 		},
 		},
@@ -193,18 +193,18 @@ func TestFindCacheSlot(t *testing.T) {
 			cache: InputCache{slots: []InputCacheSlot{
 			cache: InputCache{slots: []InputCacheSlot{
 				{
 				{
 					Id:       0,
 					Id:       0,
-					Inputs:   []model.Input{{Token: 1}, {Token: 2}},
+					Inputs:   []input.Input{{Token: 1}, {Token: 2}},
 					InUse:    true,
 					InUse:    true,
 					lastUsed: time.Now().Add(-time.Second),
 					lastUsed: time.Now().Add(-time.Second),
 				},
 				},
 				{
 				{
 					Id:       1,
 					Id:       1,
-					Inputs:   []model.Input{{Token: 1}},
+					Inputs:   []input.Input{{Token: 1}},
 					InUse:    false,
 					InUse:    false,
 					lastUsed: time.Now().Add(-2 * time.Second),
 					lastUsed: time.Now().Add(-2 * time.Second),
 				},
 				},
 			}},
 			}},
-			prompt:  []model.Input{{Token: 1}, {Token: 2}},
+			prompt:  []input.Input{{Token: 1}, {Token: 2}},
 			longest: expected{result: 1, len: 1},
 			longest: expected{result: 1, len: 1},
 			best:    expected{result: 1, len: 2},
 			best:    expected{result: 1, len: 2},
 		},
 		},

+ 20 - 36
runner/ollamarunner/runner.go

@@ -26,6 +26,7 @@ import (
 	"github.com/ollama/ollama/api"
 	"github.com/ollama/ollama/api"
 	"github.com/ollama/ollama/ml"
 	"github.com/ollama/ollama/ml"
 	"github.com/ollama/ollama/model"
 	"github.com/ollama/ollama/model"
+	"github.com/ollama/ollama/model/input"
 	"github.com/ollama/ollama/runner/common"
 	"github.com/ollama/ollama/runner/common"
 	"github.com/ollama/ollama/sample"
 	"github.com/ollama/ollama/sample"
 
 
@@ -41,10 +42,10 @@ type Sequence struct {
 	iBatch int
 	iBatch int
 
 
 	// prompt inputs left to evaluate
 	// prompt inputs left to evaluate
-	inputs []model.Input
+	inputs []input.Input
 
 
 	// inputs that have been added to a batch but not yet submitted to Forward
 	// inputs that have been added to a batch but not yet submitted to Forward
-	pendingInputs []model.Input
+	pendingInputs []input.Input
 
 
 	// tokens that have been generated but not returned yet (e.g. for stop sequences)
 	// tokens that have been generated but not returned yet (e.g. for stop sequences)
 	pendingResponses []string
 	pendingResponses []string
@@ -144,8 +145,8 @@ func (s *Server) NewSequence(prompt string, images []ImageData, params NewSequen
 // inputs processes the prompt and images into a list of inputs
 // inputs processes the prompt and images into a list of inputs
 // by splitting the prompt on [img-<n>] tags, tokenizing text and
 // by splitting the prompt on [img-<n>] tags, tokenizing text and
 // decoding images
 // decoding images
-func (s *Server) inputs(ctx ml.Context, prompt string, images []ImageData) ([]model.Input, error) {
-	var inputs []model.Input
+func (s *Server) inputs(ctx ml.Context, prompt string, images []ImageData) ([]input.Input, error) {
+	var inputs []input.Input
 	var parts []string
 	var parts []string
 	var matches [][]string
 	var matches [][]string
 
 
@@ -168,7 +169,7 @@ func (s *Server) inputs(ctx ml.Context, prompt string, images []ImageData) ([]mo
 		}
 		}
 
 
 		for _, t := range tokens {
 		for _, t := range tokens {
-			inputs = append(inputs, model.Input{Token: t})
+			inputs = append(inputs, input.Input{Token: t})
 		}
 		}
 
 
 		// image - decode and store
 		// image - decode and store
@@ -196,7 +197,7 @@ func (s *Server) inputs(ctx ml.Context, prompt string, images []ImageData) ([]mo
 			_, _ = s.multimodalHash.Write(images[imageIndex].Data)
 			_, _ = s.multimodalHash.Write(images[imageIndex].Data)
 			imageHash := s.multimodalHash.Sum64()
 			imageHash := s.multimodalHash.Sum64()
 
 
-			inputs = append(inputs, model.Input{Multimodal: imageEmbeddings, MultimodalHash: imageHash})
+			inputs = append(inputs, input.Input{Multimodal: imageEmbeddings, MultimodalHash: imageHash})
 			postTokenize = true
 			postTokenize = true
 		}
 		}
 	}
 	}
@@ -250,9 +251,6 @@ type Server struct {
 	// KV cache
 	// KV cache
 	cache *InputCache
 	cache *InputCache
 
 
-	// next sequence for prompt processing to avoid starvation
-	nextSeq int
-
 	// multimodalHash generates hashes for comparing equality
 	// multimodalHash generates hashes for comparing equality
 	// of non-text data
 	// of non-text data
 	multimodalHash maphash.Hash
 	multimodalHash maphash.Hash
@@ -329,29 +327,25 @@ func (s *Server) processBatch() error {
 	}
 	}
 	defer s.mu.Unlock()
 	defer s.mu.Unlock()
 
 
-	var options model.Options
-
-	seqIdx := s.nextSeq - 1
-	for range s.seqs {
-		seqIdx = (seqIdx + 1) % len(s.seqs)
-		seq := s.seqs[seqIdx]
+	var options input.Options
 
 
+	for i, seq := range s.seqs {
 		if seq == nil {
 		if seq == nil {
 			continue
 			continue
 		}
 		}
 
 
 		// if past the num predict limit
 		// if past the num predict limit
 		if seq.numPredict > 0 && seq.numPredicted >= seq.numPredict {
 		if seq.numPredict > 0 && seq.numPredicted >= seq.numPredict {
-			s.removeSequence(seqIdx, "limit")
+			s.removeSequence(i, "limit")
 			continue
 			continue
 		}
 		}
 
 
 		if !s.cache.enabled {
 		if !s.cache.enabled {
 			seq.inputs = append(seq.cache.Inputs, seq.inputs...)
 			seq.inputs = append(seq.cache.Inputs, seq.inputs...)
-			seq.cache.Inputs = []model.Input{}
+			seq.cache.Inputs = []input.Input{}
 		}
 		}
 
 
-		for i, input := range seq.inputs {
+		for j, inp := range seq.inputs {
 			if int32(len(seq.cache.Inputs)+len(seq.pendingInputs)+1) > s.cache.numCtx {
 			if int32(len(seq.cache.Inputs)+len(seq.pendingInputs)+1) > s.cache.numCtx {
 				if len(seq.pendingInputs) == 0 {
 				if len(seq.pendingInputs) == 0 {
 					err := s.cache.ShiftCacheSlot(seq.cache, seq.numKeep)
 					err := s.cache.ShiftCacheSlot(seq.cache, seq.numKeep)
@@ -363,33 +357,23 @@ func (s *Server) processBatch() error {
 				}
 				}
 			}
 			}
 
 
-			if i >= s.batchSize {
-				break
-			}
-
-			// TODO(jessegross): This is a workaround for generating an attention mask and also providing a hint
-			// to the encoder cache.
-			//
-			// Break the batch when switching from text to images so that images are always at the beginning.
-			if input.Multimodal != nil && !(len(seq.pendingInputs) == 0 ||
-				(len(options.Multimodal) > 0 && options.Multimodal[len(options.Multimodal)-1].Index == len(options.Inputs)-1)) {
-				s.nextSeq = seqIdx
+			if j >= s.batchSize {
 				break
 				break
 			}
 			}
 
 
-			options.Inputs = append(options.Inputs, input.Token)
-			if input.Multimodal != nil {
-				options.Multimodal = append(options.Multimodal, model.MultimodalIndex{Index: len(options.Inputs) - 1, Multimodal: input.Multimodal})
+			options.Inputs = append(options.Inputs, inp.Token)
+			if inp.Multimodal != nil {
+				options.Multimodal = append(options.Multimodal, input.MultimodalIndex{Index: len(options.Inputs) - 1, Multimodal: inp.Multimodal})
 			}
 			}
 
 
 			options.Positions = append(options.Positions, int32(len(seq.cache.Inputs)+len(seq.pendingInputs)))
 			options.Positions = append(options.Positions, int32(len(seq.cache.Inputs)+len(seq.pendingInputs)))
 			options.Sequences = append(options.Sequences, seq.cache.Id)
 			options.Sequences = append(options.Sequences, seq.cache.Id)
 
 
 			seq.iBatch = len(options.Outputs)
 			seq.iBatch = len(options.Outputs)
-			if i+1 == len(seq.inputs) {
+			if j+1 == len(seq.inputs) {
 				options.Outputs = append(options.Outputs, int32(len(options.Inputs)-1))
 				options.Outputs = append(options.Outputs, int32(len(options.Inputs)-1))
 			}
 			}
-			seq.pendingInputs = append(seq.pendingInputs, input)
+			seq.pendingInputs = append(seq.pendingInputs, inp)
 		}
 		}
 
 
 		seq.inputs = seq.inputs[len(seq.pendingInputs):]
 		seq.inputs = seq.inputs[len(seq.pendingInputs):]
@@ -417,7 +401,7 @@ func (s *Server) processBatch() error {
 		// After calling Forward, pending inputs are now in the cache
 		// After calling Forward, pending inputs are now in the cache
 		if len(seq.pendingInputs) > 0 {
 		if len(seq.pendingInputs) > 0 {
 			seq.cache.Inputs = append(seq.cache.Inputs, seq.pendingInputs...)
 			seq.cache.Inputs = append(seq.cache.Inputs, seq.pendingInputs...)
-			seq.pendingInputs = []model.Input{}
+			seq.pendingInputs = []input.Input{}
 		}
 		}
 
 
 		// don't sample prompt processing
 		// don't sample prompt processing
@@ -464,7 +448,7 @@ func (s *Server) processBatch() error {
 			return err
 			return err
 		}
 		}
 
 
-		seq.inputs = []model.Input{{Token: token}}
+		seq.inputs = []input.Input{{Token: token}}
 
 
 		seq.pendingResponses = append(seq.pendingResponses, piece)
 		seq.pendingResponses = append(seq.pendingResponses, piece)
 		sequence := strings.Join(seq.pendingResponses, "")
 		sequence := strings.Join(seq.pendingResponses, "")