Timothy J. Baek 10 月之前
父節點
當前提交
646832ba8c
共有 1 個文件被更改,包括 50 次插入15 次删除
  1. 50 15
      backend/main.py

+ 50 - 15
backend/main.py

@@ -278,8 +278,16 @@ async def get_function_call_response(
                                 "email": user.email,
                                 "name": user.name,
                                 "role": user.role,
-                                "valves": Tools.get_user_valves_by_id_and_user_id(
-                                    tool_id, user.id
+                                **(
+                                    {
+                                        "valves": toolkit_module.UserValves(
+                                            Tools.get_user_valves_by_id_and_user_id(
+                                                tool_id, user.id
+                                            )
+                                        )
+                                    }
+                                    if hasattr(toolkit_module, "UserValves")
+                                    else {}
                                 ),
                             },
                         }
@@ -404,8 +412,18 @@ class ChatCompletionMiddleware(BaseHTTPMiddleware):
                                             "email": user.email,
                                             "name": user.name,
                                             "role": user.role,
-                                            "valves": Functions.get_user_valves_by_id_and_user_id(
-                                                filter_id, user.id
+                                            **(
+                                                {
+                                                    "valves": function_module.UserValves(
+                                                        Functions.get_user_valves_by_id_and_user_id(
+                                                            filter_id, user.id
+                                                        )
+                                                    )
+                                                }
+                                                if hasattr(
+                                                    function_module, "UserValves"
+                                                )
+                                                else {}
                                             ),
                                         },
                                     }
@@ -850,12 +868,6 @@ async def generate_chat_completions(form_data: dict, user=Depends(get_verified_u
 
     pipe = model.get("pipe")
     if pipe:
-        form_data["user"] = {
-            "id": user.id,
-            "email": user.email,
-            "name": user.name,
-            "role": user.role,
-        }
 
         async def job():
             pipe_id = form_data["model"]
@@ -863,7 +875,14 @@ async def generate_chat_completions(form_data: dict, user=Depends(get_verified_u
                 pipe_id, sub_pipe_id = pipe_id.split(".", 1)
             print(pipe_id)
 
-            pipe = webui_app.state.FUNCTIONS[pipe_id].pipe
+            # Check if function is already loaded
+            if pipe_id not in app.state.FUNCTIONS:
+                function_module, function_type = load_function_module_by_id(pipe_id)
+                app.state.FUNCTIONS[pipe_id] = function_module
+            else:
+                function_module = app.state.FUNCTIONS[pipe_id]
+
+            pipe = function_module.pipe
 
             # Get the signature of the function
             sig = inspect.signature(pipe)
@@ -877,8 +896,16 @@ async def generate_chat_completions(form_data: dict, user=Depends(get_verified_u
                         "email": user.email,
                         "name": user.name,
                         "role": user.role,
-                        "valves": Functions.get_user_valves_by_id_and_user_id(
-                            pipe_id, user.id
+                        **(
+                            {
+                                "valves": pipe.UserValves(
+                                    Functions.get_user_valves_by_id_and_user_id(
+                                        pipe_id, user.id
+                                    )
+                                )
+                            }
+                            if hasattr(function_module, "UserValves")
+                            else {}
                         ),
                     },
                 }
@@ -1079,8 +1106,16 @@ async def chat_completed(form_data: dict, user=Depends(get_verified_user)):
                                     "email": user.email,
                                     "name": user.name,
                                     "role": user.role,
-                                    "valves": Functions.get_user_valves_by_id_and_user_id(
-                                        filter_id, user.id
+                                    **(
+                                        {
+                                            "valves": function_module.UserValves(
+                                                Functions.get_user_valves_by_id_and_user_id(
+                                                    filter_id, user.id
+                                                )
+                                            )
+                                        }
+                                        if hasattr(function_module, "UserValves")
+                                        else {}
                                     ),
                                 },
                             }