Pārlūkot izejas kodu

model: Update encoder cache to use multimodal input processing handler

The encoder cache needs to know the position of images in the input
stream so that it knows when to delete them. Previously images didn't
have a position, so we implied one by breaking batches before an
image and then assuming the image was in the first position. However,
multimodal objects are now given explicit positions in the input
stream, so we can use that instead.

Breaking batches was also a way to simulate a cross attention mask
for mllama. However, given that it only supports a single sequence
and a single image, this mask doesn't serve any real purpose.
Removing the batch break does not appear to affect the quality of
the output.

Most of this is simply moving the input data structures to a new
package to avoid import cycles.
Jesse Gross 1 mēnesi atpakaļ
vecāks
revīzija
a1cda80bcb

+ 2 - 1
kvcache/cache.go

@@ -4,6 +4,7 @@ import (
 	"errors"
 
 	"github.com/ollama/ollama/ml"
+	"github.com/ollama/ollama/model/input"
 )
 
 var (
@@ -51,7 +52,7 @@ type Cache interface {
 	// StartForward is called before the start of the model's forward pass.
 	// For each token in the coming batch, there must be a corresponding
 	// entry in positions and seqs.
-	StartForward(ctx ml.Context, positions []int32, seqs []int) error
+	StartForward(ctx ml.Context, opts input.Options) error
 
 	// CopyPrefix copies tokens in the range [0, len) from srcSeq to dstSeq
 	CopyPrefix(srcSeq, dstSeq int, len int32)

+ 7 - 6
kvcache/causal.go

@@ -8,6 +8,7 @@ import (
 	"slices"
 
 	"github.com/ollama/ollama/ml"
+	"github.com/ollama/ollama/model/input"
 )
 
 type shiftFn func(ctx ml.Context, layer int, key, shift ml.Tensor) (ml.Tensor, error)
@@ -140,10 +141,10 @@ func (c *Causal) Close() {
 	}
 }
 
-func (c *Causal) StartForward(ctx ml.Context, positions []int32, seqs []int) error {
-	c.curBatchSize = len(positions)
-	c.curSequences = seqs
-	c.curPositions = positions
+func (c *Causal) StartForward(ctx ml.Context, opts input.Options) error {
+	c.curBatchSize = len(opts.Positions)
+	c.curSequences = opts.Sequences
+	c.curPositions = opts.Positions
 
 	var err error
 	c.curLoc, err = c.findStartLoc()
@@ -156,8 +157,8 @@ func (c *Causal) StartForward(ctx ml.Context, positions []int32, seqs []int) err
 	}
 
 	c.curCellRange = newRange()
-	for i, pos := range positions {
-		seq := seqs[i]
+	for i, pos := range opts.Positions {
+		seq := opts.Sequences[i]
 
 		c.cells[c.curLoc+i] = cacheCell{pos: pos, sequences: []int{seq}}
 

+ 2 - 1
kvcache/causal_test.go

@@ -6,6 +6,7 @@ import (
 	"testing"
 
 	"github.com/ollama/ollama/ml"
+	"github.com/ollama/ollama/model/input"
 )
 
 type testCase struct {
@@ -269,7 +270,7 @@ func testCache(t *testing.T, backend ml.Backend, cache Cache, tests []testCase)
 			context := backend.NewContext()
 			defer context.Close()
 
-			err := cache.StartForward(context, test.pos, test.seqs)
+			err := cache.StartForward(context, input.Options{Positions: test.pos, Sequences: test.seqs})
 			if err != nil {
 				panic(err)
 			}

+ 6 - 3
kvcache/encoder.go

@@ -4,6 +4,7 @@ import (
 	"fmt"
 
 	"github.com/ollama/ollama/ml"
+	"github.com/ollama/ollama/model/input"
 )
 
 // Encoder cache stores K and V tensors that are position independent
@@ -78,9 +79,11 @@ func (c *EncoderCache) Close() {
 	}
 }
 
-func (c *EncoderCache) StartForward(ctx ml.Context, positions []int32, seqs []int) error {
-	// The image is always in the first position
-	c.curPos = positions[0]
+func (c *EncoderCache) StartForward(ctx ml.Context, opts input.Options) error {
+	// We work with the most recent image
+	if len(opts.Multimodal) > 0 {
+		c.curPos = opts.Positions[opts.Multimodal[len(opts.Multimodal)-1].Index]
+	}
 
 	return nil
 }

+ 5 - 4
kvcache/wrapper.go

@@ -4,6 +4,7 @@ import (
 	"math"
 
 	"github.com/ollama/ollama/ml"
+	"github.com/ollama/ollama/model/input"
 )
 
 // Wrapper cache is a container for multiple types of caches,
@@ -40,14 +41,14 @@ func (c *WrapperCache) Close() {
 	}
 }
 
-func (c *WrapperCache) StartForward(ctx ml.Context, positions []int32, seqs []int) error {
+func (c *WrapperCache) StartForward(ctx ml.Context, opts input.Options) error {
 	for i, cache := range c.caches {
-		err := cache.StartForward(ctx, positions, seqs)
+		err := cache.StartForward(ctx, opts)
 		if err != nil {
 			// unwind on error - Remove with endIndex set to math.MaxInt32 does not fail
 			for j := i - 1; j >= 0; j-- {
-				for k := range positions {
-					_ = c.caches[j].Remove(seqs[k], positions[k], math.MaxInt32)
+				for k := range opts.Positions {
+					_ = c.caches[j].Remove(opts.Sequences[k], opts.Positions[k], math.MaxInt32)
 				}
 			}
 			return err

+ 37 - 0
model/input/input.go

@@ -0,0 +1,37 @@
+package input
+
+// Input represents one token in the input stream
+type Input struct {
+	// Token is a single element of text.
+	Token int32
+
+	// Multimodal is opaque data representing a non-text
+	// element such as an image (or part of one if the image
+	// can be processed in pieces). It may be either together
+	// with Token or on its own.
+	Multimodal any
+
+	// MultimodalHash is a unique representation of the data
+	// stored in Multimodal, used for caching and comparing
+	// equality.
+	MultimodalHash uint64
+}
+
+// MultimodalIndex is a multimodal element (such as an image)
+// together with an index into the slice of Inputs with the
+// corresponding token. Note that the index is not the same
+// as the position - to find that use the index with the
+// Positions slice.
+type MultimodalIndex struct {
+	Index      int
+	Multimodal any
+}
+
+// Options contains the inputs for a model forward pass
+type Options struct {
+	Inputs     []int32
+	Multimodal []MultimodalIndex
+	Positions  []int32
+	Sequences  []int
+	Outputs    []int32
+}

+ 24 - 59
model/model.go

@@ -19,66 +19,12 @@ import (
 	"github.com/ollama/ollama/kvcache"
 	"github.com/ollama/ollama/ml"
 	_ "github.com/ollama/ollama/ml/backend"
+	"github.com/ollama/ollama/model/input"
 )
 
-// Input represents one token in the input stream
-type Input struct {
-	// Token is a single element of text.
-	Token int32
-
-	// Multimodal is opaque data representing a non-text
-	// element such as an image (or part of one if the image
-	// can be processed in pieces). It may be either together
-	// with Token or on its own.
-	Multimodal any
-
-	// MultimodalHash is a unique representation of the data
-	// stored in Multimodal, used for caching and comparing
-	// equality.
-	MultimodalHash uint64
-}
-
-// MultimodalIndex is a multimodal element (such as an image)
-// together with an index into the slice of Inputs with the
-// corresponding token. Note that the index is not the same
-// as the position - to find that use the index with the
-// Positions slice.
-type MultimodalIndex struct {
-	Index      int
-	Multimodal any
-}
-
-// Options contains the inputs for a model forward pass
-type Options struct {
-	Inputs     []int32
-	Multimodal []MultimodalIndex
-	Positions  []int32
-	Sequences  []int
-	Outputs    []int32
-}
-
-type config struct {
-	Cache kvcache.Cache
-}
-
-// Base implements the common fields and methods for all models
-type Base struct {
-	b ml.Backend
-	config
-}
-
-// Backend returns the underlying backend that will run the model
-func (m *Base) Backend() ml.Backend {
-	return m.b
-}
-
-func (m *Base) Config() config {
-	return m.config
-}
-
 // Model implements a specific model architecture, defining the forward pass and any model-specific configuration
 type Model interface {
-	Forward(ml.Context, Options) (ml.Tensor, error)
+	Forward(ml.Context, input.Options) (ml.Tensor, error)
 
 	Backend() ml.Backend
 	Config() config
@@ -112,7 +58,26 @@ type MultimodalProcessor interface {
 	// This function is also responsible for updating MultimodalHash for any Multimodal
 	// that is modified to ensure that there is a unique hash value that accurately
 	// represents the contents.
-	PostTokenize(ml.Context, []Input) ([]Input, error)
+	PostTokenize(ml.Context, []input.Input) ([]input.Input, error)
+}
+
+// Base implements the common fields and methods for all models
+type Base struct {
+	b ml.Backend
+	config
+}
+
+type config struct {
+	Cache kvcache.Cache
+}
+
+// Backend returns the underlying backend that will run the model
+func (m *Base) Backend() ml.Backend {
+	return m.b
+}
+
+func (m *Base) Config() config {
+	return m.config
 }
 
 var models = make(map[string]func(ml.Config) (Model, error))
@@ -313,7 +278,7 @@ func canNil(t reflect.Type) bool {
 		t.Kind() == reflect.Slice
 }
 
-func Forward(ctx ml.Context, m Model, opts Options) (ml.Tensor, error) {
+func Forward(ctx ml.Context, m Model, opts input.Options) (ml.Tensor, error) {
 	if len(opts.Positions) != len(opts.Sequences) {
 		return nil, fmt.Errorf("length of positions (%v) must match length of seqs (%v)", len(opts.Positions), len(opts.Sequences))
 	}
@@ -324,7 +289,7 @@ func Forward(ctx ml.Context, m Model, opts Options) (ml.Tensor, error) {
 
 	cache := m.Config().Cache
 	if cache != nil {
-		err := cache.StartForward(ctx, opts.Positions, opts.Sequences)
+		err := cache.StartForward(ctx, opts)
 		if err != nil {
 			return nil, err
 		}

+ 2 - 1
model/model_test.go

@@ -11,6 +11,7 @@ import (
 	"github.com/ollama/ollama/ml"
 	"github.com/ollama/ollama/ml/backend/ggml"
 	"github.com/ollama/ollama/ml/nn"
+	"github.com/ollama/ollama/model/input"
 )
 
 func TestParseTags(t *testing.T) {
@@ -162,7 +163,7 @@ func TestGetTextProcessor(t *testing.T) {
 
 type notTextProcessorModel struct{}
 
-func (notTextProcessorModel) Forward(ml.Context, Options) (ml.Tensor, error) {
+func (notTextProcessorModel) Forward(ml.Context, input.Options) (ml.Tensor, error) {
 	panic("unimplemented")
 }
 

+ 2 - 1
model/models/llama/model.go

@@ -9,6 +9,7 @@ import (
 	"github.com/ollama/ollama/ml"
 	"github.com/ollama/ollama/ml/nn"
 	"github.com/ollama/ollama/model"
+	"github.com/ollama/ollama/model/input"
 )
 
 type Options struct {
@@ -137,7 +138,7 @@ func (l *Layer) Forward(ctx ml.Context, hiddenState, positionIDs, outputs ml.Ten
 	return hiddenState.Add(ctx, residual)
 }
 
-func (m *Model) Forward(ctx ml.Context, opts model.Options) (ml.Tensor, error) {
+func (m *Model) Forward(ctx ml.Context, opts input.Options) (ml.Tensor, error) {
 	inputs, err := ctx.Input().FromIntSlice(opts.Inputs, len(opts.Inputs))
 	if err != nil {
 		return nil, err

+ 7 - 6
model/models/mllama/model.go

@@ -12,6 +12,7 @@ import (
 	"github.com/ollama/ollama/ml"
 	"github.com/ollama/ollama/ml/nn"
 	"github.com/ollama/ollama/model"
+	"github.com/ollama/ollama/model/input"
 )
 
 type Model struct {
@@ -101,8 +102,8 @@ func (m *Model) EncodeMultimodal(ctx ml.Context, multimodalData []byte) (any, er
 	return m.Projector.Forward(ctx, crossAttentionStates), nil
 }
 
-func (m *Model) PostTokenize(ctx ml.Context, inputs []model.Input) ([]model.Input, error) {
-	var images []model.Input
+func (m *Model) PostTokenize(ctx ml.Context, inputs []input.Input) ([]input.Input, error) {
+	var images []input.Input
 	fnvHash := fnv.New64a()
 
 	for i := range inputs {
@@ -125,15 +126,15 @@ func (m *Model) PostTokenize(ctx ml.Context, inputs []model.Input) ([]model.Inpu
 		}
 	}
 
-	inputs = slices.DeleteFunc(inputs, func(input model.Input) bool { return input.Token == -1 })
+	inputs = slices.DeleteFunc(inputs, func(input input.Input) bool { return input.Token == -1 })
 
 	return inputs, nil
 }
 
-func (m *Model) Forward(ctx ml.Context, opts model.Options) (ml.Tensor, error) {
+func (m *Model) Forward(ctx ml.Context, opts input.Options) (ml.Tensor, error) {
 	var crossAttentionStates ml.Tensor
-	if opts.Multimodal != nil {
-		crossAttentionStates = opts.Multimodal[0].Multimodal.(ml.Tensor)
+	if len(opts.Multimodal) > 0 {
+		crossAttentionStates = opts.Multimodal[len(opts.Multimodal)-1].Multimodal.(ml.Tensor)
 	}
 
 	inputs, err := ctx.Input().FromIntSlice(opts.Inputs, len(opts.Inputs))

+ 7 - 6
runner/ollamarunner/cache.go

@@ -10,6 +10,7 @@ import (
 	"github.com/ollama/ollama/kvcache"
 	"github.com/ollama/ollama/ml"
 	"github.com/ollama/ollama/model"
+	"github.com/ollama/ollama/model/input"
 )
 
 type InputCache struct {
@@ -79,7 +80,7 @@ type InputCacheSlot struct {
 	Id int
 
 	// Inputs that are stored in the KV cache
-	Inputs []model.Input
+	Inputs []input.Input
 
 	// is this cache actively being processed as part of a sequence?
 	InUse bool
@@ -88,7 +89,7 @@ type InputCacheSlot struct {
 	lastUsed time.Time
 }
 
-func (c *InputCache) LoadCacheSlot(prompt []model.Input, cachePrompt bool) (*InputCacheSlot, []model.Input, error) {
+func (c *InputCache) LoadCacheSlot(prompt []input.Input, cachePrompt bool) (*InputCacheSlot, []input.Input, error) {
 	var slot *InputCacheSlot
 	var numPast int32
 	var err error
@@ -139,7 +140,7 @@ func (c *InputCache) LoadCacheSlot(prompt []model.Input, cachePrompt bool) (*Inp
 	return slot, prompt, nil
 }
 
-func (c *InputCache) findLongestCacheSlot(prompt []model.Input) (*InputCacheSlot, int32, error) {
+func (c *InputCache) findLongestCacheSlot(prompt []input.Input) (*InputCacheSlot, int32, error) {
 	longest := int32(-1)
 	var longestSlot *InputCacheSlot
 
@@ -162,7 +163,7 @@ func (c *InputCache) findLongestCacheSlot(prompt []model.Input) (*InputCacheSlot
 	return longestSlot, longest, nil
 }
 
-func (c *InputCache) findBestCacheSlot(prompt []model.Input) (*InputCacheSlot, int32, error) {
+func (c *InputCache) findBestCacheSlot(prompt []input.Input) (*InputCacheSlot, int32, error) {
 	oldest := time.Now()
 	var oldestSlot *InputCacheSlot
 
@@ -198,7 +199,7 @@ func (c *InputCache) findBestCacheSlot(prompt []model.Input) (*InputCacheSlot, i
 	if longest > 0 && longestSlot != oldestSlot {
 		slog.Debug("forking cache slot", "src", longestSlot.Id, "dst", oldestSlot.Id, "inputs", longest, "total",
 			len(longestSlot.Inputs))
-		oldestSlot.Inputs = make([]model.Input, longest)
+		oldestSlot.Inputs = make([]input.Input, longest)
 		copy(oldestSlot.Inputs, longestSlot.Inputs[:longest])
 		if c.cache != nil {
 			c.cache.CopyPrefix(longestSlot.Id, oldestSlot.Id, longest)
@@ -208,7 +209,7 @@ func (c *InputCache) findBestCacheSlot(prompt []model.Input) (*InputCacheSlot, i
 	return oldestSlot, longest, nil
 }
 
-func countCommonPrefix(a []model.Input, b []model.Input) int32 {
+func countCommonPrefix(a []input.Input, b []input.Input) int32 {
 	var count int32
 
 	for i := range a {

+ 36 - 36
runner/ollamarunner/cache_test.go

@@ -5,7 +5,7 @@ import (
 	"testing"
 	"time"
 
-	"github.com/ollama/ollama/model"
+	"github.com/ollama/ollama/model/input"
 )
 
 func TestCountCommon(t *testing.T) {
@@ -15,50 +15,50 @@ func TestCountCommon(t *testing.T) {
 
 	tests := []struct {
 		name     string
-		t1       []model.Input
-		t2       []model.Input
+		t1       []input.Input
+		t2       []input.Input
 		expected int32
 	}{
 		{
 			name:     "Equal",
-			t1:       []model.Input{{Token: 1}, {Token: 2}, {Token: 3}},
-			t2:       []model.Input{{Token: 1}, {Token: 2}, {Token: 3}},
+			t1:       []input.Input{{Token: 1}, {Token: 2}, {Token: 3}},
+			t2:       []input.Input{{Token: 1}, {Token: 2}, {Token: 3}},
 			expected: 3,
 		},
 		{
 			name:     "Prefix",
-			t1:       []model.Input{{Token: 1}},
-			t2:       []model.Input{{Token: 1}, {Token: 2}, {Token: 3}},
+			t1:       []input.Input{{Token: 1}},
+			t2:       []input.Input{{Token: 1}, {Token: 2}, {Token: 3}},
 			expected: 1,
 		},
 		{
 			name:     "Image Prefix",
-			t1:       []model.Input{{Multimodal: imgA, MultimodalHash: 1}},
-			t2:       []model.Input{{Multimodal: imgA, MultimodalHash: 1}, {Multimodal: imgB, MultimodalHash: 2}, {Multimodal: imgC, MultimodalHash: 3}},
+			t1:       []input.Input{{Multimodal: imgA, MultimodalHash: 1}},
+			t2:       []input.Input{{Multimodal: imgA, MultimodalHash: 1}, {Multimodal: imgB, MultimodalHash: 2}, {Multimodal: imgC, MultimodalHash: 3}},
 			expected: 1,
 		},
 		{
 			name:     "Mixed",
-			t1:       []model.Input{{Token: 1}, {Multimodal: imgA, MultimodalHash: 1}},
-			t2:       []model.Input{{Token: 1}, {Multimodal: imgA, MultimodalHash: 1}, {Token: 5}},
+			t1:       []input.Input{{Token: 1}, {Multimodal: imgA, MultimodalHash: 1}},
+			t2:       []input.Input{{Token: 1}, {Multimodal: imgA, MultimodalHash: 1}, {Token: 5}},
 			expected: 2,
 		},
 		{
 			name:     "Mixed, Same Length",
-			t1:       []model.Input{{Token: 1}, {Multimodal: imgA, MultimodalHash: 1}},
-			t2:       []model.Input{{Token: 1}, {Multimodal: imgB, MultimodalHash: 2}},
+			t1:       []input.Input{{Token: 1}, {Multimodal: imgA, MultimodalHash: 1}},
+			t2:       []input.Input{{Token: 1}, {Multimodal: imgB, MultimodalHash: 2}},
 			expected: 1,
 		},
 		{
 			name:     "Empty",
-			t1:       []model.Input{},
-			t2:       []model.Input{{Token: 1}, {Token: 2}, {Token: 3}},
+			t1:       []input.Input{},
+			t2:       []input.Input{{Token: 1}, {Token: 2}, {Token: 3}},
 			expected: 0,
 		},
 		{
 			name:     "Both Empty",
-			t1:       []model.Input{},
-			t2:       []model.Input{},
+			t1:       []input.Input{},
+			t2:       []input.Input{},
 			expected: 0,
 		},
 	}
@@ -82,7 +82,7 @@ func TestFindCacheSlot(t *testing.T) {
 	tests := []struct {
 		name    string
 		cache   InputCache
-		prompt  []model.Input
+		prompt  []input.Input
 		longest expected
 		best    expected
 	}{
@@ -91,18 +91,18 @@ func TestFindCacheSlot(t *testing.T) {
 			cache: InputCache{slots: []InputCacheSlot{
 				{
 					Id:       0,
-					Inputs:   []model.Input{},
+					Inputs:   []input.Input{},
 					InUse:    false,
 					lastUsed: time.Time{},
 				},
 				{
 					Id:       1,
-					Inputs:   []model.Input{},
+					Inputs:   []input.Input{},
 					InUse:    false,
 					lastUsed: time.Time{},
 				},
 			}},
-			prompt:  []model.Input{{Token: 1}},
+			prompt:  []input.Input{{Token: 1}},
 			longest: expected{result: 0, len: 0},
 			best:    expected{result: 0, len: 0},
 		},
@@ -111,18 +111,18 @@ func TestFindCacheSlot(t *testing.T) {
 			cache: InputCache{slots: []InputCacheSlot{
 				{
 					Id:       0,
-					Inputs:   []model.Input{{Token: 1}},
+					Inputs:   []input.Input{{Token: 1}},
 					InUse:    false,
 					lastUsed: time.Now().Add(-time.Second),
 				},
 				{
 					Id:       1,
-					Inputs:   []model.Input{{Token: 1}, {Token: 2}},
+					Inputs:   []input.Input{{Token: 1}, {Token: 2}},
 					InUse:    false,
 					lastUsed: time.Now().Add(-2 * time.Second),
 				},
 			}},
-			prompt:  []model.Input{{Token: 1}, {Token: 2}},
+			prompt:  []input.Input{{Token: 1}, {Token: 2}},
 			longest: expected{result: 1, len: 2},
 			best:    expected{result: 1, len: 2},
 		},
@@ -131,18 +131,18 @@ func TestFindCacheSlot(t *testing.T) {
 			cache: InputCache{slots: []InputCacheSlot{
 				{
 					Id:       0,
-					Inputs:   []model.Input{{Token: 1}, {Token: 2}},
+					Inputs:   []input.Input{{Token: 1}, {Token: 2}},
 					InUse:    false,
 					lastUsed: time.Now().Add(-time.Second),
 				},
 				{
 					Id:       1,
-					Inputs:   []model.Input{},
+					Inputs:   []input.Input{},
 					InUse:    false,
 					lastUsed: time.Time{},
 				},
 			}},
-			prompt:  []model.Input{{Token: 2}},
+			prompt:  []input.Input{{Token: 2}},
 			longest: expected{result: 0, len: 0},
 			best:    expected{result: 1, len: 0},
 		},
@@ -152,19 +152,19 @@ func TestFindCacheSlot(t *testing.T) {
 				slots: []InputCacheSlot{
 					{
 						Id:       0,
-						Inputs:   []model.Input{{Token: 1}, {Token: 2}},
+						Inputs:   []input.Input{{Token: 1}, {Token: 2}},
 						InUse:    false,
 						lastUsed: time.Now().Add(-time.Second),
 					},
 					{
 						Id:       1,
-						Inputs:   []model.Input{},
+						Inputs:   []input.Input{},
 						InUse:    false,
 						lastUsed: time.Time{},
 					},
 				},
 			},
-			prompt:  []model.Input{{Token: 1}},
+			prompt:  []input.Input{{Token: 1}},
 			longest: expected{result: 0, len: 1},
 			best:    expected{result: 1, len: 1},
 		},
@@ -173,18 +173,18 @@ func TestFindCacheSlot(t *testing.T) {
 			cache: InputCache{slots: []InputCacheSlot{
 				{
 					Id:       0,
-					Inputs:   []model.Input{{Token: 1}},
+					Inputs:   []input.Input{{Token: 1}},
 					InUse:    false,
 					lastUsed: time.Now().Add(-time.Second),
 				},
 				{
 					Id:       1,
-					Inputs:   []model.Input{{Token: 1}, {Token: 2}},
+					Inputs:   []input.Input{{Token: 1}, {Token: 2}},
 					InUse:    false,
 					lastUsed: time.Now().Add(-2 * time.Second),
 				},
 			}},
-			prompt:  []model.Input{{Token: 2}, {Token: 3}},
+			prompt:  []input.Input{{Token: 2}, {Token: 3}},
 			longest: expected{result: 0, len: 0},
 			best:    expected{result: 1, len: 0},
 		},
@@ -193,18 +193,18 @@ func TestFindCacheSlot(t *testing.T) {
 			cache: InputCache{slots: []InputCacheSlot{
 				{
 					Id:       0,
-					Inputs:   []model.Input{{Token: 1}, {Token: 2}},
+					Inputs:   []input.Input{{Token: 1}, {Token: 2}},
 					InUse:    true,
 					lastUsed: time.Now().Add(-time.Second),
 				},
 				{
 					Id:       1,
-					Inputs:   []model.Input{{Token: 1}},
+					Inputs:   []input.Input{{Token: 1}},
 					InUse:    false,
 					lastUsed: time.Now().Add(-2 * time.Second),
 				},
 			}},
-			prompt:  []model.Input{{Token: 1}, {Token: 2}},
+			prompt:  []input.Input{{Token: 1}, {Token: 2}},
 			longest: expected{result: 1, len: 1},
 			best:    expected{result: 1, len: 2},
 		},

+ 20 - 36
runner/ollamarunner/runner.go

@@ -26,6 +26,7 @@ import (
 	"github.com/ollama/ollama/api"
 	"github.com/ollama/ollama/ml"
 	"github.com/ollama/ollama/model"
+	"github.com/ollama/ollama/model/input"
 	"github.com/ollama/ollama/runner/common"
 	"github.com/ollama/ollama/sample"
 
@@ -41,10 +42,10 @@ type Sequence struct {
 	iBatch int
 
 	// prompt inputs left to evaluate
-	inputs []model.Input
+	inputs []input.Input
 
 	// inputs that have been added to a batch but not yet submitted to Forward
-	pendingInputs []model.Input
+	pendingInputs []input.Input
 
 	// tokens that have been generated but not returned yet (e.g. for stop sequences)
 	pendingResponses []string
@@ -144,8 +145,8 @@ func (s *Server) NewSequence(prompt string, images []ImageData, params NewSequen
 // inputs processes the prompt and images into a list of inputs
 // by splitting the prompt on [img-<n>] tags, tokenizing text and
 // decoding images
-func (s *Server) inputs(ctx ml.Context, prompt string, images []ImageData) ([]model.Input, error) {
-	var inputs []model.Input
+func (s *Server) inputs(ctx ml.Context, prompt string, images []ImageData) ([]input.Input, error) {
+	var inputs []input.Input
 	var parts []string
 	var matches [][]string
 
@@ -168,7 +169,7 @@ func (s *Server) inputs(ctx ml.Context, prompt string, images []ImageData) ([]mo
 		}
 
 		for _, t := range tokens {
-			inputs = append(inputs, model.Input{Token: t})
+			inputs = append(inputs, input.Input{Token: t})
 		}
 
 		// image - decode and store
@@ -196,7 +197,7 @@ func (s *Server) inputs(ctx ml.Context, prompt string, images []ImageData) ([]mo
 			_, _ = s.multimodalHash.Write(images[imageIndex].Data)
 			imageHash := s.multimodalHash.Sum64()
 
-			inputs = append(inputs, model.Input{Multimodal: imageEmbeddings, MultimodalHash: imageHash})
+			inputs = append(inputs, input.Input{Multimodal: imageEmbeddings, MultimodalHash: imageHash})
 			postTokenize = true
 		}
 	}
@@ -250,9 +251,6 @@ type Server struct {
 	// KV cache
 	cache *InputCache
 
-	// next sequence for prompt processing to avoid starvation
-	nextSeq int
-
 	// multimodalHash generates hashes for comparing equality
 	// of non-text data
 	multimodalHash maphash.Hash
@@ -329,29 +327,25 @@ func (s *Server) processBatch() error {
 	}
 	defer s.mu.Unlock()
 
-	var options model.Options
-
-	seqIdx := s.nextSeq - 1
-	for range s.seqs {
-		seqIdx = (seqIdx + 1) % len(s.seqs)
-		seq := s.seqs[seqIdx]
+	var options input.Options
 
+	for i, seq := range s.seqs {
 		if seq == nil {
 			continue
 		}
 
 		// if past the num predict limit
 		if seq.numPredict > 0 && seq.numPredicted >= seq.numPredict {
-			s.removeSequence(seqIdx, "limit")
+			s.removeSequence(i, "limit")
 			continue
 		}
 
 		if !s.cache.enabled {
 			seq.inputs = append(seq.cache.Inputs, seq.inputs...)
-			seq.cache.Inputs = []model.Input{}
+			seq.cache.Inputs = []input.Input{}
 		}
 
-		for i, input := range seq.inputs {
+		for j, inp := range seq.inputs {
 			if int32(len(seq.cache.Inputs)+len(seq.pendingInputs)+1) > s.cache.numCtx {
 				if len(seq.pendingInputs) == 0 {
 					err := s.cache.ShiftCacheSlot(seq.cache, seq.numKeep)
@@ -363,33 +357,23 @@ func (s *Server) processBatch() error {
 				}
 			}
 
-			if i >= s.batchSize {
-				break
-			}
-
-			// TODO(jessegross): This is a workaround for generating an attention mask and also providing a hint
-			// to the encoder cache.
-			//
-			// Break the batch when switching from text to images so that images are always at the beginning.
-			if input.Multimodal != nil && !(len(seq.pendingInputs) == 0 ||
-				(len(options.Multimodal) > 0 && options.Multimodal[len(options.Multimodal)-1].Index == len(options.Inputs)-1)) {
-				s.nextSeq = seqIdx
+			if j >= s.batchSize {
 				break
 			}
 
-			options.Inputs = append(options.Inputs, input.Token)
-			if input.Multimodal != nil {
-				options.Multimodal = append(options.Multimodal, model.MultimodalIndex{Index: len(options.Inputs) - 1, Multimodal: input.Multimodal})
+			options.Inputs = append(options.Inputs, inp.Token)
+			if inp.Multimodal != nil {
+				options.Multimodal = append(options.Multimodal, input.MultimodalIndex{Index: len(options.Inputs) - 1, Multimodal: inp.Multimodal})
 			}
 
 			options.Positions = append(options.Positions, int32(len(seq.cache.Inputs)+len(seq.pendingInputs)))
 			options.Sequences = append(options.Sequences, seq.cache.Id)
 
 			seq.iBatch = len(options.Outputs)
-			if i+1 == len(seq.inputs) {
+			if j+1 == len(seq.inputs) {
 				options.Outputs = append(options.Outputs, int32(len(options.Inputs)-1))
 			}
-			seq.pendingInputs = append(seq.pendingInputs, input)
+			seq.pendingInputs = append(seq.pendingInputs, inp)
 		}
 
 		seq.inputs = seq.inputs[len(seq.pendingInputs):]
@@ -417,7 +401,7 @@ func (s *Server) processBatch() error {
 		// After calling Forward, pending inputs are now in the cache
 		if len(seq.pendingInputs) > 0 {
 			seq.cache.Inputs = append(seq.cache.Inputs, seq.pendingInputs...)
-			seq.pendingInputs = []model.Input{}
+			seq.pendingInputs = []input.Input{}
 		}
 
 		// don't sample prompt processing
@@ -464,7 +448,7 @@ func (s *Server) processBatch() error {
 			return err
 		}
 
-		seq.inputs = []model.Input{{Token: token}}
+		seq.inputs = []input.Input{{Token: token}}
 
 		seq.pendingResponses = append(seq.pendingResponses, piece)
 		sequence := strings.Join(seq.pendingResponses, "")