浏览代码

kvcache: create cache ctx per layer

each cache layer creates and maintains its own context instead of using
a large context for all layers
Michael Yang 2 月之前
父节点
当前提交
764e199d67
共有 4 个文件被更改,包括 68 次插入46 次删除
  1. 32 20
      kvcache/causal.go
  2. 22 14
      kvcache/encoder.go
  3. 1 1
      ml/backend.go
  4. 13 11
      ml/backend/ggml/ggml.go

+ 32 - 20
kvcache/causal.go

@@ -55,8 +55,8 @@ type Causal struct {
 
 
 	shiftFn      shiftFn
 	shiftFn      shiftFn
 	backend      ml.Backend
 	backend      ml.Backend
-	cacheCtx     ml.Context
-	keys, values []ml.Tensor
+	ctxs         map[int]ml.Context
+	keys, values map[int]ml.Tensor
 }
 }
 
 
 type cacheCell struct {
 type cacheCell struct {
@@ -70,11 +70,23 @@ type cellRange struct {
 }
 }
 
 
 func NewCausalCache(shift shiftFn) *Causal {
 func NewCausalCache(shift shiftFn) *Causal {
-	return &Causal{windowSize: math.MaxInt32, shiftFn: shift}
+	return &Causal{
+		windowSize: math.MaxInt32,
+		shiftFn:    shift,
+		ctxs:       make(map[int]ml.Context),
+		keys:       make(map[int]ml.Tensor),
+		values:     make(map[int]ml.Tensor),
+	}
 }
 }
 
 
 func NewSWACache(windowSize int32, shift shiftFn) *Causal {
 func NewSWACache(windowSize int32, shift shiftFn) *Causal {
-	return &Causal{windowSize: windowSize, shiftFn: shift}
+	return &Causal{
+		windowSize: windowSize,
+		shiftFn:    shift,
+		ctxs:       make(map[int]ml.Context),
+		keys:       make(map[int]ml.Tensor),
+		values:     make(map[int]ml.Tensor),
+	}
 }
 }
 
 
 func (c *Causal) Init(backend ml.Backend, dtype ml.DType, capacity int32) {
 func (c *Causal) Init(backend ml.Backend, dtype ml.DType, capacity int32) {
@@ -103,7 +115,6 @@ func (c *Causal) Init(backend ml.Backend, dtype ml.DType, capacity int32) {
 	c.cells = make([]cacheCell, c.Capacity)
 	c.cells = make([]cacheCell, c.Capacity)
 	c.cellRanges = make(map[int]cellRange)
 	c.cellRanges = make(map[int]cellRange)
 	c.backend = backend
 	c.backend = backend
-	c.cacheCtx = backend.NewContext()
 }
 }
 
 
 func (c *Causal) SetConfig(config ml.CacheConfig) {
 func (c *Causal) SetConfig(config ml.CacheConfig) {
@@ -115,7 +126,9 @@ func (c *Causal) SetConfig(config ml.CacheConfig) {
 }
 }
 
 
 func (c *Causal) Close() {
 func (c *Causal) Close() {
-	c.cacheCtx.Close()
+	for _, ctx := range c.ctxs {
+		ctx.Close()
+	}
 }
 }
 
 
 func (c *Causal) StartForward(ctx ml.Context, positions []int32, seqs []int) error {
 func (c *Causal) StartForward(ctx ml.Context, positions []int32, seqs []int) error {
@@ -239,13 +252,11 @@ func (c *Causal) buildMask(ctx ml.Context, positions []int32, seqs []int) (ml.Te
 }
 }
 
 
 func (c *Causal) moveCells(ctx ml.Context, src, dst, len int) {
 func (c *Causal) moveCells(ctx ml.Context, src, dst, len int) {
-	for i := range c.keys {
-		if c.keys[i] == nil {
+	for i, key := range c.keys {
+		if key == nil {
 			continue
 			continue
 		}
 		}
 
 
-		key := c.keys[i]
-
 		kHeadDim := key.Dim(0)
 		kHeadDim := key.Dim(0)
 		numKVHeads := key.Dim(1)
 		numKVHeads := key.Dim(1)
 		rowSize := key.Stride(2)
 		rowSize := key.Stride(2)
@@ -305,7 +316,7 @@ func (c *Causal) defrag() {
 		layers++
 		layers++
 	}
 	}
 
 
-	maxMoves := ctx.MaxTensors() / (6 * layers)
+	maxMoves := ctx.MaxGraphNodes() / (6 * layers)
 	moves := 0
 	moves := 0
 
 
 	var pendingSrc, pendingDst, pendingLen int
 	var pendingSrc, pendingDst, pendingLen int
@@ -377,11 +388,6 @@ func (c *Causal) defrag() {
 }
 }
 
 
 func (c *Causal) SetLayer(layer int) {
 func (c *Causal) SetLayer(layer int) {
-	if layer >= len(c.keys) {
-		c.keys = append(c.keys, make([]ml.Tensor, layer-len(c.keys)+1)...)
-		c.values = append(c.values, make([]ml.Tensor, layer-len(c.values)+1)...)
-	}
-
 	c.curLayer = layer
 	c.curLayer = layer
 }
 }
 
 
@@ -433,13 +439,19 @@ func (c *Causal) Put(ctx ml.Context, key, value ml.Tensor) {
 		panic(fmt.Errorf("inconsistent batch sizes (layer: %v, batch size: %v layer batch size: %v)", c.curLayer, c.curBatchSize, batchSize))
 		panic(fmt.Errorf("inconsistent batch sizes (layer: %v, batch size: %v layer batch size: %v)", c.curLayer, c.curBatchSize, batchSize))
 	}
 	}
 
 
-	if c.keys[c.curLayer] == nil || c.values[c.curLayer] == nil {
-		c.keys[c.curLayer] = c.cacheCtx.Zeros(c.DType, kHeadDim, numKVHeads, int(c.Capacity))
+	if _, ok := c.ctxs[c.curLayer]; !ok {
+		c.ctxs[c.curLayer] = c.backend.NewContext()
+	}
+
+	if _, ok := c.keys[c.curLayer]; !ok {
+		c.keys[c.curLayer] = c.ctxs[c.curLayer].Zeros(c.DType, kHeadDim, numKVHeads, int(c.Capacity))
+	}
 
 
+	if _, ok := c.values[c.curLayer]; !ok {
 		if c.config.PermutedV {
 		if c.config.PermutedV {
-			c.values[c.curLayer] = c.cacheCtx.Zeros(c.DType, int(c.Capacity), vHeadDim, numKVHeads)
+			c.values[c.curLayer] = c.ctxs[c.curLayer].Zeros(c.DType, int(c.Capacity), vHeadDim, numKVHeads)
 		} else {
 		} else {
-			c.values[c.curLayer] = c.cacheCtx.Zeros(c.DType, vHeadDim, numKVHeads, int(c.Capacity))
+			c.values[c.curLayer] = c.ctxs[c.curLayer].Zeros(c.DType, vHeadDim, numKVHeads, int(c.Capacity))
 		}
 		}
 	}
 	}
 
 

+ 22 - 14
kvcache/encoder.go

@@ -35,13 +35,17 @@ type EncoderCache struct {
 	encoderPos int32
 	encoderPos int32
 
 
 	// ** cache data storage **
 	// ** cache data storage **
-
-	cacheCtx     ml.Context
-	keys, values []ml.Tensor
+	backend      ml.Backend
+	ctxs         map[int]ml.Context
+	keys, values map[int]ml.Tensor
 }
 }
 
 
 func NewEncoderCache() *EncoderCache {
 func NewEncoderCache() *EncoderCache {
-	return &EncoderCache{}
+	return &EncoderCache{
+		ctxs:   make(map[int]ml.Context),
+		keys:   make(map[int]ml.Tensor),
+		values: make(map[int]ml.Tensor),
+	}
 }
 }
 
 
 func (c *EncoderCache) Init(backend ml.Backend, dtype ml.DType, capacity int32) {
 func (c *EncoderCache) Init(backend ml.Backend, dtype ml.DType, capacity int32) {
@@ -57,7 +61,7 @@ func (c *EncoderCache) Init(backend ml.Backend, dtype ml.DType, capacity int32)
 		panic(fmt.Errorf("encoder cache is unable to enforce requested CachePadding (%v)", c.config.CachePadding))
 		panic(fmt.Errorf("encoder cache is unable to enforce requested CachePadding (%v)", c.config.CachePadding))
 	}
 	}
 
 
-	c.cacheCtx = backend.NewContext()
+	c.backend = backend
 }
 }
 
 
 func (c *EncoderCache) SetConfig(config ml.CacheConfig) {
 func (c *EncoderCache) SetConfig(config ml.CacheConfig) {
@@ -69,7 +73,9 @@ func (c *EncoderCache) SetConfig(config ml.CacheConfig) {
 }
 }
 
 
 func (c *EncoderCache) Close() {
 func (c *EncoderCache) Close() {
-	c.cacheCtx.Close()
+	for _, ctx := range c.ctxs {
+		ctx.Close()
+	}
 }
 }
 
 
 func (c *EncoderCache) StartForward(ctx ml.Context, positions []int32, seqs []int) error {
 func (c *EncoderCache) StartForward(ctx ml.Context, positions []int32, seqs []int) error {
@@ -80,11 +86,6 @@ func (c *EncoderCache) StartForward(ctx ml.Context, positions []int32, seqs []in
 }
 }
 
 
 func (c *EncoderCache) SetLayer(layer int) {
 func (c *EncoderCache) SetLayer(layer int) {
-	if layer >= len(c.keys) {
-		c.keys = append(c.keys, make([]ml.Tensor, layer-len(c.keys)+1)...)
-		c.values = append(c.values, make([]ml.Tensor, layer-len(c.values)+1)...)
-	}
-
 	c.curLayer = layer
 	c.curLayer = layer
 }
 }
 
 
@@ -104,9 +105,16 @@ func (c *EncoderCache) Put(ctx ml.Context, key, value ml.Tensor) {
 		value = value.Permute(ctx, 1, 2, 0, 3)
 		value = value.Permute(ctx, 1, 2, 0, 3)
 	}
 	}
 
 
-	if c.keys[c.curLayer] == nil || c.values[c.curLayer] == nil {
-		c.keys[c.curLayer] = c.cacheCtx.Empty(key.DType(), key.Shape()...)
-		c.values[c.curLayer] = c.cacheCtx.Empty(value.DType(), value.Shape()...)
+	if _, ok := c.ctxs[c.curLayer]; !ok {
+		c.ctxs[c.curLayer] = c.backend.NewContext()
+	}
+
+	if _, ok := c.keys[c.curLayer]; !ok {
+		c.keys[c.curLayer] = c.ctxs[c.curLayer].Empty(key.DType(), key.Shape()...)
+	}
+
+	if _, ok := c.values[c.curLayer]; !ok {
+		c.values[c.curLayer] = c.ctxs[c.curLayer].Empty(value.DType(), value.Shape()...)
 	}
 	}
 
 
 	ctx.Forward(
 	ctx.Forward(

+ 1 - 1
ml/backend.go

@@ -99,7 +99,7 @@ type Context interface {
 
 
 	Forward(...Tensor) Context
 	Forward(...Tensor) Context
 	Compute(...Tensor)
 	Compute(...Tensor)
-	MaxTensors() int
+	MaxGraphNodes() int
 	Close()
 	Close()
 }
 }
 
 

+ 13 - 11
ml/backend/ggml/ggml.go

@@ -339,14 +339,15 @@ func (b *Backend) Get(name string) ml.Tensor {
 }
 }
 
 
 func (b *Backend) NewContext() ml.Context {
 func (b *Backend) NewContext() ml.Context {
-	maxTensors := max(8192, len(b.meta.Tensors().Items())*5)
+	maxGraphNodes := max(8192, len(b.meta.Tensors().Items())*5)
 	return &Context{
 	return &Context{
 		b:          b,
 		b:          b,
-		maxTensors: maxTensors,
 		ctx: C.ggml_init(C.struct_ggml_init_params{
 		ctx: C.ggml_init(C.struct_ggml_init_params{
-			mem_size: C.size_t(maxTensors)*C.ggml_tensor_overhead() + C.ggml_graph_overhead_custom(C.size_t(maxTensors), false),
+			mem_size: C.size_t(maxGraphNodes)*C.ggml_tensor_overhead() + C.ggml_graph_overhead_custom(C.size_t(maxGraphNodes), false),
 			no_alloc: true,
 			no_alloc: true,
 		}),
 		}),
+		backend:       C.ggml_backend_sched_get_backend(b.sched, 0),
+		maxGraphNodes: maxGraphNodes,
 	}
 	}
 }
 }
 
 
@@ -363,13 +364,14 @@ type Context struct {
 
 
 	ctx   *C.struct_ggml_context
 	ctx   *C.struct_ggml_context
 	graph *C.struct_ggml_cgraph
 	graph *C.struct_ggml_cgraph
+	backend *C.struct_ggml_backend
 
 
-	maxTensors int
+	maxGraphNodes int
 }
 }
 
 
 func (c *Context) Forward(tensors ...ml.Tensor) ml.Context {
 func (c *Context) Forward(tensors ...ml.Tensor) ml.Context {
 	if c.graph == nil {
 	if c.graph == nil {
-		c.graph = C.ggml_new_graph_custom(c.ctx, C.size_t(c.maxTensors), false)
+		c.graph = C.ggml_new_graph_custom(c.ctx, C.size_t(c.maxGraphNodes), false)
 	}
 	}
 
 
 	for _, tensor := range tensors {
 	for _, tensor := range tensors {
@@ -399,8 +401,8 @@ func (c *Context) Compute(tensors ...ml.Tensor) {
 	}
 	}
 }
 }
 
 
-func (c *Context) MaxTensors() int {
-	return c.maxTensors
+func (c *Context) MaxGraphNodes() int {
+	return c.maxGraphNodes
 }
 }
 
 
 func shapeToGGML(shape []int) *C.int64_t {
 func shapeToGGML(shape []int) *C.int64_t {
@@ -435,7 +437,7 @@ func newTensor(ctx Context, dtype ml.DType, shape []int) ml.Tensor {
 		panic("unsupported dtype")
 		panic("unsupported dtype")
 	}
 	}
 
 
-	b := C.ggml_backend_alloc_buffer(C.ggml_backend_sched_get_backend(ctx.b.sched, 0), C.ggml_nbytes(t))
+	b := C.ggml_backend_alloc_buffer(ctx.backend, C.ggml_nbytes(t))
 	C.ggml_backend_tensor_alloc(b, t, C.ggml_backend_buffer_get_base(b))
 	C.ggml_backend_tensor_alloc(b, t, C.ggml_backend_buffer_get_base(b))
 	C.ggml_set_input(t)
 	C.ggml_set_input(t)
 	return &Tensor{b: ctx.b, t: t}
 	return &Tensor{b: ctx.b, t: t}
@@ -469,7 +471,7 @@ func fromSlice[S ~[]E, E float32 | int32](ctx Context, s S, shape []int, dtype u
 	}
 	}
 
 
 	t := C.ggml_new_tensor(ctx.ctx, dtype, C.int(len(shape)), shapeToGGML(shape))
 	t := C.ggml_new_tensor(ctx.ctx, dtype, C.int(len(shape)), shapeToGGML(shape))
-	b := C.ggml_backend_alloc_buffer(C.ggml_backend_sched_get_backend(ctx.b.sched, 0), C.ggml_nbytes(t))
+	b := C.ggml_backend_alloc_buffer(ctx.backend, C.ggml_nbytes(t))
 	C.ggml_backend_tensor_alloc(b, t, C.ggml_backend_buffer_get_base(b))
 	C.ggml_backend_tensor_alloc(b, t, C.ggml_backend_buffer_get_base(b))
 	C.ggml_backend_tensor_set(t, unsafe.Pointer(&s[0]), 0, C.ggml_nbytes(t))
 	C.ggml_backend_tensor_set(t, unsafe.Pointer(&s[0]), 0, C.ggml_nbytes(t))
 	C.ggml_set_input(t)
 	C.ggml_set_input(t)
@@ -484,8 +486,8 @@ func (c Context) FromIntSlice(s []int32, shape ...int) (ml.Tensor, error) {
 	return fromSlice(c, s, shape, C.GGML_TYPE_I32)
 	return fromSlice(c, s, shape, C.GGML_TYPE_I32)
 }
 }
 
 
-func (c *Context) Close() {
-	if c != nil {
+func (c Context) Close() {
+	if c.ctx != nil {
 		C.ggml_free(c.ctx)
 		C.ggml_free(c.ctx)
 	}
 	}
 }
 }