浏览代码

fix more LSP errors

Michael Poluektov 8 月之前
父节点
当前提交
ff9d899f9c
共有 1 个文件被更改,包括 51 次插入79 次删除
  1. 51 79
      backend/main.py

+ 51 - 79
backend/main.py

@@ -261,6 +261,7 @@ def get_filter_function_ids(model):
     def get_priority(function_id):
         function = Functions.get_function_by_id(function_id)
         if function is not None and hasattr(function, "valves"):
+            # TODO: Fix FunctionModel
             return (function.valves if function.valves else {}).get("priority", 0)
         return 0
 
@@ -322,14 +323,7 @@ async def call_tool_from_completion(
 
 
 async def get_function_call_response(
-    messages,
-    files,
-    tool_id,
-    template,
-    task_model_id,
-    user,
-    __event_emitter__=None,
-    __event_call__=None,
+    messages, files, tool_id, template, task_model_id, user, extra_params
 ) -> tuple[Optional[str], Optional[dict], bool]:
     tool = Tools.get_tool_by_id(tool_id)
     if tool is None:
@@ -373,32 +367,22 @@ async def get_function_call_response(
         toolkit_module, _ = load_toolkit_module_by_id(tool_id)
         webui_app.state.TOOLS[tool_id] = toolkit_module
 
-    __user__ = {
-        "id": user.id,
-        "email": user.email,
-        "name": user.name,
-        "role": user.role,
+    custom_params = {
+        **extra_params,
+        "__model__": app.state.MODELS[task_model_id],
+        "__id__": tool_id,
+        "__messages__": messages,
+        "__files__": files,
     }
-
     try:
         if hasattr(toolkit_module, "UserValves"):
-            __user__["valves"] = toolkit_module.UserValves(
+            custom_params["__user__"]["valves"] = toolkit_module.UserValves(
                 **Tools.get_user_valves_by_id_and_user_id(tool_id, user.id)
             )
 
     except Exception as e:
         print(e)
 
-    extra_params = {
-        "__model__": app.state.MODELS[task_model_id],
-        "__id__": tool_id,
-        "__messages__": messages,
-        "__files__": files,
-        "__event_emitter__": __event_emitter__,
-        "__event_call__": __event_call__,
-        "__user__": __user__,
-    }
-
     file_handler = hasattr(toolkit_module, "file_handler")
 
     if hasattr(toolkit_module, "valves") and hasattr(toolkit_module, "Valves"):
@@ -417,7 +401,7 @@ async def get_function_call_response(
         result = json.loads(content)
 
         function_result = await call_tool_from_completion(
-            result, extra_params, toolkit_module
+            result, custom_params, toolkit_module
         )
 
         if hasattr(toolkit_module, "citation") and toolkit_module.citation:
@@ -438,9 +422,7 @@ async def get_function_call_response(
     return None, None, False
 
 
-async def chat_completion_inlets_handler(
-    body, model, user, __event_emitter__, __event_call__
-):
+async def chat_completion_inlets_handler(body, model, extra_params):
     skip_files = None
 
     filter_ids = get_filter_function_ids(model)
@@ -476,38 +458,18 @@ async def chat_completion_inlets_handler(
             params = {"body": body}
 
             # Extra parameters to be passed to the function
-            extra_params = {
-                "__model__": model,
-                "__id__": filter_id,
-                "__event_emitter__": __event_emitter__,
-                "__event_call__": __event_call__,
-            }
+            custom_params = {**extra_params, "__model__": model, "__id__": filter_id}
+            if hasattr(function_module, "UserValves") and "__user__" in sig.parameters:
+                uid = custom_params["__user__"]["id"]
+                custom_params["__user__"]["valves"] = function_module.UserValves(
+                    **Functions.get_user_valves_by_id_and_user_id(filter_id, uid)
+                )
 
             # Add extra params in contained in function signature
-            for key, value in extra_params.items():
+            for key, value in custom_params.items():
                 if key in sig.parameters:
                     params[key] = value
 
-            if "__user__" in sig.parameters:
-                __user__ = {
-                    "id": user.id,
-                    "email": user.email,
-                    "name": user.name,
-                    "role": user.role,
-                }
-
-                try:
-                    if hasattr(function_module, "UserValves"):
-                        __user__["valves"] = function_module.UserValves(
-                            **Functions.get_user_valves_by_id_and_user_id(
-                                filter_id, user.id
-                            )
-                        )
-                except Exception as e:
-                    print(e)
-
-                params = {**params, "__user__": __user__}
-
             if inspect.iscoroutinefunction(inlet):
                 body = await inlet(**params)
             else:
@@ -524,7 +486,7 @@ async def chat_completion_inlets_handler(
     return body, {}
 
 
-async def chat_completion_tools_handler(body, user, __event_emitter__, __event_call__):
+async def chat_completion_tools_handler(body, user, extra_params):
     skip_files = None
 
     contexts = []
@@ -547,8 +509,7 @@ async def chat_completion_tools_handler(body, user, __event_emitter__, __event_c
                 template=app.state.config.TOOLS_FUNCTION_CALLING_PROMPT_TEMPLATE,
                 task_model_id=task_model_id,
                 user=user,
-                __event_emitter__=__event_emitter__,
-                __event_call__=__event_call__,
+                extra_params=extra_params,
             )
 
             print(file_handler)
@@ -584,10 +545,7 @@ async def chat_completion_files_handler(body):
     contexts = []
     citations = None
 
-    if "files" in body:
-        files = body["files"]
-        del body["files"]
-
+    if files := body.pop("files", None):
         contexts, citations = get_rag_context(
             files=files,
             messages=body["messages"],
@@ -634,8 +592,18 @@ class ChatCompletionMiddleware(BaseHTTPMiddleware):
             "valves": body.pop("valves", None),
         }
 
-        __event_emitter__ = get_event_emitter(metadata)
-        __event_call__ = get_event_call(metadata)
+        __user__ = {
+            "id": user.id,
+            "email": user.email,
+            "name": user.name,
+            "role": user.role,
+        }
+
+        extra_params = {
+            "__user__": __user__,
+            "__event_emitter__": get_event_emitter(metadata),
+            "__event_call__": get_event_call(metadata),
+        }
 
         # Initialize data_items to store additional data to be sent to the client
         # Initalize contexts and citation
@@ -645,7 +613,7 @@ class ChatCompletionMiddleware(BaseHTTPMiddleware):
 
         try:
             body, flags = await chat_completion_inlets_handler(
-                body, model, user, __event_emitter__, __event_call__
+                body, model, extra_params
             )
         except Exception as e:
             return JSONResponse(
@@ -654,10 +622,7 @@ class ChatCompletionMiddleware(BaseHTTPMiddleware):
             )
 
         try:
-            body, flags = await chat_completion_tools_handler(
-                body, user, __event_emitter__, __event_call__
-            )
-
+            body, flags = await chat_completion_tools_handler(body, user, extra_params)
             contexts.extend(flags.get("contexts", []))
             citations.extend(flags.get("citations", []))
         except Exception as e:
@@ -666,7 +631,6 @@ class ChatCompletionMiddleware(BaseHTTPMiddleware):
 
         try:
             body, flags = await chat_completion_files_handler(body)
-
             contexts.extend(flags.get("contexts", []))
             citations.extend(flags.get("citations", []))
         except Exception as e:
@@ -713,7 +677,7 @@ class ChatCompletionMiddleware(BaseHTTPMiddleware):
         response = await call_next(request)
         if isinstance(response, StreamingResponse):
             # If it's a streaming response, inject it as SSE event or NDJSON line
-            content_type = response.headers.get("Content-Type")
+            content_type = response.headers["Content-Type"]
             if "text/event-stream" in content_type:
                 return StreamingResponse(
                     self.openai_stream_wrapper(response.body_iterator, data_items),
@@ -832,7 +796,7 @@ class PipelineMiddleware(BaseHTTPMiddleware):
 
         user = get_current_user(
             request,
-            get_http_authorization_cred(request.headers.get("Authorization")),
+            get_http_authorization_cred(request.headers["Authorization"]),
         )
 
         try:
@@ -1015,6 +979,8 @@ async def get_all_models():
         model["actions"] = []
         for action_id in action_ids:
             action = Functions.get_function_by_id(action_id)
+            if action is None:
+                raise Exception(f"Action not found: {action_id}")
 
             if action_id in webui_app.state.FUNCTIONS:
                 function_module = webui_app.state.FUNCTIONS[action_id]
@@ -1022,6 +988,10 @@ async def get_all_models():
                 function_module, _, _ = load_function_module_by_id(action_id)
                 webui_app.state.FUNCTIONS[action_id] = function_module
 
+            icon_url = None
+            if action.meta.manifest is not None:
+                icon_url = action.meta.manifest.get("icon_url", None)
+
             if hasattr(function_module, "actions"):
                 actions = function_module.actions
                 model["actions"].extend(
@@ -1032,9 +1002,7 @@ async def get_all_models():
                                 "name", f"{action.name} ({_action['id']})"
                             ),
                             "description": action.meta.description,
-                            "icon_url": _action.get(
-                                "icon_url", action.meta.manifest.get("icon_url", None)
-                            ),
+                            "icon_url": _action.get("icon_url", icon_url),
                         }
                         for _action in actions
                     ]
@@ -1045,7 +1013,7 @@ async def get_all_models():
                         "id": action_id,
                         "name": action.name,
                         "description": action.meta.description,
-                        "icon_url": action.meta.manifest.get("icon_url", None),
+                        "icon_url": icon_url,
                     }
                 )
 
@@ -1175,6 +1143,7 @@ async def chat_completed(form_data: dict, user=Depends(get_verified_user)):
     def get_priority(function_id):
         function = Functions.get_function_by_id(function_id)
         if function is not None and hasattr(function, "valves"):
+            # TODO: Fix FunctionModel to include vavles
             return (function.valves if function.valves else {}).get("priority", 0)
         return 0
 
@@ -1631,7 +1600,7 @@ async def upload_pipeline(
 ):
     print("upload_pipeline", urlIdx, file.filename)
     # Check if the uploaded file is a python file
-    if not file.filename.endswith(".py"):
+    if not (file.filename and file.filename.endswith(".py")):
         raise HTTPException(
             status_code=status.HTTP_400_BAD_REQUEST,
             detail="Only Python (.py) files are allowed.",
@@ -2080,7 +2049,10 @@ async def oauth_login(provider: str, request: Request):
     redirect_uri = OAUTH_PROVIDERS[provider].get("redirect_uri") or request.url_for(
         "oauth_callback", provider=provider
     )
-    return await oauth.create_client(provider).authorize_redirect(request, redirect_uri)
+    client = oauth.create_client(provider)
+    if client is None:
+        raise HTTPException(404)
+    return await client.authorize_redirect(request, redirect_uri)
 
 
 # OAuth login logic is as follows: