|
@@ -9,15 +9,17 @@ package ggml
|
|
|
import "C"
|
|
|
|
|
|
import (
|
|
|
- "errors"
|
|
|
+ "context"
|
|
|
"fmt"
|
|
|
"io"
|
|
|
"log/slog"
|
|
|
"maps"
|
|
|
"os"
|
|
|
+ "runtime"
|
|
|
"slices"
|
|
|
"strconv"
|
|
|
"strings"
|
|
|
+ "sync/atomic"
|
|
|
"unicode"
|
|
|
"unsafe"
|
|
|
|
|
@@ -58,7 +60,7 @@ type Backend struct {
|
|
|
maxGraphNodes int
|
|
|
}
|
|
|
|
|
|
-func New(r *os.File, params ml.BackendParams) (ml.Backend, error) {
|
|
|
+func New(ctx context.Context, r *os.File, params ml.BackendParams) (ml.Backend, error) {
|
|
|
meta, n, err := fs.Decode(r, -1)
|
|
|
if err != nil {
|
|
|
return nil, err
|
|
@@ -297,12 +299,16 @@ func New(r *os.File, params ml.BackendParams) (ml.Backend, error) {
|
|
|
}
|
|
|
}
|
|
|
|
|
|
- // concurrently read in tensor data. uses a section reader which is safe for concurrent reads
|
|
|
- sr := io.NewSectionReader(r, int64(meta.Tensors().Offset), n-int64(meta.Tensors().Offset))
|
|
|
- var g errgroup.Group
|
|
|
+ var doneBytes atomic.Uint64
|
|
|
+ totalBytes := uint64(n) - meta.Tensors().Offset
|
|
|
+
|
|
|
+ g, ctx := errgroup.WithContext(ctx)
|
|
|
+ g.SetLimit(runtime.GOMAXPROCS(0))
|
|
|
for _, t := range meta.Tensors().Items() {
|
|
|
- for _, target := range targets[t.Name] {
|
|
|
- g.Go(func() error {
|
|
|
+ g.Go(func() error {
|
|
|
+ tts := make([]*C.struct_ggml_tensor, max(1, len(targets[t.Name])))
|
|
|
+ for i := range tts {
|
|
|
+ target := targets[t.Name][i]
|
|
|
if target == "" {
|
|
|
target = t.Name
|
|
|
}
|
|
@@ -312,24 +318,43 @@ func New(r *os.File, params ml.BackendParams) (ml.Backend, error) {
|
|
|
return fmt.Errorf("unassigned tensor: %s", t.Name)
|
|
|
}
|
|
|
|
|
|
- bts := C.malloc(C.size_t(t.Size()))
|
|
|
- if bts == nil {
|
|
|
- return errors.New("failed to allocate tensor buffer")
|
|
|
+ tts[i] = tt
|
|
|
+ }
|
|
|
+
|
|
|
+ sr := io.NewSectionReader(r, int64(meta.Tensors().Offset+t.Offset), int64(t.Size()))
|
|
|
+ bts := make([]byte, 128*format.KibiByte)
|
|
|
+
|
|
|
+ var s uint64
|
|
|
+ for s < t.Size() {
|
|
|
+ n, err := io.ReadFull(sr, bts[:min(len(bts), int(t.Size()-s))])
|
|
|
+ if err != nil {
|
|
|
+ return err
|
|
|
}
|
|
|
- defer C.free(bts)
|
|
|
|
|
|
- buf := unsafe.Slice((*byte)(bts), t.Size())
|
|
|
- n, err := io.ReadFull(io.NewSectionReader(sr, int64(t.Offset), int64(t.Size())), buf)
|
|
|
- if err != nil || n != len(buf) {
|
|
|
- return errors.New("read failed")
|
|
|
+ for _, tt := range tts {
|
|
|
+ C.ggml_backend_tensor_set(tt, unsafe.Pointer(&bts[0]), C.size_t(s), C.size_t(n))
|
|
|
}
|
|
|
|
|
|
- C.ggml_backend_tensor_set(tt, bts, 0, C.size_t(t.Size()))
|
|
|
- return nil
|
|
|
- })
|
|
|
- }
|
|
|
+ s += uint64(n)
|
|
|
+
|
|
|
+ if params.Progress != nil {
|
|
|
+ done := doneBytes.Add(uint64(n))
|
|
|
+ params.Progress(float32(done) / float32(totalBytes))
|
|
|
+ }
|
|
|
+ }
|
|
|
+
|
|
|
+ return nil
|
|
|
+ })
|
|
|
}
|
|
|
|
|
|
+ // start a goroutine to cancel the errgroup if the parent context is done
|
|
|
+ go func() {
|
|
|
+ <-ctx.Done()
|
|
|
+ g.Go(func() error {
|
|
|
+ return ctx.Err()
|
|
|
+ })
|
|
|
+ }()
|
|
|
+
|
|
|
if err := g.Wait(); err != nil {
|
|
|
return nil, err
|
|
|
}
|