auths.py 8.2 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274
  1. from fastapi import Response, Request
  2. from fastapi import Depends, FastAPI, HTTPException, status
  3. from datetime import datetime, timedelta
  4. from typing import List, Union
  5. from fastapi import APIRouter, status
  6. from pydantic import BaseModel
  7. import time
  8. import uuid
  9. import re
  10. from apps.web.models.auths import (
  11. SigninForm,
  12. SignupForm,
  13. UpdateProfileForm,
  14. UpdatePasswordForm,
  15. UserResponse,
  16. SigninResponse,
  17. Auths,
  18. )
  19. from apps.web.models.users import Users
  20. from utils.utils import (
  21. get_password_hash,
  22. get_current_user,
  23. get_admin_user,
  24. create_token,
  25. )
  26. from utils.misc import parse_duration, validate_email_format
  27. from utils.webhook import post_webhook
  28. from constants import ERROR_MESSAGES, WEBHOOK_MESSAGES
  29. from config import WEBUI_AUTH_TRUSTED_EMAIL_HEADER
  30. router = APIRouter()
  31. ############################
  32. # GetSessionUser
  33. ############################
  34. @router.get("/", response_model=UserResponse)
  35. async def get_session_user(user=Depends(get_current_user)):
  36. return {
  37. "id": user.id,
  38. "email": user.email,
  39. "name": user.name,
  40. "role": user.role,
  41. "profile_image_url": user.profile_image_url,
  42. }
  43. ############################
  44. # Update Profile
  45. ############################
  46. @router.post("/update/profile", response_model=UserResponse)
  47. async def update_profile(
  48. form_data: UpdateProfileForm, session_user=Depends(get_current_user)
  49. ):
  50. if session_user:
  51. user = Users.update_user_by_id(
  52. session_user.id,
  53. {"profile_image_url": form_data.profile_image_url, "name": form_data.name},
  54. )
  55. if user:
  56. return user
  57. else:
  58. raise HTTPException(400, detail=ERROR_MESSAGES.DEFAULT())
  59. else:
  60. raise HTTPException(400, detail=ERROR_MESSAGES.INVALID_CRED)
  61. ############################
  62. # Update Password
  63. ############################
  64. @router.post("/update/password", response_model=bool)
  65. async def update_password(
  66. form_data: UpdatePasswordForm, session_user=Depends(get_current_user)
  67. ):
  68. if WEBUI_AUTH_TRUSTED_EMAIL_HEADER:
  69. raise HTTPException(400, detail=ERROR_MESSAGES.ACTION_PROHIBITED)
  70. if session_user:
  71. user = Auths.authenticate_user(session_user.email, form_data.password)
  72. if user:
  73. hashed = get_password_hash(form_data.new_password)
  74. return Auths.update_user_password_by_id(user.id, hashed)
  75. else:
  76. raise HTTPException(400, detail=ERROR_MESSAGES.INVALID_PASSWORD)
  77. else:
  78. raise HTTPException(400, detail=ERROR_MESSAGES.INVALID_CRED)
  79. ############################
  80. # SignIn
  81. ############################
  82. @router.post("/signin", response_model=SigninResponse)
  83. async def signin(request: Request, form_data: SigninForm):
  84. if WEBUI_AUTH_TRUSTED_EMAIL_HEADER:
  85. if WEBUI_AUTH_TRUSTED_EMAIL_HEADER not in request.headers:
  86. raise HTTPException(400,
  87. detail=ERROR_MESSAGES.INVALID_TRUSTED_HEADER)
  88. trusted_email = request.headers[WEBUI_AUTH_TRUSTED_EMAIL_HEADER].lower(
  89. )
  90. user = Auths.authenticate_user_by_trusted_header(trusted_email)
  91. else:
  92. user = Auths.authenticate_user(form_data.email.lower(),
  93. form_data.password)
  94. if user:
  95. token = create_token(
  96. data={"id": user.id},
  97. expires_delta=parse_duration(request.app.state.JWT_EXPIRES_IN),
  98. )
  99. return {
  100. "token": token,
  101. "token_type": "Bearer",
  102. "id": user.id,
  103. "email": user.email,
  104. "name": user.name,
  105. "role": user.role,
  106. "profile_image_url": user.profile_image_url,
  107. }
  108. else:
  109. raise HTTPException(400, detail=ERROR_MESSAGES.INVALID_CRED)
  110. ############################
  111. # SignUp
  112. ############################
  113. @router.post("/signup", response_model=SigninResponse)
  114. async def signup(request: Request, form_data: SignupForm):
  115. if not request.app.state.ENABLE_SIGNUP:
  116. raise HTTPException(
  117. status.HTTP_403_FORBIDDEN, detail=ERROR_MESSAGES.ACCESS_PROHIBITED
  118. )
  119. if not validate_email_format(form_data.email.lower()):
  120. raise HTTPException(
  121. status.HTTP_400_BAD_REQUEST, detail=ERROR_MESSAGES.INVALID_EMAIL_FORMAT
  122. )
  123. if Users.get_user_by_email(form_data.email.lower()):
  124. raise HTTPException(400, detail=ERROR_MESSAGES.EMAIL_TAKEN)
  125. if WEBUI_AUTH_TRUSTED_EMAIL_HEADER:
  126. if WEBUI_AUTH_TRUSTED_EMAIL_HEADER not in request.headers:
  127. raise HTTPException(400,
  128. detail=ERROR_MESSAGES.INVALID_TRUSTED_HEADER)
  129. trusted_email = request.headers[WEBUI_AUTH_TRUSTED_EMAIL_HEADER].lower(
  130. )
  131. if trusted_email != form_data.email:
  132. raise HTTPException(400, detail=ERROR_MESSAGES.EMAIL_MISMATCH)
  133. # TODO: Yolo hack to assign a password
  134. form_data.password = str(uuid.uuid4())
  135. try:
  136. role = (
  137. "admin"
  138. if Users.get_num_users() == 0
  139. else request.app.state.DEFAULT_USER_ROLE
  140. )
  141. hashed = get_password_hash(form_data.password)
  142. user = Auths.insert_new_auth(
  143. form_data.email.lower(), hashed, form_data.name, role
  144. )
  145. if user:
  146. token = create_token(
  147. data={"id": user.id},
  148. expires_delta=parse_duration(request.app.state.JWT_EXPIRES_IN),
  149. )
  150. # response.set_cookie(key='token', value=token, httponly=True)
  151. if request.app.state.WEBHOOK_URL:
  152. post_webhook(
  153. request.app.state.WEBHOOK_URL,
  154. WEBHOOK_MESSAGES.USER_SIGNUP(user.name),
  155. {
  156. "action": "signup",
  157. "message": WEBHOOK_MESSAGES.USER_SIGNUP(user.name),
  158. "user": user.model_dump_json(exclude_none=True),
  159. },
  160. )
  161. return {
  162. "token": token,
  163. "token_type": "Bearer",
  164. "id": user.id,
  165. "email": user.email,
  166. "name": user.name,
  167. "role": user.role,
  168. "profile_image_url": user.profile_image_url,
  169. }
  170. else:
  171. raise HTTPException(500, detail=ERROR_MESSAGES.CREATE_USER_ERROR)
  172. except Exception as err:
  173. raise HTTPException(500, detail=ERROR_MESSAGES.DEFAULT(err))
  174. ############################
  175. # ToggleSignUp
  176. ############################
  177. @router.get("/signup/enabled", response_model=bool)
  178. async def get_sign_up_status(request: Request, user=Depends(get_admin_user)):
  179. return request.app.state.ENABLE_SIGNUP
  180. @router.get("/signup/enabled/toggle", response_model=bool)
  181. async def toggle_sign_up(request: Request, user=Depends(get_admin_user)):
  182. request.app.state.ENABLE_SIGNUP = not request.app.state.ENABLE_SIGNUP
  183. return request.app.state.ENABLE_SIGNUP
  184. ############################
  185. # Default User Role
  186. ############################
  187. @router.get("/signup/user/role")
  188. async def get_default_user_role(request: Request, user=Depends(get_admin_user)):
  189. return request.app.state.DEFAULT_USER_ROLE
  190. class UpdateRoleForm(BaseModel):
  191. role: str
  192. @router.post("/signup/user/role")
  193. async def update_default_user_role(
  194. request: Request, form_data: UpdateRoleForm, user=Depends(get_admin_user)
  195. ):
  196. if form_data.role in ["pending", "user", "admin"]:
  197. request.app.state.DEFAULT_USER_ROLE = form_data.role
  198. return request.app.state.DEFAULT_USER_ROLE
  199. ############################
  200. # JWT Expiration
  201. ############################
  202. @router.get("/token/expires")
  203. async def get_token_expires_duration(request: Request, user=Depends(get_admin_user)):
  204. return request.app.state.JWT_EXPIRES_IN
  205. class UpdateJWTExpiresDurationForm(BaseModel):
  206. duration: str
  207. @router.post("/token/expires/update")
  208. async def update_token_expires_duration(
  209. request: Request,
  210. form_data: UpdateJWTExpiresDurationForm,
  211. user=Depends(get_admin_user),
  212. ):
  213. pattern = r"^(-1|0|(-?\d+(\.\d+)?)(ms|s|m|h|d|w))$"
  214. # Check if the input string matches the pattern
  215. if re.match(pattern, form_data.duration):
  216. request.app.state.JWT_EXPIRES_IN = form_data.duration
  217. return request.app.state.JWT_EXPIRES_IN
  218. else:
  219. return request.app.state.JWT_EXPIRES_IN