|
@@ -9,8 +9,6 @@ package ggml
|
|
|
import "C"
|
|
|
|
|
|
import (
|
|
|
- "bytes"
|
|
|
- "encoding/binary"
|
|
|
"fmt"
|
|
|
"io"
|
|
|
"log/slog"
|
|
@@ -245,12 +243,17 @@ func (c *Context) Forward(t ml.Tensor) {
|
|
|
func (c *Context) Compute(tensors ...ml.Tensor) {
|
|
|
C.ggml_backend_sched_graph_compute_async(c.sched, c.graph)
|
|
|
|
|
|
- for _, t := range tensors {
|
|
|
- if C.ggml_nbytes(t.(*Tensor).t) != 0 {
|
|
|
- backend := C.ggml_backend_sched_get_tensor_backend(c.sched, t.(*Tensor).t)
|
|
|
+ needSync := true
|
|
|
+ sync := func() {
|
|
|
+ if needSync {
|
|
|
+ C.ggml_backend_sched_synchronize(c.sched)
|
|
|
+ needSync = false
|
|
|
+ }
|
|
|
+ }
|
|
|
|
|
|
- t.(*Tensor).data = make([]byte, C.ggml_nbytes(t.(*Tensor).t))
|
|
|
- C.ggml_backend_tensor_get_async(backend, t.(*Tensor).t, unsafe.Pointer(&t.(*Tensor).data[0]), 0, C.ggml_nbytes(t.(*Tensor).t))
|
|
|
+ for _, t := range tensors {
|
|
|
+ if C.ggml_nbytes(t.(*Tensor).t) > 0 {
|
|
|
+ t.(*Tensor).sync = sync
|
|
|
}
|
|
|
}
|
|
|
}
|
|
@@ -330,7 +333,7 @@ func (c *Context) Close() {
|
|
|
|
|
|
type Tensor struct {
|
|
|
t *C.struct_ggml_tensor
|
|
|
- data []byte
|
|
|
+ sync func()
|
|
|
}
|
|
|
|
|
|
func (t *Tensor) LogValue() slog.Value {
|
|
@@ -358,14 +361,23 @@ func (t *Tensor) Shape() []int {
|
|
|
return shape
|
|
|
}
|
|
|
|
|
|
-func (t *Tensor) Bytes() []byte {
|
|
|
- return t.data
|
|
|
+func (t *Tensor) Bytes() (data []byte) {
|
|
|
+ if t.sync != nil {
|
|
|
+ data = make([]byte, C.ggml_nbytes(t.t))
|
|
|
+
|
|
|
+ t.sync()
|
|
|
+ C.ggml_backend_tensor_get(t.t, unsafe.Pointer(&data[0]), 0, C.ggml_nbytes(t.t))
|
|
|
+ }
|
|
|
+
|
|
|
+ return
|
|
|
}
|
|
|
|
|
|
-func (t *Tensor) Floats() (f32s []float32) {
|
|
|
- if t.data != nil {
|
|
|
- f32s = make([]float32, C.ggml_nelements(t.t))
|
|
|
- _ = binary.Read(bytes.NewReader(t.data), binary.LittleEndian, f32s)
|
|
|
+func (t *Tensor) Floats() (data []float32) {
|
|
|
+ if t.sync != nil {
|
|
|
+ data = make([]float32, C.ggml_nelements(t.t))
|
|
|
+
|
|
|
+ t.sync()
|
|
|
+ C.ggml_backend_tensor_get(t.t, unsafe.Pointer(&data[0]), 0, C.ggml_nbytes(t.t))
|
|
|
}
|
|
|
|
|
|
return
|