auth.py 7.0 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248
  1. import logging
  2. import uuid
  3. import jwt
  4. import base64
  5. import hmac
  6. import hashlib
  7. import requests
  8. import os
  9. from datetime import UTC, datetime, timedelta
  10. from typing import Optional, Union, List, Dict
  11. from open_webui.models.users import Users
  12. from open_webui.constants import ERROR_MESSAGES
  13. from open_webui.env import (
  14. WEBUI_SECRET_KEY,
  15. TRUSTED_SIGNATURE_KEY,
  16. STATIC_DIR,
  17. SRC_LOG_LEVELS,
  18. )
  19. from fastapi import BackgroundTasks, Depends, HTTPException, Request, Response, status
  20. from fastapi.security import HTTPAuthorizationCredentials, HTTPBearer
  21. from passlib.context import CryptContext
  22. logging.getLogger("passlib").setLevel(logging.ERROR)
  23. log = logging.getLogger(__name__)
  24. log.setLevel(SRC_LOG_LEVELS["OAUTH"])
  25. SESSION_SECRET = WEBUI_SECRET_KEY
  26. ALGORITHM = "HS256"
  27. ##############
  28. # Auth Utils
  29. ##############
  30. def verify_signature(payload: str, signature: str) -> bool:
  31. """
  32. Verifies the HMAC signature of the received payload.
  33. """
  34. try:
  35. expected_signature = base64.b64encode(
  36. hmac.new(TRUSTED_SIGNATURE_KEY, payload.encode(), hashlib.sha256).digest()
  37. ).decode()
  38. # Compare securely to prevent timing attacks
  39. return hmac.compare_digest(expected_signature, signature)
  40. except Exception:
  41. return False
  42. def override_static(path: str, content: str):
  43. # Ensure path is safe
  44. if "/" in path or ".." in path:
  45. log.error(f"Invalid path: {path}")
  46. return
  47. file_path = os.path.join(STATIC_DIR, path)
  48. os.makedirs(os.path.dirname(file_path), exist_ok=True)
  49. with open(file_path, "wb") as f:
  50. f.write(base64.b64decode(content)) # Convert Base64 back to raw binary
  51. def get_license_data(app, key):
  52. if key:
  53. try:
  54. res = requests.post(
  55. "https://api.openwebui.com/api/v1/license",
  56. json={"key": key, "version": "1"},
  57. timeout=5,
  58. )
  59. if getattr(res, "ok", False):
  60. payload = getattr(res, "json", lambda: {})()
  61. for k, v in payload.items():
  62. if k == "resources":
  63. for p, c in v.items():
  64. globals().get("override_static", lambda a, b: None)(p, c)
  65. elif k == "count":
  66. setattr(app.state, "USER_COUNT", v)
  67. elif k == "name":
  68. setattr(app.state, "WEBUI_NAME", v)
  69. elif k == "metadata":
  70. setattr(app.state, "LICENSE_METADATA", v)
  71. return True
  72. else:
  73. log.error(
  74. f"License: retrieval issue: {getattr(res, 'text', 'unknown error')}"
  75. )
  76. except Exception as ex:
  77. log.exception(f"License: Uncaught Exception: {ex}")
  78. return False
  79. bearer_security = HTTPBearer(auto_error=False)
  80. pwd_context = CryptContext(schemes=["bcrypt"], deprecated="auto")
  81. def verify_password(plain_password, hashed_password):
  82. return (
  83. pwd_context.verify(plain_password, hashed_password) if hashed_password else None
  84. )
  85. def get_password_hash(password):
  86. return pwd_context.hash(password)
  87. def create_token(data: dict, expires_delta: Union[timedelta, None] = None) -> str:
  88. payload = data.copy()
  89. if expires_delta:
  90. expire = datetime.now(UTC) + expires_delta
  91. payload.update({"exp": expire})
  92. encoded_jwt = jwt.encode(payload, SESSION_SECRET, algorithm=ALGORITHM)
  93. return encoded_jwt
  94. def decode_token(token: str) -> Optional[dict]:
  95. try:
  96. decoded = jwt.decode(token, SESSION_SECRET, algorithms=[ALGORITHM])
  97. return decoded
  98. except Exception:
  99. return None
  100. def extract_token_from_auth_header(auth_header: str):
  101. return auth_header[len("Bearer ") :]
  102. def create_api_key():
  103. key = str(uuid.uuid4()).replace("-", "")
  104. return f"sk-{key}"
  105. def get_http_authorization_cred(auth_header: str):
  106. try:
  107. scheme, credentials = auth_header.split(" ")
  108. return HTTPAuthorizationCredentials(scheme=scheme, credentials=credentials)
  109. except Exception:
  110. raise ValueError(ERROR_MESSAGES.INVALID_TOKEN)
  111. def get_current_user(
  112. request: Request,
  113. background_tasks: BackgroundTasks,
  114. auth_token: HTTPAuthorizationCredentials = Depends(bearer_security),
  115. ):
  116. token = None
  117. if auth_token is not None:
  118. token = auth_token.credentials
  119. if token is None and "token" in request.cookies:
  120. token = request.cookies.get("token")
  121. if token is None:
  122. raise HTTPException(status_code=403, detail="Not authenticated")
  123. # auth by api key
  124. if token.startswith("sk-"):
  125. if not request.state.enable_api_key:
  126. raise HTTPException(
  127. status.HTTP_403_FORBIDDEN, detail=ERROR_MESSAGES.API_KEY_NOT_ALLOWED
  128. )
  129. if request.app.state.config.ENABLE_API_KEY_ENDPOINT_RESTRICTIONS:
  130. allowed_paths = [
  131. path.strip()
  132. for path in str(
  133. request.app.state.config.API_KEY_ALLOWED_ENDPOINTS
  134. ).split(",")
  135. ]
  136. if request.url.path not in allowed_paths:
  137. raise HTTPException(
  138. status.HTTP_403_FORBIDDEN, detail=ERROR_MESSAGES.API_KEY_NOT_ALLOWED
  139. )
  140. return get_current_user_by_api_key(token)
  141. # auth by jwt token
  142. try:
  143. data = decode_token(token)
  144. except Exception as e:
  145. raise HTTPException(
  146. status_code=status.HTTP_401_UNAUTHORIZED,
  147. detail="Invalid token",
  148. )
  149. if data is not None and "id" in data:
  150. user = Users.get_user_by_id(data["id"])
  151. if user is None:
  152. raise HTTPException(
  153. status_code=status.HTTP_401_UNAUTHORIZED,
  154. detail=ERROR_MESSAGES.INVALID_TOKEN,
  155. )
  156. else:
  157. # Refresh the user's last active timestamp asynchronously
  158. # to prevent blocking the request
  159. if background_tasks:
  160. background_tasks.add_task(Users.update_user_last_active_by_id, user.id)
  161. return user
  162. else:
  163. raise HTTPException(
  164. status_code=status.HTTP_401_UNAUTHORIZED,
  165. detail=ERROR_MESSAGES.UNAUTHORIZED,
  166. )
  167. def get_current_user_by_api_key(api_key: str):
  168. user = Users.get_user_by_api_key(api_key)
  169. if user is None:
  170. raise HTTPException(
  171. status_code=status.HTTP_401_UNAUTHORIZED,
  172. detail=ERROR_MESSAGES.INVALID_TOKEN,
  173. )
  174. else:
  175. Users.update_user_last_active_by_id(user.id)
  176. return user
  177. def get_verified_user(user=Depends(get_current_user)):
  178. if user.role not in {"user", "admin"}:
  179. raise HTTPException(
  180. status_code=status.HTTP_401_UNAUTHORIZED,
  181. detail=ERROR_MESSAGES.ACCESS_PROHIBITED,
  182. )
  183. return user
  184. def get_admin_user(user=Depends(get_current_user)):
  185. if user.role != "admin":
  186. raise HTTPException(
  187. status_code=status.HTTP_401_UNAUTHORIZED,
  188. detail=ERROR_MESSAGES.ACCESS_PROHIBITED,
  189. )
  190. return user