Browse Source

new runner

Jesse Gross 4 months ago
parent
commit
0d22c0ec1a

+ 346 - 25
cache/cache.go

@@ -1,63 +1,384 @@
 package cache
 
 import (
+	"errors"
+	"fmt"
+	"log/slog"
+	"math"
+	"slices"
+
 	"github.com/ollama/ollama/ml"
 )
 
-type Options struct {
-	Position int
-}
+var ErrNotSupported = errors.New("model does not support operation")
 
 type Cache interface {
+	// used by model implementations
 	Sub(i int) Cache
-	Put(ctx ml.Context, key, value ml.Tensor, opts Options) (ml.Tensor, ml.Tensor)
+	Put(ctx ml.Context, key, value ml.Tensor) (ml.Tensor, ml.Tensor, ml.Tensor)
+
+	// cache management
+	Close()
+
+	StartForward(ctx ml.Context, seqs []int) error
+
+	CopyPrefix(srcSeq, dstSeq int, len int)
+	Remove(seq int, beginIndex, endIndex int) error
 }
 
-type Simple struct {
+type Causal struct {
 	DType    ml.DType
 	Capacity int
 
+	// current forward pass
+	curLayer     int
+	curPos       int
+	curBatchSize int
+	curMask      ml.Tensor
+	curCellRange cellRange
+
+	// metadata
+	cells      []cacheCell
+	seqNextPos map[int]int
+	cellRanges map[int]cellRange
+
+	// cache data storage
+	backend      ml.Backend
+	cacheCtx     ml.Context
 	keys, values []ml.Tensor
 }
 
-func (c *Simple) Sub(i int) Cache {
+type seqCell struct {
+	seq int
+	pos int
+}
+
+type cacheCell struct {
+	sequences []seqCell
+}
+
+type cellRange struct {
+	min int
+	max int
+}
+
+func (cell cacheCell) findSeq(seq int) *seqCell {
+	for i := range cell.sequences {
+		if cell.sequences[i].seq == seq {
+			return &cell.sequences[i]
+		}
+	}
+	return nil
+}
+
+func NewCausalCache(backend ml.Backend, capacity int, dtype ml.DType) Cache {
+	return &Causal{
+		Capacity:   capacity,
+		DType:      dtype,
+		cells:      make([]cacheCell, capacity),
+		seqNextPos: make(map[int]int),
+		cellRanges: make(map[int]cellRange),
+		backend:    backend,
+		// TODO(jessegross): This context is not sized appropriately
+		cacheCtx: backend.NewContext(),
+	}
+}
+
+func (c *Causal) Close() {
+	c.cacheCtx.Close()
+}
+
+var ErrKvCacheFull = errors.New("could not find a kv cache slot")
+
+func (c *Causal) StartForward(ctx ml.Context, seqs []int) error {
+	c.curBatchSize = len(seqs)
+
+	var err error
+	c.curPos, err = c.findStartPos()
+	if errors.Is(err, ErrKvCacheFull) {
+		c.defrag()
+		c.curPos, err = c.findStartPos()
+	}
+	if err != nil {
+		return err
+	}
+
+	// TODO(jessegross): There should be a better way to do this
+	origSeq := make(map[int]int)
+	for k, v := range c.seqNextPos {
+		origSeq[k] = v
+	}
+
+	c.curCellRange = newRange()
+	for i, seq := range seqs {
+		c.cells[c.curPos+i] = cacheCell{sequences: []seqCell{{seq: seq, pos: c.seqNextPos[seq]}}}
+		c.seqNextPos[seq]++
+
+		ranges := c.cellRanges[seq]
+		if c.curPos+i > ranges.max {
+			ranges.max = c.curPos + i
+		}
+		if ranges.max > c.curCellRange.max {
+			c.curCellRange.max = ranges.max
+		}
+
+		if c.curPos+i < ranges.min {
+			ranges.min = c.curPos + i
+		}
+		if ranges.min < c.curCellRange.min {
+			c.curCellRange.min = ranges.min
+		}
+		c.cellRanges[seq] = ranges
+	}
+
+	c.curMask, err = c.buildMask(ctx, origSeq, seqs)
+
+	return err
+}
+
+func newRange() cellRange {
+	return cellRange{
+		min: math.MaxInt,
+		max: 0,
+	}
+}
+
+func (c *Causal) findStartPos() (int, error) {
+	var start, count int
+	for i := range c.cells {
+		if len(c.cells[i].sequences) == 0 {
+			count++
+			if count >= c.curBatchSize {
+				return start, nil
+			}
+		} else {
+			start = i + 1
+			count = 0
+		}
+	}
+
+	return 0, fmt.Errorf("%w (length: %v)", ErrKvCacheFull, c.Capacity)
+}
+
+func (c *Causal) buildMask(ctx ml.Context, origSeq map[int]int, seqs []int) (ml.Tensor, error) {
+	// TODO(jessegross): This makes a number of simplifications such as no padding
+	len := c.curCellRange.max - c.curCellRange.min
+	mask := make([]float32, c.curBatchSize*len)
+
+	for i := range c.curBatchSize {
+		for j := c.curCellRange.min; j < c.curCellRange.max; j++ {
+			cellSeq := c.cells[j].findSeq(seqs[i])
+			if cellSeq == nil || cellSeq.pos > origSeq[seqs[i]]+i {
+				mask[i*len+(j-c.curCellRange.min)] = float32(math.Inf(-1))
+			}
+		}
+	}
+
+	return ctx.FromFloatSlice(mask, len, c.curBatchSize)
+}
+
+func moveCell(ctx ml.Context, objs []ml.Tensor, src, dst, len int) {
+	for _, obj := range objs {
+		srcView := obj.View(ctx, int(obj.Stride(2))*src, int(obj.Dim(0)*obj.Dim(1))*len)
+		dstView := obj.View(ctx, int(obj.Stride(2))*dst, int(obj.Dim(0)*obj.Dim(1))*len)
+
+		ctx.Forward(srcView.Copy(ctx, dstView))
+	}
+}
+
+func (c *Causal) defrag() {
+	slog.Debug("defragmenting kv cache")
+
+	// Defrag strategy:
+	// - Search for empty holes at the beginning of the cache,
+	//   filling them with active data starting at the end
+	// - If there are contiguous elements that need to be moved,
+	//   combine them into a single operation by holding new moves
+	//   until we see the next one is non-contiguous
+	// - Fill up the context with the maximum number of operations it
+	//   can hold then compute that and continue with a new context
+
+	// TODO(jessegross):
+	// - Need to size the context and compute maxMoves correctly
+	// - Just compacts, doesn't optimize placement
+	maxMoves := 8192 / (6 * len(c.keys))
+
+	ctx := c.backend.NewContext()
+	moves := 0
+
+	var pendingSrc, pendingDst, pendingLen int
+
+	for dst := range c.cells {
+		if len(c.cells[dst].sequences) == 0 {
+			for src := len(c.cells) - 1; src > dst; src-- {
+				if len(c.cells[src].sequences) != 0 {
+					c.cells[dst] = c.cells[src]
+					c.cells[src] = cacheCell{}
+
+					if pendingLen > 0 {
+						if src == pendingSrc-pendingLen && dst == pendingDst+pendingLen {
+							pendingSrc = src
+							pendingLen++
+							break
+						} else {
+							moveCell(ctx, c.keys, pendingSrc, pendingDst, pendingLen)
+							moveCell(ctx, c.values, pendingSrc, pendingDst, pendingLen)
+							moves++
+						}
+					}
+
+					pendingSrc = src
+					pendingDst = dst
+					pendingLen = 1
+
+					break
+				}
+			}
+		}
+
+		if moves >= maxMoves {
+			ctx.Compute(nil)
+			ctx.Close()
+			ctx = c.backend.NewContext()
+
+			moves = 0
+		}
+	}
+
+	if pendingLen > 0 {
+		moveCell(ctx, c.keys, pendingSrc, pendingDst, pendingLen)
+		moveCell(ctx, c.values, pendingSrc, pendingDst, pendingLen)
+		moves++
+	}
+
+	if moves > 0 {
+		ctx.Compute(nil)
+	}
+	ctx.Close()
+
+	for seq := range c.cellRanges {
+		seqRange := newRange()
+
+		for i, cell := range c.cells {
+			if cell.findSeq(seq) != nil {
+				if i < seqRange.min {
+					seqRange.min = i
+				}
+				if i > seqRange.max {
+					seqRange.max = i
+				}
+			}
+		}
+
+		c.cellRanges[seq] = seqRange
+	}
+}
+
+func (c *Causal) Sub(i int) Cache {
 	if i >= len(c.keys) {
 		c.keys = append(c.keys, make([]ml.Tensor, i-len(c.keys)+1)...)
 		c.values = append(c.values, make([]ml.Tensor, i-len(c.values)+1)...)
 	}
 
-	return &Simple{
-		keys:     c.keys[i : i+1],
-		values:   c.values[i : i+1],
-		Capacity: c.Capacity,
-		DType:    c.DType,
-	}
+	c.curLayer = i
+
+	return c
 }
 
-func (c *Simple) Put(ctx ml.Context, key, value ml.Tensor, opts Options) (ml.Tensor, ml.Tensor) {
-	if c.keys[0] == nil || c.values[0] == nil {
-		c.keys[0] = ctx.Zeros(c.DType, int(key.Dim(0)*key.Dim(1))*c.Capacity)
-		c.values[0] = ctx.Zeros(c.DType, int(value.Dim(0)*value.Dim(1))*c.Capacity)
+func (c *Causal) Put(ctx ml.Context, key, value ml.Tensor) (ml.Tensor, ml.Tensor, ml.Tensor) {
+	if c.curBatchSize != int(key.Dim(2)) {
+		panic(fmt.Errorf("inconsistent batch sizes (layer: %v, batch size: %v layer batch size: %v)", c.curLayer, c.curBatchSize, int(key.Dim(2))))
+	}
+
+	if c.keys[c.curLayer] == nil || c.values[c.curLayer] == nil {
+		c.keys[c.curLayer] = c.cacheCtx.Zeros(c.DType, key.Dim(0), key.Dim(1), int64(c.Capacity))
+		c.values[c.curLayer] = c.cacheCtx.Zeros(c.DType, value.Dim(0), value.Dim(1), int64(c.Capacity))
 	}
 
-	ctx.Forward(key.Copy(ctx, c.keys[0].View(ctx, int(key.Stride(2))*opts.Position, int(key.Dim(0)*key.Dim(1)*key.Dim(2)))))
-	ctx.Forward(value.Copy(ctx, c.values[0].View(ctx, int(value.Stride(2))*opts.Position, int(value.Dim(0)*value.Dim(1)*value.Dim(2)))))
+	ctx.Forward(key.Copy(ctx, c.keys[c.curLayer].View(ctx, int(key.Stride(2))*c.curPos, int(key.Dim(0)*key.Dim(1)*key.Dim(2)))))
+	ctx.Forward(value.Copy(ctx, c.values[c.curLayer].View(ctx, int(value.Stride(2))*c.curPos, int(value.Dim(0)*value.Dim(1)*value.Dim(2)))))
 
-	n := min(c.Capacity, int(key.Dim(2))+opts.Position)
+	len := c.curCellRange.max - c.curCellRange.min
 
-	key = c.keys[0].View(ctx, 0,
+	key = c.keys[c.curLayer].View(ctx, int(key.Stride(2))*c.curCellRange.min,
 		int(key.Dim(0)), int(key.Stride(1)),
 		int(key.Dim(1)), int(key.Stride(2)),
-		n,
+		len,
 	)
 
-	value = c.values[0].View(ctx, 0,
+	value = c.values[c.curLayer].View(ctx, int(key.Stride(2))*c.curCellRange.min,
 		int(value.Dim(0)), int(value.Stride(1)),
 		int(value.Dim(1)), int(value.Stride(2)),
-		n,
+		len,
 	)
 
-	// TODO shift context if necessary
+	return key, value, c.curMask
+}
+
+func (c *Causal) CopyPrefix(srcSeq, dstSeq int, len int) {
+	seqRange := newRange()
+
+	for i := range c.cells {
+		srcCellSeq := c.cells[i].findSeq(srcSeq)
+		dstCellSeq := c.cells[i].findSeq(dstSeq)
+
+		if dstCellSeq != nil {
+			c.cells[i].sequences = slices.DeleteFunc(c.cells[i].sequences, func(s seqCell) bool { return s.seq == dstSeq })
+		}
+
+		if srcCellSeq != nil && srcCellSeq.pos < len {
+			c.cells[i].sequences = append(c.cells[i].sequences, seqCell{seq: dstSeq, pos: srcCellSeq.pos})
+			if i < seqRange.min {
+				seqRange.min = i
+			}
+			if i > seqRange.max {
+				seqRange.max = i
+			}
+		}
+	}
+
+	c.cellRanges[dstSeq] = seqRange
+	c.seqNextPos[dstSeq] = len
+}
+
+func (c *Causal) shift(seq int, beginIndex, endIndex, offset int) error {
+	panic("Shift not yet implemented")
+}
+
+func (c *Causal) Remove(seq int, beginIndex, endIndex int) error {
+	endIndex = min(endIndex, c.seqNextPos[seq])
+	offset := beginIndex - endIndex
+
+	seqRange := newRange()
+
+	for i := range c.cells {
+		cellSeq := c.cells[i].findSeq(seq)
+		if cellSeq != nil {
+			if cellSeq.pos >= beginIndex && cellSeq.pos < endIndex {
+				c.cells[i].sequences = slices.DeleteFunc(c.cells[i].sequences, func(s seqCell) bool { return s.seq == seq })
+			} else {
+				if cellSeq.pos >= endIndex {
+					cellSeq.pos += offset
+				}
+				if i < seqRange.min {
+					seqRange.min = i
+				}
+				if i > seqRange.max {
+					seqRange.max = i
+				}
+			}
+		}
+	}
+
+	if endIndex != c.seqNextPos[seq] {
+		err := c.shift(seq, endIndex, c.seqNextPos[seq], offset)
+		if err != nil {
+			return err
+		}
+	}
+
+	c.cellRanges[seq] = seqRange
+	c.seqNextPos[seq] += offset
 
-	return key, value
+	return nil
 }

+ 48 - 0
cache/tensor.go

@@ -0,0 +1,48 @@
+package cache
+
+import (
+	"github.com/ollama/ollama/ml"
+)
+
+type TensorCache struct {
+	curLayer int
+
+	cacheCtx     ml.Context
+	keys, values []ml.Tensor
+}
+
+func NewTensorCache(backend ml.Backend) *TensorCache {
+	return &TensorCache{
+		// TODO(jessegross): This context is not sized appropriately
+		cacheCtx: backend.NewContext(),
+	}
+}
+
+func (c *TensorCache) Close() {
+	c.cacheCtx.Close()
+}
+
+func (c *TensorCache) Sub(i int) *TensorCache {
+	if i >= len(c.keys) {
+		c.keys = append(c.keys, make([]ml.Tensor, i-len(c.keys)+1)...)
+		c.values = append(c.values, make([]ml.Tensor, i-len(c.values)+1)...)
+	}
+
+	c.curLayer = i
+
+	return c
+}
+
+func (c *TensorCache) Get(ctx ml.Context) (ml.Tensor, ml.Tensor, ml.Tensor) {
+	return c.keys[c.curLayer], c.values[c.curLayer], nil
+}
+
+func (c *TensorCache) Put(ctx ml.Context, key, value ml.Tensor) {
+	if c.keys[c.curLayer] == nil || c.values[c.curLayer] == nil {
+		c.keys[c.curLayer] = c.cacheCtx.Zeros(key.DType(), key.Shape()...)
+		c.values[c.curLayer] = c.cacheCtx.Zeros(value.DType(), value.Shape()...)
+	}
+
+	ctx.Forward(key.Copy(ctx, c.keys[c.curLayer]))
+	ctx.Forward(value.Copy(ctx, c.values[c.curLayer]))
+}

+ 2 - 2
cmd/cmd.go

@@ -35,9 +35,9 @@ import (
 	"github.com/ollama/ollama/envconfig"
 	"github.com/ollama/ollama/format"
 	"github.com/ollama/ollama/llama"
-	"github.com/ollama/ollama/llama/runner"
 	"github.com/ollama/ollama/parser"
 	"github.com/ollama/ollama/progress"
+	"github.com/ollama/ollama/runner"
 	"github.com/ollama/ollama/server"
 	"github.com/ollama/ollama/types/model"
 	"github.com/ollama/ollama/version"
@@ -338,7 +338,7 @@ func RunHandler(cmd *cobra.Command, args []string) error {
 		return err
 	}
 
-	opts.MultiModal = len(info.ProjectorInfo) != 0
+	opts.MultiModal = true //len(info.ProjectorInfo) != 0
 	opts.ParentModel = info.Details.ParentModel
 
 	if interactive {

+ 1 - 1
cmd/runner/main.go

@@ -4,7 +4,7 @@ import (
 	"fmt"
 	"os"
 
-	"github.com/ollama/ollama/llama/runner"
+	"github.com/ollama/ollama/runner"
 )
 
 func main() {

+ 3 - 0
envconfig/config.go

@@ -165,6 +165,8 @@ var (
 	IntelGPU = Bool("OLLAMA_INTEL_GPU")
 	// MultiUserCache optimizes prompt caching for multi-user scenarios
 	MultiUserCache = Bool("OLLAMA_MULTIUSER_CACHE")
+	// Enable the new Ollama engine
+	NewRunners = Bool("OLLAMA_NEW_RUNNERS")
 )
 
 func String(s string) func() string {
@@ -250,6 +252,7 @@ func AsMap() map[string]EnvVar {
 		"OLLAMA_ORIGINS":           {"OLLAMA_ORIGINS", Origins(), "A comma separated list of allowed origins"},
 		"OLLAMA_SCHED_SPREAD":      {"OLLAMA_SCHED_SPREAD", SchedSpread(), "Always schedule model across all GPUs"},
 		"OLLAMA_MULTIUSER_CACHE":   {"OLLAMA_MULTIUSER_CACHE", MultiUserCache(), "Optimize prompt caching for multi-user scenarios"},
+		"OLLAMA_NEW_RUNNERS":       {"OLLAMA_NEW_RUNNERS", NewRunners(), "Enable the new Ollama engine"},
 
 		// Informational
 		"HTTP_PROXY":  {"HTTP_PROXY", String("HTTP_PROXY")(), "HTTP proxy"},

+ 3 - 0
llm/server.go

@@ -252,6 +252,9 @@ func NewLlamaServer(gpus discover.GpuInfoList, model string, f *ggml.GGML, adapt
 		port = rand.Intn(65535-49152) + 49152 // get a random port in the ephemeral range
 	}
 	finalParams := []string{"runner"}
+	if envconfig.NewRunners() {
+		finalParams = append(finalParams, "--new-runner")
+	}
 	finalParams = append(finalParams, params...)
 	finalParams = append(finalParams, "--port", strconv.Itoa(port))
 

+ 1 - 1
ml/backend.go

@@ -43,7 +43,7 @@ func NewBackend(f *os.File) (Backend, error) {
 }
 
 type Context interface {
-	Zeros(dtype DType, shape ...int) Tensor
+	Zeros(dtype DType, shape ...int64) Tensor
 	FromFloatSlice(s []float32, shape ...int) (Tensor, error)
 	FromIntSlice(s []int32, shape ...int) (Tensor, error)
 

+ 17 - 9
ml/backend/ggml/ggml.go

@@ -23,7 +23,7 @@ import (
 	"github.com/ollama/ollama/ml"
 	"golang.org/x/sync/errgroup"
 
-	"github.com/ollama/ollama/ml/backend/ggml/ggml/src"
+	ggml "github.com/ollama/ollama/ml/backend/ggml/ggml/src"
 )
 
 type device struct {
@@ -198,10 +198,9 @@ func (b *Backend) Get(name string) ml.Tensor {
 
 func (b *Backend) NewContext() ml.Context {
 	nodes := max(8192, len(b.meta.Tensors().Items())*5)
-	bts := make([]byte, C.size_t(nodes)*C.ggml_tensor_overhead()+C.ggml_graph_overhead_custom(C.size_t(nodes), false))
 	c := C.ggml_init(C.struct_ggml_init_params{
-		mem_buffer: unsafe.Pointer(&bts[0]),
-		mem_size:   C.size_t(len(bts)),
+		mem_buffer: nil,
+		mem_size:   C.size_t(nodes)*C.ggml_tensor_overhead() + C.ggml_graph_overhead_custom(C.size_t(nodes), false),
 		no_alloc:   true,
 	})
 
@@ -244,17 +243,19 @@ func (c *Context) Forward(t ml.Tensor) {
 }
 
 func (c *Context) Compute(t ml.Tensor) ml.Tensor {
-	c.Forward(t)
 	C.ggml_backend_sched_graph_compute_async(c.sched, c.graph)
 
-	backend := C.ggml_backend_sched_get_tensor_backend(c.sched, t.(*Tensor).t)
+	if t != nil && C.ggml_nbytes(t.(*Tensor).t) != 0 {
+		backend := C.ggml_backend_sched_get_tensor_backend(c.sched, t.(*Tensor).t)
+
+		t.(*Tensor).data = make([]byte, C.ggml_nbytes(t.(*Tensor).t))
+		C.ggml_backend_tensor_get_async(backend, t.(*Tensor).t, unsafe.Pointer(&t.(*Tensor).data[0]), 0, C.ggml_nbytes(t.(*Tensor).t))
+	}
 
-	t.(*Tensor).data = make([]byte, C.ggml_nbytes(t.(*Tensor).t))
-	C.ggml_backend_tensor_get_async(backend, t.(*Tensor).t, unsafe.Pointer(&t.(*Tensor).data[0]), 0, C.ggml_nbytes(t.(*Tensor).t))
 	return t
 }
 
-func (c Context) Zeros(dtype ml.DType, shape ...int) ml.Tensor {
+func (c Context) Zeros(dtype ml.DType, shape ...int64) ml.Tensor {
 	if len(shape) < 1 || len(shape) > 4 {
 		panic("unsupported number of dimensions")
 	}
@@ -283,6 +284,13 @@ func (c Context) Zeros(dtype ml.DType, shape ...int) ml.Tensor {
 
 func fromSlice[S ~[]E, E float32 | int32](ctx Context, s S, shape []int, dtype uint32) (ml.Tensor, error) {
 	n := len(s)
+
+	if n == 0 {
+		shape := 0
+		t := C.ggml_new_tensor(ctx.ctx, dtype, 1, (*C.int64_t)(unsafe.Pointer(&shape)))
+		return &Tensor{t: t}, nil
+	}
+
 	for _, v := range shape {
 		n /= v
 	}

+ 0 - 160
model/cmd/main.go

@@ -1,160 +0,0 @@
-package main
-
-import (
-	"errors"
-	"flag"
-	"fmt"
-	"image"
-	"io"
-	"log/slog"
-	"os"
-	"path/filepath"
-	"strings"
-
-	"github.com/ollama/ollama/cache"
-	"github.com/ollama/ollama/ml"
-	"github.com/ollama/ollama/model"
-	_ "github.com/ollama/ollama/model/llama"
-	_ "github.com/ollama/ollama/model/mllama"
-	"github.com/ollama/ollama/sample"
-)
-
-var args struct {
-	n     int
-	debug bool
-	image string
-	cache bool
-}
-
-func temp() error {
-	flag.IntVar(&args.n, "n", 10, "number of samples")
-	flag.BoolVar(&args.debug, "debug", false, "enable debug logging")
-	flag.StringVar(&args.image, "image", "", "path to image file")
-	flag.BoolVar(&args.cache, "cache", false, "enable KV cache")
-
-	flag.Parse()
-
-	var prompt string
-	if n := len(flag.Args()); n == 1 {
-		bts, err := io.ReadAll(os.Stdin)
-		if err != nil {
-			return err
-		}
-
-		prompt = string(bts)
-	} else if n > 1 {
-		prompt = strings.Join(flag.Args()[1:], " ")
-	} else {
-		return fmt.Errorf("usage: %s path/to/file <prompt\n", filepath.Base(os.Args[0]))
-	}
-
-	level := slog.LevelInfo
-	if args.debug {
-		level = slog.LevelDebug
-	}
-
-	slog.SetDefault(slog.New(slog.NewTextHandler(os.Stderr, &slog.HandlerOptions{
-		Level:     level,
-		AddSource: true,
-		ReplaceAttr: func(_ []string, attr slog.Attr) slog.Attr {
-			if attr.Key == slog.SourceKey {
-				source := attr.Value.Any().(*slog.Source)
-				source.File = filepath.Base(source.File)
-			}
-
-			return attr
-		},
-	})))
-
-	m, err := model.New(flag.Arg(0))
-	if err != nil {
-		return err
-	}
-
-	inputIDs, err := m.(model.TextProcessor).Encode(prompt)
-	if err != nil {
-		return err
-	}
-
-	var opts []model.OptionsFunc
-	if args.cache {
-		opts = append(opts, model.WithCache(&cache.Simple{
-			Capacity: 2048,
-			DType:    ml.DTypeF32,
-		}))
-	}
-
-	if args.image != "" {
-		if err := func() error {
-			f, err := os.Open(args.image)
-			if err != nil {
-				return err
-			}
-			defer f.Close()
-
-			img, _, err := image.Decode(f)
-			if err != nil {
-				return err
-			}
-
-			opts = append(opts, model.WithImage(img))
-			return nil
-		}(); err != nil {
-			return err
-		}
-	}
-
-	var offset int
-	for range args.n {
-		logit, err := model.Forward(m, append(opts, model.WithInputIDs(inputIDs), model.WithOffset(offset))...)
-		if err != nil {
-			return err
-		}
-
-		f32s := logit.Floats()
-		f64s := make([]float64, len(f32s))
-		for i, f32 := range f32s {
-			f64s[i] = float64(f32)
-		}
-
-		// do sampling
-		f64s, err = sample.Sample(f64s, sample.Greedy())
-		if err != nil {
-			return err
-		}
-
-		var outputIDs []int32
-		for _, f64 := range f64s {
-			if !m.(model.TextProcessor).Is(uint32(f64), model.SpecialEOS) {
-				outputIDs = append(outputIDs, int32(f64))
-			}
-		}
-
-		if len(outputIDs) == 0 {
-			break
-		}
-
-		s, err := m.(model.TextProcessor).Decode(outputIDs)
-		if errors.Is(err, io.EOF) {
-			break
-		} else if err != nil {
-			return err
-		}
-
-		fmt.Print(s)
-
-		inputIDs = append(inputIDs, outputIDs...)
-		if args.cache {
-			offset = len(inputIDs) - 1
-		}
-	}
-
-	return nil
-}
-
-func main() {
-	if err := temp(); err != nil {
-		fmt.Println("err", err)
-		os.Exit(1)
-	}
-}

+ 6 - 4
model/llama/model.go

@@ -3,6 +3,7 @@ package llama
 import (
 	"math"
 
+	"github.com/ollama/ollama/cache"
 	"github.com/ollama/ollama/ml"
 	"github.com/ollama/ollama/ml/nn"
 	"github.com/ollama/ollama/model"
@@ -59,7 +60,7 @@ type SelfAttention struct {
 	Output *nn.Linear `gguf:"attn_output"`
 }
 
-func (sa *SelfAttention) Forward(ctx ml.Context, hiddenState, positionIDs ml.Tensor, cache model.Cache, opts *Options) ml.Tensor {
+func (sa *SelfAttention) Forward(ctx ml.Context, hiddenState, positionIDs ml.Tensor, cache cache.Cache, opts *Options) ml.Tensor {
 	batchSize := hiddenState.Dim(1)
 	headDim := opts.hiddenSize / opts.numHeads
 
@@ -74,7 +75,7 @@ func (sa *SelfAttention) Forward(ctx ml.Context, hiddenState, positionIDs ml.Ten
 	v := sa.Value.Forward(ctx, hiddenState)
 	v = v.Reshape(ctx, headDim, opts.numKVHeads, batchSize)
 
-	k, v = cache.Put(ctx, k, v, cache.Options)
+	k, v, mask := cache.Put(ctx, k, v)
 
 	q = q.Permute(ctx, 0, 2, 1, 3).Contiguous(ctx)
 	k = k.Permute(ctx, 0, 2, 1, 3).Contiguous(ctx)
@@ -82,6 +83,7 @@ func (sa *SelfAttention) Forward(ctx ml.Context, hiddenState, positionIDs ml.Ten
 
 	kq := k.Mulmat(ctx, q)
 	kq = kq.Scale(ctx, 1.0/math.Sqrt(float64(headDim)))
+	kq = kq.Add(ctx, mask)
 	kq = kq.Softmax(ctx)
 
 	kqv := v.Mulmat(ctx, kq)
@@ -109,7 +111,7 @@ type Layer struct {
 	MLP           *MLP
 }
 
-func (l *Layer) Forward(ctx ml.Context, hiddenState, positionIDs ml.Tensor, cache model.Cache, opts *Options) ml.Tensor {
+func (l *Layer) Forward(ctx ml.Context, hiddenState, positionIDs ml.Tensor, cache cache.Cache, opts *Options) ml.Tensor {
 	residual := hiddenState
 
 	hiddenState = l.AttentionNorm.Forward(ctx, hiddenState, opts.eps)
@@ -142,7 +144,7 @@ func (m *Model) Forward(ctx ml.Context, opts model.Options) (ml.Tensor, error) {
 	hiddenState = m.OutputNorm.Forward(ctx, hiddenState, m.eps)
 	hiddenState = m.Output.Forward(ctx, hiddenState)
 
-	outputs, err := ctx.FromIntSlice([]int32{int32(len(opts.Positions())) - 1}, 1)
+	outputs, err := ctx.FromIntSlice(opts.Outputs(), len(opts.Outputs()))
 	if err != nil {
 		return nil, err
 	}

+ 12 - 2
model/mllama/model.go

@@ -1,6 +1,9 @@
 package mllama
 
 import (
+	"sync"
+
+	"github.com/ollama/ollama/cache"
 	"github.com/ollama/ollama/ml"
 	"github.com/ollama/ollama/ml/nn"
 	"github.com/ollama/ollama/model"
@@ -16,6 +19,9 @@ type Model struct {
 
 	ImageProcessor
 	TextProcessor
+
+	start  sync.Once
+	tCache *cache.TensorCache
 }
 
 func New(c ml.Config) (model.Model, error) {
@@ -28,6 +34,10 @@ func New(c ml.Config) (model.Model, error) {
 }
 
 func (m *Model) Forward(ctx ml.Context, opts model.Options) (ml.Tensor, error) {
+	m.start.Do(func() {
+		m.tCache = cache.NewTensorCache(m.Backend())
+	})
+
 	var crossAttentionStates ml.Tensor
 	if opts.Images != nil {
 		f32s, aspectRatioID, err := m.ImageProcessor.ProcessImage(opts.Images[0])
@@ -75,9 +85,9 @@ func (m *Model) Forward(ctx ml.Context, opts model.Options) (ml.Tensor, error) {
 	}
 
 	// TODO: attention mask, cross attention mask
-	hiddenState := m.TextModel.Forward(ctx, inputs, positions, nil, crossAttentionStates, nil, opts.Cache)
+	hiddenState := m.TextModel.Forward(ctx, inputs, positions, nil, crossAttentionStates, nil, opts.Cache, m.tCache)
 
-	outputs, err := ctx.FromIntSlice([]int32{int32(len(opts.Positions())) - 1}, 1)
+	outputs, err := ctx.FromIntSlice(opts.Outputs(), len(opts.Outputs()))
 	if err != nil {
 		return nil, err
 	}

+ 40 - 25
model/mllama/model_text.go

@@ -4,9 +4,9 @@ import (
 	"math"
 	"slices"
 
+	"github.com/ollama/ollama/cache"
 	"github.com/ollama/ollama/ml"
 	"github.com/ollama/ollama/ml/nn"
-	"github.com/ollama/ollama/model"
 )
 
 type TextSelfAttention struct {
@@ -16,7 +16,7 @@ type TextSelfAttention struct {
 	Output *nn.Linear `gguf:"attn_output"`
 }
 
-func (sa *TextSelfAttention) Forward(ctx ml.Context, hiddenState, positions, mask ml.Tensor, cache model.Cache, opts *TextModelOptions) ml.Tensor {
+func (sa *TextSelfAttention) Forward(ctx ml.Context, hiddenState, positions, _ ml.Tensor, cache cache.Cache, opts *TextModelOptions) ml.Tensor {
 	batchSize := hiddenState.Dim(1)
 	headDim := opts.hiddenSize / opts.numHeads
 
@@ -31,7 +31,7 @@ func (sa *TextSelfAttention) Forward(ctx ml.Context, hiddenState, positions, mas
 	value := sa.Value.Forward(ctx, hiddenState)
 	value = value.Reshape(ctx, headDim, opts.numKVHeads, batchSize)
 
-	key, value = cache.Put(ctx, key, value, cache.Options)
+	key, value, mask := cache.Put(ctx, key, value)
 
 	query = query.Permute(ctx, 0, 2, 1, 3).Contiguous(ctx)
 	key = key.Permute(ctx, 0, 2, 1, 3).Contiguous(ctx)
@@ -39,11 +39,7 @@ func (sa *TextSelfAttention) Forward(ctx ml.Context, hiddenState, positions, mas
 
 	scores := key.Mulmat(ctx, query)
 	scores = scores.Scale(ctx, 1.0/math.Sqrt(float64(headDim)))
-
-	if mask != nil {
-		scores = scores.Add(ctx, mask)
-	}
-
+	scores = scores.Add(ctx, mask)
 	scores = scores.Softmax(ctx)
 
 	attention := value.Mulmat(ctx, scores)
@@ -72,7 +68,7 @@ type TextSelfAttentionDecoderLayer struct {
 	MLP     *TextMLP
 }
 
-func (d *TextSelfAttentionDecoderLayer) Forward(ctx ml.Context, hiddenState, positions, mask, _, _ ml.Tensor, cache model.Cache, opts *TextModelOptions) ml.Tensor {
+func (d *TextSelfAttentionDecoderLayer) Forward(ctx ml.Context, hiddenState, positions, mask, _, _ ml.Tensor, cache cache.Cache, _ *cache.TensorCache, opts *TextModelOptions) ml.Tensor {
 	residual := hiddenState
 
 	hiddenState = d.AttentionNorm.Forward(ctx, hiddenState, opts.eps)
@@ -85,6 +81,10 @@ func (d *TextSelfAttentionDecoderLayer) Forward(ctx ml.Context, hiddenState, pos
 	return hiddenState.Add(ctx, residual)
 }
 
+func (d *TextSelfAttentionDecoderLayer) Run() bool {
+	return true
+}
+
 type TextCrossAttention struct {
 	QueryNorm *nn.RMSNorm `gguf:"cross_attn_q_norm"`
 	Query     *nn.Linear  `gguf:"cross_attn_q_proj"`
@@ -94,23 +94,29 @@ type TextCrossAttention struct {
 	Output    *nn.Linear  `gguf:"cross_attn_o_proj"`
 }
 
-func (ca *TextCrossAttention) Forward(ctx ml.Context, hiddenState, crossAttentionStates ml.Tensor, cache model.Cache, opts *TextModelOptions) ml.Tensor {
+func (ca *TextCrossAttention) Forward(ctx ml.Context, hiddenState, crossAttentionStates ml.Tensor, _ cache.Cache, tCache *cache.TensorCache, opts *TextModelOptions) ml.Tensor {
 	batchSize := hiddenState.Dim(1)
 	headDim := opts.hiddenSize / opts.numHeads
-	numVisionTokens, numTiles := crossAttentionStates.Dim(1), crossAttentionStates.Dim(2)
 
 	query := ca.Query.Forward(ctx, hiddenState)
 	query = query.Reshape(ctx, headDim, opts.numHeads, batchSize)
 	query = ca.QueryNorm.Forward(ctx, query, opts.eps)
 
-	key := ca.Key.Forward(ctx, crossAttentionStates)
-	key = key.Reshape(ctx, headDim, opts.numKVHeads, numVisionTokens*numTiles)
-	key = ca.KeyNorm.Forward(ctx, key, opts.eps)
+	var key, value ml.Tensor
+	if crossAttentionStates != nil {
+		numVisionTokens, numTiles := crossAttentionStates.Dim(1), crossAttentionStates.Dim(2)
 
-	value := ca.Value.Forward(ctx, crossAttentionStates)
-	value = value.Reshape(ctx, headDim, opts.numKVHeads, numVisionTokens*numTiles)
+		key = ca.Key.Forward(ctx, crossAttentionStates)
+		key = key.Reshape(ctx, headDim, opts.numKVHeads, numVisionTokens*numTiles)
+		key = ca.KeyNorm.Forward(ctx, key, opts.eps)
 
-	// TODO cache key, value
+		value = ca.Value.Forward(ctx, crossAttentionStates)
+		value = value.Reshape(ctx, headDim, opts.numKVHeads, numVisionTokens*numTiles)
+
+		tCache.Put(ctx, key, value)
+	} else {
+		key, value, _ = tCache.Get(ctx)
+	}
 
 	query = query.Permute(ctx, 0, 2, 1, 3).Contiguous(ctx)
 	key = key.Permute(ctx, 0, 2, 1, 3).Contiguous(ctx)
@@ -135,13 +141,17 @@ type TextCrossAttentionDecoderLayer struct {
 	MLPNorm *nn.RMSNorm `gguf:"ffn_norm"`
 	MLP     *TextMLP
 	MLPGate ml.Tensor `gguf:"cross_attn_mlp_gate"`
+
+	run bool
 }
 
-func (d TextCrossAttentionDecoderLayer) Forward(ctx ml.Context, hiddenState, _, _, crossAttentionStates, crossAttentionMask ml.Tensor, cache model.Cache, opts *TextModelOptions) ml.Tensor {
+func (d *TextCrossAttentionDecoderLayer) Forward(ctx ml.Context, hiddenState, _, _, crossAttentionStates, crossAttentionMask ml.Tensor, cache cache.Cache, tCache *cache.TensorCache, opts *TextModelOptions) ml.Tensor {
+	d.run = true
+
 	residual := hiddenState
 
 	hiddenState = d.AttentionNorm.Forward(ctx, hiddenState, opts.eps)
-	hiddenState = d.CrossAttention.Forward(ctx, hiddenState, crossAttentionStates, cache, opts)
+	hiddenState = d.CrossAttention.Forward(ctx, hiddenState, crossAttentionStates, cache, tCache, opts)
 	hiddenState = hiddenState.Mul(ctx, d.AttentionGate.Tanh(ctx))
 	hiddenState = hiddenState.Add(ctx, residual)
 	residual = hiddenState
@@ -152,18 +162,23 @@ func (d TextCrossAttentionDecoderLayer) Forward(ctx ml.Context, hiddenState, _,
 	return hiddenState.Add(ctx, residual)
 }
 
+func (d *TextCrossAttentionDecoderLayer) Run() bool {
+	return d.run
+}
+
 type TextDecoderLayer interface {
-	Forward(ctx ml.Context, hiddenState, positionIDs, mask, crossAttentionStates, crossAttentionMask ml.Tensor, cache model.Cache, opts *TextModelOptions) ml.Tensor
+	Forward(ctx ml.Context, hiddenState, positionIDs, mask, crossAttentionStates, crossAttentionMask ml.Tensor, cache cache.Cache, tCache *cache.TensorCache, opts *TextModelOptions) ml.Tensor
+	Run() bool
 }
 
 type TextDecoder struct {
 	Layers []TextDecoderLayer
 }
 
-func (d *TextDecoder) Forward(ctx ml.Context, hiddenState, positionIDs, mask, crossAttentionStates, crossAttentionMask ml.Tensor, cache model.Cache, opts *TextModelOptions) ml.Tensor {
+func (d *TextDecoder) Forward(ctx ml.Context, hiddenState, positionIDs, mask, crossAttentionStates, crossAttentionMask ml.Tensor, cache cache.Cache, tCache *cache.TensorCache, opts *TextModelOptions) ml.Tensor {
 	for i, layer := range d.Layers {
-		if !slices.Contains(opts.crossAttentionLayers, uint32(i)) || crossAttentionStates != nil {
-			hiddenState = layer.Forward(ctx, hiddenState, positionIDs, mask, crossAttentionStates, crossAttentionMask, cache.Sub(i), opts)
+		if layer.Run() || crossAttentionStates != nil {
+			hiddenState = layer.Forward(ctx, hiddenState, positionIDs, mask, crossAttentionStates, crossAttentionMask, cache.Sub(i), tCache.Sub(i), opts)
 		}
 	}
 
@@ -189,9 +204,9 @@ type TextModel struct {
 	*TextModelOptions
 }
 
-func (m *TextModel) Forward(ctx ml.Context, inputIDs, positionIDs, mask, crossAttentionStates, crossAttentionMask ml.Tensor, cache model.Cache) ml.Tensor {
+func (m *TextModel) Forward(ctx ml.Context, inputIDs, positionIDs, mask, crossAttentionStates, crossAttentionMask ml.Tensor, cache cache.Cache, tCache *cache.TensorCache) ml.Tensor {
 	hiddenState := m.TokenEmbedding.Forward(ctx, inputIDs)
-	hiddenState = m.Transformer.Forward(ctx, hiddenState, positionIDs, mask, crossAttentionStates, crossAttentionMask, cache, m.TextModelOptions)
+	hiddenState = m.Transformer.Forward(ctx, hiddenState, positionIDs, mask, crossAttentionStates, crossAttentionMask, cache, tCache, m.TextModelOptions)
 	hiddenState = m.OutputNorm.Forward(ctx, hiddenState, m.eps)
 	return m.Output.Forward(ctx, hiddenState)
 }

+ 32 - 45
model/model.go

@@ -20,51 +20,28 @@ import (
 	_ "github.com/ollama/ollama/ml/backend"
 )
 
-type Cache struct {
-	cache.Cache
-	cache.Options
-}
-
-func (c Cache) Sub(i int) Cache {
-	if c.Cache != nil {
-		return Cache{
-			Cache:   c.Cache.Sub(i),
-			Options: c.Options,
-		}
-	}
-
-	return c
-}
-
-func (c Cache) Put(ctx ml.Context, key, value ml.Tensor, opts cache.Options) (ml.Tensor, ml.Tensor) {
-	if c.Cache != nil {
-		return c.Cache.Put(ctx, key, value, opts)
-	}
-
-	return key, value
-}
-
 type Options struct {
-	inputs []int32
+	inputs    []int32
+	positions []int32
+	outputs   []int32
 
-	Offset int
+	sequences []int
 
 	Images []image.Image
 
-	Cache
+	cache.Cache
 }
 
 func (opts Options) Inputs() []int32 {
-	return opts.inputs[opts.Offset:]
+	return opts.inputs
 }
 
 func (opts Options) Positions() []int32 {
-	positions := make([]int32, len(opts.inputs)-opts.Offset)
-	for i := range positions {
-		positions[i] = int32(opts.Offset + i)
-	}
+	return opts.positions
+}
 
-	return positions
+func (opts Options) Outputs() []int32 {
+	return opts.outputs
 }
 
 type OptionsFunc func(Model, *Options)
@@ -75,10 +52,21 @@ func WithInputIDs(ids []int32) OptionsFunc {
 	}
 }
 
-func WithOffset(offset int) OptionsFunc {
+func WithPositions(pos []int32) OptionsFunc {
+	return func(m Model, opts *Options) {
+		opts.positions = pos
+	}
+}
+
+func WithOutputs(outputs []int32) OptionsFunc {
+	return func(m Model, opts *Options) {
+		opts.outputs = outputs
+	}
+}
+
+func WithSequences(seqs []int) OptionsFunc {
 	return func(m Model, opts *Options) {
-		opts.Offset = offset
-		opts.Cache.Position = offset
+		opts.sequences = seqs
 	}
 }
 
@@ -90,12 +78,7 @@ func WithImage(img image.Image) OptionsFunc {
 
 func WithCache(c cache.Cache) OptionsFunc {
 	return func(m Model, opts *Options) {
-		opts.Cache = Cache{
-			Cache: c,
-			Options: cache.Options{
-				Position: opts.Offset,
-			},
-		}
+		opts.Cache = c
 	}
 }
 
@@ -272,18 +255,22 @@ func canNil(t reflect.Type) bool {
 		t.Kind() == reflect.Slice
 }
 
-func Forward(m Model, optsFuncs ...OptionsFunc) (ml.Tensor, error) {
+func Forward(ctx ml.Context, m Model, optsFuncs ...OptionsFunc) (ml.Tensor, error) {
 	var opts Options
 	for _, optsFunc := range optsFuncs {
 		optsFunc(m, &opts)
 	}
 
-	ctx := m.Backend().NewContext()
+	err := opts.Cache.StartForward(ctx, opts.sequences)
+	if err != nil {
+		return nil, err
+	}
+
 	t, err := m.Forward(ctx, opts)
 	if err != nil {
 		return nil, err
 	}
-	defer ctx.Close()
 
+	ctx.Forward(t)
 	return ctx.Compute(t), nil
 }

+ 0 - 0
llama/runner/README.md → runner/README.md


+ 5 - 5
llama/runner/stop.go → runner/common/stop.go

@@ -1,10 +1,10 @@
-package runner
+package common
 
 import (
 	"strings"
 )
 
-func findStop(sequence string, stops []string) (bool, string) {
+func FindStop(sequence string, stops []string) (bool, string) {
 	for _, stop := range stops {
 		if strings.Contains(sequence, stop) {
 			return true, stop
@@ -14,7 +14,7 @@ func findStop(sequence string, stops []string) (bool, string) {
 	return false, ""
 }
 
-func containsStopSuffix(sequence string, stops []string) bool {
+func ContainsStopSuffix(sequence string, stops []string) bool {
 	for _, stop := range stops {
 		for i := 1; i <= len(stop); i++ {
 			if strings.HasSuffix(sequence, stop[:i]) {
@@ -29,7 +29,7 @@ func containsStopSuffix(sequence string, stops []string) bool {
 // truncateStop removes the provided stop string from pieces,
 // returning the partial pieces with stop removed, including truncating
 // the last piece if required (and signalling if this was the case)
-func truncateStop(pieces []string, stop string) ([]string, bool) {
+func TruncateStop(pieces []string, stop string) ([]string, bool) {
 	joined := strings.Join(pieces, "")
 
 	index := strings.Index(joined, stop)
@@ -65,7 +65,7 @@ func truncateStop(pieces []string, stop string) ([]string, bool) {
 	return result, tokenTruncated
 }
 
-func incompleteUnicode(token string) bool {
+func IncompleteUnicode(token string) bool {
 	incomplete := false
 
 	// check if there is incomplete UTF-8 character at the end

+ 3 - 3
llama/runner/stop_test.go → runner/common/stop_test.go

@@ -1,4 +1,4 @@
-package runner
+package common
 
 import (
 	"reflect"
@@ -52,7 +52,7 @@ func TestTruncateStop(t *testing.T) {
 
 	for _, tt := range tests {
 		t.Run(tt.name, func(t *testing.T) {
-			result, resultTrunc := truncateStop(tt.pieces, tt.stop)
+			result, resultTrunc := TruncateStop(tt.pieces, tt.stop)
 			if !reflect.DeepEqual(result, tt.expected) || resultTrunc != tt.expectedTrunc {
 				t.Errorf("truncateStop(%v, %s): have %v (%v); want %v (%v)", tt.pieces, tt.stop, result, resultTrunc, tt.expected, tt.expectedTrunc)
 			}
@@ -120,7 +120,7 @@ func TestIncompleteUnicode(t *testing.T) {
 
 	for _, tt := range tests {
 		t.Run(tt.name, func(t *testing.T) {
-			result := incompleteUnicode(tt.input)
+			result := IncompleteUnicode(tt.input)
 			if result != tt.expected {
 				t.Errorf("incompleteUnicode(%s): have %v; want %v", tt.input, result, tt.expected)
 			}

+ 248 - 0
runner/newrunner/cache.go

@@ -0,0 +1,248 @@
+package newrunner
+
+import (
+	"errors"
+	"fmt"
+	"log/slog"
+	"math"
+	"reflect"
+	"time"
+
+	"github.com/ollama/ollama/cache"
+	"github.com/ollama/ollama/ml"
+)
+
+type InputCache struct {
+	// context window size (per slot)
+	numCtx int
+
+	// individual KV caches
+	slots []InputCacheSlot
+
+	// optimize cache eviction for multiple users
+	multiUserCache bool
+
+	cache cache.Cache
+}
+
+func NewInputCache(backend ml.Backend, kvSize int, numSlots int, multiUserCache bool) (*InputCache, error) {
+	if kvSize/numSlots < 1 {
+		return nil, fmt.Errorf("must have at least one kv cache entry per parallel sequence (kv: %v parallel: %v)", kvSize, numSlots)
+	}
+
+	slots := make([]InputCacheSlot, numSlots)
+
+	for i := range slots {
+		slots[i] = InputCacheSlot{
+			Id:     i,
+			Inputs: make([]input, 0),
+		}
+	}
+
+	return &InputCache{
+		numCtx:         kvSize / numSlots,
+		slots:          slots,
+		multiUserCache: multiUserCache,
+		cache:          cache.NewCausalCache(backend, kvSize, ml.DTypeF32),
+	}, nil
+}
+
+// Locking: Operations on InputCacheSlot (including finding one
+// through LoadCacheSlot) require a lock to be be held that serializes
+// these operations with each other and llama.Decode
+
+type InputCacheSlot struct {
+	// Index in the KV cache
+	Id int
+
+	// Inputs that are stored in the KV cache
+	Inputs []input
+
+	// is this cache actively being processed as part of a sequence?
+	InUse bool
+
+	// last time this cache was used (as of start of processing)
+	lastUsed time.Time
+}
+
+func (c *InputCache) LoadCacheSlot(prompt []input, cachePrompt bool) (*InputCacheSlot, []input, error) {
+	var slot *InputCacheSlot
+	var numPast int
+	var err error
+
+	// In single-user scenarios, the longest cache slot works fine for getting good input
+	// cache hit rates and it reuses the same VRAM over and over again, which is good for
+	// GPU performance in situations where we miss the input cache.
+	// For multiple users, the "best" cache slot produces better input cache hit rates
+	// at the cost of worse performance when we miss the input cache (because it causes
+	// GPU L2 cache misses due to spreading out accesses across VRAM).
+	if !c.multiUserCache {
+		slot, numPast, err = c.findLongestCacheSlot(prompt)
+	} else {
+		slot, numPast, err = c.findBestCacheSlot(prompt)
+	}
+	if err != nil {
+		return nil, nil, err
+	}
+
+	if !cachePrompt {
+		numPast = 0
+	}
+
+	slot.InUse = true
+	slot.lastUsed = time.Now()
+
+	if numPast == len(prompt) {
+		// Leave one input to sample so we can get a response
+		numPast--
+	}
+
+	err = c.cache.Remove(slot.Id, numPast, math.MaxInt)
+	if err != nil {
+		// Some models don't support partial erasure
+		c.cache.Remove(slot.Id, 0, math.MaxInt)
+		numPast = 0
+	}
+
+	slog.Debug("loading cache slot", "id", slot.Id, "cache", len(slot.Inputs), "prompt", len(prompt),
+		"used", numPast, "remaining", len(prompt)-numPast)
+
+	prompt = prompt[numPast:]
+	slot.Inputs = slot.Inputs[:numPast]
+
+	return slot, prompt, nil
+}
+
+func (c *InputCache) findLongestCacheSlot(prompt []input) (*InputCacheSlot, int, error) {
+	longest := -1
+	var longestSlot *InputCacheSlot
+
+	for i, s := range c.slots {
+		if s.InUse {
+			continue
+		}
+
+		count := countCommonPrefix(s.Inputs, prompt)
+		if count > longest {
+			longest = count
+			longestSlot = &c.slots[i]
+		}
+	}
+
+	if longestSlot == nil {
+		return nil, 0, errors.New("no available cache slots")
+	}
+
+	return longestSlot, longest, nil
+}
+
+func (c *InputCache) findBestCacheSlot(prompt []input) (*InputCacheSlot, int, error) {
+	oldest := time.Now()
+	var oldestSlot *InputCacheSlot
+
+	longest := -1
+	var longestSlot *InputCacheSlot
+
+	for i, s := range c.slots {
+		count := countCommonPrefix(s.Inputs, prompt)
+		if count > longest {
+			longest = count
+			longestSlot = &c.slots[i]
+		}
+
+		if s.lastUsed.Compare(oldest) < 0 && !s.InUse {
+			oldest = s.lastUsed
+			oldestSlot = &c.slots[i]
+		}
+	}
+
+	if longest == len(longestSlot.Inputs) && !longestSlot.InUse {
+		return longestSlot, longest, nil
+	}
+
+	if oldestSlot.InUse {
+		return nil, 0, errors.New("no available cache slots")
+	}
+
+	if len(oldestSlot.Inputs) != 0 {
+		slog.Debug("evicting cache slot", "id", oldestSlot.Id, "inputs", len(oldestSlot.Inputs),
+			"used", oldestSlot.lastUsed)
+	}
+
+	if longest > 0 && longestSlot != oldestSlot {
+		slog.Debug("forking cache slot", "src", longestSlot.Id, "dst", oldestSlot.Id, "inputs", longest, "total",
+			len(longestSlot.Inputs))
+		oldestSlot.Inputs = make([]input, longest)
+		copy(oldestSlot.Inputs, longestSlot.Inputs[:longest])
+		// This is only nil for unit tests
+		if c.cache != nil {
+			c.cache.CopyPrefix(longestSlot.Id, oldestSlot.Id, longest)
+		}
+	}
+
+	return oldestSlot, longest, nil
+}
+
+func countCommonPrefix(a []input, b []input) int {
+	var count int
+
+	for i := range a {
+		if i >= len(b) {
+			break
+		}
+
+		if !reflect.DeepEqual(a[i], b[i]) {
+			break
+		}
+
+		count++
+	}
+
+	return count
+}
+
+func (c *InputCache) ShiftDiscard(inputLen int, numKeep int) int {
+	targetFree := (c.numCtx - numKeep) / 2
+	targetFree = max(targetFree, 1)
+
+	currentFree := c.numCtx - inputLen
+	discard := targetFree - currentFree
+
+	if discard < 0 {
+		discard = 0
+	}
+
+	return discard
+}
+
+// Frees up space in the KV cache by deleting the oldest half of history and shifting
+// the newest half into that space (saving numKeep inputs at the beginning).
+//
+// Assumes that at least 1 entry can be freed up by shifting (i.e. numKeep < numCtx)
+func (c *InputCache) ShiftCacheSlot(slot *InputCacheSlot, numKeep int) error {
+	if numKeep >= c.numCtx {
+		return fmt.Errorf("unable to shift context - keep exceeds context (keep: %v context: %v)", numKeep, c.numCtx)
+	}
+
+	discard := c.ShiftDiscard(len(slot.Inputs), numKeep)
+
+	if discard <= 0 {
+		return nil
+	}
+
+	slog.Debug("context limit hit - shifting", "id", slot.Id, "limit", c.numCtx, "input", len(slot.Inputs),
+		"keep", numKeep, "discard", discard)
+
+	// TODO (jessegross): KV cache removal can fail for certain types of models
+	err := c.cache.Remove(slot.Id, numKeep, numKeep+discard)
+	if err != nil {
+		return fmt.Errorf("unable to remove old kv cache entries (id: %v, keep: %v discard: %v): %w", slot.Id, numKeep, discard, err)
+	}
+
+	for i := numKeep + discard; i < len(slot.Inputs); i++ {
+		slot.Inputs[i-discard] = slot.Inputs[i]
+	}
+	slot.Inputs = slot.Inputs[:len(slot.Inputs)-discard]
+
+	return nil
+}

+ 1 - 1
llama/runner/cache_test.go → runner/newrunner/cache_test.go

@@ -1,4 +1,4 @@
-package runner
+package newrunner
 
 import (
 	"testing"

+ 971 - 0
runner/newrunner/runner.go

@@ -0,0 +1,971 @@
+package newrunner
+
+import (
+	"bytes"
+	"context"
+	"encoding/json"
+	"errors"
+	"flag"
+	"fmt"
+	"image"
+	"io"
+	"log"
+	"log/slog"
+	"net"
+	"net/http"
+	"os"
+	"path/filepath"
+	"regexp"
+	"runtime"
+	"strconv"
+	"strings"
+	"sync"
+	"time"
+	"unicode/utf8"
+
+	"golang.org/x/sync/semaphore"
+
+	"github.com/ollama/ollama/api"
+	"github.com/ollama/ollama/model"
+	"github.com/ollama/ollama/runner/common"
+	"github.com/ollama/ollama/sample"
+
+	_ "github.com/ollama/ollama/model/llama"
+	_ "github.com/ollama/ollama/model/mllama"
+)
+
+// input is an element of the prompt to process, either
+// a token or an image embedding (generated from a vision projector)
+type input struct {
+	token int32
+
+	// embed is an image embedding
+	//embed []float32
+
+	image image.Image
+}
+
+type Sequence struct {
+	// batch index
+	iBatch int
+
+	// prompt inputs left to evaluate
+	inputs []input
+
+	// inputs that have been added to a batch but not yet submitted to Decode
+	pendingInputs []input
+
+	// tokens that have been generated but not returned yet (e.g. for stop sequences)
+	pendingResponses []string
+
+	// input cache being used by this sequence
+	cache *InputCacheSlot
+
+	// channel to send responses over
+	responses chan string
+
+	// channel to stop decoding (such as if the remote connection is closed)
+	quit chan bool
+
+	// number of tokens to predict
+	numPredict int
+
+	// set of samplers to run on generated logits
+	samplers []sample.Sampler
+
+	// channel to send back the embedding if embedding only
+	embedding chan []float32
+
+	// stop sequences
+	stop []string
+
+	// number of inputs to keep at the beginning when shifting context window
+	numKeep int
+
+	// true if an embedding are to be returned instead of text generation
+	embeddingOnly bool
+
+	doneReason string
+
+	// Metrics
+	startProcessingTime time.Time
+	startGenerationTime time.Time
+	numPredicted        int
+	numPromptInputs     int
+}
+
+type NewSequenceParams struct {
+	numPredict int
+	stop       []string
+	numKeep    int
+	samplers   []sample.Sampler
+	embedding  bool
+}
+
+func (s *Server) NewSequence(prompt string, images []ImageData, params NewSequenceParams) (*Sequence, error) {
+	s.ready.Wait()
+
+	startTime := time.Now()
+
+	inputs, err := s.inputs(prompt, images)
+	if err != nil {
+		return nil, fmt.Errorf("failed to process inputs: %w", err)
+	} else if len(inputs) == 0 {
+		return nil, errors.New("no input provided")
+	}
+
+	if params.numKeep < 0 {
+		params.numKeep = len(inputs)
+	}
+
+	// Ensure that at least 1 input can be discarded during shift
+	params.numKeep = min(params.numKeep, s.cache.numCtx-1)
+
+	if len(inputs) > s.cache.numCtx {
+		discard := len(inputs) - s.cache.numCtx
+		newInputs := inputs[:params.numKeep]
+		newInputs = append(newInputs, inputs[params.numKeep+discard:]...)
+
+		slog.Warn("truncating input prompt", "limit", s.cache.numCtx, "prompt", len(inputs), "keep", params.numKeep, "new", len(newInputs))
+		inputs = newInputs
+	}
+
+	// TODO(jessegross): Ingest cached history for grammar
+
+	return &Sequence{
+		inputs:              inputs,
+		numPromptInputs:     len(inputs),
+		startProcessingTime: startTime,
+		numPredict:          params.numPredict,
+		pendingResponses:    make([]string, 0),
+		responses:           make(chan string, 100),
+		quit:                make(chan bool, 1),
+		embedding:           make(chan []float32, 1),
+		samplers:            params.samplers,
+		embeddingOnly:       params.embedding,
+		stop:                params.stop,
+		numKeep:             params.numKeep,
+	}, nil
+}
+
+// inputs processes the prompt and images into a list of inputs
+// by splitting the prompt on [img-<n>] tags, tokenizing text and
+// generating image embeddings for each image
+func (s *Server) inputs(prompt string, images []ImageData) ([]input, error) {
+	var inputs []input
+	var parts []string
+	var matches [][]string
+
+	//if s.image != nil {
+	re := regexp.MustCompile(`\[img-(\d+)\]`)
+	parts = re.Split(prompt, -1)
+	matches = re.FindAllStringSubmatch(prompt, -1)
+	/*} else {
+		parts = []string{prompt}
+	}*/
+
+	for i, part := range parts {
+		// text - tokenize
+		tokens, err := s.model.(model.TextProcessor).Encode(part)
+		if err != nil {
+			return nil, err
+		}
+
+		for _, t := range tokens {
+			inputs = append(inputs, input{token: t})
+		}
+
+		// image - generate image embedding
+		if i < len(matches) {
+			n, _ := strconv.Atoi(matches[i][1])
+
+			imageIndex := -1
+			for j := range images {
+				if images[j].ID == n {
+					imageIndex = j
+					break
+				}
+			}
+
+			if imageIndex < 0 {
+				return nil, fmt.Errorf("invalid image index: %d", n)
+			}
+
+			image, _, err := image.Decode(bytes.NewReader(images[imageIndex].Data))
+			if err != nil {
+				return nil, err
+			}
+
+			inputs = append(inputs, input{image: image})
+
+			/*embed, err := s.image.NewEmbed(s.lc, images[imageIndex].Data, images[imageIndex].AspectRatioID)
+			if err != nil {
+				return nil, err
+			}
+
+			for _, e := range embed {
+				inputs = append(inputs, input{embed: e})
+			}*/
+		}
+	}
+
+	return inputs, nil
+}
+
+type Server struct {
+	// is the server ready to process requests?
+	// protects access to model and image
+	ready sync.WaitGroup
+
+	// loaded model
+	model model.Model
+
+	// status for external health reporting - loading, ready to serve, etc.
+	status ServerStatus
+
+	// current progress on loading the model
+	progress float32
+
+	// number of simultaneous requests to handle
+	parallel int
+
+	// maximum number of elements in a batch (per sequence)
+	// TODO (jmorganca): make this n_batch
+	batchSize int
+
+	// protects access to everything below this line
+	// this is context state needed for decoding
+	mu sync.Mutex
+
+	// indicates that data is ready for processing
+	cond *sync.Cond
+
+	// the list of simultaneous sequences being evaluated
+	seqs []*Sequence
+
+	// seqs can have a maximum of parallel entries, which
+	// is enfoced by seqSem
+	seqsSem *semaphore.Weighted
+
+	// KV cache
+	cache *InputCache
+
+	// next sequence for prompt processing to avoid starvation
+	// TODO(jessegross): Currently unused
+	nextSeq int
+}
+
+func (s *Server) allNil() bool {
+	for _, item := range s.seqs {
+		if item != nil {
+			return false
+		}
+	}
+	return true
+}
+
+func flushPending(seq *Sequence) bool {
+	joined := strings.Join(seq.pendingResponses, "")
+	seq.pendingResponses = []string{}
+
+	// Check if there are any partial UTF-8 characters remaining.
+	// We already check and queue as we are generating but some may
+	// still make it here:
+	// - Sequence is ending, e.g. generation limit has been hit
+	// - Invalid characters in the middle of a string
+	// This is a stricter check to ensure we never output invalid Unicode.
+	for !utf8.ValidString(joined) {
+		joined = joined[:len(joined)-1]
+	}
+
+	if len(joined) == 0 {
+		return true
+	}
+
+	select {
+	case seq.responses <- joined:
+		return true
+	case <-seq.quit:
+		return false
+	}
+}
+
+func (s *Server) removeSequence(seqIndex int, reason string) {
+	seq := s.seqs[seqIndex]
+
+	flushPending(seq)
+	seq.doneReason = reason
+	close(seq.responses)
+	close(seq.embedding)
+	seq.cache.InUse = false
+	s.seqs[seqIndex] = nil
+	s.seqsSem.Release(1)
+}
+
+func (s *Server) run(ctx context.Context) {
+	s.ready.Wait()
+
+	for {
+		select {
+		case <-ctx.Done():
+			return
+		default:
+			err := s.processBatch()
+			if err != nil {
+				panic(err)
+			}
+		}
+	}
+}
+
+func (s *Server) processBatch() error {
+	s.mu.Lock()
+	for s.allNil() {
+		s.cond.Wait() // Wait until an item is added
+	}
+	defer s.mu.Unlock()
+
+	var inputIDs []int32
+	var pos []int32
+	var outputs []int32
+	var seqs []int
+
+	var image image.Image
+
+	for i, seq := range s.seqs {
+		if seq == nil {
+			continue
+		}
+
+		// if past the num predict limit
+		if seq.numPredict > 0 && seq.numPredicted >= seq.numPredict {
+			s.removeSequence(i, "limit")
+			continue
+		}
+
+		for j, input := range seq.inputs {
+			if len(seq.cache.Inputs)+len(seq.pendingInputs)+1 > s.cache.numCtx {
+				if len(seq.pendingInputs) == 0 {
+					err := s.cache.ShiftCacheSlot(seq.cache, seq.numKeep)
+					if err != nil {
+						return err
+					}
+				} else {
+					break
+				}
+			}
+
+			if j >= s.batchSize {
+				break
+			}
+
+			if input.image != nil {
+				if image != nil {
+					break
+				}
+				image = input.image
+				seq.pendingInputs = append(seq.pendingInputs, input)
+				continue
+			}
+
+			inputIDs = append(inputIDs, input.token)
+			pos = append(pos, int32(len(seq.cache.Inputs)+len(seq.pendingInputs)))
+			seqs = append(seqs, seq.cache.Id)
+
+			seq.iBatch = len(outputs)
+			if j+1 == len(seq.inputs) {
+				outputs = append(outputs, int32(len(inputIDs)-1))
+			}
+			seq.pendingInputs = append(seq.pendingInputs, input)
+		}
+
+		seq.inputs = seq.inputs[len(seq.pendingInputs):]
+	}
+
+	if len(inputIDs) == 0 {
+		return nil
+	}
+
+	var options []model.OptionsFunc
+	if image != nil {
+		options = append(options, model.WithImage(image))
+	}
+
+	ctx := s.model.Backend().NewContext()
+	defer ctx.Close()
+
+	logit, err := model.Forward(ctx, s.model, append(options, model.WithCache(s.cache.cache), model.WithInputIDs(inputIDs), model.WithPositions(pos), model.WithOutputs(outputs), model.WithSequences(seqs))...)
+	if err != nil {
+		return err
+	}
+
+	f32s := logit.Floats()
+
+	for i, seq := range s.seqs {
+		if seq == nil {
+			continue
+		}
+
+		// After calling Forward, pending inputs are now in the cache
+		if len(seq.pendingInputs) > 0 {
+			seq.cache.Inputs = append(seq.cache.Inputs, seq.pendingInputs...)
+			seq.pendingInputs = []input{}
+		}
+
+		// don't sample prompt processing
+		if len(seq.inputs) != 0 {
+			continue
+		}
+
+		seq.numPredicted++
+		if seq.numPredicted == 1 {
+			seq.startGenerationTime = time.Now()
+		}
+
+		// if done processing the prompt, generate an embedding and return
+		if seq.embeddingOnly {
+			/*embed := s.lc.GetEmbeddingsSeq(seq.cache.Id)
+			if embed == nil {
+				embed = s.lc.GetEmbeddingsIth(seq.iBatch)
+			}
+
+			seq.embedding <- embed*/
+			s.removeSequence(i, "")
+			continue
+		}
+
+		vocabSize := len(f32s) / len(outputs)
+		seqLogits := f32s[seq.iBatch*vocabSize : (seq.iBatch+1)*vocabSize]
+
+		// TODO(jessegross): The data type and number of outputs for the samplers seem inconsistent
+		f64s := make([]float64, vocabSize)
+		for j, f32 := range seqLogits {
+			f64s[j] = float64(f32)
+		}
+
+		// do sampling
+		f64s, err = sample.Sample(f64s, seq.samplers...)
+		if err != nil {
+			return err
+		}
+
+		var outputIDs []int32
+		for _, f64 := range f64s {
+			if !s.model.(model.TextProcessor).Is(uint32(f64), model.SpecialEOS) {
+				outputIDs = append(outputIDs, int32(f64))
+			} else {
+				s.removeSequence(i, "stop")
+				continue
+			}
+		}
+
+		if len(outputIDs) == 0 {
+			continue
+		}
+
+		piece, err := s.model.(model.TextProcessor).Decode(outputIDs)
+		if errors.Is(err, io.EOF) {
+			continue
+		} else if err != nil {
+			return err
+		}
+
+		for _, id := range outputIDs {
+			seq.inputs = append(seq.inputs, input{token: id})
+		}
+
+		seq.pendingResponses = append(seq.pendingResponses, piece)
+		sequence := strings.Join(seq.pendingResponses, "")
+
+		if ok, stop := common.FindStop(sequence, seq.stop); ok {
+			slog.Debug("hit stop token", "pending", seq.pendingResponses, "stop", stop)
+
+			var tokenTruncated bool
+			origLen := len(seq.pendingResponses)
+			seq.pendingResponses, tokenTruncated = common.TruncateStop(seq.pendingResponses, stop)
+			newLen := len(seq.pendingResponses)
+
+			// Update the cache based on the tokens that will be returned:
+			// - We have more tokens than are currently in the cache because
+			// the last ones generated weren't submitted to Forward
+			// - Remove any stop sequences that we stripped out
+			// - If truncateStop removed a portion of a token, drop that
+			// - As defense-in-depth, if truncatedToken didn't find a stop token
+			// remove the extra ones that we added to the cache len
+			tokenLen := len(seq.cache.Inputs) + len(outputIDs)
+			tokenLen -= origLen - newLen
+			if tokenTruncated {
+				tokenLen--
+			}
+			if origLen == newLen {
+				tokenLen = len(seq.cache.Inputs)
+			}
+			seq.cache.Inputs = seq.cache.Inputs[:tokenLen]
+
+			s.removeSequence(i, "stop")
+			continue
+		}
+
+		if common.ContainsStopSuffix(sequence, seq.stop) {
+			continue
+		}
+
+		if common.IncompleteUnicode(sequence) {
+			continue
+		}
+
+		if !flushPending(seq) {
+			s.removeSequence(i, "connection")
+		}
+	}
+
+	return nil
+}
+
+// TODO (jmorganca): use structs from the api package to avoid duplication
+// this way the api acts as a proxy instead of using a different api for the
+// runner
+type Options struct {
+	api.Runner
+
+	NumKeep          int      `json:"n_keep"`
+	Seed             int      `json:"seed"`
+	NumPredict       int      `json:"n_predict"`
+	TopK             int      `json:"top_k"`
+	TopP             float32  `json:"top_p"`
+	MinP             float32  `json:"min_p"`
+	TypicalP         float32  `json:"typical_p"`
+	RepeatLastN      int      `json:"repeat_last_n"`
+	Temperature      float32  `json:"temperature"`
+	RepeatPenalty    float32  `json:"repeat_penalty"`
+	PresencePenalty  float32  `json:"presence_penalty"`
+	FrequencyPenalty float32  `json:"frequency_penalty"`
+	Mirostat         int      `json:"mirostat"`
+	MirostatTau      float32  `json:"mirostat_tau"`
+	MirostatEta      float32  `json:"mirostat_eta"`
+	Stop             []string `json:"stop"`
+}
+
+type ImageData struct {
+	Data          []byte `json:"data"`
+	ID            int    `json:"id"`
+	AspectRatioID int    `json:"aspect_ratio_id"`
+}
+
+type CompletionRequest struct {
+	Prompt      string      `json:"prompt"`
+	Images      []ImageData `json:"image_data"`
+	Grammar     string      `json:"grammar"`
+	CachePrompt bool        `json:"cache_prompt"`
+
+	Options
+}
+
+type Timings struct {
+	PredictedN  int     `json:"predicted_n"`
+	PredictedMS float64 `json:"predicted_ms"`
+	PromptN     int     `json:"prompt_n"`
+	PromptMS    float64 `json:"prompt_ms"`
+}
+
+type CompletionResponse struct {
+	Content string `json:"content"`
+	Stop    bool   `json:"stop"`
+
+	Model        string  `json:"model,omitempty"`
+	Prompt       string  `json:"prompt,omitempty"`
+	StoppedLimit bool    `json:"stopped_limit,omitempty"`
+	PredictedN   int     `json:"predicted_n,omitempty"`
+	PredictedMS  float64 `json:"predicted_ms,omitempty"`
+	PromptN      int     `json:"prompt_n,omitempty"`
+	PromptMS     float64 `json:"prompt_ms,omitempty"`
+
+	Timings Timings `json:"timings"`
+}
+
+func getSamplers(req CompletionRequest) []sample.Sampler {
+	/*var samplingParams llama.SamplingParams
+	samplingParams.TopK = req.TopK
+	samplingParams.TopP = req.TopP
+	samplingParams.MinP = req.MinP
+	samplingParams.TypicalP = req.TypicalP
+	samplingParams.Temp = req.Temperature
+	samplingParams.RepeatLastN = req.RepeatLastN
+	samplingParams.PenaltyRepeat = req.RepeatPenalty
+	samplingParams.PenaltyFreq = req.FrequencyPenalty
+	samplingParams.PenaltyPresent = req.PresencePenalty
+	samplingParams.Mirostat = req.Mirostat
+	samplingParams.MirostatTau = req.MirostatTau
+	samplingParams.MirostatEta = req.MirostatEta
+	samplingParams.Seed = uint32(req.Seed)
+	samplingParams.Grammar = req.Grammar*/
+
+	return []sample.Sampler{sample.Greedy()}
+}
+
+func (s *Server) completion(w http.ResponseWriter, r *http.Request) {
+	var req CompletionRequest
+	req.Options = Options(api.DefaultOptions())
+	if err := json.NewDecoder(r.Body).Decode(&req); err != nil {
+		http.Error(w, "Bad request", http.StatusBadRequest)
+		return
+	}
+
+	// Set the headers to indicate streaming
+	w.Header().Set("Content-Type", "application/json")
+	w.Header().Set("Transfer-Encoding", "chunked")
+
+	flusher, ok := w.(http.Flusher)
+	if !ok {
+		http.Error(w, "Streaming not supported", http.StatusInternalServerError)
+		return
+	}
+
+	seq, err := s.NewSequence(req.Prompt, req.Images, NewSequenceParams{
+		numPredict: req.NumPredict,
+		stop:       req.Stop,
+		numKeep:    req.NumKeep,
+		samplers:   getSamplers(req),
+		embedding:  false,
+	})
+	if err != nil {
+		http.Error(w, fmt.Sprintf("Failed to create new sequence: %v", err), http.StatusInternalServerError)
+		return
+	}
+
+	// Ensure there is a place to put the sequence, released when removed from s.seqs
+	if err := s.seqsSem.Acquire(r.Context(), 1); err != nil {
+		if errors.Is(err, context.Canceled) {
+			slog.Info("aborting completion request due to client closing the connection")
+		} else {
+			slog.Error("Failed to acquire semaphore", "error", err)
+		}
+		return
+	}
+
+	s.mu.Lock()
+	found := false
+	for i, sq := range s.seqs {
+		if sq == nil {
+			seq.cache, seq.inputs, err = s.cache.LoadCacheSlot(seq.inputs, req.CachePrompt)
+			if err != nil {
+				s.mu.Unlock()
+				http.Error(w, fmt.Sprintf("Failed to load cache: %v", err), http.StatusInternalServerError)
+				return
+			}
+
+			s.seqs[i] = seq
+			s.cond.Signal()
+			found = true
+			break
+		}
+	}
+	s.mu.Unlock()
+
+	if !found {
+		http.Error(w, "could not find an available sequence", http.StatusInternalServerError)
+		return
+	}
+
+	for {
+		select {
+		case <-r.Context().Done():
+			close(seq.quit)
+			return
+		case content, ok := <-seq.responses:
+			if ok {
+				if err := json.NewEncoder(w).Encode(&CompletionResponse{
+					Content: content,
+				}); err != nil {
+					http.Error(w, fmt.Sprintf("failed to encode response: %v", err), http.StatusInternalServerError)
+					close(seq.quit)
+					return
+				}
+
+				flusher.Flush()
+			} else {
+				// Send the final response
+				if err := json.NewEncoder(w).Encode(&CompletionResponse{
+					Stop:         true,
+					StoppedLimit: seq.doneReason == "limit",
+					Timings: Timings{
+						PromptN:     seq.numPromptInputs,
+						PromptMS:    float64(seq.startGenerationTime.Sub(seq.startProcessingTime).Milliseconds()),
+						PredictedN:  seq.numPredicted,
+						PredictedMS: float64(time.Since(seq.startGenerationTime).Milliseconds()),
+					},
+				}); err != nil {
+					http.Error(w, fmt.Sprintf("failed to encode final response: %v", err), http.StatusInternalServerError)
+				}
+
+				return
+			}
+		}
+	}
+}
+
+type EmbeddingRequest struct {
+	Content     string `json:"content"`
+	CachePrompt bool   `json:"cache_prompt"`
+}
+
+type EmbeddingResponse struct {
+	Embedding []float32 `json:"embedding"`
+}
+
+func (s *Server) embeddings(w http.ResponseWriter, r *http.Request) {
+	var req EmbeddingRequest
+	if err := json.NewDecoder(r.Body).Decode(&req); err != nil {
+		http.Error(w, fmt.Sprintf("bad request: %s", err), http.StatusBadRequest)
+		return
+	}
+
+	w.Header().Set("Content-Type", "application/json")
+
+	slog.Debug("embedding request", "content", req.Content)
+
+	seq, err := s.NewSequence(req.Content, nil, NewSequenceParams{embedding: true})
+	if err != nil {
+		http.Error(w, fmt.Sprintf("Failed to create new sequence: %v", err), http.StatusInternalServerError)
+		return
+	}
+
+	// Ensure there is a place to put the sequence, released when removed from s.seqs
+	if err := s.seqsSem.Acquire(r.Context(), 1); err != nil {
+		if errors.Is(err, context.Canceled) {
+			slog.Info("aborting embeddings request due to client closing the connection")
+		} else {
+			slog.Error("Failed to acquire semaphore", "error", err)
+		}
+		return
+	}
+
+	s.mu.Lock()
+	found := false
+	for i, sq := range s.seqs {
+		if sq == nil {
+			seq.cache, seq.inputs, err = s.cache.LoadCacheSlot(seq.inputs, req.CachePrompt)
+			if err != nil {
+				s.mu.Unlock()
+				http.Error(w, fmt.Sprintf("Failed to load cache: %v", err), http.StatusInternalServerError)
+				return
+			}
+			s.seqs[i] = seq
+			s.cond.Signal()
+			found = true
+			break
+		}
+	}
+	s.mu.Unlock()
+
+	if !found {
+		http.Error(w, "could not find an available sequence", http.StatusInternalServerError)
+		return
+	}
+
+	embedding := <-seq.embedding
+
+	if err := json.NewEncoder(w).Encode(&EmbeddingResponse{
+		Embedding: embedding,
+	}); err != nil {
+		http.Error(w, fmt.Sprintf("failed to encode response: %v", err), http.StatusInternalServerError)
+	}
+}
+
+type HealthResponse struct {
+	Status   string  `json:"status"`
+	Progress float32 `json:"progress"`
+}
+
+type ServerStatus int
+
+const (
+	ServerStatusReady ServerStatus = iota
+	ServerStatusLoadingModel
+	ServerStatusError
+)
+
+func (s ServerStatus) ToString() string {
+	switch s {
+	case ServerStatusReady:
+		return "ok"
+	case ServerStatusLoadingModel:
+		return "loading model"
+	default:
+		return "server error"
+	}
+}
+
+func (s *Server) health(w http.ResponseWriter, r *http.Request) {
+	w.Header().Set("Content-Type", "application/json")
+	if err := json.NewEncoder(w).Encode(&HealthResponse{
+		Status:   s.status.ToString(),
+		Progress: s.progress,
+	}); err != nil {
+		http.Error(w, fmt.Sprintf("failed to encode response: %v", err), http.StatusInternalServerError)
+	}
+}
+
+type multiLPath []string
+
+func (m *multiLPath) Set(value string) error {
+	*m = append(*m, value)
+	return nil
+}
+
+func (m *multiLPath) String() string {
+	return strings.Join(*m, ", ")
+}
+
+func (s *Server) loadModel(
+	//params llama.ModelParams,
+	mpath string,
+	//lpath multiLPath,
+	kvSize int,
+	/*kvCacheType string,
+	flashAttention bool,*/
+	_ int,
+	multiUserCache bool,
+) {
+	var err error
+	s.model, err = model.New(mpath)
+	if err != nil {
+		panic(err)
+	}
+
+	/*	ctxParams := llama.NewContextParams(kvSize, s.batchSize*s.parallel, s.parallel, threads, flashAttention, kvCacheType)
+		s.lc, err = llama.NewContextWithModel(s.oldModel, ctxParams)
+		if err != nil {
+			panic(err)
+		}
+
+		if lpath.String() != "" {
+			for _, path := range lpath {
+				err := s.oldModel.ApplyLoraFromFile(s.lc, path, 1.0, threads)
+				if err != nil {
+					panic(err)
+				}
+			}
+		}*/
+
+	s.cache, err = NewInputCache(s.model.Backend(), kvSize, s.parallel, multiUserCache)
+	if err != nil {
+		panic(err)
+	}
+
+	s.status = ServerStatusReady
+	s.ready.Done()
+}
+
+func Execute(args []string) error {
+	fs := flag.NewFlagSet("runner", flag.ExitOnError)
+	mpath := fs.String("model", "", "Path to model binary file")
+	parallel := fs.Int("parallel", 1, "Number of sequences to handle simultaneously")
+	batchSize := fs.Int("batch-size", 512, "Batch size")
+	_ = fs.Int("n-gpu-layers", 0, "Number of layers to offload to GPU")
+	_ = fs.Int("main-gpu", 0, "Main GPU")
+	_ = fs.Bool("flash-attn", false, "Enable flash attention")
+	kvSize := fs.Int("ctx-size", 2048, "Context (or KV cache) size")
+	_ = fs.String("kv-cache-type", "", "quantization type for KV cache (default: f16)")
+	port := fs.Int("port", 8080, "Port to expose the server on")
+	threads := fs.Int("threads", runtime.NumCPU(), "Number of threads to use during generation")
+	verbose := fs.Bool("verbose", false, "verbose output (default: disabled)")
+	_ = fs.Bool("no-mmap", false, "do not memory-map model (slower load but may reduce pageouts if not using mlock)")
+	_ = fs.Bool("mlock", false, "force system to keep model in RAM rather than swapping or compressing")
+	tensorSplit := fs.String("tensor-split", "", "fraction of the model to offload to each GPU, comma-separated list of proportions")
+	multiUserCache := fs.Bool("multiuser-cache", false, "optimize input cache algorithm for multiple users")
+
+	var lpaths multiLPath
+	fs.Var(&lpaths, "lora", "Path to lora layer file (can be specified multiple times)")
+
+	fs.Usage = func() {
+		fmt.Fprintf(fs.Output(), "Runner usage\n")
+		fs.PrintDefaults()
+	}
+	if err := fs.Parse(args); err != nil {
+		return err
+	}
+	level := slog.LevelInfo
+	if *verbose {
+		level = slog.LevelDebug
+	}
+	handler := slog.NewTextHandler(os.Stderr, &slog.HandlerOptions{
+		Level:     level,
+		AddSource: true,
+		ReplaceAttr: func(_ []string, attr slog.Attr) slog.Attr {
+			if attr.Key == slog.SourceKey {
+				source := attr.Value.Any().(*slog.Source)
+				source.File = filepath.Base(source.File)
+			}
+			return attr
+		},
+	})
+	slog.SetDefault(slog.New(handler))
+	slog.Info("starting ollama engine")
+	//slog.Info("system", "info", llama.PrintSystemInfo(), "threads", *threads)
+
+	server := &Server{
+		batchSize: *batchSize,
+		parallel:  *parallel,
+		seqs:      make([]*Sequence, *parallel),
+		seqsSem:   semaphore.NewWeighted(int64(*parallel)),
+		status:    ServerStatusLoadingModel,
+	}
+
+	var tensorSplitFloats []float32
+	if *tensorSplit != "" {
+		stringFloats := regexp.MustCompile(",").Split(*tensorSplit, -1)
+
+		tensorSplitFloats = make([]float32, 0, len(stringFloats))
+		for _, s := range stringFloats {
+			f, _ := strconv.ParseFloat(s, 32)
+			tensorSplitFloats = append(tensorSplitFloats, float32(f))
+		}
+	}
+
+	/*params := llama.ModelParams{
+		NumGpuLayers: *nGpuLayers,
+		MainGpu:      *mainGpu,
+		UseMmap:      !*noMmap && lpaths.String() == "",
+		UseMlock:     *mlock,
+		TensorSplit:  tensorSplitFloats,
+		Progress: func(progress float32) {
+			server.progress = progress
+		},
+	}*/
+
+	server.ready.Add(1)
+	go server.loadModel(*mpath, *kvSize, *threads, *multiUserCache)
+
+	server.cond = sync.NewCond(&server.mu)
+
+	ctx, cancel := context.WithCancel(context.Background())
+	go server.run(ctx)
+
+	addr := "127.0.0.1:" + strconv.Itoa(*port)
+	listener, err := net.Listen("tcp", addr)
+	if err != nil {
+		fmt.Println("Listen error:", err)
+		cancel()
+		return err
+	}
+	defer listener.Close()
+
+	mux := http.NewServeMux()
+	mux.HandleFunc("/embedding", server.embeddings)
+	mux.HandleFunc("/completion", server.completion)
+	mux.HandleFunc("/health", server.health)
+
+	httpServer := http.Server{
+		Handler: mux,
+	}
+
+	log.Println("Server listening on", addr)
+	if err := httpServer.Serve(listener); err != nil {
+		log.Fatal("server error:", err)
+		return err
+	}
+
+	cancel()
+	return nil
+}

+ 1 - 1
llama/runner/cache.go → runner/oldrunner/cache.go

@@ -1,4 +1,4 @@
-package runner
+package oldrunner
 
 import (
 	"errors"

+ 292 - 0
runner/oldrunner/cache_test.go

@@ -0,0 +1,292 @@
+package oldrunner
+
+import (
+	"testing"
+	"time"
+)
+
+func TestCountCommon(t *testing.T) {
+	tests := []struct {
+		name     string
+		t1       []input
+		t2       []input
+		expected int
+	}{
+		{
+			name:     "Equal",
+			t1:       []input{{token: 1}, {token: 2}, {token: 3}},
+			t2:       []input{{token: 1}, {token: 2}, {token: 3}},
+			expected: 3,
+		},
+		{
+			name:     "Prefix",
+			t1:       []input{{token: 1}},
+			t2:       []input{{token: 1}, {token: 2}, {token: 3}},
+			expected: 1,
+		},
+		{
+			name:     "Embeddings Prefix",
+			t1:       []input{{embed: []float32{0.1, 0.2, 0.3}}},
+			t2:       []input{{embed: []float32{0.1, 0.2, 0.3}}, {embed: []float32{0.4, 0.5, 0.6}}, {embed: []float32{0.7}}},
+			expected: 1,
+		},
+		{
+			name:     "Embeddings Prefix Partial",
+			t1:       []input{{embed: []float32{0.1, 0.2, 0.3}}},
+			t2:       []input{{embed: []float32{0.1, 0.2}}, {embed: []float32{0.4, 0.5, 0.6}}, {embed: []float32{0.7}}},
+			expected: 0,
+		},
+		{
+			name:     "Mixed",
+			t1:       []input{{token: 1}, {embed: []float32{0.2, 0.3, 0.4}}},
+			t2:       []input{{token: 1}, {embed: []float32{0.2, 0.3, 0.4}}, {token: 5}},
+			expected: 2,
+		},
+		{
+			name:     "Empty",
+			t1:       []input{},
+			t2:       []input{{token: 1}, {token: 2}, {token: 3}},
+			expected: 0,
+		},
+		{
+			name:     "Both Empty",
+			t1:       []input{},
+			t2:       []input{},
+			expected: 0,
+		},
+	}
+
+	for _, tt := range tests {
+		t.Run(tt.name, func(t *testing.T) {
+			result := countCommonPrefix(tt.t1, tt.t2)
+			if result != tt.expected {
+				t.Errorf("countCommonPrefix(%v, %v): have %v; want %v", tt.t1, tt.t2, result, tt.expected)
+			}
+		})
+	}
+}
+
+func TestFindCacheSlot(t *testing.T) {
+	type expected struct {
+		result int
+		len    int
+	}
+
+	tests := []struct {
+		name    string
+		cache   InputCache
+		prompt  []input
+		longest expected
+		best    expected
+	}{
+		{
+			name: "Empty",
+			cache: InputCache{slots: []InputCacheSlot{
+				{
+					Id:       0,
+					Inputs:   []input{},
+					InUse:    false,
+					lastUsed: time.Time{},
+				},
+				{
+					Id:       1,
+					Inputs:   []input{},
+					InUse:    false,
+					lastUsed: time.Time{},
+				},
+			}},
+			prompt:  []input{{token: 1}},
+			longest: expected{result: 0, len: 0},
+			best:    expected{result: 0, len: 0},
+		},
+		{
+			name: "Extend",
+			cache: InputCache{slots: []InputCacheSlot{
+				{
+					Id:       0,
+					Inputs:   []input{{token: 1}},
+					InUse:    false,
+					lastUsed: time.Now().Add(-time.Second),
+				},
+				{
+					Id:       1,
+					Inputs:   []input{{token: 1}, {token: 2}},
+					InUse:    false,
+					lastUsed: time.Now().Add(-2 * time.Second),
+				},
+			}},
+			prompt:  []input{{token: 1}, {token: 2}},
+			longest: expected{result: 1, len: 2},
+			best:    expected{result: 1, len: 2},
+		},
+		{
+			name: "New",
+			cache: InputCache{slots: []InputCacheSlot{
+				{
+					Id:       0,
+					Inputs:   []input{{token: 1}, {token: 2}},
+					InUse:    false,
+					lastUsed: time.Now().Add(-time.Second),
+				},
+				{
+					Id:       1,
+					Inputs:   []input{},
+					InUse:    false,
+					lastUsed: time.Time{},
+				},
+			}},
+			prompt:  []input{{token: 2}},
+			longest: expected{result: 0, len: 0},
+			best:    expected{result: 1, len: 0},
+		},
+		{
+			name: "Fork",
+			cache: InputCache{
+				slots: []InputCacheSlot{
+					{
+						Id:       0,
+						Inputs:   []input{{token: 1}, {token: 2}},
+						InUse:    false,
+						lastUsed: time.Now().Add(-time.Second),
+					},
+					{
+						Id:       1,
+						Inputs:   []input{},
+						InUse:    false,
+						lastUsed: time.Time{},
+					},
+				},
+			},
+			prompt:  []input{{token: 1}},
+			longest: expected{result: 0, len: 1},
+			best:    expected{result: 1, len: 1},
+		},
+		{
+			name: "Evict",
+			cache: InputCache{slots: []InputCacheSlot{
+				{
+					Id:       0,
+					Inputs:   []input{{token: 1}},
+					InUse:    false,
+					lastUsed: time.Now().Add(-time.Second),
+				},
+				{
+					Id:       1,
+					Inputs:   []input{{token: 1}, {token: 2}},
+					InUse:    false,
+					lastUsed: time.Now().Add(-2 * time.Second),
+				},
+			}},
+			prompt:  []input{{token: 2}, {token: 3}},
+			longest: expected{result: 0, len: 0},
+			best:    expected{result: 1, len: 0},
+		},
+		{
+			name: "In use",
+			cache: InputCache{slots: []InputCacheSlot{
+				{
+					Id:       0,
+					Inputs:   []input{{token: 1}, {token: 2}},
+					InUse:    true,
+					lastUsed: time.Now().Add(-time.Second),
+				},
+				{
+					Id:       1,
+					Inputs:   []input{{token: 1}},
+					InUse:    false,
+					lastUsed: time.Now().Add(-2 * time.Second),
+				},
+			}},
+			prompt:  []input{{token: 1}, {token: 2}},
+			longest: expected{result: 1, len: 1},
+			best:    expected{result: 1, len: 2},
+		},
+	}
+
+	for _, tt := range tests {
+		t.Run("Longest-"+tt.name, func(t *testing.T) {
+			result, resultLen, err := tt.cache.findLongestCacheSlot(tt.prompt)
+			if err != nil {
+				t.Errorf("findLongestCacheSlot: err %v", err)
+			} else if result.Id != tt.longest.result || resultLen != tt.longest.len {
+				t.Errorf("findLongestCacheSlot: slot have %v, want %v len have %v, want %v",
+					result.Id, tt.longest.result, resultLen, tt.longest.len)
+			}
+		})
+	}
+
+	for _, tt := range tests {
+		t.Run("Best-"+tt.name, func(t *testing.T) {
+			result, resultLen, err := tt.cache.findBestCacheSlot(tt.prompt)
+			if err != nil {
+				t.Errorf("findBestCacheSlot: err %v", err)
+			} else if result.Id != tt.best.result || resultLen != tt.best.len {
+				t.Errorf("findBestCacheSlot: slot have %v, want %v len have %v, want %v",
+					result.Id, tt.best.result, resultLen, tt.best.len)
+			}
+		})
+	}
+}
+
+func TestShiftDiscard(t *testing.T) {
+	tests := []struct {
+		name     string
+		numCtx   int
+		numKeep  int
+		inputLen int
+		expected int
+	}{
+		{
+			name:     "Shift",
+			numCtx:   2048,
+			numKeep:  5,
+			inputLen: 2048,
+			expected: 1021,
+		},
+		{
+			name:     "Max Keep",
+			numCtx:   2048,
+			numKeep:  2047,
+			inputLen: 2048,
+			expected: 1,
+		},
+		{
+			name:     "No Keep",
+			numCtx:   2048,
+			numKeep:  0,
+			inputLen: 2048,
+			expected: 1024,
+		},
+		{
+			name:     "Truncate",
+			numCtx:   2048,
+			numKeep:  5,
+			inputLen: 5000,
+			expected: 3973,
+		},
+		{
+			name:     "Truncate Keep",
+			numCtx:   2048,
+			numKeep:  2047,
+			inputLen: 5000,
+			expected: 2953,
+		},
+		{
+			name:     "No Op",
+			numCtx:   2048,
+			numKeep:  5,
+			inputLen: 512,
+			expected: 0,
+		},
+	}
+
+	for _, tt := range tests {
+		t.Run(tt.name, func(t *testing.T) {
+			c := InputCache{numCtx: tt.numCtx}
+			result := c.ShiftDiscard(tt.inputLen, tt.numKeep)
+			if result != tt.expected {
+				t.Errorf("shiftDiscard(ctx: %v, keep: %v input: %v): have %v; want %v", tt.numCtx, tt.numKeep, tt.inputLen, result, tt.expected)
+			}
+		})
+	}
+}

+ 1 - 1
llama/runner/image.go → runner/oldrunner/image.go

@@ -1,4 +1,4 @@
-package runner
+package oldrunner
 
 import (
 	"errors"

+ 1 - 1
llama/runner/image_test.go → runner/oldrunner/image_test.go

@@ -1,4 +1,4 @@
-package runner
+package oldrunner
 
 import (
 	"reflect"

+ 6 - 8
llama/runner/runner.go → runner/oldrunner/runner.go

@@ -1,4 +1,4 @@
-package runner
+package oldrunner
 
 import (
 	"context"
@@ -24,6 +24,7 @@ import (
 
 	"github.com/ollama/ollama/api"
 	"github.com/ollama/ollama/llama"
+	"github.com/ollama/ollama/runner/common"
 )
 
 // input is an element of the prompt to process, either
@@ -498,12 +499,12 @@ func (s *Server) processBatch(tokenBatch *llama.Batch, embedBatch *llama.Batch)
 		seq.pendingResponses = append(seq.pendingResponses, piece)
 		sequence := strings.Join(seq.pendingResponses, "")
 
-		if ok, stop := findStop(sequence, seq.stop); ok {
+		if ok, stop := common.FindStop(sequence, seq.stop); ok {
 			slog.Debug("hit stop token", "pending", seq.pendingResponses, "stop", stop)
 
 			var tokenTruncated bool
 			origLen := len(seq.pendingResponses)
-			seq.pendingResponses, tokenTruncated = truncateStop(seq.pendingResponses, stop)
+			seq.pendingResponses, tokenTruncated = common.TruncateStop(seq.pendingResponses, stop)
 			newLen := len(seq.pendingResponses)
 
 			// Update the cache based on the tokens that will be returned:
@@ -524,11 +525,11 @@ func (s *Server) processBatch(tokenBatch *llama.Batch, embedBatch *llama.Batch)
 			continue
 		}
 
-		if containsStopSuffix(sequence, seq.stop) {
+		if common.ContainsStopSuffix(sequence, seq.stop) {
 			continue
 		}
 
-		if incompleteUnicode(sequence) {
+		if common.IncompleteUnicode(sequence) {
 			continue
 		}
 
@@ -885,9 +886,6 @@ func (s *Server) loadModel(
 }
 
 func Execute(args []string) error {
-	if args[0] == "runner" {
-		args = args[1:]
-	}
 	fs := flag.NewFlagSet("runner", flag.ExitOnError)
 	mpath := fs.String("model", "", "Path to model binary file")
 	ppath := fs.String("mmproj", "", "Path to projector binary file")

+ 24 - 0
runner/runner.go

@@ -0,0 +1,24 @@
+package runner
+
+import (
+	"github.com/ollama/ollama/runner/newrunner"
+	"github.com/ollama/ollama/runner/oldrunner"
+)
+
+func Execute(args []string) error {
+	if args[0] == "runner" {
+		args = args[1:]
+	}
+
+	var newRunner bool
+	if args[0] == "--new-runner" {
+		args = args[1:]
+		newRunner = true
+	}
+
+	if newRunner {
+		return newrunner.Execute(args)
+	} else {
+		return oldrunner.Execute(args)
+	}
+}

+ 28 - 20
server/prompt.go

@@ -10,6 +10,7 @@ import (
 	"strings"
 
 	"github.com/ollama/ollama/api"
+	"github.com/ollama/ollama/envconfig"
 	"github.com/ollama/ollama/llm"
 	"github.com/ollama/ollama/model/mllama"
 	"github.com/ollama/ollama/template"
@@ -92,26 +93,33 @@ func chatPrompt(ctx context.Context, m *Model, tokenize tokenizeFunc, opts *api.
 			var imgData llm.ImageData
 
 			if isMllama {
-				data, opts, err := mllama.Preprocess(bytes.NewReader(i))
-				if err != nil {
-					return "", nil, err
-				}
-
-				buf := new(bytes.Buffer)
-				err = binary.Write(buf, binary.LittleEndian, data)
-				if err != nil {
-					return "", nil, err
-				}
-
-				ar, ok := opts["aspectRatioIndex"].(int)
-				if !ok {
-					return "", nil, fmt.Errorf("missing aspect ratio for image")
-				}
-
-				imgData = llm.ImageData{
-					ID:            len(images),
-					Data:          buf.Bytes(),
-					AspectRatioID: ar,
+				if envconfig.NewRunners() {
+					imgData = llm.ImageData{
+						ID:   len(images),
+						Data: i,
+					}
+				} else {
+					data, opts, err := mllama.Preprocess(bytes.NewReader(i))
+					if err != nil {
+						return "", nil, err
+					}
+
+					buf := new(bytes.Buffer)
+					err = binary.Write(buf, binary.LittleEndian, data)
+					if err != nil {
+						return "", nil, err
+					}
+
+					ar, ok := opts["aspectRatioIndex"].(int)
+					if !ok {
+						return "", nil, fmt.Errorf("missing aspect ratio for image")
+					}
+
+					imgData = llm.ImageData{
+						ID:            len(images),
+						Data:          buf.Bytes(),
+						AspectRatioID: ar,
+					}
 				}
 				imgPrompt = "<|image|>"
 			} else {