|
@@ -1,4 +1,4 @@
|
|
-import base64
|
|
|
|
|
|
+import asyncio
|
|
import inspect
|
|
import inspect
|
|
import json
|
|
import json
|
|
import logging
|
|
import logging
|
|
@@ -7,20 +7,38 @@ import os
|
|
import shutil
|
|
import shutil
|
|
import sys
|
|
import sys
|
|
import time
|
|
import time
|
|
-import uuid
|
|
|
|
-import asyncio
|
|
|
|
-
|
|
|
|
|
|
+import random
|
|
from contextlib import asynccontextmanager
|
|
from contextlib import asynccontextmanager
|
|
from typing import Optional
|
|
from typing import Optional
|
|
|
|
|
|
import aiohttp
|
|
import aiohttp
|
|
import requests
|
|
import requests
|
|
|
|
+from fastapi import (
|
|
|
|
+ Depends,
|
|
|
|
+ FastAPI,
|
|
|
|
+ File,
|
|
|
|
+ Form,
|
|
|
|
+ HTTPException,
|
|
|
|
+ Request,
|
|
|
|
+ UploadFile,
|
|
|
|
+ status,
|
|
|
|
+)
|
|
|
|
+from fastapi.middleware.cors import CORSMiddleware
|
|
|
|
+from fastapi.responses import JSONResponse, RedirectResponse
|
|
|
|
+from fastapi.staticfiles import StaticFiles
|
|
|
|
+from pydantic import BaseModel
|
|
|
|
+from sqlalchemy import text
|
|
|
|
+from starlette.exceptions import HTTPException as StarletteHTTPException
|
|
|
|
+from starlette.middleware.base import BaseHTTPMiddleware
|
|
|
|
+from starlette.middleware.sessions import SessionMiddleware
|
|
|
|
+from starlette.responses import Response, StreamingResponse
|
|
|
|
|
|
|
|
+from open_webui.apps.audio.main import app as audio_app
|
|
|
|
+from open_webui.apps.images.main import app as images_app
|
|
from open_webui.apps.ollama.main import (
|
|
from open_webui.apps.ollama.main import (
|
|
app as ollama_app,
|
|
app as ollama_app,
|
|
get_all_models as get_ollama_models,
|
|
get_all_models as get_ollama_models,
|
|
generate_chat_completion as generate_ollama_chat_completion,
|
|
generate_chat_completion as generate_ollama_chat_completion,
|
|
- generate_openai_chat_completion as generate_ollama_openai_chat_completion,
|
|
|
|
GenerateChatCompletionForm,
|
|
GenerateChatCompletionForm,
|
|
)
|
|
)
|
|
from open_webui.apps.openai.main import (
|
|
from open_webui.apps.openai.main import (
|
|
@@ -28,38 +46,24 @@ from open_webui.apps.openai.main import (
|
|
generate_chat_completion as generate_openai_chat_completion,
|
|
generate_chat_completion as generate_openai_chat_completion,
|
|
get_all_models as get_openai_models,
|
|
get_all_models as get_openai_models,
|
|
)
|
|
)
|
|
-
|
|
|
|
from open_webui.apps.retrieval.main import app as retrieval_app
|
|
from open_webui.apps.retrieval.main import app as retrieval_app
|
|
from open_webui.apps.retrieval.utils import get_rag_context, rag_template
|
|
from open_webui.apps.retrieval.utils import get_rag_context, rag_template
|
|
-
|
|
|
|
from open_webui.apps.socket.main import (
|
|
from open_webui.apps.socket.main import (
|
|
app as socket_app,
|
|
app as socket_app,
|
|
periodic_usage_pool_cleanup,
|
|
periodic_usage_pool_cleanup,
|
|
get_event_call,
|
|
get_event_call,
|
|
get_event_emitter,
|
|
get_event_emitter,
|
|
)
|
|
)
|
|
-
|
|
|
|
|
|
+from open_webui.apps.webui.internal.db import Session
|
|
from open_webui.apps.webui.main import (
|
|
from open_webui.apps.webui.main import (
|
|
app as webui_app,
|
|
app as webui_app,
|
|
generate_function_chat_completion,
|
|
generate_function_chat_completion,
|
|
- get_pipe_models,
|
|
|
|
|
|
+ get_all_models as get_open_webui_models,
|
|
)
|
|
)
|
|
-from open_webui.apps.webui.internal.db import Session
|
|
|
|
-
|
|
|
|
-from open_webui.apps.webui.models.auths import Auths
|
|
|
|
from open_webui.apps.webui.models.functions import Functions
|
|
from open_webui.apps.webui.models.functions import Functions
|
|
from open_webui.apps.webui.models.models import Models
|
|
from open_webui.apps.webui.models.models import Models
|
|
from open_webui.apps.webui.models.users import UserModel, Users
|
|
from open_webui.apps.webui.models.users import UserModel, Users
|
|
-
|
|
|
|
from open_webui.apps.webui.utils import load_function_module_by_id
|
|
from open_webui.apps.webui.utils import load_function_module_by_id
|
|
-
|
|
|
|
-from open_webui.apps.audio.main import app as audio_app
|
|
|
|
-from open_webui.apps.images.main import app as images_app
|
|
|
|
-
|
|
|
|
-from authlib.integrations.starlette_client import OAuth
|
|
|
|
-from authlib.oidc.core import UserInfo
|
|
|
|
-
|
|
|
|
-
|
|
|
|
from open_webui.config import (
|
|
from open_webui.config import (
|
|
CACHE_DIR,
|
|
CACHE_DIR,
|
|
CORS_ALLOW_ORIGIN,
|
|
CORS_ALLOW_ORIGIN,
|
|
@@ -67,13 +71,11 @@ from open_webui.config import (
|
|
ENABLE_ADMIN_CHAT_ACCESS,
|
|
ENABLE_ADMIN_CHAT_ACCESS,
|
|
ENABLE_ADMIN_EXPORT,
|
|
ENABLE_ADMIN_EXPORT,
|
|
ENABLE_MODEL_FILTER,
|
|
ENABLE_MODEL_FILTER,
|
|
- ENABLE_OAUTH_SIGNUP,
|
|
|
|
ENABLE_OLLAMA_API,
|
|
ENABLE_OLLAMA_API,
|
|
ENABLE_OPENAI_API,
|
|
ENABLE_OPENAI_API,
|
|
ENV,
|
|
ENV,
|
|
FRONTEND_BUILD_DIR,
|
|
FRONTEND_BUILD_DIR,
|
|
MODEL_FILTER_LIST,
|
|
MODEL_FILTER_LIST,
|
|
- OAUTH_MERGE_ACCOUNTS_BY_EMAIL,
|
|
|
|
OAUTH_PROVIDERS,
|
|
OAUTH_PROVIDERS,
|
|
ENABLE_SEARCH_QUERY,
|
|
ENABLE_SEARCH_QUERY,
|
|
SEARCH_QUERY_GENERATION_PROMPT_TEMPLATE,
|
|
SEARCH_QUERY_GENERATION_PROMPT_TEMPLATE,
|
|
@@ -81,15 +83,15 @@ from open_webui.config import (
|
|
TASK_MODEL,
|
|
TASK_MODEL,
|
|
TASK_MODEL_EXTERNAL,
|
|
TASK_MODEL_EXTERNAL,
|
|
TITLE_GENERATION_PROMPT_TEMPLATE,
|
|
TITLE_GENERATION_PROMPT_TEMPLATE,
|
|
|
|
+ TAGS_GENERATION_PROMPT_TEMPLATE,
|
|
TOOLS_FUNCTION_CALLING_PROMPT_TEMPLATE,
|
|
TOOLS_FUNCTION_CALLING_PROMPT_TEMPLATE,
|
|
WEBHOOK_URL,
|
|
WEBHOOK_URL,
|
|
WEBUI_AUTH,
|
|
WEBUI_AUTH,
|
|
WEBUI_NAME,
|
|
WEBUI_NAME,
|
|
AppConfig,
|
|
AppConfig,
|
|
- run_migrations,
|
|
|
|
reset_config,
|
|
reset_config,
|
|
)
|
|
)
|
|
-from open_webui.constants import ERROR_MESSAGES, TASKS, WEBHOOK_MESSAGES
|
|
|
|
|
|
+from open_webui.constants import TASKS
|
|
from open_webui.env import (
|
|
from open_webui.env import (
|
|
CHANGELOG,
|
|
CHANGELOG,
|
|
GLOBAL_LOG_LEVEL,
|
|
GLOBAL_LOG_LEVEL,
|
|
@@ -102,64 +104,41 @@ from open_webui.env import (
|
|
WEBUI_SESSION_COOKIE_SECURE,
|
|
WEBUI_SESSION_COOKIE_SECURE,
|
|
WEBUI_URL,
|
|
WEBUI_URL,
|
|
RESET_CONFIG_ON_START,
|
|
RESET_CONFIG_ON_START,
|
|
|
|
+ OFFLINE_MODE,
|
|
)
|
|
)
|
|
-from fastapi import (
|
|
|
|
- Depends,
|
|
|
|
- FastAPI,
|
|
|
|
- File,
|
|
|
|
- Form,
|
|
|
|
- HTTPException,
|
|
|
|
- Request,
|
|
|
|
- UploadFile,
|
|
|
|
- status,
|
|
|
|
-)
|
|
|
|
-from fastapi.middleware.cors import CORSMiddleware
|
|
|
|
-from fastapi.responses import JSONResponse
|
|
|
|
-from fastapi.staticfiles import StaticFiles
|
|
|
|
-from pydantic import BaseModel
|
|
|
|
-from sqlalchemy import text
|
|
|
|
-from starlette.exceptions import HTTPException as StarletteHTTPException
|
|
|
|
-from starlette.middleware.base import BaseHTTPMiddleware
|
|
|
|
-from starlette.middleware.sessions import SessionMiddleware
|
|
|
|
-from starlette.responses import RedirectResponse, Response, StreamingResponse
|
|
|
|
-
|
|
|
|
-from open_webui.utils.security_headers import SecurityHeadersMiddleware
|
|
|
|
-
|
|
|
|
from open_webui.utils.misc import (
|
|
from open_webui.utils.misc import (
|
|
add_or_update_system_message,
|
|
add_or_update_system_message,
|
|
get_last_user_message,
|
|
get_last_user_message,
|
|
- parse_duration,
|
|
|
|
prepend_to_first_user_message_content,
|
|
prepend_to_first_user_message_content,
|
|
)
|
|
)
|
|
|
|
+from open_webui.utils.oauth import oauth_manager
|
|
|
|
+from open_webui.utils.payload import convert_payload_openai_to_ollama
|
|
|
|
+from open_webui.utils.response import (
|
|
|
|
+ convert_response_ollama_to_openai,
|
|
|
|
+ convert_streaming_response_ollama_to_openai,
|
|
|
|
+)
|
|
|
|
+from open_webui.utils.security_headers import SecurityHeadersMiddleware
|
|
from open_webui.utils.task import (
|
|
from open_webui.utils.task import (
|
|
moa_response_generation_template,
|
|
moa_response_generation_template,
|
|
|
|
+ tags_generation_template,
|
|
search_query_generation_template,
|
|
search_query_generation_template,
|
|
|
|
+ emoji_generation_template,
|
|
title_generation_template,
|
|
title_generation_template,
|
|
tools_function_calling_generation_template,
|
|
tools_function_calling_generation_template,
|
|
)
|
|
)
|
|
from open_webui.utils.tools import get_tools
|
|
from open_webui.utils.tools import get_tools
|
|
from open_webui.utils.utils import (
|
|
from open_webui.utils.utils import (
|
|
- create_token,
|
|
|
|
decode_token,
|
|
decode_token,
|
|
get_admin_user,
|
|
get_admin_user,
|
|
get_current_user,
|
|
get_current_user,
|
|
get_http_authorization_cred,
|
|
get_http_authorization_cred,
|
|
- get_password_hash,
|
|
|
|
get_verified_user,
|
|
get_verified_user,
|
|
)
|
|
)
|
|
-from open_webui.utils.webhook import post_webhook
|
|
|
|
-
|
|
|
|
-from open_webui.utils.payload import convert_payload_openai_to_ollama
|
|
|
|
-from open_webui.utils.response import (
|
|
|
|
- convert_response_ollama_to_openai,
|
|
|
|
- convert_streaming_response_ollama_to_openai,
|
|
|
|
-)
|
|
|
|
|
|
|
|
if SAFE_MODE:
|
|
if SAFE_MODE:
|
|
print("SAFE MODE ENABLED")
|
|
print("SAFE MODE ENABLED")
|
|
Functions.deactivate_all_functions()
|
|
Functions.deactivate_all_functions()
|
|
|
|
|
|
-
|
|
|
|
logging.basicConfig(stream=sys.stdout, level=GLOBAL_LOG_LEVEL)
|
|
logging.basicConfig(stream=sys.stdout, level=GLOBAL_LOG_LEVEL)
|
|
log = logging.getLogger(__name__)
|
|
log = logging.getLogger(__name__)
|
|
log.setLevel(SRC_LOG_LEVELS["MAIN"])
|
|
log.setLevel(SRC_LOG_LEVELS["MAIN"])
|
|
@@ -178,14 +157,14 @@ class SPAStaticFiles(StaticFiles):
|
|
|
|
|
|
print(
|
|
print(
|
|
rf"""
|
|
rf"""
|
|
- ___ __ __ _ _ _ ___
|
|
|
|
|
|
+ ___ __ __ _ _ _ ___
|
|
/ _ \ _ __ ___ _ __ \ \ / /__| |__ | | | |_ _|
|
|
/ _ \ _ __ ___ _ __ \ \ / /__| |__ | | | |_ _|
|
|
-| | | | '_ \ / _ \ '_ \ \ \ /\ / / _ \ '_ \| | | || |
|
|
|
|
-| |_| | |_) | __/ | | | \ V V / __/ |_) | |_| || |
|
|
|
|
|
|
+| | | | '_ \ / _ \ '_ \ \ \ /\ / / _ \ '_ \| | | || |
|
|
|
|
+| |_| | |_) | __/ | | | \ V V / __/ |_) | |_| || |
|
|
\___/| .__/ \___|_| |_| \_/\_/ \___|_.__/ \___/|___|
|
|
\___/| .__/ \___|_| |_| \_/\_/ \___|_.__/ \___/|___|
|
|
- |_|
|
|
|
|
|
|
+ |_|
|
|
|
|
+
|
|
|
|
|
|
-
|
|
|
|
v{VERSION} - building the best open-source AI user interface.
|
|
v{VERSION} - building the best open-source AI user interface.
|
|
{f"Commit: {WEBUI_BUILD_HASH}" if WEBUI_BUILD_HASH != "dev-build" else ""}
|
|
{f"Commit: {WEBUI_BUILD_HASH}" if WEBUI_BUILD_HASH != "dev-build" else ""}
|
|
https://github.com/open-webui/open-webui
|
|
https://github.com/open-webui/open-webui
|
|
@@ -216,10 +195,10 @@ app.state.config.MODEL_FILTER_LIST = MODEL_FILTER_LIST
|
|
|
|
|
|
app.state.config.WEBHOOK_URL = WEBHOOK_URL
|
|
app.state.config.WEBHOOK_URL = WEBHOOK_URL
|
|
|
|
|
|
-
|
|
|
|
app.state.config.TASK_MODEL = TASK_MODEL
|
|
app.state.config.TASK_MODEL = TASK_MODEL
|
|
app.state.config.TASK_MODEL_EXTERNAL = TASK_MODEL_EXTERNAL
|
|
app.state.config.TASK_MODEL_EXTERNAL = TASK_MODEL_EXTERNAL
|
|
app.state.config.TITLE_GENERATION_PROMPT_TEMPLATE = TITLE_GENERATION_PROMPT_TEMPLATE
|
|
app.state.config.TITLE_GENERATION_PROMPT_TEMPLATE = TITLE_GENERATION_PROMPT_TEMPLATE
|
|
|
|
+app.state.config.TAGS_GENERATION_PROMPT_TEMPLATE = TAGS_GENERATION_PROMPT_TEMPLATE
|
|
app.state.config.SEARCH_QUERY_GENERATION_PROMPT_TEMPLATE = (
|
|
app.state.config.SEARCH_QUERY_GENERATION_PROMPT_TEMPLATE = (
|
|
SEARCH_QUERY_GENERATION_PROMPT_TEMPLATE
|
|
SEARCH_QUERY_GENERATION_PROMPT_TEMPLATE
|
|
)
|
|
)
|
|
@@ -577,7 +556,7 @@ class ChatCompletionMiddleware(BaseHTTPMiddleware):
|
|
}
|
|
}
|
|
|
|
|
|
# Initialize data_items to store additional data to be sent to the client
|
|
# Initialize data_items to store additional data to be sent to the client
|
|
- # Initalize contexts and citation
|
|
|
|
|
|
+ # Initialize contexts and citation
|
|
data_items = []
|
|
data_items = []
|
|
contexts = []
|
|
contexts = []
|
|
citations = []
|
|
citations = []
|
|
@@ -689,6 +668,7 @@ class ChatCompletionMiddleware(BaseHTTPMiddleware):
|
|
|
|
|
|
app.add_middleware(ChatCompletionMiddleware)
|
|
app.add_middleware(ChatCompletionMiddleware)
|
|
|
|
|
|
|
|
+
|
|
##################################
|
|
##################################
|
|
#
|
|
#
|
|
# Pipeline Middleware
|
|
# Pipeline Middleware
|
|
@@ -824,6 +804,32 @@ class PipelineMiddleware(BaseHTTPMiddleware):
|
|
app.add_middleware(PipelineMiddleware)
|
|
app.add_middleware(PipelineMiddleware)
|
|
|
|
|
|
|
|
|
|
|
|
+from urllib.parse import urlencode, parse_qs, urlparse
|
|
|
|
+
|
|
|
|
+
|
|
|
|
+class RedirectMiddleware(BaseHTTPMiddleware):
|
|
|
|
+ async def dispatch(self, request: Request, call_next):
|
|
|
|
+ # Check if the request is a GET request
|
|
|
|
+ if request.method == "GET":
|
|
|
|
+ path = request.url.path
|
|
|
|
+ query_params = dict(parse_qs(urlparse(str(request.url)).query))
|
|
|
|
+
|
|
|
|
+ # Check for the specific watch path and the presence of 'v' parameter
|
|
|
|
+ if path.endswith("/watch") and "v" in query_params:
|
|
|
|
+ video_id = query_params["v"][0] # Extract the first 'v' parameter
|
|
|
|
+ encoded_video_id = urlencode({"youtube": video_id})
|
|
|
|
+ redirect_url = f"/?{encoded_video_id}"
|
|
|
|
+ return RedirectResponse(url=redirect_url)
|
|
|
|
+
|
|
|
|
+ # Proceed with the normal flow of other requests
|
|
|
|
+ response = await call_next(request)
|
|
|
|
+ return response
|
|
|
|
+
|
|
|
|
+
|
|
|
|
+# Add the middleware to the app
|
|
|
|
+app.add_middleware(RedirectMiddleware)
|
|
|
|
+
|
|
|
|
+
|
|
app.add_middleware(
|
|
app.add_middleware(
|
|
CORSMiddleware,
|
|
CORSMiddleware,
|
|
allow_origins=CORS_ALLOW_ORIGIN,
|
|
allow_origins=CORS_ALLOW_ORIGIN,
|
|
@@ -900,12 +906,10 @@ webui_app.state.EMBEDDING_FUNCTION = retrieval_app.state.EMBEDDING_FUNCTION
|
|
|
|
|
|
async def get_all_models():
|
|
async def get_all_models():
|
|
# TODO: Optimize this function
|
|
# TODO: Optimize this function
|
|
- pipe_models = []
|
|
|
|
|
|
+ open_webui_models = []
|
|
openai_models = []
|
|
openai_models = []
|
|
ollama_models = []
|
|
ollama_models = []
|
|
|
|
|
|
- pipe_models = await get_pipe_models()
|
|
|
|
-
|
|
|
|
if app.state.config.ENABLE_OPENAI_API:
|
|
if app.state.config.ENABLE_OPENAI_API:
|
|
openai_models = await get_openai_models()
|
|
openai_models = await get_openai_models()
|
|
openai_models = openai_models["data"]
|
|
openai_models = openai_models["data"]
|
|
@@ -924,7 +928,13 @@ async def get_all_models():
|
|
for model in ollama_models["models"]
|
|
for model in ollama_models["models"]
|
|
]
|
|
]
|
|
|
|
|
|
- models = pipe_models + openai_models + ollama_models
|
|
|
|
|
|
+ open_webui_models = await get_open_webui_models()
|
|
|
|
+
|
|
|
|
+ models = open_webui_models + openai_models + ollama_models
|
|
|
|
+
|
|
|
|
+ # If there are no models, return an empty list
|
|
|
|
+ if len([model for model in models if model["owned_by"] != "arena"]) == 0:
|
|
|
|
+ return []
|
|
|
|
|
|
global_action_ids = [
|
|
global_action_ids = [
|
|
function.id for function in Functions.get_global_action_functions()
|
|
function.id for function in Functions.get_global_action_functions()
|
|
@@ -963,11 +973,13 @@ async def get_all_models():
|
|
owned_by = model["owned_by"]
|
|
owned_by = model["owned_by"]
|
|
if "pipe" in model:
|
|
if "pipe" in model:
|
|
pipe = model["pipe"]
|
|
pipe = model["pipe"]
|
|
-
|
|
|
|
- if "info" in model and "meta" in model["info"]:
|
|
|
|
- action_ids.extend(model["info"]["meta"].get("actionIds", []))
|
|
|
|
break
|
|
break
|
|
|
|
|
|
|
|
+ if custom_model.meta:
|
|
|
|
+ meta = custom_model.meta.model_dump()
|
|
|
|
+ if "actionIds" in meta:
|
|
|
|
+ action_ids.extend(meta["actionIds"])
|
|
|
|
+
|
|
models.append(
|
|
models.append(
|
|
{
|
|
{
|
|
"id": custom_model.id,
|
|
"id": custom_model.id,
|
|
@@ -1070,7 +1082,9 @@ async def get_models(user=Depends(get_verified_user)):
|
|
|
|
|
|
|
|
|
|
@app.post("/api/chat/completions")
|
|
@app.post("/api/chat/completions")
|
|
-async def generate_chat_completions(form_data: dict, user=Depends(get_verified_user)):
|
|
|
|
|
|
+async def generate_chat_completions(
|
|
|
|
+ form_data: dict, user=Depends(get_verified_user), bypass_filter: bool = False
|
|
|
|
+):
|
|
model_id = form_data["model"]
|
|
model_id = form_data["model"]
|
|
|
|
|
|
if model_id not in app.state.MODELS:
|
|
if model_id not in app.state.MODELS:
|
|
@@ -1079,7 +1093,7 @@ async def generate_chat_completions(form_data: dict, user=Depends(get_verified_u
|
|
detail="Model not found",
|
|
detail="Model not found",
|
|
)
|
|
)
|
|
|
|
|
|
- if app.state.config.ENABLE_MODEL_FILTER:
|
|
|
|
|
|
+ if not bypass_filter and app.state.config.ENABLE_MODEL_FILTER:
|
|
if user.role == "user" and model_id not in app.state.config.MODEL_FILTER_LIST:
|
|
if user.role == "user" and model_id not in app.state.config.MODEL_FILTER_LIST:
|
|
raise HTTPException(
|
|
raise HTTPException(
|
|
status_code=status.HTTP_403_FORBIDDEN,
|
|
status_code=status.HTTP_403_FORBIDDEN,
|
|
@@ -1087,6 +1101,53 @@ async def generate_chat_completions(form_data: dict, user=Depends(get_verified_u
|
|
)
|
|
)
|
|
|
|
|
|
model = app.state.MODELS[model_id]
|
|
model = app.state.MODELS[model_id]
|
|
|
|
+
|
|
|
|
+ if model["owned_by"] == "arena":
|
|
|
|
+ model_ids = model.get("info", {}).get("meta", {}).get("model_ids")
|
|
|
|
+ filter_mode = model.get("info", {}).get("meta", {}).get("filter_mode")
|
|
|
|
+ if model_ids and filter_mode == "exclude":
|
|
|
|
+ model_ids = [
|
|
|
|
+ model["id"]
|
|
|
|
+ for model in await get_all_models()
|
|
|
|
+ if model.get("owned_by") != "arena"
|
|
|
|
+ and not model.get("info", {}).get("meta", {}).get("hidden", False)
|
|
|
|
+ and model["id"] not in model_ids
|
|
|
|
+ ]
|
|
|
|
+
|
|
|
|
+ selected_model_id = None
|
|
|
|
+ if isinstance(model_ids, list) and model_ids:
|
|
|
|
+ selected_model_id = random.choice(model_ids)
|
|
|
|
+ else:
|
|
|
|
+ model_ids = [
|
|
|
|
+ model["id"]
|
|
|
|
+ for model in await get_all_models()
|
|
|
|
+ if model.get("owned_by") != "arena"
|
|
|
|
+ and not model.get("info", {}).get("meta", {}).get("hidden", False)
|
|
|
|
+ ]
|
|
|
|
+ selected_model_id = random.choice(model_ids)
|
|
|
|
+
|
|
|
|
+ form_data["model"] = selected_model_id
|
|
|
|
+
|
|
|
|
+ if form_data.get("stream") == True:
|
|
|
|
+
|
|
|
|
+ async def stream_wrapper(stream):
|
|
|
|
+ yield f"data: {json.dumps({'selected_model_id': selected_model_id})}\n\n"
|
|
|
|
+ async for chunk in stream:
|
|
|
|
+ yield chunk
|
|
|
|
+
|
|
|
|
+ response = await generate_chat_completions(
|
|
|
|
+ form_data, user, bypass_filter=True
|
|
|
|
+ )
|
|
|
|
+ return StreamingResponse(
|
|
|
|
+ stream_wrapper(response.body_iterator), media_type="text/event-stream"
|
|
|
|
+ )
|
|
|
|
+ else:
|
|
|
|
+ return {
|
|
|
|
+ **(
|
|
|
|
+ await generate_chat_completions(form_data, user, bypass_filter=True)
|
|
|
|
+ ),
|
|
|
|
+ "selected_model_id": selected_model_id,
|
|
|
|
+ }
|
|
if model.get("pipe"):
|
|
if model.get("pipe"):
|
|
return await generate_function_chat_completion(form_data, user=user)
|
|
return await generate_function_chat_completion(form_data, user=user)
|
|
if model["owned_by"] == "ollama":
|
|
if model["owned_by"] == "ollama":
|
|
@@ -1398,6 +1459,7 @@ async def get_task_config(user=Depends(get_verified_user)):
|
|
"TASK_MODEL": app.state.config.TASK_MODEL,
|
|
"TASK_MODEL": app.state.config.TASK_MODEL,
|
|
"TASK_MODEL_EXTERNAL": app.state.config.TASK_MODEL_EXTERNAL,
|
|
"TASK_MODEL_EXTERNAL": app.state.config.TASK_MODEL_EXTERNAL,
|
|
"TITLE_GENERATION_PROMPT_TEMPLATE": app.state.config.TITLE_GENERATION_PROMPT_TEMPLATE,
|
|
"TITLE_GENERATION_PROMPT_TEMPLATE": app.state.config.TITLE_GENERATION_PROMPT_TEMPLATE,
|
|
|
|
+ "TAGS_GENERATION_PROMPT_TEMPLATE": app.state.config.TAGS_GENERATION_PROMPT_TEMPLATE,
|
|
"ENABLE_SEARCH_QUERY": app.state.config.ENABLE_SEARCH_QUERY,
|
|
"ENABLE_SEARCH_QUERY": app.state.config.ENABLE_SEARCH_QUERY,
|
|
"SEARCH_QUERY_GENERATION_PROMPT_TEMPLATE": app.state.config.SEARCH_QUERY_GENERATION_PROMPT_TEMPLATE,
|
|
"SEARCH_QUERY_GENERATION_PROMPT_TEMPLATE": app.state.config.SEARCH_QUERY_GENERATION_PROMPT_TEMPLATE,
|
|
"TOOLS_FUNCTION_CALLING_PROMPT_TEMPLATE": app.state.config.TOOLS_FUNCTION_CALLING_PROMPT_TEMPLATE,
|
|
"TOOLS_FUNCTION_CALLING_PROMPT_TEMPLATE": app.state.config.TOOLS_FUNCTION_CALLING_PROMPT_TEMPLATE,
|
|
@@ -1408,6 +1470,7 @@ class TaskConfigForm(BaseModel):
|
|
TASK_MODEL: Optional[str]
|
|
TASK_MODEL: Optional[str]
|
|
TASK_MODEL_EXTERNAL: Optional[str]
|
|
TASK_MODEL_EXTERNAL: Optional[str]
|
|
TITLE_GENERATION_PROMPT_TEMPLATE: str
|
|
TITLE_GENERATION_PROMPT_TEMPLATE: str
|
|
|
|
+ TAGS_GENERATION_PROMPT_TEMPLATE: str
|
|
SEARCH_QUERY_GENERATION_PROMPT_TEMPLATE: str
|
|
SEARCH_QUERY_GENERATION_PROMPT_TEMPLATE: str
|
|
ENABLE_SEARCH_QUERY: bool
|
|
ENABLE_SEARCH_QUERY: bool
|
|
TOOLS_FUNCTION_CALLING_PROMPT_TEMPLATE: str
|
|
TOOLS_FUNCTION_CALLING_PROMPT_TEMPLATE: str
|
|
@@ -1420,6 +1483,10 @@ async def update_task_config(form_data: TaskConfigForm, user=Depends(get_admin_u
|
|
app.state.config.TITLE_GENERATION_PROMPT_TEMPLATE = (
|
|
app.state.config.TITLE_GENERATION_PROMPT_TEMPLATE = (
|
|
form_data.TITLE_GENERATION_PROMPT_TEMPLATE
|
|
form_data.TITLE_GENERATION_PROMPT_TEMPLATE
|
|
)
|
|
)
|
|
|
|
+ app.state.config.TAGS_GENERATION_PROMPT_TEMPLATE = (
|
|
|
|
+ form_data.TAGS_GENERATION_PROMPT_TEMPLATE
|
|
|
|
+ )
|
|
|
|
+
|
|
app.state.config.SEARCH_QUERY_GENERATION_PROMPT_TEMPLATE = (
|
|
app.state.config.SEARCH_QUERY_GENERATION_PROMPT_TEMPLATE = (
|
|
form_data.SEARCH_QUERY_GENERATION_PROMPT_TEMPLATE
|
|
form_data.SEARCH_QUERY_GENERATION_PROMPT_TEMPLATE
|
|
)
|
|
)
|
|
@@ -1432,6 +1499,7 @@ async def update_task_config(form_data: TaskConfigForm, user=Depends(get_admin_u
|
|
"TASK_MODEL": app.state.config.TASK_MODEL,
|
|
"TASK_MODEL": app.state.config.TASK_MODEL,
|
|
"TASK_MODEL_EXTERNAL": app.state.config.TASK_MODEL_EXTERNAL,
|
|
"TASK_MODEL_EXTERNAL": app.state.config.TASK_MODEL_EXTERNAL,
|
|
"TITLE_GENERATION_PROMPT_TEMPLATE": app.state.config.TITLE_GENERATION_PROMPT_TEMPLATE,
|
|
"TITLE_GENERATION_PROMPT_TEMPLATE": app.state.config.TITLE_GENERATION_PROMPT_TEMPLATE,
|
|
|
|
+ "TAGS_GENERATION_PROMPT_TEMPLATE": app.state.config.TAGS_GENERATION_PROMPT_TEMPLATE,
|
|
"SEARCH_QUERY_GENERATION_PROMPT_TEMPLATE": app.state.config.SEARCH_QUERY_GENERATION_PROMPT_TEMPLATE,
|
|
"SEARCH_QUERY_GENERATION_PROMPT_TEMPLATE": app.state.config.SEARCH_QUERY_GENERATION_PROMPT_TEMPLATE,
|
|
"ENABLE_SEARCH_QUERY": app.state.config.ENABLE_SEARCH_QUERY,
|
|
"ENABLE_SEARCH_QUERY": app.state.config.ENABLE_SEARCH_QUERY,
|
|
"TOOLS_FUNCTION_CALLING_PROMPT_TEMPLATE": app.state.config.TOOLS_FUNCTION_CALLING_PROMPT_TEMPLATE,
|
|
"TOOLS_FUNCTION_CALLING_PROMPT_TEMPLATE": app.state.config.TOOLS_FUNCTION_CALLING_PROMPT_TEMPLATE,
|
|
@@ -1459,7 +1527,7 @@ async def generate_title(form_data: dict, user=Depends(get_verified_user)):
|
|
if app.state.config.TITLE_GENERATION_PROMPT_TEMPLATE != "":
|
|
if app.state.config.TITLE_GENERATION_PROMPT_TEMPLATE != "":
|
|
template = app.state.config.TITLE_GENERATION_PROMPT_TEMPLATE
|
|
template = app.state.config.TITLE_GENERATION_PROMPT_TEMPLATE
|
|
else:
|
|
else:
|
|
- template = """Create a concise, 3-5 word title with an emoji as a title for the prompt in the given language. Suitable Emojis for the summary can be used to enhance understanding but avoid quotation marks or special formatting. RESPOND ONLY WITH THE TITLE TEXT.
|
|
|
|
|
|
+ template = """Create a concise, 3-5 word title with an emoji as a title for the chat history, in the given language. Suitable Emojis for the summary can be used to enhance understanding but avoid quotation marks or special formatting. RESPOND ONLY WITH THE TITLE TEXT.
|
|
|
|
|
|
Examples of titles:
|
|
Examples of titles:
|
|
📉 Stock Market Trends
|
|
📉 Stock Market Trends
|
|
@@ -1469,11 +1537,13 @@ Remote Work Productivity Tips
|
|
Artificial Intelligence in Healthcare
|
|
Artificial Intelligence in Healthcare
|
|
🎮 Video Game Development Insights
|
|
🎮 Video Game Development Insights
|
|
|
|
|
|
-Prompt: {{prompt:middletruncate:8000}}"""
|
|
|
|
|
|
+<chat_history>
|
|
|
|
+{{MESSAGES:END:2}}
|
|
|
|
+</chat_history>"""
|
|
|
|
|
|
content = title_generation_template(
|
|
content = title_generation_template(
|
|
template,
|
|
template,
|
|
- form_data["prompt"],
|
|
|
|
|
|
+ form_data["messages"],
|
|
{
|
|
{
|
|
"name": user.name,
|
|
"name": user.name,
|
|
"location": user.info.get("location") if user.info else None,
|
|
"location": user.info.get("location") if user.info else None,
|
|
@@ -1516,6 +1586,75 @@ Prompt: {{prompt:middletruncate:8000}}"""
|
|
return await generate_chat_completions(form_data=payload, user=user)
|
|
return await generate_chat_completions(form_data=payload, user=user)
|
|
|
|
|
|
|
|
|
|
|
|
+@app.post("/api/task/tags/completions")
|
|
|
|
+async def generate_chat_tags(form_data: dict, user=Depends(get_verified_user)):
|
|
|
|
+ print("generate_chat_tags")
|
|
|
|
+ model_id = form_data["model"]
|
|
|
|
+ if model_id not in app.state.MODELS:
|
|
|
|
+ raise HTTPException(
|
|
|
|
+ status_code=status.HTTP_404_NOT_FOUND,
|
|
|
|
+ detail="Model not found",
|
|
|
|
+ )
|
|
|
|
+
|
|
|
|
+ # Check if the user has a custom task model
|
|
|
|
+ # If the user has a custom task model, use that model
|
|
|
|
+ task_model_id = get_task_model_id(model_id)
|
|
|
|
+ print(task_model_id)
|
|
|
|
+
|
|
|
|
+ if app.state.config.TAGS_GENERATION_PROMPT_TEMPLATE != "":
|
|
|
|
+ template = app.state.config.TAGS_GENERATION_PROMPT_TEMPLATE
|
|
|
|
+ else:
|
|
|
|
+ template = """### Task:
|
|
|
|
+Generate 1-3 broad tags categorizing the main themes of the chat history, along with 1-3 more specific subtopic tags.
|
|
|
|
+
|
|
|
|
+### Guidelines:
|
|
|
|
+- Start with high-level domains (e.g. Science, Technology, Philosophy, Arts, Politics, Business, Health, Sports, Entertainment, Education)
|
|
|
|
+- Consider including relevant subfields/subdomains if they are strongly represented throughout the conversation
|
|
|
|
+- If content is too short (less than 3 messages) or too diverse, use only ["General"]
|
|
|
|
+- Use the chat's primary language; default to English if multilingual
|
|
|
|
+- Prioritize accuracy over specificity
|
|
|
|
+
|
|
|
|
+### Output:
|
|
|
|
+JSON format: { "tags": ["tag1", "tag2", "tag3"] }
|
|
|
|
+
|
|
|
|
+### Chat History:
|
|
|
|
+<chat_history>
|
|
|
|
+{{MESSAGES:END:6}}
|
|
|
|
+</chat_history>"""
|
|
|
|
+
|
|
|
|
+ content = tags_generation_template(
|
|
|
|
+ template, form_data["messages"], {"name": user.name}
|
|
|
|
+ )
|
|
|
|
+
|
|
|
|
+ print("content", content)
|
|
|
|
+ payload = {
|
|
|
|
+ "model": task_model_id,
|
|
|
|
+ "messages": [{"role": "user", "content": content}],
|
|
|
|
+ "stream": False,
|
|
|
|
+ "metadata": {"task": str(TASKS.TAGS_GENERATION), "task_body": form_data},
|
|
|
|
+ }
|
|
|
|
+ log.debug(payload)
|
|
|
|
+
|
|
|
|
+ # Handle pipeline filters
|
|
|
|
+ try:
|
|
|
|
+ payload = filter_pipeline(payload, user)
|
|
|
|
+ except Exception as e:
|
|
|
|
+ if len(e.args) > 1:
|
|
|
|
+ return JSONResponse(
|
|
|
|
+ status_code=e.args[0],
|
|
|
|
+ content={"detail": e.args[1]},
|
|
|
|
+ )
|
|
|
|
+ else:
|
|
|
|
+ return JSONResponse(
|
|
|
|
+ status_code=status.HTTP_400_BAD_REQUEST,
|
|
|
|
+ content={"detail": str(e)},
|
|
|
|
+ )
|
|
|
|
+ if "chat_id" in payload:
|
|
|
|
+ del payload["chat_id"]
|
|
|
|
+
|
|
|
|
+ return await generate_chat_completions(form_data=payload, user=user)
|
|
|
|
+
|
|
|
|
+
|
|
@app.post("/api/task/query/completions")
|
|
@app.post("/api/task/query/completions")
|
|
async def generate_search_query(form_data: dict, user=Depends(get_verified_user)):
|
|
async def generate_search_query(form_data: dict, user=Depends(get_verified_user)):
|
|
print("generate_search_query")
|
|
print("generate_search_query")
|
|
@@ -1616,7 +1755,7 @@ Your task is to reflect the speaker's likely facial expression through a fitting
|
|
|
|
|
|
Message: """{{prompt}}"""
|
|
Message: """{{prompt}}"""
|
|
'''
|
|
'''
|
|
- content = title_generation_template(
|
|
|
|
|
|
+ content = emoji_generation_template(
|
|
template,
|
|
template,
|
|
form_data["prompt"],
|
|
form_data["prompt"],
|
|
{
|
|
{
|
|
@@ -2181,6 +2320,11 @@ async def get_app_changelog():
|
|
|
|
|
|
@app.get("/api/version/updates")
|
|
@app.get("/api/version/updates")
|
|
async def get_app_latest_release_version():
|
|
async def get_app_latest_release_version():
|
|
|
|
+ if OFFLINE_MODE:
|
|
|
|
+ log.debug(
|
|
|
|
+ f"Offline mode is enabled, returning current version as latest version"
|
|
|
|
+ )
|
|
|
|
+ return {"current": VERSION, "latest": VERSION}
|
|
try:
|
|
try:
|
|
timeout = aiohttp.ClientTimeout(total=1)
|
|
timeout = aiohttp.ClientTimeout(total=1)
|
|
async with aiohttp.ClientSession(timeout=timeout, trust_env=True) as session:
|
|
async with aiohttp.ClientSession(timeout=timeout, trust_env=True) as session:
|
|
@@ -2201,20 +2345,6 @@ async def get_app_latest_release_version():
|
|
# OAuth Login & Callback
|
|
# OAuth Login & Callback
|
|
############################
|
|
############################
|
|
|
|
|
|
-oauth = OAuth()
|
|
|
|
-
|
|
|
|
-for provider_name, provider_config in OAUTH_PROVIDERS.items():
|
|
|
|
- oauth.register(
|
|
|
|
- name=provider_name,
|
|
|
|
- client_id=provider_config["client_id"],
|
|
|
|
- client_secret=provider_config["client_secret"],
|
|
|
|
- server_metadata_url=provider_config["server_metadata_url"],
|
|
|
|
- client_kwargs={
|
|
|
|
- "scope": provider_config["scope"],
|
|
|
|
- },
|
|
|
|
- redirect_uri=provider_config["redirect_uri"],
|
|
|
|
- )
|
|
|
|
-
|
|
|
|
# SessionMiddleware is used by authlib for oauth
|
|
# SessionMiddleware is used by authlib for oauth
|
|
if len(OAUTH_PROVIDERS) > 0:
|
|
if len(OAUTH_PROVIDERS) > 0:
|
|
app.add_middleware(
|
|
app.add_middleware(
|
|
@@ -2228,16 +2358,7 @@ if len(OAUTH_PROVIDERS) > 0:
|
|
|
|
|
|
@app.get("/oauth/{provider}/login")
|
|
@app.get("/oauth/{provider}/login")
|
|
async def oauth_login(provider: str, request: Request):
|
|
async def oauth_login(provider: str, request: Request):
|
|
- if provider not in OAUTH_PROVIDERS:
|
|
|
|
- raise HTTPException(404)
|
|
|
|
- # If the provider has a custom redirect URL, use that, otherwise automatically generate one
|
|
|
|
- redirect_uri = OAUTH_PROVIDERS[provider].get("redirect_uri") or request.url_for(
|
|
|
|
- "oauth_callback", provider=provider
|
|
|
|
- )
|
|
|
|
- client = oauth.create_client(provider)
|
|
|
|
- if client is None:
|
|
|
|
- raise HTTPException(404)
|
|
|
|
- return await client.authorize_redirect(request, redirect_uri)
|
|
|
|
|
|
+ return await oauth_manager.handle_login(provider, request)
|
|
|
|
|
|
|
|
|
|
# OAuth login logic is as follows:
|
|
# OAuth login logic is as follows:
|
|
@@ -2245,119 +2366,10 @@ async def oauth_login(provider: str, request: Request):
|
|
# 2. If OAUTH_MERGE_ACCOUNTS_BY_EMAIL is true, find a user with the email address provided via OAuth
|
|
# 2. If OAUTH_MERGE_ACCOUNTS_BY_EMAIL is true, find a user with the email address provided via OAuth
|
|
# - This is considered insecure in general, as OAuth providers do not always verify email addresses
|
|
# - This is considered insecure in general, as OAuth providers do not always verify email addresses
|
|
# 3. If there is no user, and ENABLE_OAUTH_SIGNUP is true, create a user
|
|
# 3. If there is no user, and ENABLE_OAUTH_SIGNUP is true, create a user
|
|
-# - Email addresses are considered unique, so we fail registration if the email address is alreayd taken
|
|
|
|
|
|
+# - Email addresses are considered unique, so we fail registration if the email address is already taken
|
|
@app.get("/oauth/{provider}/callback")
|
|
@app.get("/oauth/{provider}/callback")
|
|
async def oauth_callback(provider: str, request: Request, response: Response):
|
|
async def oauth_callback(provider: str, request: Request, response: Response):
|
|
- if provider not in OAUTH_PROVIDERS:
|
|
|
|
- raise HTTPException(404)
|
|
|
|
- client = oauth.create_client(provider)
|
|
|
|
- try:
|
|
|
|
- token = await client.authorize_access_token(request)
|
|
|
|
- except Exception as e:
|
|
|
|
- log.warning(f"OAuth callback error: {e}")
|
|
|
|
- raise HTTPException(400, detail=ERROR_MESSAGES.INVALID_CRED)
|
|
|
|
- user_data: UserInfo = token["userinfo"]
|
|
|
|
-
|
|
|
|
- sub = user_data.get("sub")
|
|
|
|
- if not sub:
|
|
|
|
- log.warning(f"OAuth callback failed, sub is missing: {user_data}")
|
|
|
|
- raise HTTPException(400, detail=ERROR_MESSAGES.INVALID_CRED)
|
|
|
|
- provider_sub = f"{provider}@{sub}"
|
|
|
|
- email_claim = webui_app.state.config.OAUTH_EMAIL_CLAIM
|
|
|
|
- email = user_data.get(email_claim, "").lower()
|
|
|
|
- # We currently mandate that email addresses are provided
|
|
|
|
- if not email:
|
|
|
|
- log.warning(f"OAuth callback failed, email is missing: {user_data}")
|
|
|
|
- raise HTTPException(400, detail=ERROR_MESSAGES.INVALID_CRED)
|
|
|
|
-
|
|
|
|
- # Check if the user exists
|
|
|
|
- user = Users.get_user_by_oauth_sub(provider_sub)
|
|
|
|
-
|
|
|
|
- if not user:
|
|
|
|
- # If the user does not exist, check if merging is enabled
|
|
|
|
- if OAUTH_MERGE_ACCOUNTS_BY_EMAIL.value:
|
|
|
|
- # Check if the user exists by email
|
|
|
|
- user = Users.get_user_by_email(email)
|
|
|
|
- if user:
|
|
|
|
- # Update the user with the new oauth sub
|
|
|
|
- Users.update_user_oauth_sub_by_id(user.id, provider_sub)
|
|
|
|
-
|
|
|
|
- if not user:
|
|
|
|
- # If the user does not exist, check if signups are enabled
|
|
|
|
- if ENABLE_OAUTH_SIGNUP.value:
|
|
|
|
- # Check if an existing user with the same email already exists
|
|
|
|
- existing_user = Users.get_user_by_email(user_data.get("email", "").lower())
|
|
|
|
- if existing_user:
|
|
|
|
- raise HTTPException(400, detail=ERROR_MESSAGES.EMAIL_TAKEN)
|
|
|
|
-
|
|
|
|
- picture_claim = webui_app.state.config.OAUTH_PICTURE_CLAIM
|
|
|
|
- picture_url = user_data.get(picture_claim, "")
|
|
|
|
- if picture_url:
|
|
|
|
- # Download the profile image into a base64 string
|
|
|
|
- try:
|
|
|
|
- async with aiohttp.ClientSession() as session:
|
|
|
|
- async with session.get(picture_url) as resp:
|
|
|
|
- picture = await resp.read()
|
|
|
|
- base64_encoded_picture = base64.b64encode(picture).decode(
|
|
|
|
- "utf-8"
|
|
|
|
- )
|
|
|
|
- guessed_mime_type = mimetypes.guess_type(picture_url)[0]
|
|
|
|
- if guessed_mime_type is None:
|
|
|
|
- # assume JPG, browsers are tolerant enough of image formats
|
|
|
|
- guessed_mime_type = "image/jpeg"
|
|
|
|
- picture_url = f"data:{guessed_mime_type};base64,{base64_encoded_picture}"
|
|
|
|
- except Exception as e:
|
|
|
|
- log.error(f"Error downloading profile image '{picture_url}': {e}")
|
|
|
|
- picture_url = ""
|
|
|
|
- if not picture_url:
|
|
|
|
- picture_url = "/user.png"
|
|
|
|
- username_claim = webui_app.state.config.OAUTH_USERNAME_CLAIM
|
|
|
|
- role = (
|
|
|
|
- "admin"
|
|
|
|
- if Users.get_num_users() == 0
|
|
|
|
- else webui_app.state.config.DEFAULT_USER_ROLE
|
|
|
|
- )
|
|
|
|
- user = Auths.insert_new_auth(
|
|
|
|
- email=email,
|
|
|
|
- password=get_password_hash(
|
|
|
|
- str(uuid.uuid4())
|
|
|
|
- ), # Random password, not used
|
|
|
|
- name=user_data.get(username_claim, "User"),
|
|
|
|
- profile_image_url=picture_url,
|
|
|
|
- role=role,
|
|
|
|
- oauth_sub=provider_sub,
|
|
|
|
- )
|
|
|
|
-
|
|
|
|
- if webui_app.state.config.WEBHOOK_URL:
|
|
|
|
- post_webhook(
|
|
|
|
- webui_app.state.config.WEBHOOK_URL,
|
|
|
|
- WEBHOOK_MESSAGES.USER_SIGNUP(user.name),
|
|
|
|
- {
|
|
|
|
- "action": "signup",
|
|
|
|
- "message": WEBHOOK_MESSAGES.USER_SIGNUP(user.name),
|
|
|
|
- "user": user.model_dump_json(exclude_none=True),
|
|
|
|
- },
|
|
|
|
- )
|
|
|
|
- else:
|
|
|
|
- raise HTTPException(
|
|
|
|
- status.HTTP_403_FORBIDDEN, detail=ERROR_MESSAGES.ACCESS_PROHIBITED
|
|
|
|
- )
|
|
|
|
-
|
|
|
|
- jwt_token = create_token(
|
|
|
|
- data={"id": user.id},
|
|
|
|
- expires_delta=parse_duration(webui_app.state.config.JWT_EXPIRES_IN),
|
|
|
|
- )
|
|
|
|
-
|
|
|
|
- # Set the cookie token
|
|
|
|
- response.set_cookie(
|
|
|
|
- key="token",
|
|
|
|
- value=jwt_token,
|
|
|
|
- httponly=True, # Ensures the cookie is not accessible via JavaScript
|
|
|
|
- )
|
|
|
|
-
|
|
|
|
- # Redirect back to the frontend with the JWT token
|
|
|
|
- redirect_url = f"{request.base_url}auth#token={jwt_token}"
|
|
|
|
- return RedirectResponse(url=redirect_url)
|
|
|
|
|
|
+ return await oauth_manager.handle_callback(provider, request, response)
|
|
|
|
|
|
|
|
|
|
@app.get("/manifest.json")
|
|
@app.get("/manifest.json")
|
|
@@ -2416,6 +2428,7 @@ async def healthcheck_with_db():
|
|
app.mount("/static", StaticFiles(directory=STATIC_DIR), name="static")
|
|
app.mount("/static", StaticFiles(directory=STATIC_DIR), name="static")
|
|
app.mount("/cache", StaticFiles(directory=CACHE_DIR), name="cache")
|
|
app.mount("/cache", StaticFiles(directory=CACHE_DIR), name="cache")
|
|
|
|
|
|
|
|
+
|
|
if os.path.exists(FRONTEND_BUILD_DIR):
|
|
if os.path.exists(FRONTEND_BUILD_DIR):
|
|
mimetypes.add_type("text/javascript", ".js")
|
|
mimetypes.add_type("text/javascript", ".js")
|
|
app.mount(
|
|
app.mount(
|