utils.py 3.9 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147
  1. from fastapi.security import HTTPBearer, HTTPAuthorizationCredentials
  2. from fastapi import HTTPException, status, Depends, Request
  3. from sqlalchemy.orm import Session
  4. from apps.webui.internal.db import get_db
  5. from apps.webui.models.users import Users
  6. from pydantic import BaseModel
  7. from typing import Union, Optional
  8. from constants import ERROR_MESSAGES
  9. from passlib.context import CryptContext
  10. from datetime import datetime, timedelta
  11. import requests
  12. import jwt
  13. import uuid
  14. import logging
  15. import config
  16. logging.getLogger("passlib").setLevel(logging.ERROR)
  17. SESSION_SECRET = config.WEBUI_SECRET_KEY
  18. ALGORITHM = "HS256"
  19. ##############
  20. # Auth Utils
  21. ##############
  22. bearer_security = HTTPBearer(auto_error=False)
  23. pwd_context = CryptContext(schemes=["bcrypt"], deprecated="auto")
  24. def verify_password(plain_password, hashed_password):
  25. return (
  26. pwd_context.verify(plain_password, hashed_password) if hashed_password else None
  27. )
  28. def get_password_hash(password):
  29. return pwd_context.hash(password)
  30. def create_token(data: dict, expires_delta: Union[timedelta, None] = None) -> str:
  31. payload = data.copy()
  32. if expires_delta:
  33. expire = datetime.utcnow() + expires_delta
  34. payload.update({"exp": expire})
  35. encoded_jwt = jwt.encode(payload, SESSION_SECRET, algorithm=ALGORITHM)
  36. return encoded_jwt
  37. def decode_token(token: str) -> Optional[dict]:
  38. try:
  39. decoded = jwt.decode(token, SESSION_SECRET, algorithms=[ALGORITHM])
  40. return decoded
  41. except Exception as e:
  42. return None
  43. def extract_token_from_auth_header(auth_header: str):
  44. return auth_header[len("Bearer ") :]
  45. def create_api_key():
  46. key = str(uuid.uuid4()).replace("-", "")
  47. return f"sk-{key}"
  48. def get_http_authorization_cred(auth_header: str):
  49. try:
  50. scheme, credentials = auth_header.split(" ")
  51. return HTTPAuthorizationCredentials(scheme=scheme, credentials=credentials)
  52. except:
  53. raise ValueError(ERROR_MESSAGES.INVALID_TOKEN)
  54. def get_current_user(
  55. request: Request,
  56. auth_token: HTTPAuthorizationCredentials = Depends(bearer_security),
  57. db=Depends(get_db),
  58. ):
  59. token = None
  60. if auth_token is not None:
  61. token = auth_token.credentials
  62. if token is None and "token" in request.cookies:
  63. token = request.cookies.get("token")
  64. if token is None:
  65. raise HTTPException(status_code=403, detail="Not authenticated")
  66. # auth by api key
  67. if token.startswith("sk-"):
  68. return get_current_user_by_api_key(db, token)
  69. # auth by jwt token
  70. data = decode_token(token)
  71. if data != None and "id" in data:
  72. user = Users.get_user_by_id(db, data["id"])
  73. if user is None:
  74. raise HTTPException(
  75. status_code=status.HTTP_401_UNAUTHORIZED,
  76. detail=ERROR_MESSAGES.INVALID_TOKEN,
  77. )
  78. else:
  79. Users.update_user_last_active_by_id(db, user.id)
  80. return user
  81. else:
  82. raise HTTPException(
  83. status_code=status.HTTP_401_UNAUTHORIZED,
  84. detail=ERROR_MESSAGES.UNAUTHORIZED,
  85. )
  86. def get_current_user_by_api_key(db: Session, api_key: str):
  87. user = Users.get_user_by_api_key(db, api_key)
  88. if user is None:
  89. raise HTTPException(
  90. status_code=status.HTTP_401_UNAUTHORIZED,
  91. detail=ERROR_MESSAGES.INVALID_TOKEN,
  92. )
  93. else:
  94. Users.update_user_last_active_by_id(db, user.id)
  95. return user
  96. def get_verified_user(user=Depends(get_current_user)):
  97. if user.role not in {"user", "admin"}:
  98. raise HTTPException(
  99. status_code=status.HTTP_401_UNAUTHORIZED,
  100. detail=ERROR_MESSAGES.ACCESS_PROHIBITED,
  101. )
  102. return user
  103. def get_admin_user(user=Depends(get_current_user)):
  104. if user.role != "admin":
  105. raise HTTPException(
  106. status_code=status.HTTP_401_UNAUTHORIZED,
  107. detail=ERROR_MESSAGES.ACCESS_PROHIBITED,
  108. )
  109. return user