utils.py 3.8 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145
  1. from fastapi.security import HTTPBearer, HTTPAuthorizationCredentials
  2. from fastapi import HTTPException, status, Depends, Request
  3. from sqlalchemy.orm import Session
  4. from apps.webui.models.users import Users
  5. from pydantic import BaseModel
  6. from typing import Union, Optional
  7. from constants import ERROR_MESSAGES
  8. from passlib.context import CryptContext
  9. from datetime import datetime, timedelta
  10. import requests
  11. import jwt
  12. import uuid
  13. import logging
  14. import config
  15. logging.getLogger("passlib").setLevel(logging.ERROR)
  16. SESSION_SECRET = config.WEBUI_SECRET_KEY
  17. ALGORITHM = "HS256"
  18. ##############
  19. # Auth Utils
  20. ##############
  21. bearer_security = HTTPBearer(auto_error=False)
  22. pwd_context = CryptContext(schemes=["bcrypt"], deprecated="auto")
  23. def verify_password(plain_password, hashed_password):
  24. return (
  25. pwd_context.verify(plain_password, hashed_password) if hashed_password else None
  26. )
  27. def get_password_hash(password):
  28. return pwd_context.hash(password)
  29. def create_token(data: dict, expires_delta: Union[timedelta, None] = None) -> str:
  30. payload = data.copy()
  31. if expires_delta:
  32. expire = datetime.utcnow() + expires_delta
  33. payload.update({"exp": expire})
  34. encoded_jwt = jwt.encode(payload, SESSION_SECRET, algorithm=ALGORITHM)
  35. return encoded_jwt
  36. def decode_token(token: str) -> Optional[dict]:
  37. try:
  38. decoded = jwt.decode(token, SESSION_SECRET, algorithms=[ALGORITHM])
  39. return decoded
  40. except Exception as e:
  41. return None
  42. def extract_token_from_auth_header(auth_header: str):
  43. return auth_header[len("Bearer ") :]
  44. def create_api_key():
  45. key = str(uuid.uuid4()).replace("-", "")
  46. return f"sk-{key}"
  47. def get_http_authorization_cred(auth_header: str):
  48. try:
  49. scheme, credentials = auth_header.split(" ")
  50. return HTTPAuthorizationCredentials(scheme=scheme, credentials=credentials)
  51. except:
  52. raise ValueError(ERROR_MESSAGES.INVALID_TOKEN)
  53. def get_current_user(
  54. request: Request,
  55. auth_token: HTTPAuthorizationCredentials = Depends(bearer_security),
  56. ):
  57. token = None
  58. if auth_token is not None:
  59. token = auth_token.credentials
  60. if token is None and "token" in request.cookies:
  61. token = request.cookies.get("token")
  62. if token is None:
  63. raise HTTPException(status_code=403, detail="Not authenticated")
  64. # auth by api key
  65. if token.startswith("sk-"):
  66. return get_current_user_by_api_key(token)
  67. # auth by jwt token
  68. data = decode_token(token)
  69. if data != None and "id" in data:
  70. user = Users.get_user_by_id(data["id"])
  71. if user is None:
  72. raise HTTPException(
  73. status_code=status.HTTP_401_UNAUTHORIZED,
  74. detail=ERROR_MESSAGES.INVALID_TOKEN,
  75. )
  76. else:
  77. Users.update_user_last_active_by_id(user.id)
  78. return user
  79. else:
  80. raise HTTPException(
  81. status_code=status.HTTP_401_UNAUTHORIZED,
  82. detail=ERROR_MESSAGES.UNAUTHORIZED,
  83. )
  84. def get_current_user_by_api_key(db: Session, api_key: str):
  85. user = Users.get_user_by_api_key(db, api_key)
  86. if user is None:
  87. raise HTTPException(
  88. status_code=status.HTTP_401_UNAUTHORIZED,
  89. detail=ERROR_MESSAGES.INVALID_TOKEN,
  90. )
  91. else:
  92. Users.update_user_last_active_by_id(db, user.id)
  93. return user
  94. def get_verified_user(user=Depends(get_current_user)):
  95. if user.role not in {"user", "admin"}:
  96. raise HTTPException(
  97. status_code=status.HTTP_401_UNAUTHORIZED,
  98. detail=ERROR_MESSAGES.ACCESS_PROHIBITED,
  99. )
  100. return user
  101. def get_admin_user(user=Depends(get_current_user)):
  102. if user.role != "admin":
  103. raise HTTPException(
  104. status_code=status.HTTP_401_UNAUTHORIZED,
  105. detail=ERROR_MESSAGES.ACCESS_PROHIBITED,
  106. )
  107. return user