Browse Source

ml/backend/ggml: load tensors in 32KiB chunks

Michael Yang 1 month ago
parent
commit
74bd09652d
4 changed files with 58 additions and 30 deletions
  1. 5 4
      ml/backend.go
  2. 44 19
      ml/backend/ggml/ggml.go
  3. 3 2
      model/model.go
  4. 6 5
      runner/ollamarunner/runner.go

+ 5 - 4
ml/backend.go

@@ -2,6 +2,7 @@ package ml
 
 import (
 	"bytes"
+	"context"
 	"encoding/binary"
 	"fmt"
 	"os"
@@ -80,9 +81,9 @@ type BackendParams struct {
 	FlashAttention bool
 }
 
-var backends = make(map[string]func(*os.File, BackendParams) (Backend, error))
+var backends = make(map[string]func(context.Context, *os.File, BackendParams) (Backend, error))
 
-func RegisterBackend(name string, f func(*os.File, BackendParams) (Backend, error)) {
+func RegisterBackend(name string, f func(context.Context, *os.File, BackendParams) (Backend, error)) {
 	if _, ok := backends[name]; ok {
 		panic("backend: backend already registered")
 	}
@@ -90,9 +91,9 @@ func RegisterBackend(name string, f func(*os.File, BackendParams) (Backend, erro
 	backends[name] = f
 }
 
-func NewBackend(f *os.File, params BackendParams) (Backend, error) {
+func NewBackend(ctx context.Context, f *os.File, params BackendParams) (Backend, error) {
 	if backend, ok := backends["ggml"]; ok {
-		return backend(f, params)
+		return backend(ctx, f, params)
 	}
 
 	return nil, fmt.Errorf("unsupported backend")

+ 44 - 19
ml/backend/ggml/ggml.go

@@ -9,15 +9,17 @@ package ggml
 import "C"
 
 import (
-	"errors"
+	"context"
 	"fmt"
 	"io"
 	"log/slog"
 	"maps"
 	"os"
+	"runtime"
 	"slices"
 	"strconv"
 	"strings"
+	"sync/atomic"
 	"unicode"
 	"unsafe"
 
@@ -58,7 +60,7 @@ type Backend struct {
 	maxGraphNodes int
 }
 
-func New(r *os.File, params ml.BackendParams) (ml.Backend, error) {
+func New(ctx context.Context, r *os.File, params ml.BackendParams) (ml.Backend, error) {
 	meta, n, err := fs.Decode(r, -1)
 	if err != nil {
 		return nil, err
@@ -297,12 +299,16 @@ func New(r *os.File, params ml.BackendParams) (ml.Backend, error) {
 		}
 	}
 
-	// concurrently read in tensor data. uses a section reader which is safe for concurrent reads
-	sr := io.NewSectionReader(r, int64(meta.Tensors().Offset), n-int64(meta.Tensors().Offset))
-	var g errgroup.Group
+	var doneBytes atomic.Uint64
+	totalBytes := uint64(n) - meta.Tensors().Offset
+
+	g, ctx := errgroup.WithContext(ctx)
+	g.SetLimit(runtime.GOMAXPROCS(0))
 	for _, t := range meta.Tensors().Items() {
-		for _, target := range targets[t.Name] {
-			g.Go(func() error {
+		g.Go(func() error {
+			tts := make([]*C.struct_ggml_tensor, max(1, len(targets[t.Name])))
+			for i := range tts {
+				target := targets[t.Name][i]
 				if target == "" {
 					target = t.Name
 				}
@@ -312,24 +318,43 @@ func New(r *os.File, params ml.BackendParams) (ml.Backend, error) {
 					return fmt.Errorf("unassigned tensor: %s", t.Name)
 				}
 
-				bts := C.malloc(C.size_t(t.Size()))
-				if bts == nil {
-					return errors.New("failed to allocate tensor buffer")
+				tts[i] = tt
+			}
+
+			sr := io.NewSectionReader(r, int64(meta.Tensors().Offset+t.Offset), int64(t.Size()))
+			bts := make([]byte, 128*format.KibiByte)
+
+			var s uint64
+			for s < t.Size() {
+				n, err := io.ReadFull(sr, bts[:min(len(bts), int(t.Size()-s))])
+				if err != nil {
+					return err
 				}
-				defer C.free(bts)
 
-				buf := unsafe.Slice((*byte)(bts), t.Size())
-				n, err := io.ReadFull(io.NewSectionReader(sr, int64(t.Offset), int64(t.Size())), buf)
-				if err != nil || n != len(buf) {
-					return errors.New("read failed")
+				for _, tt := range tts {
+					C.ggml_backend_tensor_set(tt, unsafe.Pointer(&bts[0]), C.size_t(s), C.size_t(n))
 				}
 
-				C.ggml_backend_tensor_set(tt, bts, 0, C.size_t(t.Size()))
-				return nil
-			})
-		}
+				s += uint64(n)
+
+				if params.Progress != nil {
+					done := doneBytes.Add(uint64(n))
+					params.Progress(float32(done) / float32(totalBytes))
+				}
+			}
+
+			return nil
+		})
 	}
 
+	// start a goroutine to cancel the errgroup if the parent context is done
+	go func() {
+		<-ctx.Done()
+		g.Go(func() error {
+			return ctx.Err()
+		})
+	}()
+
 	if err := g.Wait(); err != nil {
 		return nil, err
 	}

+ 3 - 2
model/model.go

@@ -1,6 +1,7 @@
 package model
 
 import (
+	"context"
 	"errors"
 	"fmt"
 	_ "image/jpeg"
@@ -94,14 +95,14 @@ func Register(name string, f func(ml.Config) (Model, error)) {
 }
 
 // New initializes a new model instance with the provided configuration based on the metadata in the model file
-func New(modelPath string, params ml.BackendParams) (Model, error) {
+func New(ctx context.Context, modelPath string, params ml.BackendParams) (Model, error) {
 	r, err := os.Open(modelPath)
 	if err != nil {
 		return nil, err
 	}
 	defer r.Close()
 
-	b, err := ml.NewBackend(r, params)
+	b, err := ml.NewBackend(ctx, r, params)
 	if err != nil {
 		return nil, err
 	}

+ 6 - 5
runner/ollamarunner/runner.go

@@ -678,6 +678,7 @@ func (m *multiLPath) String() string {
 }
 
 func (s *Server) loadModel(
+	ctx context.Context,
 	mpath string,
 	params ml.BackendParams,
 	lpath multiLPath,
@@ -687,7 +688,7 @@ func (s *Server) loadModel(
 	multiUserCache bool,
 ) {
 	var err error
-	s.model, err = model.New(mpath, params)
+	s.model, err = model.New(ctx, mpath, params)
 	if err != nil {
 		panic(err)
 	}
@@ -794,13 +795,13 @@ func Execute(args []string) error {
 	}
 
 	server.ready.Add(1)
-	go server.loadModel(*mpath, params, lpaths, *parallel, *kvCacheType, *kvSize, *multiUserCache)
-
-	server.cond = sync.NewCond(&server.mu)
-
 	ctx, cancel := context.WithCancel(context.Background())
 	defer cancel()
 
+	go server.loadModel(ctx, *mpath, params, lpaths, *parallel, *kvCacheType, *kvSize, *multiUserCache)
+
+	server.cond = sync.NewCond(&server.mu)
+
 	go server.run(ctx)
 
 	addr := "127.0.0.1:" + strconv.Itoa(*port)