浏览代码

refac: backend/main.py

Michael Poluektov 10 月之前
父节点
当前提交
e3e02e04e8
共有 1 个文件被更改,包括 214 次插入280 次删除
  1. 214 280
      backend/main.py

+ 214 - 280
backend/main.py

@@ -1,13 +1,10 @@
 import base64
 import base64
 import uuid
 import uuid
-import subprocess
 from contextlib import asynccontextmanager
 from contextlib import asynccontextmanager
 
 
 from authlib.integrations.starlette_client import OAuth
 from authlib.integrations.starlette_client import OAuth
 from authlib.oidc.core import UserInfo
 from authlib.oidc.core import UserInfo
-from bs4 import BeautifulSoup
 import json
 import json
-import markdown
 import time
 import time
 import os
 import os
 import sys
 import sys
@@ -19,14 +16,11 @@ import shutil
 import os
 import os
 import uuid
 import uuid
 import inspect
 import inspect
-import asyncio
 
 
-from fastapi.concurrency import run_in_threadpool
 from fastapi import FastAPI, Request, Depends, status, UploadFile, File, Form
 from fastapi import FastAPI, Request, Depends, status, UploadFile, File, Form
 from fastapi.staticfiles import StaticFiles
 from fastapi.staticfiles import StaticFiles
 from fastapi.responses import JSONResponse
 from fastapi.responses import JSONResponse
 from fastapi import HTTPException
 from fastapi import HTTPException
-from fastapi.middleware.wsgi import WSGIMiddleware
 from fastapi.middleware.cors import CORSMiddleware
 from fastapi.middleware.cors import CORSMiddleware
 from sqlalchemy import text
 from sqlalchemy import text
 from starlette.exceptions import HTTPException as StarletteHTTPException
 from starlette.exceptions import HTTPException as StarletteHTTPException
@@ -38,7 +32,6 @@ from starlette.responses import StreamingResponse, Response, RedirectResponse
 from apps.socket.main import sio, app as socket_app
 from apps.socket.main import sio, app as socket_app
 from apps.ollama.main import (
 from apps.ollama.main import (
     app as ollama_app,
     app as ollama_app,
-    OpenAIChatCompletionForm,
     get_all_models as get_ollama_models,
     get_all_models as get_ollama_models,
     generate_openai_chat_completion as generate_ollama_chat_completion,
     generate_openai_chat_completion as generate_ollama_chat_completion,
 )
 )
@@ -56,14 +49,14 @@ from apps.webui.main import (
     get_pipe_models,
     get_pipe_models,
     generate_function_chat_completion,
     generate_function_chat_completion,
 )
 )
-from apps.webui.internal.db import Session, SessionLocal
+from apps.webui.internal.db import Session
 
 
 
 
 from pydantic import BaseModel
 from pydantic import BaseModel
-from typing import List, Optional, Iterator, Generator, Union
+from typing import List, Optional
 
 
 from apps.webui.models.auths import Auths
 from apps.webui.models.auths import Auths
-from apps.webui.models.models import Models, ModelModel
+from apps.webui.models.models import Models
 from apps.webui.models.tools import Tools
 from apps.webui.models.tools import Tools
 from apps.webui.models.functions import Functions
 from apps.webui.models.functions import Functions
 from apps.webui.models.users import Users
 from apps.webui.models.users import Users
@@ -86,14 +79,12 @@ from utils.task import (
 from utils.misc import (
 from utils.misc import (
     get_last_user_message,
     get_last_user_message,
     add_or_update_system_message,
     add_or_update_system_message,
-    stream_message_template,
     parse_duration,
     parse_duration,
 )
 )
 
 
 from apps.rag.utils import get_rag_context, rag_template
 from apps.rag.utils import get_rag_context, rag_template
 
 
 from config import (
 from config import (
-    CONFIG_DATA,
     WEBUI_NAME,
     WEBUI_NAME,
     WEBUI_URL,
     WEBUI_URL,
     WEBUI_AUTH,
     WEBUI_AUTH,
@@ -101,7 +92,6 @@ from config import (
     VERSION,
     VERSION,
     CHANGELOG,
     CHANGELOG,
     FRONTEND_BUILD_DIR,
     FRONTEND_BUILD_DIR,
-    UPLOAD_DIR,
     CACHE_DIR,
     CACHE_DIR,
     STATIC_DIR,
     STATIC_DIR,
     DEFAULT_LOCALE,
     DEFAULT_LOCALE,
@@ -128,9 +118,8 @@ from config import (
     WEBUI_SESSION_COOKIE_SAME_SITE,
     WEBUI_SESSION_COOKIE_SAME_SITE,
     WEBUI_SESSION_COOKIE_SECURE,
     WEBUI_SESSION_COOKIE_SECURE,
     AppConfig,
     AppConfig,
-    BACKEND_DIR,
-    DATABASE_URL,
 )
 )
+
 from constants import ERROR_MESSAGES, WEBHOOK_MESSAGES, TASKS
 from constants import ERROR_MESSAGES, WEBHOOK_MESSAGES, TASKS
 from utils.webhook import post_webhook
 from utils.webhook import post_webhook
 
 
@@ -355,121 +344,94 @@ async def get_function_call_response(
         else:
         else:
             content = response["choices"][0]["message"]["content"]
             content = response["choices"][0]["message"]["content"]
 
 
+        if content is None:
+            return None, None, False
+
         # Parse the function response
         # Parse the function response
-        if content is not None:
-            print(f"content: {content}")
-            result = json.loads(content)
-            print(result)
-
-            citation = None
-            # Call the function
-            if "name" in result:
-                if tool_id in webui_app.state.TOOLS:
-                    toolkit_module = webui_app.state.TOOLS[tool_id]
-                else:
-                    toolkit_module, frontmatter = load_toolkit_module_by_id(tool_id)
-                    webui_app.state.TOOLS[tool_id] = toolkit_module
+        print(f"content: {content}")
+        result = json.loads(content)
+        print(result)
 
 
-                file_handler = False
-                # check if toolkit_module has file_handler self variable
-                if hasattr(toolkit_module, "file_handler"):
-                    file_handler = True
-                    print("file_handler: ", file_handler)
+        citation = None
 
 
-                if hasattr(toolkit_module, "valves") and hasattr(
-                    toolkit_module, "Valves"
-                ):
-                    valves = Tools.get_tool_valves_by_id(tool_id)
-                    toolkit_module.valves = toolkit_module.Valves(
-                        **(valves if valves else {})
-                    )
+        if "name" not in result:
+            return None, None, False
 
 
-                function = getattr(toolkit_module, result["name"])
-                function_result = None
-                try:
-                    # Get the signature of the function
-                    sig = inspect.signature(function)
-                    params = result["parameters"]
+        # Call the function
+        if tool_id in webui_app.state.TOOLS:
+            toolkit_module = webui_app.state.TOOLS[tool_id]
+        else:
+            toolkit_module, _ = load_toolkit_module_by_id(tool_id)
+            webui_app.state.TOOLS[tool_id] = toolkit_module
 
 
-                    if "__user__" in sig.parameters:
-                        # Call the function with the '__user__' parameter included
-                        __user__ = {
-                            "id": user.id,
-                            "email": user.email,
-                            "name": user.name,
-                            "role": user.role,
-                        }
-
-                        try:
-                            if hasattr(toolkit_module, "UserValves"):
-                                __user__["valves"] = toolkit_module.UserValves(
-                                    **Tools.get_user_valves_by_id_and_user_id(
-                                        tool_id, user.id
-                                    )
-                                )
-                        except Exception as e:
-                            print(e)
-
-                        params = {**params, "__user__": __user__}
-                    if "__messages__" in sig.parameters:
-                        # Call the function with the '__messages__' parameter included
-                        params = {
-                            **params,
-                            "__messages__": messages,
-                        }
-
-                    if "__files__" in sig.parameters:
-                        # Call the function with the '__files__' parameter included
-                        params = {
-                            **params,
-                            "__files__": files,
-                        }
-
-                    if "__model__" in sig.parameters:
-                        # Call the function with the '__model__' parameter included
-                        params = {
-                            **params,
-                            "__model__": model,
-                        }
-
-                    if "__id__" in sig.parameters:
-                        # Call the function with the '__id__' parameter included
-                        params = {
-                            **params,
-                            "__id__": tool_id,
-                        }
-
-                    if "__event_emitter__" in sig.parameters:
-                        # Call the function with the '__event_emitter__' parameter included
-                        params = {
-                            **params,
-                            "__event_emitter__": __event_emitter__,
-                        }
-
-                    if "__event_call__" in sig.parameters:
-                        # Call the function with the '__event_call__' parameter included
-                        params = {
-                            **params,
-                            "__event_call__": __event_call__,
-                        }
-
-                    if inspect.iscoroutinefunction(function):
-                        function_result = await function(**params)
-                    else:
-                        function_result = function(**params)
-
-                    if hasattr(toolkit_module, "citation") and toolkit_module.citation:
-                        citation = {
-                            "source": {"name": f"TOOL:{tool.name}/{result['name']}"},
-                            "document": [function_result],
-                            "metadata": [{"source": result["name"]}],
-                        }
+        file_handler = False
+        # check if toolkit_module has file_handler self variable
+        if hasattr(toolkit_module, "file_handler"):
+            file_handler = True
+            print("file_handler: ", file_handler)
+
+        if hasattr(toolkit_module, "valves") and hasattr(toolkit_module, "Valves"):
+            valves = Tools.get_tool_valves_by_id(tool_id)
+            toolkit_module.valves = toolkit_module.Valves(**(valves if valves else {}))
+
+        function = getattr(toolkit_module, result["name"])
+        function_result = None
+        try:
+            # Get the signature of the function
+            sig = inspect.signature(function)
+            params = result["parameters"]
+
+            # Extra parameters to be passed to the function
+            extra_params = {
+                "__model__": model,
+                "__id__": tool_id,
+                "__messages__": messages,
+                "__files__": files,
+                "__event_emitter__": __event_emitter__,
+                "__event_call__": __event_call__,
+            }
+
+            # Add extra params in contained in function signature
+            for key, value in extra_params.items():
+                if key in sig.parameters:
+                    params[key] = value
+
+            if "__user__" in sig.parameters:
+                # Call the function with the '__user__' parameter included
+                __user__ = {
+                    "id": user.id,
+                    "email": user.email,
+                    "name": user.name,
+                    "role": user.role,
+                }
+
+                try:
+                    if hasattr(toolkit_module, "UserValves"):
+                        __user__["valves"] = toolkit_module.UserValves(
+                            **Tools.get_user_valves_by_id_and_user_id(tool_id, user.id)
+                        )
                 except Exception as e:
                 except Exception as e:
                     print(e)
                     print(e)
 
 
-                # Add the function result to the system prompt
-                if function_result is not None:
-                    return function_result, citation, file_handler
+                params = {**params, "__user__": __user__}
+
+            if inspect.iscoroutinefunction(function):
+                function_result = await function(**params)
+            else:
+                function_result = function(**params)
+
+            if hasattr(toolkit_module, "citation") and toolkit_module.citation:
+                citation = {
+                    "source": {"name": f"TOOL:{tool.name}/{result['name']}"},
+                    "document": [function_result],
+                    "metadata": [{"source": result["name"]}],
+                }
+        except Exception as e:
+            print(e)
+
+        # Add the function result to the system prompt
+        if function_result is not None:
+            return function_result, citation, file_handler
     except Exception as e:
     except Exception as e:
         print(f"Error: {e}")
         print(f"Error: {e}")
 
 
@@ -484,87 +446,74 @@ async def chat_completion_functions_handler(
     filter_ids = get_filter_function_ids(model)
     filter_ids = get_filter_function_ids(model)
     for filter_id in filter_ids:
     for filter_id in filter_ids:
         filter = Functions.get_function_by_id(filter_id)
         filter = Functions.get_function_by_id(filter_id)
-        if filter:
-            if filter_id in webui_app.state.FUNCTIONS:
-                function_module = webui_app.state.FUNCTIONS[filter_id]
-            else:
-                function_module, function_type, frontmatter = (
-                    load_function_module_by_id(filter_id)
-                )
-                webui_app.state.FUNCTIONS[filter_id] = function_module
-
-            # Check if the function has a file_handler variable
-            if hasattr(function_module, "file_handler"):
-                skip_files = function_module.file_handler
-
-            if hasattr(function_module, "valves") and hasattr(
-                function_module, "Valves"
-            ):
-                valves = Functions.get_function_valves_by_id(filter_id)
-                function_module.valves = function_module.Valves(
-                    **(valves if valves else {})
-                )
+        if not filter:
+            continue
 
 
-            try:
-                if hasattr(function_module, "inlet"):
-                    inlet = function_module.inlet
+        if filter_id in webui_app.state.FUNCTIONS:
+            function_module = webui_app.state.FUNCTIONS[filter_id]
+        else:
+            function_module, _, _ = load_function_module_by_id(filter_id)
+            webui_app.state.FUNCTIONS[filter_id] = function_module
 
 
-                    # Get the signature of the function
-                    sig = inspect.signature(inlet)
-                    params = {"body": body}
+        # Check if the function has a file_handler variable
+        if hasattr(function_module, "file_handler"):
+            skip_files = function_module.file_handler
 
 
-                    if "__user__" in sig.parameters:
-                        __user__ = {
-                            "id": user.id,
-                            "email": user.email,
-                            "name": user.name,
-                            "role": user.role,
-                        }
-
-                        try:
-                            if hasattr(function_module, "UserValves"):
-                                __user__["valves"] = function_module.UserValves(
-                                    **Functions.get_user_valves_by_id_and_user_id(
-                                        filter_id, user.id
-                                    )
+        if hasattr(function_module, "valves") and hasattr(function_module, "Valves"):
+            valves = Functions.get_function_valves_by_id(filter_id)
+            function_module.valves = function_module.Valves(
+                **(valves if valves else {})
+            )
+
+        try:
+            if hasattr(function_module, "inlet"):
+                inlet = function_module.inlet
+
+                # Get the signature of the function
+                sig = inspect.signature(inlet)
+                params = {"body": body}
+
+                # Extra parameters to be passed to the function
+                extra_params = {
+                    "__model__": model,
+                    "__id__": filter_id,
+                    "__event_emitter__": __event_emitter__,
+                    "__event_call__": __event_call__,
+                }
+
+                # Add extra params in contained in function signature
+                for key, value in extra_params.items():
+                    if key in sig.parameters:
+                        params[key] = value
+
+                if "__user__" in sig.parameters:
+                    __user__ = {
+                        "id": user.id,
+                        "email": user.email,
+                        "name": user.name,
+                        "role": user.role,
+                    }
+
+                    try:
+                        if hasattr(function_module, "UserValves"):
+                            __user__["valves"] = function_module.UserValves(
+                                **Functions.get_user_valves_by_id_and_user_id(
+                                    filter_id, user.id
                                 )
                                 )
-                        except Exception as e:
-                            print(e)
-
-                        params = {**params, "__user__": __user__}
-
-                    if "__id__" in sig.parameters:
-                        params = {
-                            **params,
-                            "__id__": filter_id,
-                        }
-
-                    if "__model__" in sig.parameters:
-                        params = {
-                            **params,
-                            "__model__": model,
-                        }
-
-                    if "__event_emitter__" in sig.parameters:
-                        params = {
-                            **params,
-                            "__event_emitter__": __event_emitter__,
-                        }
-
-                    if "__event_call__" in sig.parameters:
-                        params = {
-                            **params,
-                            "__event_call__": __event_call__,
-                        }
-
-                    if inspect.iscoroutinefunction(inlet):
-                        body = await inlet(**params)
-                    else:
-                        body = inlet(**params)
+                            )
+                    except Exception as e:
+                        print(e)
 
 
-            except Exception as e:
-                print(f"Error: {e}")
-                raise e
+                    params = {**params, "__user__": __user__}
+
+                if inspect.iscoroutinefunction(inlet):
+                    body = await inlet(**params)
+                else:
+                    body = inlet(**params)
+
+        except Exception as e:
+            print(f"Error: {e}")
+            raise e
 
 
     if skip_files:
     if skip_files:
         if "files" in body:
         if "files" in body:
@@ -1220,86 +1169,73 @@ async def chat_completed(form_data: dict, user=Depends(get_verified_user)):
 
 
     for filter_id in filter_ids:
     for filter_id in filter_ids:
         filter = Functions.get_function_by_id(filter_id)
         filter = Functions.get_function_by_id(filter_id)
-        if filter:
-            if filter_id in webui_app.state.FUNCTIONS:
-                function_module = webui_app.state.FUNCTIONS[filter_id]
-            else:
-                function_module, function_type, frontmatter = (
-                    load_function_module_by_id(filter_id)
-                )
-                webui_app.state.FUNCTIONS[filter_id] = function_module
-
-            if hasattr(function_module, "valves") and hasattr(
-                function_module, "Valves"
-            ):
-                valves = Functions.get_function_valves_by_id(filter_id)
-                function_module.valves = function_module.Valves(
-                    **(valves if valves else {})
-                )
+        if not filter:
+            continue
 
 
-            try:
-                if hasattr(function_module, "outlet"):
-                    outlet = function_module.outlet
+        if filter_id in webui_app.state.FUNCTIONS:
+            function_module = webui_app.state.FUNCTIONS[filter_id]
+        else:
+            function_module, _, _ = load_function_module_by_id(filter_id)
+            webui_app.state.FUNCTIONS[filter_id] = function_module
 
 
-                    # Get the signature of the function
-                    sig = inspect.signature(outlet)
-                    params = {"body": data}
+        if hasattr(function_module, "valves") and hasattr(function_module, "Valves"):
+            valves = Functions.get_function_valves_by_id(filter_id)
+            function_module.valves = function_module.Valves(
+                **(valves if valves else {})
+            )
 
 
-                    if "__user__" in sig.parameters:
-                        __user__ = {
-                            "id": user.id,
-                            "email": user.email,
-                            "name": user.name,
-                            "role": user.role,
-                        }
-
-                        try:
-                            if hasattr(function_module, "UserValves"):
-                                __user__["valves"] = function_module.UserValves(
-                                    **Functions.get_user_valves_by_id_and_user_id(
-                                        filter_id, user.id
-                                    )
+        try:
+            if hasattr(function_module, "outlet"):
+                outlet = function_module.outlet
+
+                # Get the signature of the function
+                sig = inspect.signature(outlet)
+                params = {"body": data}
+
+                # Extra parameters to be passed to the function
+                extra_params = {
+                    "__model__": model,
+                    "__id__": filter_id,
+                    "__event_emitter__": __event_emitter__,
+                    "__event_call__": __event_call__,
+                }
+
+                # Add extra params in contained in function signature
+                for key, value in extra_params.items():
+                    if key in sig.parameters:
+                        params[key] = value
+
+                if "__user__" in sig.parameters:
+                    __user__ = {
+                        "id": user.id,
+                        "email": user.email,
+                        "name": user.name,
+                        "role": user.role,
+                    }
+
+                    try:
+                        if hasattr(function_module, "UserValves"):
+                            __user__["valves"] = function_module.UserValves(
+                                **Functions.get_user_valves_by_id_and_user_id(
+                                    filter_id, user.id
                                 )
                                 )
-                        except Exception as e:
-                            print(e)
-
-                        params = {**params, "__user__": __user__}
-
-                    if "__id__" in sig.parameters:
-                        params = {
-                            **params,
-                            "__id__": filter_id,
-                        }
-
-                    if "__model__" in sig.parameters:
-                        params = {
-                            **params,
-                            "__model__": model,
-                        }
-
-                    if "__event_emitter__" in sig.parameters:
-                        params = {
-                            **params,
-                            "__event_emitter__": __event_emitter__,
-                        }
-
-                    if "__event_call__" in sig.parameters:
-                        params = {
-                            **params,
-                            "__event_call__": __event_call__,
-                        }
-
-                    if inspect.iscoroutinefunction(outlet):
-                        data = await outlet(**params)
-                    else:
-                        data = outlet(**params)
+                            )
+                    except Exception as e:
+                        print(e)
 
 
-            except Exception as e:
-                print(f"Error: {e}")
-                return JSONResponse(
-                    status_code=status.HTTP_400_BAD_REQUEST,
-                    content={"detail": str(e)},
-                )
+                    params = {**params, "__user__": __user__}
+
+                if inspect.iscoroutinefunction(outlet):
+                    data = await outlet(**params)
+                else:
+                    data = outlet(**params)
+
+        except Exception as e:
+            print(f"Error: {e}")
+            return JSONResponse(
+                status_code=status.HTTP_400_BAD_REQUEST,
+                content={"detail": str(e)},
+            )
 
 
     return data
     return data
 
 
@@ -1387,7 +1323,6 @@ async def generate_title(form_data: dict, user=Depends(get_verified_user)):
                 model_id = task_model_id
                 model_id = task_model_id
 
 
     print(model_id)
     print(model_id)
-    model = app.state.MODELS[model_id]
 
 
     template = app.state.config.TITLE_GENERATION_PROMPT_TEMPLATE
     template = app.state.config.TITLE_GENERATION_PROMPT_TEMPLATE
 
 
@@ -1456,7 +1391,6 @@ async def generate_search_query(form_data: dict, user=Depends(get_verified_user)
                 model_id = task_model_id
                 model_id = task_model_id
 
 
     print(model_id)
     print(model_id)
-    model = app.state.MODELS[model_id]
 
 
     template = app.state.config.SEARCH_QUERY_GENERATION_PROMPT_TEMPLATE
     template = app.state.config.SEARCH_QUERY_GENERATION_PROMPT_TEMPLATE
 
 
@@ -1513,7 +1447,6 @@ async def generate_emoji(form_data: dict, user=Depends(get_verified_user)):
                 model_id = task_model_id
                 model_id = task_model_id
 
 
     print(model_id)
     print(model_id)
-    model = app.state.MODELS[model_id]
 
 
     template = '''
     template = '''
 Your task is to reflect the speaker's likely facial expression through a fitting emoji. Interpret emotions from the message and reflect their facial expression using fitting, diverse emojis (e.g., 😊, 😢, 😡, 😱).
 Your task is to reflect the speaker's likely facial expression through a fitting emoji. Interpret emotions from the message and reflect their facial expression using fitting, diverse emojis (e.g., 😊, 😢, 😡, 😱).
@@ -1583,7 +1516,7 @@ async def get_tools_function_calling(form_data: dict, user=Depends(get_verified_
     template = app.state.config.TOOLS_FUNCTION_CALLING_PROMPT_TEMPLATE
     template = app.state.config.TOOLS_FUNCTION_CALLING_PROMPT_TEMPLATE
 
 
     try:
     try:
-        context, citation, file_handler = await get_function_call_response(
+        context, _, _ = await get_function_call_response(
             form_data["messages"],
             form_data["messages"],
             form_data.get("files", []),
             form_data.get("files", []),
             form_data["tool_id"],
             form_data["tool_id"],
@@ -1647,6 +1580,7 @@ async def upload_pipeline(
     os.makedirs(upload_folder, exist_ok=True)
     os.makedirs(upload_folder, exist_ok=True)
     file_path = os.path.join(upload_folder, file.filename)
     file_path = os.path.join(upload_folder, file.filename)
 
 
+    r = None
     try:
     try:
         # Save the uploaded file
         # Save the uploaded file
         with open(file_path, "wb") as buffer:
         with open(file_path, "wb") as buffer:
@@ -1670,7 +1604,9 @@ async def upload_pipeline(
         print(f"Connection error: {e}")
         print(f"Connection error: {e}")
 
 
         detail = "Pipeline not found"
         detail = "Pipeline not found"
+        status_code = status.HTTP_404_NOT_FOUND
         if r is not None:
         if r is not None:
+            status_code = r.status_code
             try:
             try:
                 res = r.json()
                 res = r.json()
                 if "detail" in res:
                 if "detail" in res:
@@ -1679,7 +1615,7 @@ async def upload_pipeline(
                 pass
                 pass
 
 
         raise HTTPException(
         raise HTTPException(
-            status_code=(r.status_code if r is not None else status.HTTP_404_NOT_FOUND),
+            status_code=status_code,
             detail=detail,
             detail=detail,
         )
         )
     finally:
     finally:
@@ -1778,8 +1714,6 @@ async def delete_pipeline(form_data: DeletePipelineForm, user=Depends(get_admin_
 async def get_pipelines(urlIdx: Optional[int] = None, user=Depends(get_admin_user)):
 async def get_pipelines(urlIdx: Optional[int] = None, user=Depends(get_admin_user)):
     r = None
     r = None
     try:
     try:
-        urlIdx
-
         url = openai_app.state.config.OPENAI_API_BASE_URLS[urlIdx]
         url = openai_app.state.config.OPENAI_API_BASE_URLS[urlIdx]
         key = openai_app.state.config.OPENAI_API_KEYS[urlIdx]
         key = openai_app.state.config.OPENAI_API_KEYS[urlIdx]