Browse Source

refactor get_function_call_response

Michael Poluektov 8 months ago
parent
commit
a68b918cbb
1 changed files with 73 additions and 77 deletions
  1. 73 77
      backend/main.py

+ 73 - 77
backend/main.py

@@ -297,6 +297,30 @@ async def get_content_from_response(response) -> Optional[str]:
     return content
 
 
+async def call_tool_from_completion(
+    result: dict, extra_params: dict, toolkit_module
+) -> Optional[str]:
+    if "name" not in result:
+        return None
+
+    tool = getattr(toolkit_module, result["name"])
+    try:
+        # Get the signature of the function
+        sig = inspect.signature(tool)
+        params = result["parameters"]
+        for key, value in extra_params.items():
+            if key in sig.parameters:
+                params[key] = value
+
+        if inspect.iscoroutinefunction(tool):
+            return await tool(**params)
+        else:
+            return tool(**params)
+    except Exception as e:
+        print(f"Error: {e}")
+        return None
+
+
 async def get_function_call_response(
     messages,
     files,
@@ -306,7 +330,7 @@ async def get_function_call_response(
     user,
     __event_emitter__=None,
     __event_call__=None,
-):
+) -> tuple[Optional[str], Optional[dict], bool]:
     tool = Tools.get_tool_by_id(tool_id)
     if tool is None:
         return None, None, False
@@ -343,7 +367,43 @@ async def get_function_call_response(
     except Exception as e:
         raise e
 
-    model = app.state.MODELS[task_model_id]
+    if tool_id in webui_app.state.TOOLS:
+        toolkit_module = webui_app.state.TOOLS[tool_id]
+    else:
+        toolkit_module, _ = load_toolkit_module_by_id(tool_id)
+        webui_app.state.TOOLS[tool_id] = toolkit_module
+
+    __user__ = {
+        "id": user.id,
+        "email": user.email,
+        "name": user.name,
+        "role": user.role,
+    }
+
+    try:
+        if hasattr(toolkit_module, "UserValves"):
+            __user__["valves"] = toolkit_module.UserValves(
+                **Tools.get_user_valves_by_id_and_user_id(tool_id, user.id)
+            )
+
+    except Exception as e:
+        print(e)
+
+    extra_params = {
+        "__model__": app.state.MODELS[task_model_id],
+        "__id__": tool_id,
+        "__messages__": messages,
+        "__files__": files,
+        "__event_emitter__": __event_emitter__,
+        "__event_call__": __event_call__,
+        "__user__": __user__,
+    }
+
+    file_handler = hasattr(toolkit_module, "file_handler")
+
+    if hasattr(toolkit_module, "valves") and hasattr(toolkit_module, "Valves"):
+        valves = Tools.get_tool_valves_by_id(tool_id)
+        toolkit_module.valves = toolkit_module.Valves(**(valves if valves else {}))
 
     try:
         response = await generate_chat_completions(form_data=payload, user=user)
@@ -353,85 +413,21 @@ async def get_function_call_response(
             return None, None, False
 
         # Parse the function response
-        print(f"content: {content}")
+        log.debug(f"content: {content}")
         result = json.loads(content)
-        print(result)
 
-        if "name" not in result:
-            return None, None, False
-
-        # Call the function
-        if tool_id in webui_app.state.TOOLS:
-            toolkit_module = webui_app.state.TOOLS[tool_id]
-        else:
-            toolkit_module, _ = load_toolkit_module_by_id(tool_id)
-            webui_app.state.TOOLS[tool_id] = toolkit_module
-
-        file_handler = False
-        # check if toolkit_module has file_handler self variable
-        if hasattr(toolkit_module, "file_handler"):
-            file_handler = True
-            print("file_handler: ", file_handler)
-
-        if hasattr(toolkit_module, "valves") and hasattr(toolkit_module, "Valves"):
-            valves = Tools.get_tool_valves_by_id(tool_id)
-            toolkit_module.valves = toolkit_module.Valves(**(valves if valves else {}))
-
-        function = getattr(toolkit_module, result["name"])
-        function_result = None
-        citation = None
-        try:
-            # Get the signature of the function
-            sig = inspect.signature(function)
-            params = result["parameters"]
+        function_result = await call_tool_from_completion(
+            result, extra_params, toolkit_module
+        )
 
-            # Extra parameters to be passed to the function
-            extra_params = {
-                "__model__": model,
-                "__id__": tool_id,
-                "__messages__": messages,
-                "__files__": files,
-                "__event_emitter__": __event_emitter__,
-                "__event_call__": __event_call__,
+        if hasattr(toolkit_module, "citation") and toolkit_module.citation:
+            citation = {
+                "source": {"name": f"TOOL:{tool.name}/{result['name']}"},
+                "document": [function_result],
+                "metadata": [{"source": result["name"]}],
             }
-
-            # Add extra params in contained in function signature
-            for key, value in extra_params.items():
-                if key in sig.parameters:
-                    params[key] = value
-
-            if "__user__" in sig.parameters:
-                # Call the function with the '__user__' parameter included
-                __user__ = {
-                    "id": user.id,
-                    "email": user.email,
-                    "name": user.name,
-                    "role": user.role,
-                }
-
-                try:
-                    if hasattr(toolkit_module, "UserValves"):
-                        __user__["valves"] = toolkit_module.UserValves(
-                            **Tools.get_user_valves_by_id_and_user_id(tool_id, user.id)
-                        )
-                except Exception as e:
-                    print(e)
-
-                params = {**params, "__user__": __user__}
-
-            if inspect.iscoroutinefunction(function):
-                function_result = await function(**params)
-            else:
-                function_result = function(**params)
-
-            if hasattr(toolkit_module, "citation") and toolkit_module.citation:
-                citation = {
-                    "source": {"name": f"TOOL:{tool.name}/{result['name']}"},
-                    "document": [function_result],
-                    "metadata": [{"source": result["name"]}],
-                }
-        except Exception as e:
-            print(e)
+        else:
+            citation = None
 
         # Add the function result to the system prompt
         if function_result is not None: