浏览代码

ggml-backend: Ensure data is available after async computation

We need to sync before retrieving data after async computation.
It is also important to ensure that the Go buffer is not moved by
the GC across function calls so we do a synchronous copy.
Jesse Gross 2 月之前
父节点
当前提交
60830695c2
共有 1 个文件被更改,包括 26 次插入14 次删除
  1. 26 14
      ml/backend/ggml/ggml.go

+ 26 - 14
ml/backend/ggml/ggml.go

@@ -9,8 +9,6 @@ package ggml
 import "C"
 import "C"
 
 
 import (
 import (
-	"bytes"
-	"encoding/binary"
 	"fmt"
 	"fmt"
 	"io"
 	"io"
 	"log/slog"
 	"log/slog"
@@ -245,12 +243,17 @@ func (c *Context) Forward(t ml.Tensor) {
 func (c *Context) Compute(tensors ...ml.Tensor) {
 func (c *Context) Compute(tensors ...ml.Tensor) {
 	C.ggml_backend_sched_graph_compute_async(c.sched, c.graph)
 	C.ggml_backend_sched_graph_compute_async(c.sched, c.graph)
 
 
-	for _, t := range tensors {
-		if C.ggml_nbytes(t.(*Tensor).t) != 0 {
-			backend := C.ggml_backend_sched_get_tensor_backend(c.sched, t.(*Tensor).t)
+	needSync := true
+	sync := func() {
+		if needSync {
+			C.ggml_backend_sched_synchronize(c.sched)
+			needSync = false
+		}
+	}
 
 
-			t.(*Tensor).data = make([]byte, C.ggml_nbytes(t.(*Tensor).t))
-			C.ggml_backend_tensor_get_async(backend, t.(*Tensor).t, unsafe.Pointer(&t.(*Tensor).data[0]), 0, C.ggml_nbytes(t.(*Tensor).t))
+	for _, t := range tensors {
+		if C.ggml_nbytes(t.(*Tensor).t) > 0 {
+			t.(*Tensor).sync = sync
 		}
 		}
 	}
 	}
 }
 }
@@ -330,7 +333,7 @@ func (c *Context) Close() {
 
 
 type Tensor struct {
 type Tensor struct {
 	t    *C.struct_ggml_tensor
 	t    *C.struct_ggml_tensor
-	data []byte
+	sync func()
 }
 }
 
 
 func (t *Tensor) LogValue() slog.Value {
 func (t *Tensor) LogValue() slog.Value {
@@ -358,14 +361,23 @@ func (t *Tensor) Shape() []int {
 	return shape
 	return shape
 }
 }
 
 
-func (t *Tensor) Bytes() []byte {
-	return t.data
+func (t *Tensor) Bytes() (data []byte) {
+	if t.sync != nil {
+		data = make([]byte, C.ggml_nbytes(t.t))
+
+		t.sync()
+		C.ggml_backend_tensor_get(t.t, unsafe.Pointer(&data[0]), 0, C.ggml_nbytes(t.t))
+	}
+
+	return
 }
 }
 
 
-func (t *Tensor) Floats() (f32s []float32) {
-	if t.data != nil {
-		f32s = make([]float32, C.ggml_nelements(t.t))
-		_ = binary.Read(bytes.NewReader(t.data), binary.LittleEndian, f32s)
+func (t *Tensor) Floats() (data []float32) {
+	if t.sync != nil {
+		data = make([]float32, C.ggml_nelements(t.t))
+
+		t.sync()
+		C.ggml_backend_tensor_get(t.t, unsafe.Pointer(&data[0]), 0, C.ggml_nbytes(t.t))
 	}
 	}
 
 
 	return
 	return