Kaynağa Gözat

Merge pull request #4815 from michaelpoluektov/fix-user-valves

fix: Fix user valves
Timothy Jaeryang Baek 8 ay önce
ebeveyn
işleme
99db82a161
1 değiştirilmiş dosya ile 35 ekleme ve 32 silme
  1. 35 32
      backend/apps/webui/main.py

+ 35 - 32
backend/apps/webui/main.py

@@ -56,12 +56,15 @@ from apps.socket.main import get_event_call, get_event_emitter
 
 import inspect
 import json
+import logging
 
 from typing import Iterator, Generator, AsyncGenerator
 from pydantic import BaseModel
 
 app = FastAPI()
 
+log = logging.getLogger(__name__)
+
 app.state.config = AppConfig()
 
 app.state.config.ENABLE_SIGNUP = ENABLE_SIGNUP
@@ -243,44 +246,37 @@ def get_pipe_id(form_data: dict) -> str:
     return pipe_id
 
 
-def get_function_params(function_module, form_data, user, extra_params={}):
+def get_function_params(function_module, form_data, user, extra_params=None):
+    if extra_params is None:
+        extra_params = {}
+
     pipe_id = get_pipe_id(form_data)
+
     # Get the signature of the function
     sig = inspect.signature(function_module.pipe)
-    params = {"body": form_data}
-
-    for key, value in extra_params.items():
-        if key in sig.parameters:
-            params[key] = value
-
-    if "__user__" in sig.parameters:
-        __user__ = {
-            "id": user.id,
-            "email": user.email,
-            "name": user.name,
-            "role": user.role,
-        }
+    params = {"body": form_data} | {
+        k: v for k, v in extra_params.items() if k in sig.parameters
+    }
 
+    if "__user__" in params and hasattr(function_module, "UserValves"):
+        user_valves = Functions.get_user_valves_by_id_and_user_id(pipe_id, user.id)
         try:
-            if hasattr(function_module, "UserValves"):
-                __user__["valves"] = function_module.UserValves(
-                    **Functions.get_user_valves_by_id_and_user_id(pipe_id, user.id)
-                )
+            params["__user__"]["valves"] = function_module.UserValves(**user_valves)
         except Exception as e:
-            print(e)
+            log.exception(e)
+            params["__user__"]["valves"] = function_module.UserValves()
 
-        params["__user__"] = __user__
     return params
 
 
 async def generate_function_chat_completion(form_data, user):
     model_id = form_data.get("model")
     model_info = Models.get_model_by_id(model_id)
+
     metadata = form_data.pop("metadata", {})
 
     files = metadata.get("files", [])
     tool_ids = metadata.get("tool_ids", [])
-
     # Check if tool_ids is None
     if tool_ids is None:
         tool_ids = []
@@ -299,18 +295,25 @@ async def generate_function_chat_completion(form_data, user):
         "__event_emitter__": __event_emitter__,
         "__event_call__": __event_call__,
         "__task__": __task__,
-    }
-
-    extra_params["__tools__"] = get_tools(
-        app,
-        tool_ids,
-        user,
-        {
-            **extra_params,
-            "__model__": app.state.MODELS[form_data["model"]],
-            "__messages__": form_data["messages"],
-            "__files__": files,
+        "__user__": {
+            "id": user.id,
+            "email": user.email,
+            "name": user.name,
+            "role": user.role,
         },
+    }
+    extra_params["__tools__"] = (
+        get_tools(
+            app,
+            tool_ids,
+            user,
+            {
+                **extra_params,
+                "__model__": app.state.MODELS[form_data["model"]],
+                "__messages__": form_data["messages"],
+                "__files__": files,
+            },
+        ),
     )
 
     if model_info: