utils.py 3.7 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142
  1. from fastapi.security import HTTPBearer, HTTPAuthorizationCredentials
  2. from fastapi import HTTPException, status, Depends, Request
  3. from apps.webui.models.users import Users
  4. from typing import Union, Optional
  5. from constants import ERROR_MESSAGES
  6. from passlib.context import CryptContext
  7. from datetime import datetime, timedelta, UTC
  8. import jwt
  9. import uuid
  10. import logging
  11. from env import WEBUI_SECRET_KEY
  12. logging.getLogger("passlib").setLevel(logging.ERROR)
  13. SESSION_SECRET = WEBUI_SECRET_KEY
  14. ALGORITHM = "HS256"
  15. ##############
  16. # Auth Utils
  17. ##############
  18. bearer_security = HTTPBearer(auto_error=False)
  19. pwd_context = CryptContext(schemes=["bcrypt"], deprecated="auto")
  20. def verify_password(plain_password, hashed_password):
  21. return (
  22. pwd_context.verify(plain_password, hashed_password) if hashed_password else None
  23. )
  24. def get_password_hash(password):
  25. return pwd_context.hash(password)
  26. def create_token(data: dict, expires_delta: Union[timedelta, None] = None) -> str:
  27. payload = data.copy()
  28. if expires_delta:
  29. expire = datetime.now(UTC) + expires_delta
  30. payload.update({"exp": expire})
  31. encoded_jwt = jwt.encode(payload, SESSION_SECRET, algorithm=ALGORITHM)
  32. return encoded_jwt
  33. def decode_token(token: str) -> Optional[dict]:
  34. try:
  35. decoded = jwt.decode(token, SESSION_SECRET, algorithms=[ALGORITHM])
  36. return decoded
  37. except Exception:
  38. return None
  39. def extract_token_from_auth_header(auth_header: str):
  40. return auth_header[len("Bearer ") :]
  41. def create_api_key():
  42. key = str(uuid.uuid4()).replace("-", "")
  43. return f"sk-{key}"
  44. def get_http_authorization_cred(auth_header: str):
  45. try:
  46. scheme, credentials = auth_header.split(" ")
  47. return HTTPAuthorizationCredentials(scheme=scheme, credentials=credentials)
  48. except Exception:
  49. raise ValueError(ERROR_MESSAGES.INVALID_TOKEN)
  50. def get_current_user(
  51. request: Request,
  52. auth_token: HTTPAuthorizationCredentials = Depends(bearer_security),
  53. ):
  54. token = None
  55. if auth_token is not None:
  56. token = auth_token.credentials
  57. if token is None and "token" in request.cookies:
  58. token = request.cookies.get("token")
  59. if token is None:
  60. raise HTTPException(status_code=403, detail="Not authenticated")
  61. # auth by api key
  62. if token.startswith("sk-"):
  63. return get_current_user_by_api_key(token)
  64. # auth by jwt token
  65. data = decode_token(token)
  66. if data is not None and "id" in data:
  67. user = Users.get_user_by_id(data["id"])
  68. if user is None:
  69. raise HTTPException(
  70. status_code=status.HTTP_401_UNAUTHORIZED,
  71. detail=ERROR_MESSAGES.INVALID_TOKEN,
  72. )
  73. else:
  74. Users.update_user_last_active_by_id(user.id)
  75. return user
  76. else:
  77. raise HTTPException(
  78. status_code=status.HTTP_401_UNAUTHORIZED,
  79. detail=ERROR_MESSAGES.UNAUTHORIZED,
  80. )
  81. def get_current_user_by_api_key(api_key: str):
  82. user = Users.get_user_by_api_key(api_key)
  83. if user is None:
  84. raise HTTPException(
  85. status_code=status.HTTP_401_UNAUTHORIZED,
  86. detail=ERROR_MESSAGES.INVALID_TOKEN,
  87. )
  88. else:
  89. Users.update_user_last_active_by_id(user.id)
  90. return user
  91. def get_verified_user(user=Depends(get_current_user)):
  92. if user.role not in {"user", "admin"}:
  93. raise HTTPException(
  94. status_code=status.HTTP_401_UNAUTHORIZED,
  95. detail=ERROR_MESSAGES.ACCESS_PROHIBITED,
  96. )
  97. return user
  98. def get_admin_user(user=Depends(get_current_user)):
  99. if user.role != "admin":
  100. raise HTTPException(
  101. status_code=status.HTTP_401_UNAUTHORIZED,
  102. detail=ERROR_MESSAGES.ACCESS_PROHIBITED,
  103. )
  104. return user