Jelajahi Sumber

ggml-backend: Ensure data is available after async computation

We need to sync before retrieving data after async computation.
It is also important to ensure that the Go buffer is not moved by
the GC across function calls so we do a synchronous copy.
Jesse Gross 2 bulan lalu
induk
melakukan
60830695c2
1 mengubah file dengan 26 tambahan dan 14 penghapusan
  1. 26 14
      ml/backend/ggml/ggml.go

+ 26 - 14
ml/backend/ggml/ggml.go

@@ -9,8 +9,6 @@ package ggml
 import "C"
 
 import (
-	"bytes"
-	"encoding/binary"
 	"fmt"
 	"io"
 	"log/slog"
@@ -245,12 +243,17 @@ func (c *Context) Forward(t ml.Tensor) {
 func (c *Context) Compute(tensors ...ml.Tensor) {
 	C.ggml_backend_sched_graph_compute_async(c.sched, c.graph)
 
-	for _, t := range tensors {
-		if C.ggml_nbytes(t.(*Tensor).t) != 0 {
-			backend := C.ggml_backend_sched_get_tensor_backend(c.sched, t.(*Tensor).t)
+	needSync := true
+	sync := func() {
+		if needSync {
+			C.ggml_backend_sched_synchronize(c.sched)
+			needSync = false
+		}
+	}
 
-			t.(*Tensor).data = make([]byte, C.ggml_nbytes(t.(*Tensor).t))
-			C.ggml_backend_tensor_get_async(backend, t.(*Tensor).t, unsafe.Pointer(&t.(*Tensor).data[0]), 0, C.ggml_nbytes(t.(*Tensor).t))
+	for _, t := range tensors {
+		if C.ggml_nbytes(t.(*Tensor).t) > 0 {
+			t.(*Tensor).sync = sync
 		}
 	}
 }
@@ -330,7 +333,7 @@ func (c *Context) Close() {
 
 type Tensor struct {
 	t    *C.struct_ggml_tensor
-	data []byte
+	sync func()
 }
 
 func (t *Tensor) LogValue() slog.Value {
@@ -358,14 +361,23 @@ func (t *Tensor) Shape() []int {
 	return shape
 }
 
-func (t *Tensor) Bytes() []byte {
-	return t.data
+func (t *Tensor) Bytes() (data []byte) {
+	if t.sync != nil {
+		data = make([]byte, C.ggml_nbytes(t.t))
+
+		t.sync()
+		C.ggml_backend_tensor_get(t.t, unsafe.Pointer(&data[0]), 0, C.ggml_nbytes(t.t))
+	}
+
+	return
 }
 
-func (t *Tensor) Floats() (f32s []float32) {
-	if t.data != nil {
-		f32s = make([]float32, C.ggml_nelements(t.t))
-		_ = binary.Read(bytes.NewReader(t.data), binary.LittleEndian, f32s)
+func (t *Tensor) Floats() (data []float32) {
+	if t.sync != nil {
+		data = make([]float32, C.ggml_nelements(t.t))
+
+		t.sync()
+		C.ggml_backend_tensor_get(t.t, unsafe.Pointer(&data[0]), 0, C.ggml_nbytes(t.t))
 	}
 
 	return