Browse Source

ml/backend/ggml: stable sort devices by score (#9081)

Jeffrey Morgan 2 months ago
parent
commit
6600bd7d91

+ 11 - 11
llama/patches/0014-sort-devices-by-score.patch

@@ -8,7 +8,7 @@ Subject: [PATCH] sort devices by score
  1 file changed, 13 insertions(+), 8 deletions(-)
 
 diff --git a/ggml/src/ggml-backend-reg.cpp b/ggml/src/ggml-backend-reg.cpp
-index 899d16f2..ac5cda07 100644
+index 899d16f2..135f7df0 100644
 --- a/ggml/src/ggml-backend-reg.cpp
 +++ b/ggml/src/ggml-backend-reg.cpp
 @@ -150,7 +150,7 @@ struct ggml_backend_reg_entry {
@@ -29,7 +29,7 @@ index 899d16f2..ac5cda07 100644
          if (!reg) {
              return;
          }
-@@ -206,15 +206,15 @@ struct ggml_backend_registry {
+@@ -206,15 +206,20 @@ struct ggml_backend_registry {
  #endif
          backends.push_back({ reg, std::move(handle) });
          for (size_t i = 0; i < ggml_backend_reg_dev_count(reg); i++) {
@@ -45,10 +45,15 @@ index 899d16f2..ac5cda07 100644
  #endif
 -        devices.push_back(device);
 +        devices.push_back({device, score});
++        std::stable_sort(devices.begin(), devices.end(),
++            [](const auto & a, const auto & b) {
++                return a.second > b.second;
++            }
++        );
      }
  
      ggml_backend_reg_t load_backend(const std::wstring & path, bool silent) {
-@@ -257,7 +257,7 @@ struct ggml_backend_registry {
+@@ -257,7 +262,7 @@ struct ggml_backend_registry {
  
          GGML_LOG_INFO("%s: loaded %s backend from %s\n", __func__, ggml_backend_reg_name(reg), utf16_to_utf8(path).c_str());
  
@@ -57,7 +62,7 @@ index 899d16f2..ac5cda07 100644
  
          return reg;
      }
-@@ -280,7 +280,7 @@ struct ggml_backend_registry {
+@@ -280,7 +285,7 @@ struct ggml_backend_registry {
          // remove devices
          devices.erase(
              std::remove_if(devices.begin(), devices.end(),
@@ -66,17 +71,12 @@ index 899d16f2..ac5cda07 100644
              devices.end());
  
          // remove backend
-@@ -338,7 +338,12 @@ size_t ggml_backend_dev_count() {
+@@ -338,7 +343,7 @@ size_t ggml_backend_dev_count() {
  
  ggml_backend_dev_t ggml_backend_dev_get(size_t index) {
      GGML_ASSERT(index < ggml_backend_dev_count());
 -    return get_reg().devices[index];
-+    auto devices = get_reg().devices;
-+    if (!std::is_heap(devices.begin(), devices.end())) {
-+        std::make_heap(devices.begin(), devices.end(), [](const auto & a, const auto & b) { return a.second < b.second; });
-+    }
-+
-+    return devices[index].first;
++    return get_reg().devices[index].first;
  }
  
  ggml_backend_dev_t ggml_backend_dev_by_name(const char * name) {

+ 1 - 1
llama/patches/0017-try-catch-backend-load.patch

@@ -8,7 +8,7 @@ Subject: [PATCH] try/catch backend load
  1 file changed, 23 insertions(+), 22 deletions(-)
 
 diff --git a/ggml/src/ggml-backend-reg.cpp b/ggml/src/ggml-backend-reg.cpp
-index ac5cda07..374c3b21 100644
+index 135f7df0..84b21dd8 100644
 --- a/ggml/src/ggml-backend-reg.cpp
 +++ b/ggml/src/ggml-backend-reg.cpp
 @@ -512,32 +512,33 @@ static ggml_backend_reg_t ggml_backend_load_best(const char * name, bool silent,

+ 6 - 6
ml/backend/ggml/ggml/src/ggml-backend-reg.cpp

@@ -215,6 +215,11 @@ struct ggml_backend_registry {
         GGML_LOG_DEBUG("%s: registered device %s (%s)\n", __func__, ggml_backend_dev_name(device), ggml_backend_dev_description(device));
 #endif
         devices.push_back({device, score});
+        std::stable_sort(devices.begin(), devices.end(),
+            [](const auto & a, const auto & b) {
+                return a.second > b.second;
+            }
+        );
     }
 
     ggml_backend_reg_t load_backend(const std::wstring & path, bool silent) {
@@ -338,12 +343,7 @@ size_t ggml_backend_dev_count() {
 
 ggml_backend_dev_t ggml_backend_dev_get(size_t index) {
     GGML_ASSERT(index < ggml_backend_dev_count());
-    auto devices = get_reg().devices;
-    if (!std::is_heap(devices.begin(), devices.end())) {
-        std::make_heap(devices.begin(), devices.end(), [](const auto & a, const auto & b) { return a.second < b.second; });
-    }
-
-    return devices[index].first;
+    return get_reg().devices[index].first;
 }
 
 ggml_backend_dev_t ggml_backend_dev_by_name(const char * name) {