Timothy J. Baek 10 달 전
부모
커밋
0cf936f9e8
5개의 변경된 파일36개의 추가작업 그리고 8개의 파일을 삭제
  1. 2 2
      backend/apps/webui/models/functions.py
  2. 2 2
      backend/apps/webui/models/tools.py
  3. 2 2
      backend/apps/webui/routers/functions.py
  4. 2 2
      backend/apps/webui/routers/tools.py
  5. 28 0
      backend/main.py

+ 2 - 2
backend/apps/webui/models/functions.py

@@ -143,10 +143,10 @@ class FunctionsTable:
                 for function in Function.select().where(Function.type == type)
             ]
 
-    def get_function_valves_by_id(self, id: str) -> Optional[FunctionValves]:
+    def get_function_valves_by_id(self, id: str) -> Optional[dict]:
         try:
             function = Function.get(Function.id == id)
-            return FunctionValves(**model_to_dict(function))
+            return function.valves if "valves" in function and function.valves else {}
         except Exception as e:
             print(f"An error occurred: {e}")
             return None

+ 2 - 2
backend/apps/webui/models/tools.py

@@ -114,10 +114,10 @@ class ToolsTable:
     def get_tools(self) -> List[ToolModel]:
         return [ToolModel(**model_to_dict(tool)) for tool in Tool.select()]
 
-    def get_tool_valves_by_id(self, id: str) -> Optional[ToolValves]:
+    def get_tool_valves_by_id(self, id: str) -> Optional[dict]:
         try:
             tool = Tool.get(Tool.id == id)
-            return ToolValves(**model_to_dict(tool))
+            return tool.valves if "valves" in tool and tool.valves else {}
         except Exception as e:
             print(f"An error occurred: {e}")
             return None

+ 2 - 2
backend/apps/webui/routers/functions.py

@@ -127,8 +127,8 @@ async def get_function_valves_by_id(id: str, user=Depends(get_admin_user)):
     function = Functions.get_function_by_id(id)
     if function:
         try:
-            function_valves = Functions.get_function_valves_by_id(id)
-            return function_valves.valves
+            valves = Functions.get_function_valves_by_id(id)
+            return valves
         except Exception as e:
             raise HTTPException(
                 status_code=status.HTTP_400_BAD_REQUEST,

+ 2 - 2
backend/apps/webui/routers/tools.py

@@ -133,8 +133,8 @@ async def get_toolkit_valves_by_id(id: str, user=Depends(get_admin_user)):
     toolkit = Tools.get_tool_by_id(id)
     if toolkit:
         try:
-            tool_valves = Tools.get_tool_valves_by_id(id)
-            return tool_valves.valves
+            valves = Tools.get_tool_valves_by_id(id)
+            return valves
         except Exception as e:
             raise HTTPException(
                 status_code=status.HTTP_400_BAD_REQUEST,

+ 28 - 0
backend/main.py

@@ -262,6 +262,13 @@ async def get_function_call_response(
                     file_handler = True
                     print("file_handler: ", file_handler)
 
+                if hasattr(toolkit_module, "valves") and hasattr(
+                    toolkit_module, "Valves"
+                ):
+                    toolkit_module.valves = toolkit_module.Valves(
+                        **Tools.get_tool_valves_by_id(tool_id)
+                    )
+
                 function = getattr(toolkit_module, result["name"])
                 function_result = None
                 try:
@@ -402,6 +409,13 @@ class ChatCompletionMiddleware(BaseHTTPMiddleware):
                     if hasattr(function_module, "file_handler"):
                         skip_files = function_module.file_handler
 
+                    if hasattr(function_module, "valves") and hasattr(
+                        function_module, "Valves"
+                    ):
+                        function_module.valves = function_module.Valves(
+                            **Functions.get_function_valves_by_id(filter_id)
+                        )
+
                     try:
                         if hasattr(function_module, "inlet"):
                             inlet = function_module.inlet
@@ -884,6 +898,13 @@ async def generate_chat_completions(form_data: dict, user=Depends(get_verified_u
             else:
                 function_module = webui_app.state.FUNCTIONS[pipe_id]
 
+            if hasattr(function_module, "valves") and hasattr(
+                function_module, "Valves"
+            ):
+                function_module.valves = function_module.Valves(
+                    **Functions.get_function_valves_by_id(pipe_id)
+                )
+
             pipe = function_module.pipe
 
             # Get the signature of the function
@@ -1105,6 +1126,13 @@ async def chat_completed(form_data: dict, user=Depends(get_verified_user)):
                     )
                     webui_app.state.FUNCTIONS[filter_id] = function_module
 
+                if hasattr(function_module, "valves") and hasattr(
+                    function_module, "Valves"
+                ):
+                    function_module.valves = function_module.Valves(
+                        **Functions.get_function_valves_by_id(filter_id)
+                    )
+
                 try:
                     if hasattr(function_module, "outlet"):
                         outlet = function_module.outlet