Timothy J. Baek 8 月之前
父節點
當前提交
63ba8145b9
共有 2 個文件被更改,包括 40 次插入37 次删除
  1. 7 7
      backend/apps/webui/routers/files.py
  2. 33 30
      backend/main.py

+ 7 - 7
backend/apps/webui/routers/files.py

@@ -26,7 +26,7 @@ from apps.webui.models.files import (
     FileModel,
     FileModelResponse,
 )
-from utils.utils import get_verified_user, get_admin_user
+from utils.utils import get_current_user, get_admin_user
 from constants import ERROR_MESSAGES
 
 from importlib import util
@@ -50,7 +50,7 @@ router = APIRouter()
 
 
 @router.post("/")
-def upload_file(file: UploadFile = File(...), user=Depends(get_verified_user)):
+def upload_file(file: UploadFile = File(...), user=Depends(get_current_user)):
     log.info(f"file.content_type: {file.content_type}")
     try:
         unsanitized_filename = file.filename
@@ -105,7 +105,7 @@ def upload_file(file: UploadFile = File(...), user=Depends(get_verified_user)):
 
 
 @router.get("/", response_model=list[FileModel])
-async def list_files(user=Depends(get_verified_user)):
+async def list_files(user=Depends(get_current_user)):
     files = Files.get_files()
     return files
 
@@ -153,7 +153,7 @@ async def delete_all_files(user=Depends(get_admin_user)):
 
 
 @router.get("/{id}", response_model=Optional[FileModel])
-async def get_file_by_id(id: str, user=Depends(get_verified_user)):
+async def get_file_by_id(id: str, user=Depends(get_current_user)):
     file = Files.get_file_by_id(id)
 
     if file:
@@ -171,7 +171,7 @@ async def get_file_by_id(id: str, user=Depends(get_verified_user)):
 
 
 @router.get("/{id}/content", response_model=Optional[FileModel])
-async def get_file_content_by_id(id: str, user=Depends(get_verified_user)):
+async def get_file_content_by_id(id: str, user=Depends(get_current_user)):
     file = Files.get_file_by_id(id)
 
     if file:
@@ -194,7 +194,7 @@ async def get_file_content_by_id(id: str, user=Depends(get_verified_user)):
 
 
 @router.get("/{id}/content/{file_name}", response_model=Optional[FileModel])
-async def get_file_content_by_id(id: str, user=Depends(get_verified_user)):
+async def get_file_content_by_id(id: str, user=Depends(get_current_user)):
     file = Files.get_file_by_id(id)
 
     if file:
@@ -222,7 +222,7 @@ async def get_file_content_by_id(id: str, user=Depends(get_verified_user)):
 
 
 @router.delete("/{id}")
-async def delete_file_by_id(id: str, user=Depends(get_verified_user)):
+async def delete_file_by_id(id: str, user=Depends(get_current_user)):
     file = Files.get_file_by_id(id)
 
     if file:

+ 33 - 30
backend/main.py

@@ -299,24 +299,26 @@ async def chat_completion_filter_functions_handler(body, model, extra_params):
 
             # Get the signature of the function
             sig = inspect.signature(inlet)
-            params = {"body": body}
+            params = {"body": body} | {
+                k: v
+                for k, v in {
+                    **extra_params,
+                    "__model__": model,
+                    "__id__": filter_id,
+                }.items()
+                if k in sig.parameters
+            }
 
-            # Extra parameters to be passed to the function
-            custom_params = {**extra_params, "__model__": model, "__id__": filter_id}
-            if hasattr(function_module, "UserValves") and "__user__" in sig.parameters:
+            if "__user__" in params and hasattr(function_module, "UserValves"):
                 try:
-                    uid = custom_params["__user__"]["id"]
-                    custom_params["__user__"]["valves"] = function_module.UserValves(
-                        **Functions.get_user_valves_by_id_and_user_id(filter_id, uid)
+                    params["__user__"]["valves"] = function_module.UserValves(
+                        **Functions.get_user_valves_by_id_and_user_id(
+                            filter_id, params["__user__"]["id"]
+                        )
                     )
                 except Exception as e:
                     print(e)
 
-            # Add extra params in contained in function signature
-            for key, value in custom_params.items():
-                if key in sig.parameters:
-                    params[key] = value
-
             if inspect.iscoroutinefunction(inlet):
                 body = await inlet(**params)
             else:
@@ -372,7 +374,9 @@ async def chat_completion_tools_handler(
 ) -> tuple[dict, dict]:
     # If tool_ids field is present, call the functions
     metadata = body.get("metadata", {})
+
     tool_ids = metadata.get("tool_ids", None)
+    log.debug(f"{tool_ids=}")
     if not tool_ids:
         return body, {}
 
@@ -381,16 +385,17 @@ async def chat_completion_tools_handler(
     citations = []
 
     task_model_id = get_task_model_id(body["model"])
-
-    log.debug(f"{tool_ids=}")
-
-    custom_params = {
-        **extra_params,
-        "__model__": app.state.MODELS[task_model_id],
-        "__messages__": body["messages"],
-        "__files__": metadata.get("files", []),
-    }
-    tools = get_tools(webui_app, tool_ids, user, custom_params)
+    tools = get_tools(
+        webui_app,
+        tool_ids,
+        user,
+        {
+            **extra_params,
+            "__model__": app.state.MODELS[task_model_id],
+            "__messages__": body["messages"],
+            "__files__": metadata.get("files", []),
+        },
+    )
     log.info(f"{tools=}")
 
     specs = [tool["spec"] for tool in tools.values()]
@@ -530,17 +535,15 @@ class ChatCompletionMiddleware(BaseHTTPMiddleware):
         }
         body["metadata"] = metadata
 
-        __user__ = {
-            "id": user.id,
-            "email": user.email,
-            "name": user.name,
-            "role": user.role,
-        }
-
         extra_params = {
-            "__user__": __user__,
             "__event_emitter__": get_event_emitter(metadata),
             "__event_call__": get_event_call(metadata),
+            "__user__": {
+                "id": user.id,
+                "email": user.email,
+                "name": user.name,
+                "role": user.role,
+            },
         }
 
         # Initialize data_items to store additional data to be sent to the client