Selaa lähdekoodia

ml/backend/ggml: create tensor on specific backend

some tensors should be created on specific backends to reduce number of
copies and improve performance
Michael Yang 2 kuukautta sitten
vanhempi
commit
7bae7fa5ce
6 muutettua tiedostoa jossa 129 lisäystä ja 60 poistoa
  1. 3 3
      kvcache/causal.go
  2. 1 1
      kvcache/encoder.go
  3. 10 0
      ml/backend.go
  4. 106 47
      ml/backend/ggml/ggml.go
  5. 3 3
      model/models/llama/model.go
  6. 6 6
      model/models/mllama/model.go

+ 3 - 3
kvcache/causal.go

@@ -237,13 +237,13 @@ func (c *Causal) buildMask(ctx ml.Context, positions []int32, seqs []int) (ml.Te
 		mask[i] = float32(math.Inf(-1))
 	}
 
-	maskTensor, err := ctx.FromFloatSlice(mask, length, batchSize)
+	maskTensor, err := ctx.Input().FromFloatSlice(mask, length, batchSize)
 	if err != nil {
 		return nil, err
 	}
 
 	if c.config.MaskDType != ml.DTypeF32 {
-		out := ctx.Empty(c.config.MaskDType, maskTensor.Shape()...)
+		out := ctx.Input().Empty(c.config.MaskDType, maskTensor.Shape()...)
 		ctx.Forward(maskTensor.Copy(ctx, out))
 		maskTensor = out
 	}
@@ -440,7 +440,7 @@ func (c *Causal) Put(ctx ml.Context, key, value ml.Tensor) {
 	}
 
 	if _, ok := c.ctxs[c.curLayer]; !ok {
-		c.ctxs[c.curLayer] = c.backend.NewContext()
+		c.ctxs[c.curLayer] = c.backend.NewContextSize(2).Layer(c.curLayer)
 	}
 
 	if _, ok := c.keys[c.curLayer]; !ok {

+ 1 - 1
kvcache/encoder.go

@@ -106,7 +106,7 @@ func (c *EncoderCache) Put(ctx ml.Context, key, value ml.Tensor) {
 	}
 
 	if _, ok := c.ctxs[c.curLayer]; !ok {
-		c.ctxs[c.curLayer] = c.backend.NewContext()
+		c.ctxs[c.curLayer] = c.backend.NewContextSize(2).Layer(c.curLayer)
 	}
 
 	if _, ok := c.keys[c.curLayer]; !ok {

+ 10 - 0
ml/backend.go

@@ -24,6 +24,7 @@ type Backend interface {
 	Config() Config
 	Get(name string) Tensor
 	NewContext() Context
+	NewContextSize(size int) Context
 }
 
 // BackendCacheConfig should be implemented by backends that need special output
@@ -101,6 +102,15 @@ type Context interface {
 	Compute(...Tensor)
 	MaxGraphNodes() int
 	Close()
+
+	// Input returns a context appropriate for creating input tensors
+	Input() Context
+
+	// Output returns a context appropriate for creating output tensors
+	Output() Context
+
+	// Layer returns a context appropriate for creating intermediate tensors
+	Layer(int) Context
 }
 
 type Tensor interface {

+ 106 - 47
ml/backend/ggml/ggml.go

@@ -41,16 +41,14 @@ func devices() iter.Seq[*C.struct_ggml_backend_device] {
 }
 
 type Backend struct {
-	meta *fs.GGML
+	meta    *fs.GGML
+	sched   *C.struct_ggml_backend_sched
+	tensors map[string]*C.struct_ggml_tensor
+	input   *C.struct_ggml_backend
+	output  *C.struct_ggml_backend
+	layers  map[int]*C.struct_ggml_backend
 
 	flashAttention bool
-
-	sched *C.struct_ggml_backend_sched
-
-	tensors  map[string]*C.struct_ggml_tensor
-	ctxs     []*C.struct_ggml_context
-	backends []*C.struct_ggml_backend
-	bufts    []*C.struct_ggml_backend_buffer_type
 }
 
 func New(r *os.File, params ml.BackendParams) (ml.Backend, error) {
@@ -118,7 +116,6 @@ func New(r *os.File, params ml.BackendParams) (ml.Backend, error) {
 	}
 
 	input := dbt{C.ggml_backend_dev_by_type(C.GGML_BACKEND_DEVICE_TYPE_CPU), cpuBufferTypes}
-	slog.Info("input layer", "device", C.GoString(C.ggml_backend_dev_name(input.d)))
 
 	var blocks int
 	for key, value := range meta.KV() {
@@ -136,18 +133,14 @@ func New(r *os.File, params ml.BackendParams) (ml.Backend, error) {
 	layers := make([]dbt, blocks)
 	for i := range layers {
 		layers[i] = gpuBufferTypes[slices.IndexFunc(splits, indexFunc(i))]
-		slog.Info("layer", "i", i, "device", C.GoString(C.ggml_backend_dev_name(layers[i].d)))
 	}
 
 	output := gpuBufferTypes[slices.IndexFunc(splits, indexFunc(blocks))]
-	slog.Info("output layer", "device", C.GoString(C.ggml_backend_dev_name(output.d)))
 
 	maxTensors := len(meta.Tensors().Items())
 	maxTensors += 1
 	maxTensors += blocks * 2
 
-	slog.Info("max tensors", "max_tensors", maxTensors)
-
 	type tensor struct {
 		source *fs.Tensor
 		target string
@@ -242,7 +235,7 @@ func New(r *os.File, params ml.BackendParams) (ml.Backend, error) {
 
 	for bs := range maps.Values(bbs) {
 		for _, b := range bs {
-			slog.Info("model", "buffer", C.GoString(C.ggml_backend_buffer_name(b)), "size", format.HumanBytes2(uint64(C.ggml_backend_buffer_get_size(b))))
+			slog.Info("model weights", "buffer", C.GoString(C.ggml_backend_buffer_name(b)), "size", format.HumanBytes2(uint64(C.ggml_backend_buffer_get_size(b))))
 		}
 	}
 
@@ -290,11 +283,13 @@ func New(r *os.File, params ml.BackendParams) (ml.Backend, error) {
 		return nil, err
 	}
 
+	deviceBackends := make(map[*C.struct_ggml_backend_device]*C.struct_ggml_backend)
 	var backends []*C.struct_ggml_backend
 	var bufts []*C.struct_ggml_backend_buffer_type
 	for _, d := range append(gpus, append(accels, cpus...)...) {
 		b := C.ggml_backend_dev_init(d, nil)
 		backends = append(backends, b)
+		deviceBackends[d] = b
 
 		bt := C.ggml_backend_get_default_buffer_type(b)
 		if d := C.ggml_backend_get_device(b); C.ggml_backend_dev_type(d) == C.GGML_BACKEND_DEVICE_TYPE_CPU && len(gpus) > 0 {
@@ -305,13 +300,13 @@ func New(r *os.File, params ml.BackendParams) (ml.Backend, error) {
 
 		bufts = append(bufts, bt)
 
-		slog.Info("compute buffer", "backend", C.GoString(C.ggml_backend_name(b)), "buffer_type", C.GoString(C.ggml_backend_buft_name(bt)))
+		slog.Info("compute graph", "backend", C.GoString(C.ggml_backend_name(b)), "buffer_type", C.GoString(C.ggml_backend_buft_name(bt)))
 	}
 
 	return &Backend{
 		flashAttention: params.FlashAttention,
-		meta:              meta,
-		tensors:           tensors,
+		meta:           meta,
+		tensors:        tensors,
 		sched: C.ggml_backend_sched_new(
 			(*C.ggml_backend_t)(unsafe.Pointer(&backends[0])),
 			(*C.ggml_backend_buffer_type_t)(unsafe.Pointer(&bufts[0])),
@@ -319,6 +314,15 @@ func New(r *os.File, params ml.BackendParams) (ml.Backend, error) {
 			C.size_t(max(8192, len(meta.Tensors().Items())*5)),
 			true,
 		),
+		input:  deviceBackends[input.d],
+		output: deviceBackends[output.d],
+		layers: func() map[int]*C.struct_ggml_backend {
+			m := make(map[int]*C.struct_ggml_backend)
+			for i, layer := range layers {
+				m[i] = deviceBackends[layer.d]
+			}
+			return m
+		}(),
 	}, nil
 }
 
@@ -339,15 +343,21 @@ func (b *Backend) Get(name string) ml.Tensor {
 }
 
 func (b *Backend) NewContext() ml.Context {
-	maxGraphNodes := max(8192, len(b.meta.Tensors().Items())*5)
+	return b.NewContextSize(max(8192, len(b.meta.Tensors().Items())*5))
+}
+
+func (b *Backend) NewContextSize(n int) ml.Context {
 	return &Context{
-		b:          b,
+		b: b,
 		ctx: C.ggml_init(C.struct_ggml_init_params{
-			mem_size: C.size_t(maxGraphNodes)*C.ggml_tensor_overhead() + C.ggml_graph_overhead_custom(C.size_t(maxGraphNodes), false),
+			mem_size: C.size_t(n)*C.ggml_tensor_overhead() + C.ggml_graph_overhead_custom(C.size_t(n), false),
 			no_alloc: true,
 		}),
 		backend:       C.ggml_backend_sched_get_backend(b.sched, 0),
-		maxGraphNodes: maxGraphNodes,
+		maxGraphNodes: n,
+		input:         b.input,
+		output:        b.output,
+		layers:        b.layers,
 	}
 }
 
@@ -364,11 +374,61 @@ type Context struct {
 
 	ctx   *C.struct_ggml_context
 	graph *C.struct_ggml_cgraph
+
+	// backend is the backend used for new tensors
 	backend *C.struct_ggml_backend
 
+	// input is the backend used for inputs
+	input *C.struct_ggml_backend
+
+	// output is the backend used for outputs
+	output *C.struct_ggml_backend
+
+	// output is the backend used for repeating layers
+	layers map[int]*C.struct_ggml_backend
+
 	maxGraphNodes int
 }
 
+func (c *Context) Input() ml.Context {
+	if c.input != nil {
+		return &Context{
+			b:             c.b,
+			ctx:           c.ctx,
+			backend:       c.input,
+			maxGraphNodes: c.maxGraphNodes,
+		}
+	}
+
+	return c
+}
+
+func (c *Context) Output() ml.Context {
+	if c.output != nil {
+		return &Context{
+			b:             c.b,
+			ctx:           c.ctx,
+			backend:       c.output,
+			maxGraphNodes: c.maxGraphNodes,
+		}
+	}
+
+	return c
+}
+
+func (c *Context) Layer(i int) ml.Context {
+	if backend, ok := c.layers[i]; ok {
+		return &Context{
+			b:             c.b,
+			ctx:           c.ctx,
+			backend:       backend,
+			maxGraphNodes: c.maxGraphNodes,
+		}
+	}
+
+	return c
+}
+
 func (c *Context) Forward(tensors ...ml.Tensor) ml.Context {
 	if c.graph == nil {
 		c.graph = C.ggml_new_graph_custom(c.ctx, C.size_t(c.maxGraphNodes), false)
@@ -414,7 +474,7 @@ func shapeToGGML(shape []int) *C.int64_t {
 	return &sh[0]
 }
 
-func newTensor(ctx Context, dtype ml.DType, shape []int) ml.Tensor {
+func (c Context) newTensor(dtype ml.DType, shape []int) ml.Tensor {
 	if len(shape) < 1 || len(shape) > 4 {
 		panic("unsupported number of dimensions")
 	}
@@ -428,62 +488,61 @@ func newTensor(ctx Context, dtype ml.DType, shape []int) ml.Tensor {
 	var t *C.struct_ggml_tensor
 	switch dtype {
 	case ml.DTypeF32:
-		t = C.ggml_new_tensor(ctx.ctx, C.GGML_TYPE_F32, C.int(len(shape)), shapeToGGML(shape))
+		t = C.ggml_new_tensor(c.ctx, C.GGML_TYPE_F32, C.int(len(shape)), shapeToGGML(shape))
 	case ml.DTypeF16:
-		t = C.ggml_new_tensor(ctx.ctx, C.GGML_TYPE_F16, C.int(len(shape)), shapeToGGML(shape))
+		t = C.ggml_new_tensor(c.ctx, C.GGML_TYPE_F16, C.int(len(shape)), shapeToGGML(shape))
 	case ml.DTypeI32:
-		t = C.ggml_new_tensor(ctx.ctx, C.GGML_TYPE_I32, C.int(len(shape)), shapeToGGML(shape))
+		t = C.ggml_new_tensor(c.ctx, C.GGML_TYPE_I32, C.int(len(shape)), shapeToGGML(shape))
 	default:
 		panic("unsupported dtype")
 	}
 
-	b := C.ggml_backend_alloc_buffer(ctx.backend, C.ggml_nbytes(t))
+	b := C.ggml_backend_alloc_buffer(c.backend, C.ggml_nbytes(t))
 	C.ggml_backend_tensor_alloc(b, t, C.ggml_backend_buffer_get_base(b))
-	C.ggml_set_input(t)
-	return &Tensor{b: ctx.b, t: t}
+	return &Tensor{b: c.b, t: t}
 }
 
 func (c Context) Empty(dtype ml.DType, shape ...int) ml.Tensor {
-	return newTensor(c, dtype, shape)
+	return c.newTensor(dtype, shape)
 }
 
 func (c Context) Zeros(dtype ml.DType, shape ...int) ml.Tensor {
-	t := newTensor(c, dtype, shape)
+	t := c.newTensor(dtype, shape)
 	C.ggml_set_zero(t.(*Tensor).t)
 	return t
 }
 
-func fromSlice[S ~[]E, E float32 | int32](ctx Context, s S, shape []int, dtype uint32) (ml.Tensor, error) {
+func checkShape[S ~[]E, E any](s S, shape ...int) error {
 	n := len(s)
-
-	if n == 0 {
-		var shape C.int64_t = 0
-		t := C.ggml_new_tensor(ctx.ctx, dtype, 1, &shape)
-		return &Tensor{b: ctx.b, t: t}, nil
-	}
-
 	for _, v := range shape {
 		n /= v
 	}
 
 	if n != 1 {
-		return nil, fmt.Errorf("invalid shape %v for %d elements", shape, len(s))
+		return fmt.Errorf("invalid shape: %v", shape)
 	}
 
-	t := C.ggml_new_tensor(ctx.ctx, dtype, C.int(len(shape)), shapeToGGML(shape))
-	b := C.ggml_backend_alloc_buffer(ctx.backend, C.ggml_nbytes(t))
-	C.ggml_backend_tensor_alloc(b, t, C.ggml_backend_buffer_get_base(b))
-	C.ggml_backend_tensor_set(t, unsafe.Pointer(&s[0]), 0, C.ggml_nbytes(t))
-	C.ggml_set_input(t)
-	return &Tensor{b: ctx.b, t: t}, nil
+	return nil
 }
 
 func (c Context) FromFloatSlice(s []float32, shape ...int) (ml.Tensor, error) {
-	return fromSlice(c, s, shape, C.GGML_TYPE_F32)
+	if err := checkShape(s, shape...); err != nil {
+		return nil, err
+	}
+
+	t := c.newTensor(ml.DTypeF32, shape)
+	C.ggml_backend_tensor_set(t.(*Tensor).t, unsafe.Pointer(&s[0]), 0, C.ggml_nbytes(t.(*Tensor).t))
+	return t, nil
 }
 
 func (c Context) FromIntSlice(s []int32, shape ...int) (ml.Tensor, error) {
-	return fromSlice(c, s, shape, C.GGML_TYPE_I32)
+	if err := checkShape(s, shape...); err != nil {
+		return nil, err
+	}
+
+	t := c.newTensor(ml.DTypeI32, shape)
+	C.ggml_backend_tensor_set(t.(*Tensor).t, unsafe.Pointer(&s[0]), 0, C.ggml_nbytes(t.(*Tensor).t))
+	return t, nil
 }
 
 func (c Context) Close() {

+ 3 - 3
model/models/llama/model.go

@@ -138,17 +138,17 @@ func (l *Layer) Forward(ctx ml.Context, hiddenState, positionIDs, outputs ml.Ten
 }
 
 func (m *Model) Forward(ctx ml.Context, opts model.Options) (ml.Tensor, error) {
-	inputs, err := ctx.FromIntSlice(opts.Inputs, len(opts.Inputs))
+	inputs, err := ctx.Input().FromIntSlice(opts.Inputs, len(opts.Inputs))
 	if err != nil {
 		return nil, err
 	}
 
-	positions, err := ctx.FromIntSlice(opts.Positions, len(opts.Positions))
+	positions, err := ctx.Input().FromIntSlice(opts.Positions, len(opts.Positions))
 	if err != nil {
 		return nil, err
 	}
 
-	outputs, err := ctx.FromIntSlice(opts.Outputs, len(opts.Outputs))
+	outputs, err := ctx.Output().FromIntSlice(opts.Outputs, len(opts.Outputs))
 	if err != nil {
 		return nil, err
 	}

+ 6 - 6
model/models/mllama/model.go

@@ -72,7 +72,7 @@ func (m *Model) EncodeMultimodal(ctx ml.Context, multimodalData []byte) (any, er
 		return nil, err
 	}
 
-	pixelValues, err := ctx.FromFloatSlice(f32s,
+	pixelValues, err := ctx.Input().FromFloatSlice(f32s,
 		m.ImageProcessor.imageSize,
 		m.ImageProcessor.imageSize,
 		m.ImageProcessor.numChannels,
@@ -82,7 +82,7 @@ func (m *Model) EncodeMultimodal(ctx ml.Context, multimodalData []byte) (any, er
 		return nil, err
 	}
 
-	aspectRatio, err := ctx.FromIntSlice([]int32{int32(aspectRatioID)}, 1)
+	aspectRatio, err := ctx.Input().FromIntSlice([]int32{int32(aspectRatioID)}, 1)
 	if err != nil {
 		return nil, err
 	}
@@ -92,7 +92,7 @@ func (m *Model) EncodeMultimodal(ctx ml.Context, multimodalData []byte) (any, er
 		positions[i] = int32(i)
 	}
 
-	positionIDs, err := ctx.FromIntSlice(positions, len(positions))
+	positionIDs, err := ctx.Input().FromIntSlice(positions, len(positions))
 	if err != nil {
 		return nil, err
 	}
@@ -136,17 +136,17 @@ func (m *Model) Forward(ctx ml.Context, opts model.Options) (ml.Tensor, error) {
 		crossAttentionStates = opts.Multimodal[0].Multimodal.(ml.Tensor)
 	}
 
-	inputs, err := ctx.FromIntSlice(opts.Inputs, len(opts.Inputs))
+	inputs, err := ctx.Input().FromIntSlice(opts.Inputs, len(opts.Inputs))
 	if err != nil {
 		return nil, err
 	}
 
-	positions, err := ctx.FromIntSlice(opts.Positions, len(opts.Positions))
+	positions, err := ctx.Input().FromIntSlice(opts.Positions, len(opts.Positions))
 	if err != nil {
 		return nil, err
 	}
 
-	outputs, err := ctx.FromIntSlice(opts.Outputs, len(opts.Outputs))
+	outputs, err := ctx.Output().FromIntSlice(opts.Outputs, len(opts.Outputs))
 	if err != nil {
 		return nil, err
 	}