12345678910111213141516171819202122232425262728293031323334353637383940414243444546474849505152535455565758596061626364656667686970717273747576777879808182 |
- From 0000000000000000000000000000000000000000 Mon Sep 17 00:00:00 2001
- From: Michael Yang <mxyng@pm.me>
- Date: Tue, 14 Jan 2025 12:01:24 -0800
- Subject: [PATCH] sort devices by score
- ---
- ggml/src/ggml-backend-reg.cpp | 21 +++++++++++++--------
- 1 file changed, 13 insertions(+), 8 deletions(-)
- diff --git a/ggml/src/ggml-backend-reg.cpp b/ggml/src/ggml-backend-reg.cpp
- index 95036ef8..98d5e14d 100644
- --- a/ggml/src/ggml-backend-reg.cpp
- +++ b/ggml/src/ggml-backend-reg.cpp
- @@ -150,7 +150,7 @@ struct ggml_backend_reg_entry {
-
- struct ggml_backend_registry {
- std::vector<ggml_backend_reg_entry> backends;
- - std::vector<ggml_backend_dev_t> devices;
- + std::vector<std::pair<ggml_backend_dev_t, int>> devices;
-
- ggml_backend_registry() {
- #ifdef GGML_USE_CUDA
- @@ -195,7 +195,7 @@ struct ggml_backend_registry {
- }
- }
-
- - void register_backend(ggml_backend_reg_t reg, dl_handle_ptr handle = nullptr) {
- + void register_backend(ggml_backend_reg_t reg, int score = -1, dl_handle_ptr handle = nullptr) {
- if (!reg) {
- return;
- }
- @@ -206,15 +206,20 @@ struct ggml_backend_registry {
- #endif
- backends.push_back({ reg, std::move(handle) });
- for (size_t i = 0; i < ggml_backend_reg_dev_count(reg); i++) {
- - register_device(ggml_backend_reg_dev_get(reg, i));
- + register_device(ggml_backend_reg_dev_get(reg, i), score);
- }
- }
-
- - void register_device(ggml_backend_dev_t device) {
- + void register_device(ggml_backend_dev_t device, int score = -1) {
- #ifndef NDEBUG
- GGML_LOG_DEBUG("%s: registered device %s (%s)\n", __func__, ggml_backend_dev_name(device), ggml_backend_dev_description(device));
- #endif
- - devices.push_back(device);
- + devices.push_back({device, score});
- + std::stable_sort(devices.begin(), devices.end(),
- + [](const auto & a, const auto & b) {
- + return a.second > b.second;
- + }
- + );
- }
-
- ggml_backend_reg_t load_backend(const std::wstring & path, bool silent) {
- @@ -257,7 +262,7 @@ struct ggml_backend_registry {
-
- GGML_LOG_INFO("%s: loaded %s backend from %s\n", __func__, ggml_backend_reg_name(reg), utf16_to_utf8(path).c_str());
-
- - register_backend(reg, std::move(handle));
- + register_backend(reg, score_fn ? score_fn() : -1, std::move(handle));
-
- return reg;
- }
- @@ -280,7 +285,7 @@ struct ggml_backend_registry {
- // remove devices
- devices.erase(
- std::remove_if(devices.begin(), devices.end(),
- - [reg](ggml_backend_dev_t dev) { return ggml_backend_dev_backend_reg(dev) == reg; }),
- + [reg](std::pair<ggml_backend_dev_t, int> dev) { return ggml_backend_dev_backend_reg(dev.first) == reg; }),
- devices.end());
-
- // remove backend
- @@ -338,7 +343,7 @@ size_t ggml_backend_dev_count() {
-
- ggml_backend_dev_t ggml_backend_dev_get(size_t index) {
- GGML_ASSERT(index < ggml_backend_dev_count());
- - return get_reg().devices[index];
- + return get_reg().devices[index].first;
- }
-
- ggml_backend_dev_t ggml_backend_dev_by_name(const char * name) {
|