|
@@ -16,7 +16,6 @@ from typing import Optional
|
|
import aiohttp
|
|
import aiohttp
|
|
import requests
|
|
import requests
|
|
|
|
|
|
-
|
|
|
|
from open_webui.apps.audio.main import app as audio_app
|
|
from open_webui.apps.audio.main import app as audio_app
|
|
from open_webui.apps.images.main import app as images_app
|
|
from open_webui.apps.images.main import app as images_app
|
|
from open_webui.apps.ollama.main import app as ollama_app
|
|
from open_webui.apps.ollama.main import app as ollama_app
|
|
@@ -47,11 +46,9 @@ from open_webui.apps.webui.models.models import Models
|
|
from open_webui.apps.webui.models.users import UserModel, Users
|
|
from open_webui.apps.webui.models.users import UserModel, Users
|
|
from open_webui.apps.webui.utils import load_function_module_by_id
|
|
from open_webui.apps.webui.utils import load_function_module_by_id
|
|
|
|
|
|
-
|
|
|
|
from authlib.integrations.starlette_client import OAuth
|
|
from authlib.integrations.starlette_client import OAuth
|
|
from authlib.oidc.core import UserInfo
|
|
from authlib.oidc.core import UserInfo
|
|
|
|
|
|
-
|
|
|
|
from open_webui.config import (
|
|
from open_webui.config import (
|
|
CACHE_DIR,
|
|
CACHE_DIR,
|
|
CORS_ALLOW_ORIGIN,
|
|
CORS_ALLOW_ORIGIN,
|
|
@@ -151,7 +148,6 @@ if SAFE_MODE:
|
|
print("SAFE MODE ENABLED")
|
|
print("SAFE MODE ENABLED")
|
|
Functions.deactivate_all_functions()
|
|
Functions.deactivate_all_functions()
|
|
|
|
|
|
-
|
|
|
|
logging.basicConfig(stream=sys.stdout, level=GLOBAL_LOG_LEVEL)
|
|
logging.basicConfig(stream=sys.stdout, level=GLOBAL_LOG_LEVEL)
|
|
log = logging.getLogger(__name__)
|
|
log = logging.getLogger(__name__)
|
|
log.setLevel(SRC_LOG_LEVELS["MAIN"])
|
|
log.setLevel(SRC_LOG_LEVELS["MAIN"])
|
|
@@ -210,7 +206,6 @@ app.state.config.MODEL_FILTER_LIST = MODEL_FILTER_LIST
|
|
|
|
|
|
app.state.config.WEBHOOK_URL = WEBHOOK_URL
|
|
app.state.config.WEBHOOK_URL = WEBHOOK_URL
|
|
|
|
|
|
-
|
|
|
|
app.state.config.TASK_MODEL = TASK_MODEL
|
|
app.state.config.TASK_MODEL = TASK_MODEL
|
|
app.state.config.TASK_MODEL_EXTERNAL = TASK_MODEL_EXTERNAL
|
|
app.state.config.TASK_MODEL_EXTERNAL = TASK_MODEL_EXTERNAL
|
|
app.state.config.TITLE_GENERATION_PROMPT_TEMPLATE = TITLE_GENERATION_PROMPT_TEMPLATE
|
|
app.state.config.TITLE_GENERATION_PROMPT_TEMPLATE = TITLE_GENERATION_PROMPT_TEMPLATE
|
|
@@ -238,14 +233,14 @@ def get_task_model_id(default_model_id):
|
|
# Check if the user has a custom task model and use that model
|
|
# Check if the user has a custom task model and use that model
|
|
if app.state.MODELS[task_model_id]["owned_by"] == "ollama":
|
|
if app.state.MODELS[task_model_id]["owned_by"] == "ollama":
|
|
if (
|
|
if (
|
|
- app.state.config.TASK_MODEL
|
|
|
|
- and app.state.config.TASK_MODEL in app.state.MODELS
|
|
|
|
|
|
+ app.state.config.TASK_MODEL
|
|
|
|
+ and app.state.config.TASK_MODEL in app.state.MODELS
|
|
):
|
|
):
|
|
task_model_id = app.state.config.TASK_MODEL
|
|
task_model_id = app.state.config.TASK_MODEL
|
|
else:
|
|
else:
|
|
if (
|
|
if (
|
|
- app.state.config.TASK_MODEL_EXTERNAL
|
|
|
|
- and app.state.config.TASK_MODEL_EXTERNAL in app.state.MODELS
|
|
|
|
|
|
+ app.state.config.TASK_MODEL_EXTERNAL
|
|
|
|
+ and app.state.config.TASK_MODEL_EXTERNAL in app.state.MODELS
|
|
):
|
|
):
|
|
task_model_id = app.state.config.TASK_MODEL_EXTERNAL
|
|
task_model_id = app.state.config.TASK_MODEL_EXTERNAL
|
|
|
|
|
|
@@ -382,7 +377,7 @@ async def get_content_from_response(response) -> Optional[str]:
|
|
|
|
|
|
|
|
|
|
async def chat_completion_tools_handler(
|
|
async def chat_completion_tools_handler(
|
|
- body: dict, user: UserModel, extra_params: dict
|
|
|
|
|
|
+ body: dict, user: UserModel, extra_params: dict
|
|
) -> tuple[dict, dict]:
|
|
) -> tuple[dict, dict]:
|
|
# If tool_ids field is present, call the functions
|
|
# If tool_ids field is present, call the functions
|
|
metadata = body.get("metadata", {})
|
|
metadata = body.get("metadata", {})
|
|
@@ -608,8 +603,8 @@ class ChatCompletionMiddleware(BaseHTTPMiddleware):
|
|
if prompt is None:
|
|
if prompt is None:
|
|
raise Exception("No user message found")
|
|
raise Exception("No user message found")
|
|
if (
|
|
if (
|
|
- rag_app.state.config.RELEVANCE_THRESHOLD == 0
|
|
|
|
- and context_string.strip() == ""
|
|
|
|
|
|
+ rag_app.state.config.RELEVANCE_THRESHOLD == 0
|
|
|
|
+ and context_string.strip() == ""
|
|
):
|
|
):
|
|
log.debug(
|
|
log.debug(
|
|
f"With a 0 relevancy threshold for RAG, the context cannot be empty"
|
|
f"With a 0 relevancy threshold for RAG, the context cannot be empty"
|
|
@@ -676,6 +671,7 @@ class ChatCompletionMiddleware(BaseHTTPMiddleware):
|
|
|
|
|
|
app.add_middleware(ChatCompletionMiddleware)
|
|
app.add_middleware(ChatCompletionMiddleware)
|
|
|
|
|
|
|
|
+
|
|
##################################
|
|
##################################
|
|
#
|
|
#
|
|
# Pipeline Middleware
|
|
# Pipeline Middleware
|
|
@@ -688,15 +684,15 @@ def get_sorted_filters(model_id):
|
|
model
|
|
model
|
|
for model in app.state.MODELS.values()
|
|
for model in app.state.MODELS.values()
|
|
if "pipeline" in model
|
|
if "pipeline" in model
|
|
- and "type" in model["pipeline"]
|
|
|
|
- and model["pipeline"]["type"] == "filter"
|
|
|
|
- and (
|
|
|
|
- model["pipeline"]["pipelines"] == ["*"]
|
|
|
|
- or any(
|
|
|
|
- model_id == target_model_id
|
|
|
|
- for target_model_id in model["pipeline"]["pipelines"]
|
|
|
|
- )
|
|
|
|
- )
|
|
|
|
|
|
+ and "type" in model["pipeline"]
|
|
|
|
+ and model["pipeline"]["type"] == "filter"
|
|
|
|
+ and (
|
|
|
|
+ model["pipeline"]["pipelines"] == ["*"]
|
|
|
|
+ or any(
|
|
|
|
+ model_id == target_model_id
|
|
|
|
+ for target_model_id in model["pipeline"]["pipelines"]
|
|
|
|
+ )
|
|
|
|
+ )
|
|
]
|
|
]
|
|
sorted_filters = sorted(filters, key=lambda x: x["pipeline"]["priority"])
|
|
sorted_filters = sorted(filters, key=lambda x: x["pipeline"]["priority"])
|
|
return sorted_filters
|
|
return sorted_filters
|
|
@@ -798,7 +794,6 @@ class PipelineMiddleware(BaseHTTPMiddleware):
|
|
|
|
|
|
app.add_middleware(PipelineMiddleware)
|
|
app.add_middleware(PipelineMiddleware)
|
|
|
|
|
|
-
|
|
|
|
app.add_middleware(
|
|
app.add_middleware(
|
|
CORSMiddleware,
|
|
CORSMiddleware,
|
|
allow_origins=CORS_ALLOW_ORIGIN,
|
|
allow_origins=CORS_ALLOW_ORIGIN,
|
|
@@ -844,8 +839,8 @@ async def update_embedding_function(request: Request, call_next):
|
|
@app.middleware("http")
|
|
@app.middleware("http")
|
|
async def inspect_websocket(request: Request, call_next):
|
|
async def inspect_websocket(request: Request, call_next):
|
|
if (
|
|
if (
|
|
- "/ws/socket.io" in request.url.path
|
|
|
|
- and request.query_params.get("transport") == "websocket"
|
|
|
|
|
|
+ "/ws/socket.io" in request.url.path
|
|
|
|
+ and request.query_params.get("transport") == "websocket"
|
|
):
|
|
):
|
|
upgrade = (request.headers.get("Upgrade") or "").lower()
|
|
upgrade = (request.headers.get("Upgrade") or "").lower()
|
|
connection = (request.headers.get("Connection") or "").lower().split(",")
|
|
connection = (request.headers.get("Connection") or "").lower().split(",")
|
|
@@ -913,8 +908,8 @@ async def get_all_models():
|
|
if custom_model.base_model_id is None:
|
|
if custom_model.base_model_id is None:
|
|
for model in models:
|
|
for model in models:
|
|
if (
|
|
if (
|
|
- custom_model.id == model["id"]
|
|
|
|
- or custom_model.id == model["id"].split(":")[0]
|
|
|
|
|
|
+ custom_model.id == model["id"]
|
|
|
|
+ or custom_model.id == model["id"].split(":")[0]
|
|
):
|
|
):
|
|
model["name"] = custom_model.name
|
|
model["name"] = custom_model.name
|
|
model["info"] = custom_model.model_dump()
|
|
model["info"] = custom_model.model_dump()
|
|
@@ -931,8 +926,8 @@ async def get_all_models():
|
|
|
|
|
|
for model in models:
|
|
for model in models:
|
|
if (
|
|
if (
|
|
- custom_model.base_model_id == model["id"]
|
|
|
|
- or custom_model.base_model_id == model["id"].split(":")[0]
|
|
|
|
|
|
+ custom_model.base_model_id == model["id"]
|
|
|
|
+ or custom_model.base_model_id == model["id"].split(":")[0]
|
|
):
|
|
):
|
|
owned_by = model["owned_by"]
|
|
owned_by = model["owned_by"]
|
|
if "pipe" in model:
|
|
if "pipe" in model:
|
|
@@ -1727,7 +1722,7 @@ async def get_pipelines_list(user=Depends(get_admin_user)):
|
|
|
|
|
|
@app.post("/api/pipelines/upload")
|
|
@app.post("/api/pipelines/upload")
|
|
async def upload_pipeline(
|
|
async def upload_pipeline(
|
|
- urlIdx: int = Form(...), file: UploadFile = File(...), user=Depends(get_admin_user)
|
|
|
|
|
|
+ urlIdx: int = Form(...), file: UploadFile = File(...), user=Depends(get_admin_user)
|
|
):
|
|
):
|
|
print("upload_pipeline", urlIdx, file.filename)
|
|
print("upload_pipeline", urlIdx, file.filename)
|
|
# Check if the uploaded file is a python file
|
|
# Check if the uploaded file is a python file
|
|
@@ -1904,9 +1899,9 @@ async def get_pipelines(urlIdx: Optional[int] = None, user=Depends(get_admin_use
|
|
|
|
|
|
@app.get("/api/pipelines/{pipeline_id}/valves")
|
|
@app.get("/api/pipelines/{pipeline_id}/valves")
|
|
async def get_pipeline_valves(
|
|
async def get_pipeline_valves(
|
|
- urlIdx: Optional[int],
|
|
|
|
- pipeline_id: str,
|
|
|
|
- user=Depends(get_admin_user),
|
|
|
|
|
|
+ urlIdx: Optional[int],
|
|
|
|
+ pipeline_id: str,
|
|
|
|
+ user=Depends(get_admin_user),
|
|
):
|
|
):
|
|
r = None
|
|
r = None
|
|
try:
|
|
try:
|
|
@@ -1942,9 +1937,9 @@ async def get_pipeline_valves(
|
|
|
|
|
|
@app.get("/api/pipelines/{pipeline_id}/valves/spec")
|
|
@app.get("/api/pipelines/{pipeline_id}/valves/spec")
|
|
async def get_pipeline_valves_spec(
|
|
async def get_pipeline_valves_spec(
|
|
- urlIdx: Optional[int],
|
|
|
|
- pipeline_id: str,
|
|
|
|
- user=Depends(get_admin_user),
|
|
|
|
|
|
+ urlIdx: Optional[int],
|
|
|
|
+ pipeline_id: str,
|
|
|
|
+ user=Depends(get_admin_user),
|
|
):
|
|
):
|
|
r = None
|
|
r = None
|
|
try:
|
|
try:
|
|
@@ -1979,10 +1974,10 @@ async def get_pipeline_valves_spec(
|
|
|
|
|
|
@app.post("/api/pipelines/{pipeline_id}/valves/update")
|
|
@app.post("/api/pipelines/{pipeline_id}/valves/update")
|
|
async def update_pipeline_valves(
|
|
async def update_pipeline_valves(
|
|
- urlIdx: Optional[int],
|
|
|
|
- pipeline_id: str,
|
|
|
|
- form_data: dict,
|
|
|
|
- user=Depends(get_admin_user),
|
|
|
|
|
|
+ urlIdx: Optional[int],
|
|
|
|
+ pipeline_id: str,
|
|
|
|
+ form_data: dict,
|
|
|
|
+ user=Depends(get_admin_user),
|
|
):
|
|
):
|
|
r = None
|
|
r = None
|
|
try:
|
|
try:
|
|
@@ -2106,7 +2101,7 @@ class ModelFilterConfigForm(BaseModel):
|
|
|
|
|
|
@app.post("/api/config/model/filter")
|
|
@app.post("/api/config/model/filter")
|
|
async def update_model_filter_config(
|
|
async def update_model_filter_config(
|
|
- form_data: ModelFilterConfigForm, user=Depends(get_admin_user)
|
|
|
|
|
|
+ form_data: ModelFilterConfigForm, user=Depends(get_admin_user)
|
|
):
|
|
):
|
|
app.state.config.ENABLE_MODEL_FILTER = form_data.enabled
|
|
app.state.config.ENABLE_MODEL_FILTER = form_data.enabled
|
|
app.state.config.MODEL_FILTER_LIST = form_data.models
|
|
app.state.config.MODEL_FILTER_LIST = form_data.models
|
|
@@ -2155,7 +2150,7 @@ async def get_app_latest_release_version():
|
|
try:
|
|
try:
|
|
async with aiohttp.ClientSession(trust_env=True) as session:
|
|
async with aiohttp.ClientSession(trust_env=True) as session:
|
|
async with session.get(
|
|
async with session.get(
|
|
- "https://api.github.com/repos/open-webui/open-webui/releases/latest"
|
|
|
|
|
|
+ "https://api.github.com/repos/open-webui/open-webui/releases/latest"
|
|
) as response:
|
|
) as response:
|
|
response.raise_for_status()
|
|
response.raise_for_status()
|
|
data = await response.json()
|
|
data = await response.json()
|
|
@@ -2198,6 +2193,53 @@ if len(OAUTH_PROVIDERS) > 0:
|
|
)
|
|
)
|
|
|
|
|
|
|
|
|
|
|
|
+def get_user_role(user: UserModel, user_data: UserInfo) -> str:
|
|
|
|
+ if user and Users.get_num_users() == 1:
|
|
|
|
+ # If the user is the only user, assign the role "admin" - actually repairs role for single user on login
|
|
|
|
+ return "admin"
|
|
|
|
+ if not user and Users.get_num_users() == 0:
|
|
|
|
+ # If there are no users, assign the role "admin", as the first user will be an admin
|
|
|
|
+ return "admin"
|
|
|
|
+
|
|
|
|
+ if webui_app.state.config.ENABLE_OAUTH_ROLE_MANAGEMENT:
|
|
|
|
+ oauth_claim = webui_app.state.config.OAUTH_ROLES_CLAIM
|
|
|
|
+ oauth_allowed_roles = webui_app.state.config.OAUTH_ALLOWED_ROLES
|
|
|
|
+ oauth_admin_roles = webui_app.state.config.OAUTH_ADMIN_ROLES
|
|
|
|
+ oauth_roles = None
|
|
|
|
+ role = "pending" # Default/fallback role if no matching roles are found
|
|
|
|
+
|
|
|
|
+ # Next block extracts the roles from the user data, accepting nested claims of any depth
|
|
|
|
+ if oauth_claim and oauth_allowed_roles and oauth_admin_roles:
|
|
|
|
+ claim_data = user_data
|
|
|
|
+ nested_claims = oauth_claim.split(".")
|
|
|
|
+ for nested_claim in nested_claims:
|
|
|
|
+ claim_data = claim_data.get(nested_claim, {})
|
|
|
|
+ oauth_roles = claim_data if isinstance(claim_data, list) else None
|
|
|
|
+
|
|
|
|
+ # If any roles are found, check if they match the allowed or admin roles
|
|
|
|
+ if oauth_roles:
|
|
|
|
+ # If role management is enabled, and matching roles are provided, use the roles
|
|
|
|
+ for allowed_role in oauth_allowed_roles:
|
|
|
|
+ # If the user has any of the allowed roles, assign the role "user"
|
|
|
|
+ if allowed_role in oauth_roles:
|
|
|
|
+ role = "user"
|
|
|
|
+ break
|
|
|
|
+ for admin_role in oauth_admin_roles:
|
|
|
|
+ # If the user has any of the admin roles, assign the role "admin"
|
|
|
|
+ if admin_role in oauth_roles:
|
|
|
|
+ role = "admin"
|
|
|
|
+ break
|
|
|
|
+ else:
|
|
|
|
+ if not user:
|
|
|
|
+ # If role management is disabled, use the default role for new users
|
|
|
|
+ role = webui_app.state.config.DEFAULT_USER_ROLE
|
|
|
|
+ else:
|
|
|
|
+ # If role management is disabled, use the existing role for existing users
|
|
|
|
+ role = user.role
|
|
|
|
+
|
|
|
|
+ return role
|
|
|
|
+
|
|
|
|
+
|
|
@app.get("/oauth/{provider}/login")
|
|
@app.get("/oauth/{provider}/login")
|
|
async def oauth_login(provider: str, request: Request):
|
|
async def oauth_login(provider: str, request: Request):
|
|
if provider not in OAUTH_PROVIDERS:
|
|
if provider not in OAUTH_PROVIDERS:
|
|
@@ -2244,34 +2286,6 @@ async def oauth_callback(provider: str, request: Request, response: Response):
|
|
|
|
|
|
# Check if the user exists
|
|
# Check if the user exists
|
|
user = Users.get_user_by_oauth_sub(provider_sub)
|
|
user = Users.get_user_by_oauth_sub(provider_sub)
|
|
- # print all user data content for debugging
|
|
|
|
- log.info(f"User data: {user_data}")
|
|
|
|
- if user:
|
|
|
|
- role = user.role
|
|
|
|
- if Users.get_num_users() == 1:
|
|
|
|
- role = "admin"
|
|
|
|
- elif webui_app.state.config.ENABLE_OAUTH_ROLE_MAPPING:
|
|
|
|
- oauth_claim = webui_app.state.config.OAUTH_ROLES_CLAIM
|
|
|
|
- oauth_roles = None
|
|
|
|
-
|
|
|
|
- if oauth_claim:
|
|
|
|
- claim_data = user_data
|
|
|
|
- nested_claims = oauth_claim.split(".")
|
|
|
|
- for nested_claim in nested_claims:
|
|
|
|
- claim_data = claim_data.get(nested_claim, {})
|
|
|
|
- oauth_roles = claim_data if isinstance(claim_data, list) else None
|
|
|
|
-
|
|
|
|
- log.info(f"User {user.name} has OAuth roles: {oauth_roles}")
|
|
|
|
- if oauth_roles:
|
|
|
|
- for allowed_role in ["pending", "user", "admin"]:
|
|
|
|
- role = allowed_role if allowed_role in oauth_roles else role
|
|
|
|
- else:
|
|
|
|
- # If role mapping is enabled, but no roles are provided, fall back to pending
|
|
|
|
- role = "pending"
|
|
|
|
- log.info(f"Applied role: {role} to user {user.name}")
|
|
|
|
-
|
|
|
|
- if role != user.role:
|
|
|
|
- Users.update_user_role_by_id(user.id, role)
|
|
|
|
|
|
|
|
if not user:
|
|
if not user:
|
|
# If the user does not exist, check if merging is enabled
|
|
# If the user does not exist, check if merging is enabled
|
|
@@ -2282,6 +2296,11 @@ async def oauth_callback(provider: str, request: Request, response: Response):
|
|
# Update the user with the new oauth sub
|
|
# Update the user with the new oauth sub
|
|
Users.update_user_oauth_sub_by_id(user.id, provider_sub)
|
|
Users.update_user_oauth_sub_by_id(user.id, provider_sub)
|
|
|
|
|
|
|
|
+ if user:
|
|
|
|
+ determined_role = get_user_role(user, user_data)
|
|
|
|
+ if user.role != determined_role:
|
|
|
|
+ Users.update_user_role_by_id(user.id, determined_role)
|
|
|
|
+
|
|
if not user:
|
|
if not user:
|
|
# If the user does not exist, check if signups are enabled
|
|
# If the user does not exist, check if signups are enabled
|
|
if ENABLE_OAUTH_SIGNUP.value:
|
|
if ENABLE_OAUTH_SIGNUP.value:
|
|
@@ -2313,17 +2332,7 @@ async def oauth_callback(provider: str, request: Request, response: Response):
|
|
picture_url = "/user.png"
|
|
picture_url = "/user.png"
|
|
username_claim = webui_app.state.config.OAUTH_USERNAME_CLAIM
|
|
username_claim = webui_app.state.config.OAUTH_USERNAME_CLAIM
|
|
|
|
|
|
- role = webui_app.state.config.DEFAULT_USER_ROLE
|
|
|
|
- if Users.get_num_users() == 0:
|
|
|
|
- role = "admin"
|
|
|
|
- elif webui_app.state.config.ENABLE_OAUTH_ROLE_MAPPING:
|
|
|
|
- oauth_roles = user_data.get(webui_app.state.config.OAUTH_ROLE_CLAIM)
|
|
|
|
- if oauth_roles:
|
|
|
|
- for allowed_role in ["pending", "user", "admin"]:
|
|
|
|
- role = allowed_role if allowed_role in oauth_roles else role
|
|
|
|
- else:
|
|
|
|
- # If role mapping is enabled, but no roles are provided, fall back to pending
|
|
|
|
- role = "pending"
|
|
|
|
|
|
+ role = get_user_role(None, user_data)
|
|
|
|
|
|
user = Auths.insert_new_auth(
|
|
user = Auths.insert_new_auth(
|
|
email=email,
|
|
email=email,
|