Ver Fonte

refac: tools & functions

Timothy J. Baek há 8 meses atrás
pai
commit
cf86ba7786

+ 8 - 19
backend/open_webui/apps/webui/routers/functions.py

@@ -8,7 +8,7 @@ from open_webui.apps.webui.models.functions import (
     FunctionResponse,
     Functions,
 )
-from open_webui.apps.webui.utils import load_function_module_by_id
+from open_webui.apps.webui.utils import load_function_module_by_id, replace_imports
 from open_webui.config import CACHE_DIR, FUNCTIONS_DIR
 from open_webui.constants import ERROR_MESSAGES
 from fastapi import APIRouter, Depends, HTTPException, Request, status
@@ -55,14 +55,12 @@ async def create_new_function(
 
     function = Functions.get_function_by_id(form_data.id)
     if function is None:
-        function_path = os.path.join(FUNCTIONS_DIR, f"{form_data.id}.py")
         try:
-            with open(function_path, "w") as function_file:
-                function_file.write(form_data.content)
-
             function_module, function_type, frontmatter = load_function_module_by_id(
-                form_data.id
+                form_data.id,
+                content=form_data.content,
             )
+            form_data.content = replace_imports(form_data.content)
             form_data.meta.manifest = frontmatter
 
             FUNCTIONS = request.app.state.FUNCTIONS
@@ -174,13 +172,11 @@ async def toggle_global_by_id(id: str, user=Depends(get_admin_user)):
 async def update_function_by_id(
     request: Request, id: str, form_data: FunctionForm, user=Depends(get_admin_user)
 ):
-    function_path = os.path.join(FUNCTIONS_DIR, f"{id}.py")
-
     try:
-        with open(function_path, "w") as function_file:
-            function_file.write(form_data.content)
-
-        function_module, function_type, frontmatter = load_function_module_by_id(id)
+        function_module, function_type, frontmatter = load_function_module_by_id(
+            id, content=form_data.content
+        )
+        form_data.content = replace_imports(form_data.content)
         form_data.meta.manifest = frontmatter
 
         FUNCTIONS = request.app.state.FUNCTIONS
@@ -222,13 +218,6 @@ async def delete_function_by_id(
         if id in FUNCTIONS:
             del FUNCTIONS[id]
 
-        # delete the function file
-        function_path = os.path.join(FUNCTIONS_DIR, f"{id}.py")
-        try:
-            os.remove(function_path)
-        except Exception:
-            pass
-
     return result
 
 

+ 9 - 16
backend/open_webui/apps/webui/routers/tools.py

@@ -3,7 +3,7 @@ from pathlib import Path
 from typing import Optional
 
 from open_webui.apps.webui.models.tools import ToolForm, ToolModel, ToolResponse, Tools
-from open_webui.apps.webui.utils import load_toolkit_module_by_id
+from open_webui.apps.webui.utils import load_toolkit_module_by_id, replace_imports
 from open_webui.config import CACHE_DIR, DATA_DIR
 from open_webui.constants import ERROR_MESSAGES
 from fastapi import APIRouter, Depends, HTTPException, Request, status
@@ -59,12 +59,11 @@ async def create_new_toolkit(
 
     toolkit = Tools.get_tool_by_id(form_data.id)
     if toolkit is None:
-        toolkit_path = os.path.join(TOOLS_DIR, f"{form_data.id}.py")
         try:
-            with open(toolkit_path, "w") as tool_file:
-                tool_file.write(form_data.content)
-
-            toolkit_module, frontmatter = load_toolkit_module_by_id(form_data.id)
+            toolkit_module, frontmatter = load_toolkit_module_by_id(
+                form_data.id, content=form_data.content
+            )
+            form_data.content = replace_imports(form_data.content)
             form_data.meta.manifest = frontmatter
 
             TOOLS = request.app.state.TOOLS
@@ -126,13 +125,11 @@ async def update_toolkit_by_id(
     form_data: ToolForm,
     user=Depends(get_admin_user),
 ):
-    toolkit_path = os.path.join(TOOLS_DIR, f"{id}.py")
-
     try:
-        with open(toolkit_path, "w") as tool_file:
-            tool_file.write(form_data.content)
-
-        toolkit_module, frontmatter = load_toolkit_module_by_id(id)
+        toolkit_module, frontmatter = load_toolkit_module_by_id(
+            id, content=form_data.content
+        )
+        form_data.content = replace_imports(form_data.content)
         form_data.meta.manifest = frontmatter
 
         TOOLS = request.app.state.TOOLS
@@ -177,10 +174,6 @@ async def delete_toolkit_by_id(request: Request, id: str, user=Depends(get_admin
         if id in TOOLS:
             del TOOLS[id]
 
-        # delete the toolkit file
-        toolkit_path = os.path.join(TOOLS_DIR, f"{id}.py")
-        os.remove(toolkit_path)
-
     return result
 
 

+ 69 - 49
backend/open_webui/apps/webui/utils.py

@@ -3,6 +3,8 @@ import re
 import subprocess
 import sys
 from importlib import util
+import types
+
 
 from open_webui.apps.webui.models.functions import Functions
 from open_webui.apps.webui.models.tools import Tools
@@ -49,75 +51,92 @@ def extract_frontmatter(file_path):
     return frontmatter
 
 
-def load_toolkit_module_by_id(toolkit_id):
-    toolkit_path = os.path.join(TOOLS_DIR, f"{toolkit_id}.py")
+def replace_imports(content):
+    """
+    Replace the import paths in the content.
+    """
+    replacements = {
+        "from utils": "from open_webui.utils",
+        "from apps": "from open_webui.apps",
+        "from main": "from open_webui.main",
+        "from config": "from open_webui.config",
+    }
+
+    for old, new in replacements.items():
+        content = content.replace(old, new)
+
+    return content
 
-    if not os.path.exists(toolkit_path):
+
+def load_toolkit_module_by_id(toolkit_id, content=None):
+    if content is None:
         tool = Tools.get_tool_by_id(toolkit_id)
-        if tool:
-            with open(toolkit_path, "w") as file:
-                content = tool.content
-                content = content.replace("from utils", "from open_webui.utils")
-                content = content.replace("from apps", "from open_webui.apps")
-                content = content.replace("from main", "from open_webui.main")
-                content = content.replace("from config", "from open_webui.config")
-
-                if tool.content != content:
-                    print(f"Replaced imports for: {toolkit_id}")
-                    Tools.update_tool_by_id(toolkit_id, {"content": content})
-
-                file.write(content)
-        else:
+        if not tool:
             raise Exception(f"Toolkit not found: {toolkit_id}")
 
-    spec = util.spec_from_file_location(toolkit_id, toolkit_path)
-    module = util.module_from_spec(spec)
-    frontmatter = extract_frontmatter(toolkit_path)
+        content = tool.content
+
+    content = replace_imports(content)
+    Tools.update_tool_by_id(toolkit_id, {"content": content})
+
+    module_name = f"{toolkit_id}"
+    module = types.ModuleType(module_name)
+    sys.modules[module_name] = module
 
     try:
+        # Executing the modified content in the created module's namespace
+        exec(content, module.__dict__)
+
+        # Extract frontmatter, assuming content can be treated directly as a string
+        frontmatter = extract_frontmatter(
+            content
+        )  # Ensure this method is adaptable to handle content strings
+
+        # Install required packages found within the frontmatter
         install_frontmatter_requirements(frontmatter.get("requirements", ""))
-        spec.loader.exec_module(module)
+
         print(f"Loaded module: {module.__name__}")
+        # Create and return the object if the class 'Tools' is found in the module
         if hasattr(module, "Tools"):
             return module.Tools(), frontmatter
         else:
-            raise Exception("No Tools class found")
+            raise Exception("No Tools class found in the module")
     except Exception as e:
         print(f"Error loading module: {toolkit_id}")
-        # Move the file to the error folder
-        os.rename(toolkit_path, f"{toolkit_path}.error")
+        del sys.modules[module_name]  # Clean up
         raise e
 
 
-def load_function_module_by_id(function_id):
-    function_path = os.path.join(FUNCTIONS_DIR, f"{function_id}.py")
-
-    if not os.path.exists(function_path):
+def load_function_module_by_id(function_id, content=None):
+    if content is None:
         function = Functions.get_function_by_id(function_id)
-        if function:
-            with open(function_path, "w") as file:
-                content = function.content
-                content = content.replace("from utils", "from open_webui.utils")
-                content = content.replace("from apps", "from open_webui.apps")
-                content = content.replace("from main", "from open_webui.main")
-                content = content.replace("from config", "from open_webui.config")
-
-                if function.content != content:
-                    print(f"Replaced imports for: {function_id}")
-                    Functions.update_function_by_id(function_id, {"content": content})
-
-                file.write(content)
-        else:
+        if not function:
             raise Exception(f"Function not found: {function_id}")
+        content = function.content
+
+    # Replace the module paths in the function content
+    content = replace_imports(content)
+    Functions.update_function_by_id(function_id, {"content": content})
 
-    spec = util.spec_from_file_location(function_id, function_path)
-    module = util.module_from_spec(spec)
-    frontmatter = extract_frontmatter(function_path)
+    module_name = f"{function_id}"
+    module = types.ModuleType(module_name)
+    sys.modules[module_name] = module
 
     try:
+        # Execute the modified content in the created module's namespace
+        exec(content, module.__dict__)
+
+        # Extract the frontmatter from the content, simulate file-like behaviour
+        frontmatter = extract_frontmatter(
+            content
+        )  # This function needs to handle string inputs
+
+        # Install necessary requirements specified in frontmatter
         install_frontmatter_requirements(frontmatter.get("requirements", ""))
-        spec.loader.exec_module(module)
+
         print(f"Loaded module: {module.__name__}")
+
+        # Create appropriate object based on available class type in the module
         if hasattr(module, "Pipe"):
             return module.Pipe(), "pipe", frontmatter
         elif hasattr(module, "Filter"):
@@ -125,11 +144,12 @@ def load_function_module_by_id(function_id):
         elif hasattr(module, "Action"):
             return module.Action(), "action", frontmatter
         else:
-            raise Exception("No Function class found")
+            raise Exception("No Function class found in the module")
     except Exception as e:
         print(f"Error loading module: {function_id}")
-        # Move the file to the error folder
-        os.rename(function_path, f"{function_path}.error")
+        del sys.modules[module_name]  # Cleanup by removing the module in case of error
+
+        Functions.update_function_by_id(function_id, {"is_active": False})
         raise e