Browse Source

ggml-backend: Ensure data is available after async computation

We need to sync before retrieving data after async computation.
It is also important to ensure that the Go buffer is not moved by
the GC across function calls so we do a synchronous copy.
Jesse Gross 3 months ago
parent
commit
60830695c2
1 changed files with 26 additions and 14 deletions
  1. 26 14
      ml/backend/ggml/ggml.go

+ 26 - 14
ml/backend/ggml/ggml.go

@@ -9,8 +9,6 @@ package ggml
 import "C"
 
 import (
-	"bytes"
-	"encoding/binary"
 	"fmt"
 	"io"
 	"log/slog"
@@ -245,12 +243,17 @@ func (c *Context) Forward(t ml.Tensor) {
 func (c *Context) Compute(tensors ...ml.Tensor) {
 	C.ggml_backend_sched_graph_compute_async(c.sched, c.graph)
 
-	for _, t := range tensors {
-		if C.ggml_nbytes(t.(*Tensor).t) != 0 {
-			backend := C.ggml_backend_sched_get_tensor_backend(c.sched, t.(*Tensor).t)
+	needSync := true
+	sync := func() {
+		if needSync {
+			C.ggml_backend_sched_synchronize(c.sched)
+			needSync = false
+		}
+	}
 
-			t.(*Tensor).data = make([]byte, C.ggml_nbytes(t.(*Tensor).t))
-			C.ggml_backend_tensor_get_async(backend, t.(*Tensor).t, unsafe.Pointer(&t.(*Tensor).data[0]), 0, C.ggml_nbytes(t.(*Tensor).t))
+	for _, t := range tensors {
+		if C.ggml_nbytes(t.(*Tensor).t) > 0 {
+			t.(*Tensor).sync = sync
 		}
 	}
 }
@@ -330,7 +333,7 @@ func (c *Context) Close() {
 
 type Tensor struct {
 	t    *C.struct_ggml_tensor
-	data []byte
+	sync func()
 }
 
 func (t *Tensor) LogValue() slog.Value {
@@ -358,14 +361,23 @@ func (t *Tensor) Shape() []int {
 	return shape
 }
 
-func (t *Tensor) Bytes() []byte {
-	return t.data
+func (t *Tensor) Bytes() (data []byte) {
+	if t.sync != nil {
+		data = make([]byte, C.ggml_nbytes(t.t))
+
+		t.sync()
+		C.ggml_backend_tensor_get(t.t, unsafe.Pointer(&data[0]), 0, C.ggml_nbytes(t.t))
+	}
+
+	return
 }
 
-func (t *Tensor) Floats() (f32s []float32) {
-	if t.data != nil {
-		f32s = make([]float32, C.ggml_nelements(t.t))
-		_ = binary.Read(bytes.NewReader(t.data), binary.LittleEndian, f32s)
+func (t *Tensor) Floats() (data []float32) {
+	if t.sync != nil {
+		data = make([]float32, C.ggml_nelements(t.t))
+
+		t.sync()
+		C.ggml_backend_tensor_get(t.t, unsafe.Pointer(&data[0]), 0, C.ggml_nbytes(t.t))
 	}
 
 	return