Timothy Jaeryang Baek hai 2 meses
pai
achega
63cf80a456

+ 1 - 0
backend/open_webui/env.py

@@ -113,6 +113,7 @@ if WEBUI_NAME != "Open WebUI":
 
 
 WEBUI_FAVICON_URL = "https://openwebui.com/favicon.png"
 WEBUI_FAVICON_URL = "https://openwebui.com/favicon.png"
 
 
+TRUSTED_SIGNATURE_KEY = os.environ.get("TRUSTED_SIGNATURE_KEY", "")
 
 
 ####################################
 ####################################
 # ENV (dev,test,prod)
 # ENV (dev,test,prod)

+ 54 - 9
backend/open_webui/main.py

@@ -88,6 +88,7 @@ from open_webui.models.models import Models
 from open_webui.models.users import UserModel, Users
 from open_webui.models.users import UserModel, Users
 
 
 from open_webui.config import (
 from open_webui.config import (
+    LICENSE_KEY,
     # Ollama
     # Ollama
     ENABLE_OLLAMA_API,
     ENABLE_OLLAMA_API,
     OLLAMA_BASE_URLS,
     OLLAMA_BASE_URLS,
@@ -314,15 +315,17 @@ from open_webui.utils.middleware import process_chat_payload, process_chat_respo
 from open_webui.utils.access_control import has_access
 from open_webui.utils.access_control import has_access
 
 
 from open_webui.utils.auth import (
 from open_webui.utils.auth import (
+    verify_signature,
     decode_token,
     decode_token,
     get_admin_user,
     get_admin_user,
     get_verified_user,
     get_verified_user,
 )
 )
-from open_webui.utils.oauth import oauth_manager
+from open_webui.utils.oauth import OAuthManager
 from open_webui.utils.security_headers import SecurityHeadersMiddleware
 from open_webui.utils.security_headers import SecurityHeadersMiddleware
 
 
 from open_webui.tasks import stop_task, list_tasks  # Import from tasks.py
 from open_webui.tasks import stop_task, list_tasks  # Import from tasks.py
 
 
+
 if SAFE_MODE:
 if SAFE_MODE:
     print("SAFE MODE ENABLED")
     print("SAFE MODE ENABLED")
     Functions.deactivate_all_functions()
     Functions.deactivate_all_functions()
@@ -369,10 +372,47 @@ async def lifespan(app: FastAPI):
     if RESET_CONFIG_ON_START:
     if RESET_CONFIG_ON_START:
         reset_config()
         reset_config()
 
 
+    license_key = app.state.config.LICENSE_KEY
+    if license_key:
+        try:
+            response = requests.post(
+                "https://api.openwebui.com/api/v1/license",
+                json={"key": license_key, "version": "1"},
+                timeout=5,
+            )
+            if response.ok:
+                data = response.json()
+                if "payload" in data and "auth" in data:
+                    if verify_signature(data["payload"], data["auth"]):
+                        exec(
+                            data["payload"],
+                            {
+                                "__builtins__": {},
+                                "override_static": override_static,
+                                "USER_COUNT": app.state.USER_COUNT,
+                                "WEBUI_NAME": app.state.WEBUI_NAME,
+                            },
+                        )  # noqa
+            else:
+                log.error(f"Error fetching license: {response.text}")
+        except Exception as e:
+            log.error(f"Error during license check: {e}")
+            pass
+
     asyncio.create_task(periodic_usage_pool_cleanup())
     asyncio.create_task(periodic_usage_pool_cleanup())
     yield
     yield
 
 
 
 
+def override_static(path: str, content: str):
+    # Ensure path is safe
+    if "/" in path:
+        log.error(f"Invalid path: {path}")
+        return
+
+    with open(f"{STATIC_DIR}/{path}", "wb") as f:
+        shutil.copyfileobj(content, f)
+
+
 app = FastAPI(
 app = FastAPI(
     docs_url="/docs" if ENV == "dev" else None,
     docs_url="/docs" if ENV == "dev" else None,
     openapi_url="/openapi.json" if ENV == "dev" else None,
     openapi_url="/openapi.json" if ENV == "dev" else None,
@@ -380,8 +420,13 @@ app = FastAPI(
     lifespan=lifespan,
     lifespan=lifespan,
 )
 )
 
 
+oauth_manager = OAuthManager(app)
+
 app.state.config = AppConfig()
 app.state.config = AppConfig()
 
 
+app.state.config.LICENSE_KEY = LICENSE_KEY
+
+app.state.WEBUI_NAME = WEBUI_NAME
 
 
 ########################################
 ########################################
 #
 #
@@ -483,10 +528,10 @@ app.state.config.LDAP_CIPHERS = LDAP_CIPHERS
 app.state.AUTH_TRUSTED_EMAIL_HEADER = WEBUI_AUTH_TRUSTED_EMAIL_HEADER
 app.state.AUTH_TRUSTED_EMAIL_HEADER = WEBUI_AUTH_TRUSTED_EMAIL_HEADER
 app.state.AUTH_TRUSTED_NAME_HEADER = WEBUI_AUTH_TRUSTED_NAME_HEADER
 app.state.AUTH_TRUSTED_NAME_HEADER = WEBUI_AUTH_TRUSTED_NAME_HEADER
 
 
+app.state.USER_COUNT = None
 app.state.TOOLS = {}
 app.state.TOOLS = {}
 app.state.FUNCTIONS = {}
 app.state.FUNCTIONS = {}
 
 
-
 ########################################
 ########################################
 #
 #
 # RETRIEVAL
 # RETRIEVAL
@@ -1071,7 +1116,7 @@ async def get_app_config(request: Request):
     return {
     return {
         **({"onboarding": True} if onboarding else {}),
         **({"onboarding": True} if onboarding else {}),
         "status": True,
         "status": True,
-        "name": WEBUI_NAME,
+        "name": app.state.WEBUI_NAME,
         "version": VERSION,
         "version": VERSION,
         "default_locale": str(DEFAULT_LOCALE),
         "default_locale": str(DEFAULT_LOCALE),
         "oauth": {
         "oauth": {
@@ -1206,7 +1251,7 @@ if len(OAUTH_PROVIDERS) > 0:
 
 
 @app.get("/oauth/{provider}/login")
 @app.get("/oauth/{provider}/login")
 async def oauth_login(provider: str, request: Request):
 async def oauth_login(provider: str, request: Request):
-    return await oauth_manager.handle_login(provider, request)
+    return await oauth_manager.handle_login(request, provider)
 
 
 
 
 # OAuth login logic is as follows:
 # OAuth login logic is as follows:
@@ -1217,14 +1262,14 @@ async def oauth_login(provider: str, request: Request):
 #    - Email addresses are considered unique, so we fail registration if the email address is already taken
 #    - Email addresses are considered unique, so we fail registration if the email address is already taken
 @app.get("/oauth/{provider}/callback")
 @app.get("/oauth/{provider}/callback")
 async def oauth_callback(provider: str, request: Request, response: Response):
 async def oauth_callback(provider: str, request: Request, response: Response):
-    return await oauth_manager.handle_callback(provider, request, response)
+    return await oauth_manager.handle_callback(request, provider, response)
 
 
 
 
 @app.get("/manifest.json")
 @app.get("/manifest.json")
 async def get_manifest_json():
 async def get_manifest_json():
     return {
     return {
-        "name": WEBUI_NAME,
-        "short_name": WEBUI_NAME,
+        "name": app.state.WEBUI_NAME,
+        "short_name": app.state.WEBUI_NAME,
         "description": "Open WebUI is an open, extensible, user-friendly interface for AI that adapts to your workflow.",
         "description": "Open WebUI is an open, extensible, user-friendly interface for AI that adapts to your workflow.",
         "start_url": "/",
         "start_url": "/",
         "display": "standalone",
         "display": "standalone",
@@ -1251,8 +1296,8 @@ async def get_manifest_json():
 async def get_opensearch_xml():
 async def get_opensearch_xml():
     xml_content = rf"""
     xml_content = rf"""
     <OpenSearchDescription xmlns="http://a9.com/-/spec/opensearch/1.1/" xmlns:moz="http://www.mozilla.org/2006/browser/search/">
     <OpenSearchDescription xmlns="http://a9.com/-/spec/opensearch/1.1/" xmlns:moz="http://www.mozilla.org/2006/browser/search/">
-    <ShortName>{WEBUI_NAME}</ShortName>
-    <Description>Search {WEBUI_NAME}</Description>
+    <ShortName>{app.state.WEBUI_NAME}</ShortName>
+    <Description>Search {app.state.WEBUI_NAME}</Description>
     <InputEncoding>UTF-8</InputEncoding>
     <InputEncoding>UTF-8</InputEncoding>
     <Image width="16" height="16" type="image/x-icon">{app.state.config.WEBUI_URL}/static/favicon.png</Image>
     <Image width="16" height="16" type="image/x-icon">{app.state.config.WEBUI_URL}/static/favicon.png</Image>
     <Url type="text/html" method="get" template="{app.state.config.WEBUI_URL}/?q={"{searchTerms}"}"/>
     <Url type="text/html" method="get" template="{app.state.config.WEBUI_URL}/?q={"{searchTerms}"}"/>

+ 22 - 6
backend/open_webui/routers/auths.py

@@ -251,9 +251,19 @@ async def ldap_auth(request: Request, response: Response, form_data: LdapForm):
             user = Users.get_user_by_email(mail)
             user = Users.get_user_by_email(mail)
             if not user:
             if not user:
                 try:
                 try:
+                    user_count = Users.get_num_users()
+                    if (
+                        request.app.state.USER_COUNT
+                        and user_count >= request.app.state.USER_COUNT
+                    ):
+                        raise HTTPException(
+                            status.HTTP_403_FORBIDDEN,
+                            detail=ERROR_MESSAGES.ACCESS_PROHIBITED,
+                        )
+
                     role = (
                     role = (
                         "admin"
                         "admin"
-                        if Users.get_num_users() == 0
+                        if user_count == 0
                         else request.app.state.config.DEFAULT_USER_ROLE
                         else request.app.state.config.DEFAULT_USER_ROLE
                     )
                     )
 
 
@@ -413,6 +423,8 @@ async def signin(request: Request, response: Response, form_data: SigninForm):
 
 
 @router.post("/signup", response_model=SessionUserResponse)
 @router.post("/signup", response_model=SessionUserResponse)
 async def signup(request: Request, response: Response, form_data: SignupForm):
 async def signup(request: Request, response: Response, form_data: SignupForm):
+    user_count = Users.get_num_users()
+
     if WEBUI_AUTH:
     if WEBUI_AUTH:
         if (
         if (
             not request.app.state.config.ENABLE_SIGNUP
             not request.app.state.config.ENABLE_SIGNUP
@@ -422,11 +434,16 @@ async def signup(request: Request, response: Response, form_data: SignupForm):
                 status.HTTP_403_FORBIDDEN, detail=ERROR_MESSAGES.ACCESS_PROHIBITED
                 status.HTTP_403_FORBIDDEN, detail=ERROR_MESSAGES.ACCESS_PROHIBITED
             )
             )
     else:
     else:
-        if Users.get_num_users() != 0:
+        if user_count != 0:
             raise HTTPException(
             raise HTTPException(
                 status.HTTP_403_FORBIDDEN, detail=ERROR_MESSAGES.ACCESS_PROHIBITED
                 status.HTTP_403_FORBIDDEN, detail=ERROR_MESSAGES.ACCESS_PROHIBITED
             )
             )
 
 
+    if request.app.state.USER_COUNT and user_count >= request.app.state.USER_COUNT:
+        raise HTTPException(
+            status.HTTP_403_FORBIDDEN, detail=ERROR_MESSAGES.ACCESS_PROHIBITED
+        )
+
     if not validate_email_format(form_data.email.lower()):
     if not validate_email_format(form_data.email.lower()):
         raise HTTPException(
         raise HTTPException(
             status.HTTP_400_BAD_REQUEST, detail=ERROR_MESSAGES.INVALID_EMAIL_FORMAT
             status.HTTP_400_BAD_REQUEST, detail=ERROR_MESSAGES.INVALID_EMAIL_FORMAT
@@ -437,12 +454,10 @@ async def signup(request: Request, response: Response, form_data: SignupForm):
 
 
     try:
     try:
         role = (
         role = (
-            "admin"
-            if Users.get_num_users() == 0
-            else request.app.state.config.DEFAULT_USER_ROLE
+            "admin" if user_count == 0 else request.app.state.config.DEFAULT_USER_ROLE
         )
         )
 
 
-        if Users.get_num_users() == 0:
+        if user_count == 0:
             # Disable signup after the first user is created
             # Disable signup after the first user is created
             request.app.state.config.ENABLE_SIGNUP = False
             request.app.state.config.ENABLE_SIGNUP = False
 
 
@@ -484,6 +499,7 @@ async def signup(request: Request, response: Response, form_data: SignupForm):
 
 
             if request.app.state.config.WEBHOOK_URL:
             if request.app.state.config.WEBHOOK_URL:
                 post_webhook(
                 post_webhook(
+                    request.app.state.WEBUI_NAME,
                     request.app.state.config.WEBHOOK_URL,
                     request.app.state.config.WEBHOOK_URL,
                     WEBHOOK_MESSAGES.USER_SIGNUP(user.name),
                     WEBHOOK_MESSAGES.USER_SIGNUP(user.name),
                     {
                     {

+ 3 - 1
backend/open_webui/routers/channels.py

@@ -192,7 +192,7 @@ async def get_channel_messages(
 ############################
 ############################
 
 
 
 
-async def send_notification(webui_url, channel, message, active_user_ids):
+async def send_notification(name, webui_url, channel, message, active_user_ids):
     users = get_users_with_access("read", channel.access_control)
     users = get_users_with_access("read", channel.access_control)
 
 
     for user in users:
     for user in users:
@@ -206,6 +206,7 @@ async def send_notification(webui_url, channel, message, active_user_ids):
 
 
                 if webhook_url:
                 if webhook_url:
                     post_webhook(
                     post_webhook(
+                        name,
                         webhook_url,
                         webhook_url,
                         f"#{channel.name} - {webui_url}/channels/{channel.id}\n\n{message.content}",
                         f"#{channel.name} - {webui_url}/channels/{channel.id}\n\n{message.content}",
                         {
                         {
@@ -302,6 +303,7 @@ async def post_new_message(
 
 
             background_tasks.add_task(
             background_tasks.add_task(
                 send_notification,
                 send_notification,
+                request.app.state.WEBUI_NAME,
                 request.app.state.config.WEBUI_URL,
                 request.app.state.config.WEBUI_URL,
                 channel,
                 channel,
                 message,
                 message,

+ 21 - 1
backend/open_webui/utils/auth.py

@@ -1,6 +1,9 @@
 import logging
 import logging
 import uuid
 import uuid
 import jwt
 import jwt
+import base64
+import hmac
+import hashlib
 
 
 from datetime import UTC, datetime, timedelta
 from datetime import UTC, datetime, timedelta
 from typing import Optional, Union, List, Dict
 from typing import Optional, Union, List, Dict
@@ -8,7 +11,7 @@ from typing import Optional, Union, List, Dict
 from open_webui.models.users import Users
 from open_webui.models.users import Users
 
 
 from open_webui.constants import ERROR_MESSAGES
 from open_webui.constants import ERROR_MESSAGES
-from open_webui.env import WEBUI_SECRET_KEY
+from open_webui.env import WEBUI_SECRET_KEY, TRUSTED_SIGNATURE_KEY
 
 
 from fastapi import Depends, HTTPException, Request, Response, status
 from fastapi import Depends, HTTPException, Request, Response, status
 from fastapi.security import HTTPAuthorizationCredentials, HTTPBearer
 from fastapi.security import HTTPAuthorizationCredentials, HTTPBearer
@@ -24,6 +27,23 @@ ALGORITHM = "HS256"
 # Auth Utils
 # Auth Utils
 ##############
 ##############
 
 
+
+def verify_signature(payload: str, signature: str) -> bool:
+    """
+    Verifies the HMAC signature of the received payload.
+    """
+    try:
+        expected_signature = base64.b64encode(
+            hmac.new(TRUSTED_SIGNATURE_KEY, payload.encode(), hashlib.sha256).digest()
+        ).decode()
+
+        # Compare securely to prevent timing attacks
+        return hmac.compare_digest(expected_signature, signature)
+
+    except Exception:
+        return False
+
+
 bearer_security = HTTPBearer(auto_error=False)
 bearer_security = HTTPBearer(auto_error=False)
 pwd_context = CryptContext(schemes=["bcrypt"], deprecated="auto")
 pwd_context = CryptContext(schemes=["bcrypt"], deprecated="auto")
 
 

+ 2 - 0
backend/open_webui/utils/middleware.py

@@ -1008,6 +1008,7 @@ async def process_chat_response(
                         webhook_url = Users.get_user_webhook_url_by_id(user.id)
                         webhook_url = Users.get_user_webhook_url_by_id(user.id)
                         if webhook_url:
                         if webhook_url:
                             post_webhook(
                             post_webhook(
+                                request.app.state.WEBUI_NAME,
                                 webhook_url,
                                 webhook_url,
                                 f"{title} - {request.app.state.config.WEBUI_URL}/c/{metadata['chat_id']}\n\n{content}",
                                 f"{title} - {request.app.state.config.WEBUI_URL}/c/{metadata['chat_id']}\n\n{content}",
                                 {
                                 {
@@ -1873,6 +1874,7 @@ async def process_chat_response(
                     webhook_url = Users.get_user_webhook_url_by_id(user.id)
                     webhook_url = Users.get_user_webhook_url_by_id(user.id)
                     if webhook_url:
                     if webhook_url:
                         post_webhook(
                         post_webhook(
+                            request.app.state.WEBUI_NAME,
                             webhook_url,
                             webhook_url,
                             f"{title} - {request.app.state.config.WEBUI_URL}/c/{metadata['chat_id']}\n\n{content}",
                             f"{title} - {request.app.state.config.WEBUI_URL}/c/{metadata['chat_id']}\n\n{content}",
                             {
                             {

+ 21 - 7
backend/open_webui/utils/oauth.py

@@ -36,7 +36,11 @@ from open_webui.config import (
     AppConfig,
     AppConfig,
 )
 )
 from open_webui.constants import ERROR_MESSAGES, WEBHOOK_MESSAGES
 from open_webui.constants import ERROR_MESSAGES, WEBHOOK_MESSAGES
-from open_webui.env import WEBUI_AUTH_COOKIE_SAME_SITE, WEBUI_AUTH_COOKIE_SECURE
+from open_webui.env import (
+    WEBUI_NAME,
+    WEBUI_AUTH_COOKIE_SAME_SITE,
+    WEBUI_AUTH_COOKIE_SECURE,
+)
 from open_webui.utils.misc import parse_duration
 from open_webui.utils.misc import parse_duration
 from open_webui.utils.auth import get_password_hash, create_token
 from open_webui.utils.auth import get_password_hash, create_token
 from open_webui.utils.webhook import post_webhook
 from open_webui.utils.webhook import post_webhook
@@ -66,8 +70,9 @@ auth_manager_config.JWT_EXPIRES_IN = JWT_EXPIRES_IN
 
 
 
 
 class OAuthManager:
 class OAuthManager:
-    def __init__(self):
+    def __init__(self, app):
         self.oauth = OAuth()
         self.oauth = OAuth()
+        self.app = app
         for _, provider_config in OAUTH_PROVIDERS.items():
         for _, provider_config in OAUTH_PROVIDERS.items():
             provider_config["register"](self.oauth)
             provider_config["register"](self.oauth)
 
 
@@ -200,7 +205,7 @@ class OAuthManager:
                     id=group_model.id, form_data=update_form, overwrite=False
                     id=group_model.id, form_data=update_form, overwrite=False
                 )
                 )
 
 
-    async def handle_login(self, provider, request):
+    async def handle_login(self, request, provider):
         if provider not in OAUTH_PROVIDERS:
         if provider not in OAUTH_PROVIDERS:
             raise HTTPException(404)
             raise HTTPException(404)
         # If the provider has a custom redirect URL, use that, otherwise automatically generate one
         # If the provider has a custom redirect URL, use that, otherwise automatically generate one
@@ -212,7 +217,7 @@ class OAuthManager:
             raise HTTPException(404)
             raise HTTPException(404)
         return await client.authorize_redirect(request, redirect_uri)
         return await client.authorize_redirect(request, redirect_uri)
 
 
-    async def handle_callback(self, provider, request, response):
+    async def handle_callback(self, request, provider, response):
         if provider not in OAUTH_PROVIDERS:
         if provider not in OAUTH_PROVIDERS:
             raise HTTPException(404)
             raise HTTPException(404)
         client = self.get_client(provider)
         client = self.get_client(provider)
@@ -266,6 +271,17 @@ class OAuthManager:
                 Users.update_user_role_by_id(user.id, determined_role)
                 Users.update_user_role_by_id(user.id, determined_role)
 
 
         if not user:
         if not user:
+            user_count = Users.get_num_users()
+
+            if (
+                request.app.state.USER_COUNT
+                and user_count >= request.app.state.USER_COUNT
+            ):
+                raise HTTPException(
+                    403,
+                    detail=ERROR_MESSAGES.ACCESS_PROHIBITED,
+                )
+
             # If the user does not exist, check if signups are enabled
             # If the user does not exist, check if signups are enabled
             if auth_manager_config.ENABLE_OAUTH_SIGNUP:
             if auth_manager_config.ENABLE_OAUTH_SIGNUP:
                 # Check if an existing user with the same email already exists
                 # Check if an existing user with the same email already exists
@@ -334,6 +350,7 @@ class OAuthManager:
 
 
                 if auth_manager_config.WEBHOOK_URL:
                 if auth_manager_config.WEBHOOK_URL:
                     post_webhook(
                     post_webhook(
+                        WEBUI_NAME,
                         auth_manager_config.WEBHOOK_URL,
                         auth_manager_config.WEBHOOK_URL,
                         WEBHOOK_MESSAGES.USER_SIGNUP(user.name),
                         WEBHOOK_MESSAGES.USER_SIGNUP(user.name),
                         {
                         {
@@ -380,6 +397,3 @@ class OAuthManager:
         # Redirect back to the frontend with the JWT token
         # Redirect back to the frontend with the JWT token
         redirect_url = f"{request.base_url}auth#token={jwt_token}"
         redirect_url = f"{request.base_url}auth#token={jwt_token}"
         return RedirectResponse(url=redirect_url, headers=response.headers)
         return RedirectResponse(url=redirect_url, headers=response.headers)
-
-
-oauth_manager = OAuthManager()

+ 3 - 3
backend/open_webui/utils/webhook.py

@@ -2,14 +2,14 @@ import json
 import logging
 import logging
 
 
 import requests
 import requests
-from open_webui.config import WEBUI_FAVICON_URL, WEBUI_NAME
+from open_webui.config import WEBUI_FAVICON_URL
 from open_webui.env import SRC_LOG_LEVELS, VERSION
 from open_webui.env import SRC_LOG_LEVELS, VERSION
 
 
 log = logging.getLogger(__name__)
 log = logging.getLogger(__name__)
 log.setLevel(SRC_LOG_LEVELS["WEBHOOK"])
 log.setLevel(SRC_LOG_LEVELS["WEBHOOK"])
 
 
 
 
-def post_webhook(url: str, message: str, event_data: dict) -> bool:
+def post_webhook(name: str, url: str, message: str, event_data: dict) -> bool:
     try:
     try:
         log.debug(f"post_webhook: {url}, {message}, {event_data}")
         log.debug(f"post_webhook: {url}, {message}, {event_data}")
         payload = {}
         payload = {}
@@ -39,7 +39,7 @@ def post_webhook(url: str, message: str, event_data: dict) -> bool:
                 "sections": [
                 "sections": [
                     {
                     {
                         "activityTitle": message,
                         "activityTitle": message,
-                        "activitySubtitle": f"{WEBUI_NAME} ({VERSION}) - {action}",
+                        "activitySubtitle": f"{name} ({VERSION}) - {action}",
                         "activityImage": WEBUI_FAVICON_URL,
                         "activityImage": WEBUI_FAVICON_URL,
                         "facts": facts,
                         "facts": facts,
                         "markdown": True,
                         "markdown": True,