Browse Source

feat: async filter support

Timothy J. Baek 10 months ago
parent
commit
5621025c12
1 changed files with 45 additions and 18 deletions
  1. 45 18
      backend/main.py

+ 45 - 18
backend/main.py

@@ -384,15 +384,29 @@ class ChatCompletionMiddleware(BaseHTTPMiddleware):
 
 
                         try:
                         try:
                             if hasattr(function_module, "inlet"):
                             if hasattr(function_module, "inlet"):
-                                data = function_module.inlet(
-                                    data,
-                                    {
-                                        "id": user.id,
-                                        "email": user.email,
-                                        "name": user.name,
-                                        "role": user.role,
-                                    },
-                                )
+                                inlet = function_module.inlet
+
+                                if inspect.iscoroutinefunction(inlet):
+                                    data = await inlet(
+                                        data,
+                                        {
+                                            "id": user.id,
+                                            "email": user.email,
+                                            "name": user.name,
+                                            "role": user.role,
+                                        },
+                                    )
+                                else:
+                                    data = inlet(
+                                        data,
+                                        {
+                                            "id": user.id,
+                                            "email": user.email,
+                                            "name": user.name,
+                                            "role": user.role,
+                                        },
+                                    )
+
                         except Exception as e:
                         except Exception as e:
                             print(f"Error: {e}")
                             print(f"Error: {e}")
                             return JSONResponse(
                             return JSONResponse(
@@ -1007,15 +1021,28 @@ async def chat_completed(form_data: dict, user=Depends(get_verified_user)):
 
 
                 try:
                 try:
                     if hasattr(function_module, "outlet"):
                     if hasattr(function_module, "outlet"):
-                        data = function_module.outlet(
-                            data,
-                            {
-                                "id": user.id,
-                                "email": user.email,
-                                "name": user.name,
-                                "role": user.role,
-                            },
-                        )
+                        outlet = function_module.outlet
+                        if inspect.iscoroutinefunction(outlet):
+                            data = await outlet(
+                                data,
+                                {
+                                    "id": user.id,
+                                    "email": user.email,
+                                    "name": user.name,
+                                    "role": user.role,
+                                },
+                            )
+                        else:
+                            data = outlet(
+                                data,
+                                {
+                                    "id": user.id,
+                                    "email": user.email,
+                                    "name": user.name,
+                                    "role": user.role,
+                                },
+                            )
+
                 except Exception as e:
                 except Exception as e:
                     print(f"Error: {e}")
                     print(f"Error: {e}")
                     return JSONResponse(
                     return JSONResponse(