Browse Source

concurrent load tensors

Michael Yang 5 months ago
parent
commit
e699b8f5b9
2 changed files with 26 additions and 20 deletions
  1. 5 5
      ml/backend.go
  2. 21 15
      ml/backend/ggml/backend.go

+ 5 - 5
ml/backend.go

@@ -4,7 +4,7 @@ import (
 	"bytes"
 	"encoding/binary"
 	"fmt"
-	"io"
+	"os"
 	"strings"
 )
 
@@ -24,9 +24,9 @@ type Backend interface {
 	NewContext() Context
 }
 
-var backends = make(map[string]func(io.ReadSeeker) (Backend, error))
+var backends = make(map[string]func(*os.File) (Backend, error))
 
-func RegisterBackend(name string, f func(io.ReadSeeker) (Backend, error)) {
+func RegisterBackend(name string, f func(*os.File) (Backend, error)) {
 	if _, ok := backends[name]; ok {
 		panic("backend: backend already registered")
 	}
@@ -34,9 +34,9 @@ func RegisterBackend(name string, f func(io.ReadSeeker) (Backend, error)) {
 	backends[name] = f
 }
 
-func NewBackend(r io.ReadSeeker) (Backend, error) {
+func NewBackend(f *os.File) (Backend, error) {
 	if backend, ok := backends["ggml"]; ok {
-		return backend(r)
+		return backend(f)
 	}
 
 	return nil, fmt.Errorf("unsupported backend")

+ 21 - 15
ml/backend/ggml/backend.go

@@ -12,8 +12,11 @@ import (
 	"fmt"
 	"io"
 	"log/slog"
+	"os"
 	"unsafe"
 
+	"golang.org/x/sync/errgroup"
+
 	"github.com/ollama/ollama/format"
 	"github.com/ollama/ollama/fs/ggml"
 	"github.com/ollama/ollama/ml"
@@ -28,7 +31,7 @@ type Backend struct {
 	ggml.Tensors
 }
 
-func New(r io.ReadSeeker) (ml.Backend, error) {
+func New(r *os.File) (ml.Backend, error) {
 	f, _, err := ggml.Decode(r, -1)
 	if err != nil {
 		return nil, err
@@ -62,22 +65,20 @@ func New(r io.ReadSeeker) (ml.Backend, error) {
 
 	b := newBackend()
 	bb := C.ggml_backend_alloc_ctx_tensors(c, b)
-	for _, t := range f.Tensors().Items {
-		if _, err := r.Seek(int64(f.Tensors().Offset+t.Offset), io.SeekStart); err != nil {
-			return nil, err
-		}
 
-		var b bytes.Buffer
-		n, err := io.CopyN(&b, r, int64(t.Size()))
-		if err != nil {
-			return nil, err
-		}
+	var g errgroup.Group
+	for _, t := range f.Tensors().Items {
+		g.Go(func() error {
+			var b bytes.Buffer
+			n, err := io.Copy(&b, io.NewSectionReader(r, int64(f.Tensors().Offset+t.Offset), int64(t.Size())))
+			if err != nil {
+				return err
+			}
 
-		if n != int64(t.Size()) {
-			return nil, fmt.Errorf("expected %d bytes, got %d", t.Size(), n)
-		}
+			if n != int64(t.Size()) {
+				return fmt.Errorf("expected %d bytes, got %d", t.Size(), n)
+			}
 
-		func() {
 			cname := C.CString(t.Name)
 			defer C.free(unsafe.Pointer(cname))
 
@@ -85,7 +86,12 @@ func New(r io.ReadSeeker) (ml.Backend, error) {
 			defer C.free(cbytes)
 
 			C.ggml_backend_tensor_set(C.ggml_get_tensor(c, cname), cbytes, 0, C.size_t(n))
-		}()
+			return nil
+		})
+	}
+
+	if err := g.Wait(); err != nil {
+		return nil, err
 	}
 
 	return &Backend{c, b, bb, f.KV(), f.Tensors()}, nil