Browse Source

model: load non-repeated tensors into multiple backends

some tensors are expected to be used in repeating layers but are not
themselves repeated. this change copies these tensors into the same
backends as their repeating counterparts to minimize copying tensors
between backends
Michael Yang 2 months ago
parent
commit
bfce55db3d
2 changed files with 57 additions and 41 deletions
  1. 52 27
      ml/backend/ggml/ggml.go
  2. 5 14
      ml/backend/ggml/ggml/src/ggml-backend-reg.cpp

+ 52 - 27
ml/backend/ggml/ggml.go

@@ -25,11 +25,13 @@ import (
 	"github.com/ollama/ollama/format"
 	"github.com/ollama/ollama/format"
 	fs "github.com/ollama/ollama/fs/ggml"
 	fs "github.com/ollama/ollama/fs/ggml"
 	"github.com/ollama/ollama/ml"
 	"github.com/ollama/ollama/ml"
+	ggml "github.com/ollama/ollama/ml/backend/ggml/ggml/src"
 	"golang.org/x/sync/errgroup"
 	"golang.org/x/sync/errgroup"
 )
 )
 
 
 func devices() iter.Seq[*C.struct_ggml_backend_device] {
 func devices() iter.Seq[*C.struct_ggml_backend_device] {
 	return func(yield func(*C.struct_ggml_backend_device) bool) {
 	return func(yield func(*C.struct_ggml_backend_device) bool) {
+		ggml.OnceLoad()
 		for i := range C.ggml_backend_dev_count() {
 		for i := range C.ggml_backend_dev_count() {
 			if !yield(C.ggml_backend_dev_get(i)) {
 			if !yield(C.ggml_backend_dev_get(i)) {
 				return
 				return
@@ -146,8 +148,15 @@ func New(r *os.File, params ml.BackendParams) (ml.Backend, error) {
 
 
 	slog.Info("max tensors", "max_tensors", maxTensors)
 	slog.Info("max tensors", "max_tensors", maxTensors)
 
 
+	type tensor struct {
+		source *fs.Tensor
+		target string
+	}
+
+	targets := make(map[string][]string)
+
 	ctxs := make(map[*C.struct_ggml_backend_buffer_type]*C.struct_ggml_context)
 	ctxs := make(map[*C.struct_ggml_backend_buffer_type]*C.struct_ggml_context)
-	createTensor := func(t *fs.Tensor, bts []*C.struct_ggml_backend_buffer_type) *C.struct_ggml_tensor {
+	createTensor := func(t tensor, bts []*C.struct_ggml_backend_buffer_type) *C.struct_ggml_tensor {
 		for _, bt := range bts {
 		for _, bt := range bts {
 			if _, ok := ctxs[bt]; !ok {
 			if _, ok := ctxs[bt]; !ok {
 				ctxs[bt] = C.ggml_init(C.struct_ggml_init_params{
 				ctxs[bt] = C.ggml_init(C.struct_ggml_init_params{
@@ -156,16 +165,23 @@ func New(r *os.File, params ml.BackendParams) (ml.Backend, error) {
 				})
 				})
 			}
 			}
 
 
-			cname := C.CString(t.Name)
+			targets[t.source.Name] = append(targets[t.source.Name], t.target)
+
+			name := t.source.Name
+			if t.target != "" {
+				name = t.target
+			}
+
+			cname := C.CString(name)
 			defer C.free(unsafe.Pointer(cname))
 			defer C.free(unsafe.Pointer(cname))
 			if tt := C.ggml_get_tensor(ctxs[bt], cname); tt != nil {
 			if tt := C.ggml_get_tensor(ctxs[bt], cname); tt != nil {
 				return tt
 				return tt
 			}
 			}
 
 
-			tt := C.ggml_new_tensor(ctxs[bt], t.Kind, C.int(len(t.Shape)), (*C.int64_t)(unsafe.Pointer(&t.Shape[0])))
+			tt := C.ggml_new_tensor(ctxs[bt], t.source.Kind, C.int(len(t.source.Shape)), (*C.int64_t)(unsafe.Pointer(&t.source.Shape[0])))
 			C.ggml_set_name(tt, cname)
 			C.ggml_set_name(tt, cname)
 
 
-			slog.Debug("created tensor", "name", t.Name, "shape", t.Shape, "dtype", t.Kind, "buffer_type", C.GoString(C.ggml_backend_buft_name(bt)))
+			slog.Debug("created tensor", "name", name, "shape", t.source.Shape, "dtype", t.source.Kind, "buffer_type", C.GoString(C.ggml_backend_buft_name(bt)))
 			//nolint:staticcheck // TODO: check if buffer type supports this tensor
 			//nolint:staticcheck // TODO: check if buffer type supports this tensor
 			return tt
 			return tt
 		}
 		}
@@ -187,9 +203,9 @@ func New(r *os.File, params ml.BackendParams) (ml.Backend, error) {
 	for _, t := range meta.Tensors().Items() {
 	for _, t := range meta.Tensors().Items() {
 		switch {
 		switch {
 		case hasPart(t.Name, "position_embd", "token_embd", "token_norm_embd", "token_types"):
 		case hasPart(t.Name, "position_embd", "token_embd", "token_norm_embd", "token_types"):
-			createTensor(t, input.bts)
+			createTensor(tensor{source: t}, input.bts)
 		case hasPart(t.Name, "cls", "output", "output_norm"):
 		case hasPart(t.Name, "cls", "output", "output_norm"):
-			createTensor(t, output.bts)
+			createTensor(tensor{source: t}, output.bts)
 		default:
 		default:
 			if i := func() int {
 			if i := func() int {
 				if fields := strings.FieldsFunc(t.Name, func(r rune) bool { return !unicode.IsNumber(r) }); len(fields) > 0 {
 				if fields := strings.FieldsFunc(t.Name, func(r rune) bool { return !unicode.IsNumber(r) }); len(fields) > 0 {
@@ -200,10 +216,13 @@ func New(r *os.File, params ml.BackendParams) (ml.Backend, error) {
 
 
 				return -1
 				return -1
 			}(); i >= 0 {
 			}(); i >= 0 {
-				createTensor(t, layers[i].bts)
+				createTensor(tensor{source: t}, layers[i].bts)
 			} else {
 			} else {
-				for _, layer := range layers {
-					createTensor(t, layer.bts)
+				for i, layer := range layers {
+					createTensor(tensor{
+						source: t,
+						target: "blk." + strconv.Itoa(i) + "." + t.Name,
+					}, layer.bts)
 				}
 				}
 			}
 			}
 		}
 		}
@@ -237,28 +256,34 @@ func New(r *os.File, params ml.BackendParams) (ml.Backend, error) {
 	sr := io.NewSectionReader(r, int64(meta.Tensors().Offset), n-int64(meta.Tensors().Offset))
 	sr := io.NewSectionReader(r, int64(meta.Tensors().Offset), n-int64(meta.Tensors().Offset))
 	var g errgroup.Group
 	var g errgroup.Group
 	for _, t := range meta.Tensors().Items() {
 	for _, t := range meta.Tensors().Items() {
-		g.Go(func() error {
-			tt, ok := tensors[t.Name]
-			if !ok {
-				return fmt.Errorf("unassigned tensor: %s", t.Name)
-			}
+		for _, target := range targets[t.Name] {
+			g.Go(func() error {
+				if target == "" {
+					target = t.Name
+				}
 
 
-			bts := make([]byte, t.Size())
-			n, err := io.ReadFull(io.NewSectionReader(sr, int64(t.Offset), int64(t.Size())), bts)
-			if err != nil {
-				return err
-			}
+				tt, ok := tensors[target]
+				if !ok {
+					return fmt.Errorf("unassigned tensor: %s", t.Name)
+				}
 
 
-			if n != len(bts) {
-				return errors.New("short read")
-			}
+				bts := make([]byte, t.Size())
+				n, err := io.ReadFull(io.NewSectionReader(sr, int64(t.Offset), int64(t.Size())), bts)
+				if err != nil {
+					return err
+				}
 
 
-			cname := C.CString(t.Name)
-			C.ggml_backend_tensor_set(tt, unsafe.Pointer(&bts[0]), 0, C.size_t(t.Size()))
-			C.free(unsafe.Pointer(cname))
+				if n != len(bts) {
+					return errors.New("short read")
+				}
 
 
-			return nil
-		})
+				cname := C.CString(t.Name)
+				C.ggml_backend_tensor_set(tt, unsafe.Pointer(&bts[0]), 0, C.size_t(t.Size()))
+				C.free(unsafe.Pointer(cname))
+
+				return nil
+			})
+		}
 	}
 	}
 
 
 	if g.Wait() != nil {
 	if g.Wait() != nil {

+ 5 - 14
ml/backend/ggml/ggml/src/ggml-backend-reg.cpp

@@ -207,27 +207,18 @@ struct ggml_backend_registry {
         for (size_t i = 0; i < ggml_backend_reg_dev_count(reg); i++) {
         for (size_t i = 0; i < ggml_backend_reg_dev_count(reg); i++) {
             register_device(ggml_backend_reg_dev_get(reg, i), score);
             register_device(ggml_backend_reg_dev_get(reg, i), score);
         }
         }
-
-        std::stable_sort(devices.begin(), devices.end(),
-            [](const auto & a, const auto & b) {
-                return a.second > b.second;
-            }
-        );
     }
     }
 
 
     void register_device(ggml_backend_dev_t device, int score = -1) {
     void register_device(ggml_backend_dev_t device, int score = -1) {
-        switch (ggml_backend_dev_type(device)) {
-        case GGML_BACKEND_DEVICE_TYPE_CPU:
-        case GGML_BACKEND_DEVICE_TYPE_GPU:
-            score += 1 << 16;
-        case GGML_BACKEND_DEVICE_TYPE_ACCEL:
-            score += 1 << 20;
-        }
-
 #ifndef NDEBUG
 #ifndef NDEBUG
         GGML_LOG_DEBUG("%s: registered device %s (%s)\n", __func__, ggml_backend_dev_name(device), ggml_backend_dev_description(device));
         GGML_LOG_DEBUG("%s: registered device %s (%s)\n", __func__, ggml_backend_dev_name(device), ggml_backend_dev_description(device));
 #endif
 #endif
         devices.push_back({device, score});
         devices.push_back({device, score});
+        std::stable_sort(devices.begin(), devices.end(),
+            [](const auto & a, const auto & b) {
+                return a.second > b.second;
+            }
+        );
     }
     }
 
 
     ggml_backend_reg_t load_backend(const std::filesystem::path & path, bool silent) {
     ggml_backend_reg_t load_backend(const std::filesystem::path & path, bool silent) {