Timothy J. Baek 10 月之前
父節點
當前提交
ae567796ee
共有 2 個文件被更改,包括 42 次插入29 次删除
  1. 3 1
      backend/apps/webui/main.py
  2. 39 28
      backend/main.py

+ 3 - 1
backend/apps/webui/main.py

@@ -130,7 +130,9 @@ async def get_pipe_models():
                     manifold_pipe_name = p["name"]
 
                     if hasattr(function_module, "name"):
-                        manifold_pipe_name = f"{pipe.name}{manifold_pipe_name}"
+                        manifold_pipe_name = (
+                            f"{function_module.name}{manifold_pipe_name}"
+                        )
 
                     pipe_models.append(
                         {

+ 39 - 28
backend/main.py

@@ -389,26 +389,31 @@ class ChatCompletionMiddleware(BaseHTTPMiddleware):
                             if hasattr(function_module, "inlet"):
                                 inlet = function_module.inlet
 
-                                if inspect.iscoroutinefunction(inlet):
-                                    data = await inlet(
-                                        data,
-                                        {
+                                # Get the signature of the function
+                                sig = inspect.signature(inlet)
+                                param = {"body": data}
+
+                                if "__user__" in sig.parameters:
+                                    param = {
+                                        **param,
+                                        "__user__": {
                                             "id": user.id,
                                             "email": user.email,
                                             "name": user.name,
                                             "role": user.role,
                                         },
-                                    )
+                                    }
+
+                                if "__id__" in sig.parameters:
+                                    param = {
+                                        **param,
+                                        "__id__": filter_id,
+                                    }
+
+                                if inspect.iscoroutinefunction(inlet):
+                                    data = await inlet(**param)
                                 else:
-                                    data = inlet(
-                                        data,
-                                        {
-                                            "id": user.id,
-                                            "email": user.email,
-                                            "name": user.name,
-                                            "role": user.role,
-                                        },
-                                    )
+                                    data = inlet(**param)
 
                         except Exception as e:
                             print(f"Error: {e}")
@@ -1031,26 +1036,32 @@ async def chat_completed(form_data: dict, user=Depends(get_verified_user)):
                 try:
                     if hasattr(function_module, "outlet"):
                         outlet = function_module.outlet
-                        if inspect.iscoroutinefunction(outlet):
-                            data = await outlet(
-                                data,
-                                {
+
+                        # Get the signature of the function
+                        sig = inspect.signature(outlet)
+                        param = {"body": data}
+
+                        if "__user__" in sig.parameters:
+                            param = {
+                                **param,
+                                "__user__": {
                                     "id": user.id,
                                     "email": user.email,
                                     "name": user.name,
                                     "role": user.role,
                                 },
-                            )
+                            }
+
+                        if "__id__" in sig.parameters:
+                            param = {
+                                **param,
+                                "__id__": filter_id,
+                            }
+
+                        if inspect.iscoroutinefunction(outlet):
+                            data = await outlet(**param)
                         else:
-                            data = outlet(
-                                data,
-                                {
-                                    "id": user.id,
-                                    "email": user.email,
-                                    "name": user.name,
-                                    "role": user.role,
-                                },
-                            )
+                            data = outlet(**param)
 
                 except Exception as e:
                     print(f"Error: {e}")