浏览代码

fix: harden backend loading (#9024)

* wrap ggml_backend_load_best in try/catch
* ignore non-ollama paths
Michael Yang 2 月之前
父节点
当前提交
49df03da9a

+ 69 - 0
llama/patches/0017-try-catch-backend-load.patch

@@ -0,0 +1,69 @@
+From 0000000000000000000000000000000000000000 Mon Sep 17 00:00:00 2001
+From: Michael Yang <mxyng@pm.me>
+Date: Tue, 11 Feb 2025 14:06:36 -0800
+Subject: [PATCH] try/catch backend load
+
+---
+ ggml/src/ggml-backend-reg.cpp | 45 ++++++++++++++++++-----------------
+ 1 file changed, 23 insertions(+), 22 deletions(-)
+
+diff --git a/ggml/src/ggml-backend-reg.cpp b/ggml/src/ggml-backend-reg.cpp
+index ac5cda07..374c3b21 100644
+--- a/ggml/src/ggml-backend-reg.cpp
++++ b/ggml/src/ggml-backend-reg.cpp
+@@ -512,32 +512,33 @@ static ggml_backend_reg_t ggml_backend_load_best(const char * name, bool silent,
+         }
+         fs::directory_iterator dir_it(search_path, fs::directory_options::skip_permission_denied);
+         for (const auto & entry : dir_it) {
+-            if (entry.is_regular_file()) {
+-                std::wstring filename = entry.path().filename().wstring();
+-                std::wstring ext = entry.path().extension().wstring();
+-                if (filename.find(file_prefix) == 0 && ext == backend_filename_suffix()) {
+-                    dl_handle_ptr handle { dl_load_library(entry.path().wstring()) };
+-                    if (!handle && !silent) {
+-                        GGML_LOG_ERROR("%s: failed to load %s\n", __func__, utf16_to_utf8(entry.path().wstring()).c_str());
+-                    }
+-                    if (handle) {
++            try {
++                if (entry.is_regular_file()) {
++                    std::wstring filename = entry.path().filename().wstring();
++                    std::wstring ext = entry.path().extension().wstring();
++                    if (filename.find(file_prefix) == 0 && ext == backend_filename_suffix()) {
++                        dl_handle_ptr handle { dl_load_library(entry.path().wstring()) };
++                        if (!handle) {
++                            GGML_LOG_ERROR("%s: failed to load %s\n", __func__, utf16_to_utf8(entry.path().wstring()).c_str());
++                            continue;
++                        }
++
+                         auto score_fn = (ggml_backend_score_t) dl_get_sym(handle.get(), "ggml_backend_score");
+-                        if (score_fn) {
+-                            int s = score_fn();
+-#ifndef NDEBUG
+-                            GGML_LOG_DEBUG("%s: %s score: %d\n", __func__, utf16_to_utf8(entry.path().wstring()).c_str(), s);
+-#endif
+-                            if (s > best_score) {
+-                                best_score = s;
+-                                best_path = entry.path().wstring();
+-                            }
+-                        } else {
+-                            if (!silent) {
+-                                GGML_LOG_INFO("%s: failed to find ggml_backend_score in %s\n", __func__, utf16_to_utf8(entry.path().wstring()).c_str());
+-                            }
++                        if (!score_fn) {
++                            GGML_LOG_DEBUG("%s: failed to find ggml_backend_score in %s\n", __func__, utf16_to_utf8(entry.path().wstring()).c_str());
++                            continue;
++                        }
++
++                        int s = score_fn();
++                        GGML_LOG_DEBUG("%s: %s score: %d\n", __func__, utf16_to_utf8(entry.path().wstring()).c_str(), s);
++                        if (s > best_score) {
++                            best_score = s;
++                            best_path = entry.path().wstring();
+                         }
+                     }
+                 }
++            } catch (const std::exception & e) {
++                GGML_LOG_ERROR("%s: failed to load %s: %s\n", __func__, utf16_to_utf8(entry.path().wstring()).c_str(), e.what());
+             }
+         }
+     }

+ 23 - 22
ml/backend/ggml/ggml/src/ggml-backend-reg.cpp

@@ -512,32 +512,33 @@ static ggml_backend_reg_t ggml_backend_load_best(const char * name, bool silent,
         }
         fs::directory_iterator dir_it(search_path, fs::directory_options::skip_permission_denied);
         for (const auto & entry : dir_it) {
-            if (entry.is_regular_file()) {
-                std::wstring filename = entry.path().filename().wstring();
-                std::wstring ext = entry.path().extension().wstring();
-                if (filename.find(file_prefix) == 0 && ext == backend_filename_suffix()) {
-                    dl_handle_ptr handle { dl_load_library(entry.path().wstring()) };
-                    if (!handle && !silent) {
-                        GGML_LOG_ERROR("%s: failed to load %s\n", __func__, utf16_to_utf8(entry.path().wstring()).c_str());
-                    }
-                    if (handle) {
+            try {
+                if (entry.is_regular_file()) {
+                    std::wstring filename = entry.path().filename().wstring();
+                    std::wstring ext = entry.path().extension().wstring();
+                    if (filename.find(file_prefix) == 0 && ext == backend_filename_suffix()) {
+                        dl_handle_ptr handle { dl_load_library(entry.path().wstring()) };
+                        if (!handle) {
+                            GGML_LOG_ERROR("%s: failed to load %s\n", __func__, utf16_to_utf8(entry.path().wstring()).c_str());
+                            continue;
+                        }
+
                         auto score_fn = (ggml_backend_score_t) dl_get_sym(handle.get(), "ggml_backend_score");
-                        if (score_fn) {
-                            int s = score_fn();
-#ifndef NDEBUG
-                            GGML_LOG_DEBUG("%s: %s score: %d\n", __func__, utf16_to_utf8(entry.path().wstring()).c_str(), s);
-#endif
-                            if (s > best_score) {
-                                best_score = s;
-                                best_path = entry.path().wstring();
-                            }
-                        } else {
-                            if (!silent) {
-                                GGML_LOG_INFO("%s: failed to find ggml_backend_score in %s\n", __func__, utf16_to_utf8(entry.path().wstring()).c_str());
-                            }
+                        if (!score_fn) {
+                            GGML_LOG_DEBUG("%s: failed to find ggml_backend_score in %s\n", __func__, utf16_to_utf8(entry.path().wstring()).c_str());
+                            continue;
+                        }
+
+                        int s = score_fn();
+                        GGML_LOG_DEBUG("%s: %s score: %d\n", __func__, utf16_to_utf8(entry.path().wstring()).c_str(), s);
+                        if (s > best_score) {
+                            best_score = s;
+                            best_path = entry.path().wstring();
                         }
                     }
                 }
+            } catch (const std::exception & e) {
+                GGML_LOG_ERROR("%s: failed to load %s: %s\n", __func__, utf16_to_utf8(entry.path().wstring()).c_str(), e.what());
             }
         }
     }

+ 5 - 0
ml/backend/ggml/ggml/src/ggml.go

@@ -79,6 +79,11 @@ var OnceLoad = sync.OnceFunc(func() {
 			continue
 		}
 
+		if abspath != filepath.Dir(exe) && !strings.Contains(abspath, filepath.FromSlash("lib/ollama")) {
+			slog.Debug("skipping path which is not part of ollama", "path", abspath)
+			continue
+		}
+
 		if _, ok := visited[abspath]; !ok {
 			func() {
 				slog.Debug("ggml backend load all from path", "path", abspath)