auth.py 5.1 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189
  1. import logging
  2. import uuid
  3. import jwt
  4. import base64
  5. import hmac
  6. import hashlib
  7. from datetime import UTC, datetime, timedelta
  8. from typing import Optional, Union, List, Dict
  9. from open_webui.models.users import Users
  10. from open_webui.constants import ERROR_MESSAGES
  11. from open_webui.env import WEBUI_SECRET_KEY, TRUSTED_SIGNATURE_KEY
  12. from fastapi import Depends, HTTPException, Request, Response, status
  13. from fastapi.security import HTTPAuthorizationCredentials, HTTPBearer
  14. from passlib.context import CryptContext
  15. logging.getLogger("passlib").setLevel(logging.ERROR)
  16. SESSION_SECRET = WEBUI_SECRET_KEY
  17. ALGORITHM = "HS256"
  18. ##############
  19. # Auth Utils
  20. ##############
  21. def verify_signature(payload: str, signature: str) -> bool:
  22. """
  23. Verifies the HMAC signature of the received payload.
  24. """
  25. try:
  26. expected_signature = base64.b64encode(
  27. hmac.new(TRUSTED_SIGNATURE_KEY, payload.encode(), hashlib.sha256).digest()
  28. ).decode()
  29. # Compare securely to prevent timing attacks
  30. return hmac.compare_digest(expected_signature, signature)
  31. except Exception:
  32. return False
  33. bearer_security = HTTPBearer(auto_error=False)
  34. pwd_context = CryptContext(schemes=["bcrypt"], deprecated="auto")
  35. def verify_password(plain_password, hashed_password):
  36. return (
  37. pwd_context.verify(plain_password, hashed_password) if hashed_password else None
  38. )
  39. def get_password_hash(password):
  40. return pwd_context.hash(password)
  41. def create_token(data: dict, expires_delta: Union[timedelta, None] = None) -> str:
  42. payload = data.copy()
  43. if expires_delta:
  44. expire = datetime.now(UTC) + expires_delta
  45. payload.update({"exp": expire})
  46. encoded_jwt = jwt.encode(payload, SESSION_SECRET, algorithm=ALGORITHM)
  47. return encoded_jwt
  48. def decode_token(token: str) -> Optional[dict]:
  49. try:
  50. decoded = jwt.decode(token, SESSION_SECRET, algorithms=[ALGORITHM])
  51. return decoded
  52. except Exception:
  53. return None
  54. def extract_token_from_auth_header(auth_header: str):
  55. return auth_header[len("Bearer ") :]
  56. def create_api_key():
  57. key = str(uuid.uuid4()).replace("-", "")
  58. return f"sk-{key}"
  59. def get_http_authorization_cred(auth_header: str):
  60. try:
  61. scheme, credentials = auth_header.split(" ")
  62. return HTTPAuthorizationCredentials(scheme=scheme, credentials=credentials)
  63. except Exception:
  64. raise ValueError(ERROR_MESSAGES.INVALID_TOKEN)
  65. def get_current_user(
  66. request: Request,
  67. auth_token: HTTPAuthorizationCredentials = Depends(bearer_security),
  68. ):
  69. token = None
  70. if auth_token is not None:
  71. token = auth_token.credentials
  72. if token is None and "token" in request.cookies:
  73. token = request.cookies.get("token")
  74. if token is None:
  75. raise HTTPException(status_code=403, detail="Not authenticated")
  76. # auth by api key
  77. if token.startswith("sk-"):
  78. if not request.state.enable_api_key:
  79. raise HTTPException(
  80. status.HTTP_403_FORBIDDEN, detail=ERROR_MESSAGES.API_KEY_NOT_ALLOWED
  81. )
  82. if request.app.state.config.ENABLE_API_KEY_ENDPOINT_RESTRICTIONS:
  83. allowed_paths = [
  84. path.strip()
  85. for path in str(
  86. request.app.state.config.API_KEY_ALLOWED_ENDPOINTS
  87. ).split(",")
  88. ]
  89. if request.url.path not in allowed_paths:
  90. raise HTTPException(
  91. status.HTTP_403_FORBIDDEN, detail=ERROR_MESSAGES.API_KEY_NOT_ALLOWED
  92. )
  93. return get_current_user_by_api_key(token)
  94. # auth by jwt token
  95. try:
  96. data = decode_token(token)
  97. except Exception as e:
  98. raise HTTPException(
  99. status_code=status.HTTP_401_UNAUTHORIZED,
  100. detail="Invalid token",
  101. )
  102. if data is not None and "id" in data:
  103. user = Users.get_user_by_id(data["id"])
  104. if user is None:
  105. raise HTTPException(
  106. status_code=status.HTTP_401_UNAUTHORIZED,
  107. detail=ERROR_MESSAGES.INVALID_TOKEN,
  108. )
  109. else:
  110. Users.update_user_last_active_by_id(user.id)
  111. return user
  112. else:
  113. raise HTTPException(
  114. status_code=status.HTTP_401_UNAUTHORIZED,
  115. detail=ERROR_MESSAGES.UNAUTHORIZED,
  116. )
  117. def get_current_user_by_api_key(api_key: str):
  118. user = Users.get_user_by_api_key(api_key)
  119. if user is None:
  120. raise HTTPException(
  121. status_code=status.HTTP_401_UNAUTHORIZED,
  122. detail=ERROR_MESSAGES.INVALID_TOKEN,
  123. )
  124. else:
  125. Users.update_user_last_active_by_id(user.id)
  126. return user
  127. def get_verified_user(user=Depends(get_current_user)):
  128. if user.role not in {"user", "admin"}:
  129. raise HTTPException(
  130. status_code=status.HTTP_401_UNAUTHORIZED,
  131. detail=ERROR_MESSAGES.ACCESS_PROHIBITED,
  132. )
  133. return user
  134. def get_admin_user(user=Depends(get_current_user)):
  135. if user.role != "admin":
  136. raise HTTPException(
  137. status_code=status.HTTP_401_UNAUTHORIZED,
  138. detail=ERROR_MESSAGES.ACCESS_PROHIBITED,
  139. )
  140. return user