Timothy J. Baek 10 ماه پیش
والد
کامیت
a8a451344c
1فایلهای تغییر یافته به همراه20 افزوده شده و 5 حذف شده
  1. 20 5
      backend/main.py

+ 20 - 5
backend/main.py

@@ -858,13 +858,28 @@ async def generate_chat_completions(form_data: dict, user=Depends(get_verified_u
             print(pipe_id)
             print(pipe_id)
 
 
             pipe = webui_app.state.FUNCTIONS[pipe_id].pipe
             pipe = webui_app.state.FUNCTIONS[pipe_id].pipe
-            if form_data["stream"]:
 
 
+            # Get the signature of the function
+            sig = inspect.signature(pipe)
+            param = {"body": form_data}
+
+            if "__user__" in sig.parameters:
+                param = {
+                    **param,
+                    "__user__": {
+                        "id": user.id,
+                        "email": user.email,
+                        "name": user.name,
+                        "role": user.role,
+                    },
+                }
+
+            if form_data["stream"]:
                 async def stream_content():
                 async def stream_content():
                     if inspect.iscoroutinefunction(pipe):
                     if inspect.iscoroutinefunction(pipe):
-                        res = await pipe(body=form_data)
+                        res = await pipe(**param)
                     else:
                     else:
-                        res = pipe(body=form_data)
+                        res = pipe(**param)
 
 
                     if isinstance(res, str):
                     if isinstance(res, str):
                         message = stream_message_template(form_data["model"], res)
                         message = stream_message_template(form_data["model"], res)
@@ -910,9 +925,9 @@ async def generate_chat_completions(form_data: dict, user=Depends(get_verified_u
                 )
                 )
             else:
             else:
                 if inspect.iscoroutinefunction(pipe):
                 if inspect.iscoroutinefunction(pipe):
-                    res = await pipe(body=form_data)
+                    res = await pipe(**param)
                 else:
                 else:
-                    res = pipe(body=form_data)
+                    res = pipe(**param)
 
 
                 if isinstance(res, dict):
                 if isinstance(res, dict):
                     return res
                     return res