auths.py 8.9 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309
  1. from fastapi import Request
  2. from fastapi import Depends, HTTPException, status
  3. from fastapi import APIRouter
  4. from pydantic import BaseModel
  5. import re
  6. import uuid
  7. from apps.web.models.auths import (
  8. SigninForm,
  9. SignupForm,
  10. UpdateProfileForm,
  11. UpdatePasswordForm,
  12. UserResponse,
  13. SigninResponse,
  14. Auths,
  15. ApiKey,
  16. )
  17. from apps.web.models.users import Users
  18. from utils.utils import (
  19. get_password_hash,
  20. get_current_user,
  21. get_admin_user,
  22. create_token,
  23. create_api_key,
  24. )
  25. from utils.misc import parse_duration, validate_email_format
  26. from utils.webhook import post_webhook
  27. from constants import ERROR_MESSAGES, WEBHOOK_MESSAGES
  28. from config import WEBUI_AUTH_TRUSTED_EMAIL_HEADER
  29. router = APIRouter()
  30. ############################
  31. # GetSessionUser
  32. ############################
  33. @router.get("/", response_model=UserResponse)
  34. async def get_session_user(user=Depends(get_current_user)):
  35. return {
  36. "id": user.id,
  37. "email": user.email,
  38. "name": user.name,
  39. "role": user.role,
  40. "profile_image_url": user.profile_image_url,
  41. }
  42. ############################
  43. # Update Profile
  44. ############################
  45. @router.post("/update/profile", response_model=UserResponse)
  46. async def update_profile(
  47. form_data: UpdateProfileForm, session_user=Depends(get_current_user)
  48. ):
  49. if session_user:
  50. user = Users.update_user_by_id(
  51. session_user.id,
  52. {"profile_image_url": form_data.profile_image_url, "name": form_data.name},
  53. )
  54. if user:
  55. return user
  56. else:
  57. raise HTTPException(400, detail=ERROR_MESSAGES.DEFAULT())
  58. else:
  59. raise HTTPException(400, detail=ERROR_MESSAGES.INVALID_CRED)
  60. ############################
  61. # Update Password
  62. ############################
  63. @router.post("/update/password", response_model=bool)
  64. async def update_password(
  65. form_data: UpdatePasswordForm, session_user=Depends(get_current_user)
  66. ):
  67. if WEBUI_AUTH_TRUSTED_EMAIL_HEADER:
  68. raise HTTPException(400, detail=ERROR_MESSAGES.ACTION_PROHIBITED)
  69. if session_user:
  70. user = Auths.authenticate_user(session_user.email, form_data.password)
  71. if user:
  72. hashed = get_password_hash(form_data.new_password)
  73. return Auths.update_user_password_by_id(user.id, hashed)
  74. else:
  75. raise HTTPException(400, detail=ERROR_MESSAGES.INVALID_PASSWORD)
  76. else:
  77. raise HTTPException(400, detail=ERROR_MESSAGES.INVALID_CRED)
  78. ############################
  79. # SignIn
  80. ############################
  81. @router.post("/signin", response_model=SigninResponse)
  82. async def signin(request: Request, form_data: SigninForm):
  83. if WEBUI_AUTH_TRUSTED_EMAIL_HEADER:
  84. if WEBUI_AUTH_TRUSTED_EMAIL_HEADER not in request.headers:
  85. raise HTTPException(400, detail=ERROR_MESSAGES.INVALID_TRUSTED_HEADER)
  86. trusted_email = request.headers[WEBUI_AUTH_TRUSTED_EMAIL_HEADER].lower()
  87. if not Users.get_user_by_email(trusted_email.lower()):
  88. await signup(
  89. request,
  90. SignupForm(
  91. email=trusted_email, password=str(uuid.uuid4()), name=trusted_email
  92. ),
  93. )
  94. user = Auths.authenticate_user_by_trusted_header(trusted_email)
  95. else:
  96. user = Auths.authenticate_user(form_data.email.lower(), form_data.password)
  97. if user:
  98. token = create_token(
  99. data={"id": user.id},
  100. expires_delta=parse_duration(request.app.state.JWT_EXPIRES_IN),
  101. )
  102. return {
  103. "token": token,
  104. "token_type": "Bearer",
  105. "id": user.id,
  106. "email": user.email,
  107. "name": user.name,
  108. "role": user.role,
  109. "profile_image_url": user.profile_image_url,
  110. }
  111. else:
  112. raise HTTPException(400, detail=ERROR_MESSAGES.INVALID_CRED)
  113. ############################
  114. # SignUp
  115. ############################
  116. @router.post("/signup", response_model=SigninResponse)
  117. async def signup(request: Request, form_data: SignupForm):
  118. if not request.app.state.ENABLE_SIGNUP:
  119. raise HTTPException(
  120. status.HTTP_403_FORBIDDEN, detail=ERROR_MESSAGES.ACCESS_PROHIBITED
  121. )
  122. if not validate_email_format(form_data.email.lower()):
  123. raise HTTPException(
  124. status.HTTP_400_BAD_REQUEST, detail=ERROR_MESSAGES.INVALID_EMAIL_FORMAT
  125. )
  126. if Users.get_user_by_email(form_data.email.lower()):
  127. raise HTTPException(400, detail=ERROR_MESSAGES.EMAIL_TAKEN)
  128. try:
  129. role = (
  130. "admin"
  131. if Users.get_num_users() == 0
  132. else request.app.state.DEFAULT_USER_ROLE
  133. )
  134. hashed = get_password_hash(form_data.password)
  135. user = Auths.insert_new_auth(
  136. form_data.email.lower(),
  137. hashed,
  138. form_data.name,
  139. form_data.profile_image_url,
  140. role,
  141. )
  142. if user:
  143. token = create_token(
  144. data={"id": user.id},
  145. expires_delta=parse_duration(request.app.state.JWT_EXPIRES_IN),
  146. )
  147. # response.set_cookie(key='token', value=token, httponly=True)
  148. if request.app.state.WEBHOOK_URL:
  149. post_webhook(
  150. request.app.state.WEBHOOK_URL,
  151. WEBHOOK_MESSAGES.USER_SIGNUP(user.name),
  152. {
  153. "action": "signup",
  154. "message": WEBHOOK_MESSAGES.USER_SIGNUP(user.name),
  155. "user": user.model_dump_json(exclude_none=True),
  156. },
  157. )
  158. return {
  159. "token": token,
  160. "token_type": "Bearer",
  161. "id": user.id,
  162. "email": user.email,
  163. "name": user.name,
  164. "role": user.role,
  165. "profile_image_url": user.profile_image_url,
  166. }
  167. else:
  168. raise HTTPException(500, detail=ERROR_MESSAGES.CREATE_USER_ERROR)
  169. except Exception as err:
  170. raise HTTPException(500, detail=ERROR_MESSAGES.DEFAULT(err))
  171. ############################
  172. # ToggleSignUp
  173. ############################
  174. @router.get("/signup/enabled", response_model=bool)
  175. async def get_sign_up_status(request: Request, user=Depends(get_admin_user)):
  176. return request.app.state.ENABLE_SIGNUP
  177. @router.get("/signup/enabled/toggle", response_model=bool)
  178. async def toggle_sign_up(request: Request, user=Depends(get_admin_user)):
  179. request.app.state.ENABLE_SIGNUP = not request.app.state.ENABLE_SIGNUP
  180. return request.app.state.ENABLE_SIGNUP
  181. ############################
  182. # Default User Role
  183. ############################
  184. @router.get("/signup/user/role")
  185. async def get_default_user_role(request: Request, user=Depends(get_admin_user)):
  186. return request.app.state.DEFAULT_USER_ROLE
  187. class UpdateRoleForm(BaseModel):
  188. role: str
  189. @router.post("/signup/user/role")
  190. async def update_default_user_role(
  191. request: Request, form_data: UpdateRoleForm, user=Depends(get_admin_user)
  192. ):
  193. if form_data.role in ["pending", "user", "admin"]:
  194. request.app.state.DEFAULT_USER_ROLE = form_data.role
  195. return request.app.state.DEFAULT_USER_ROLE
  196. ############################
  197. # JWT Expiration
  198. ############################
  199. @router.get("/token/expires")
  200. async def get_token_expires_duration(request: Request, user=Depends(get_admin_user)):
  201. return request.app.state.JWT_EXPIRES_IN
  202. class UpdateJWTExpiresDurationForm(BaseModel):
  203. duration: str
  204. @router.post("/token/expires/update")
  205. async def update_token_expires_duration(
  206. request: Request,
  207. form_data: UpdateJWTExpiresDurationForm,
  208. user=Depends(get_admin_user),
  209. ):
  210. pattern = r"^(-1|0|(-?\d+(\.\d+)?)(ms|s|m|h|d|w))$"
  211. # Check if the input string matches the pattern
  212. if re.match(pattern, form_data.duration):
  213. request.app.state.JWT_EXPIRES_IN = form_data.duration
  214. return request.app.state.JWT_EXPIRES_IN
  215. else:
  216. return request.app.state.JWT_EXPIRES_IN
  217. ############################
  218. # API Key
  219. ############################
  220. # create api key
  221. @router.post("/api_key", response_model=ApiKey)
  222. async def create_api_key_(user=Depends(get_current_user)):
  223. api_key = create_api_key()
  224. success = Users.update_user_api_key_by_id(user.id, api_key)
  225. if success:
  226. return {
  227. "api_key": api_key,
  228. }
  229. else:
  230. raise HTTPException(500, detail=ERROR_MESSAGES.CREATE_API_KEY_ERROR)
  231. # delete api key
  232. @router.delete("/api_key", response_model=bool)
  233. async def delete_api_key(user=Depends(get_current_user)):
  234. success = Users.update_user_api_key_by_id(user.id, None)
  235. return success
  236. # get api key
  237. @router.get("/api_key", response_model=ApiKey)
  238. async def get_api_key(user=Depends(get_current_user)):
  239. api_key = Users.get_user_api_key_by_id(user.id)
  240. if api_key:
  241. return {
  242. "api_key": api_key,
  243. }
  244. else:
  245. raise HTTPException(404, detail=ERROR_MESSAGES.API_KEY_NOT_FOUND)