auths.py 11 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379
  1. import logging
  2. from fastapi import Request, UploadFile, File
  3. from fastapi import Depends, HTTPException, status
  4. from fastapi import APIRouter
  5. from pydantic import BaseModel
  6. import re
  7. import uuid
  8. import csv
  9. from apps.web.models.auths import (
  10. SigninForm,
  11. SignupForm,
  12. AddUserForm,
  13. UpdateProfileForm,
  14. UpdatePasswordForm,
  15. UserResponse,
  16. SigninResponse,
  17. Auths,
  18. ApiKey,
  19. )
  20. from apps.web.models.users import Users
  21. from utils.utils import (
  22. get_password_hash,
  23. get_current_user,
  24. get_admin_user,
  25. create_token,
  26. create_api_key,
  27. )
  28. from utils.misc import parse_duration, validate_email_format
  29. from utils.webhook import post_webhook
  30. from constants import ERROR_MESSAGES, WEBHOOK_MESSAGES
  31. from config import WEBUI_AUTH, WEBUI_AUTH_TRUSTED_EMAIL_HEADER, config_get, config_set
  32. router = APIRouter()
  33. ############################
  34. # GetSessionUser
  35. ############################
  36. @router.get("/", response_model=UserResponse)
  37. async def get_session_user(user=Depends(get_current_user)):
  38. return {
  39. "id": user.id,
  40. "email": user.email,
  41. "name": user.name,
  42. "role": user.role,
  43. "profile_image_url": user.profile_image_url,
  44. }
  45. ############################
  46. # Update Profile
  47. ############################
  48. @router.post("/update/profile", response_model=UserResponse)
  49. async def update_profile(
  50. form_data: UpdateProfileForm, session_user=Depends(get_current_user)
  51. ):
  52. if session_user:
  53. user = Users.update_user_by_id(
  54. session_user.id,
  55. {"profile_image_url": form_data.profile_image_url, "name": form_data.name},
  56. )
  57. if user:
  58. return user
  59. else:
  60. raise HTTPException(400, detail=ERROR_MESSAGES.DEFAULT())
  61. else:
  62. raise HTTPException(400, detail=ERROR_MESSAGES.INVALID_CRED)
  63. ############################
  64. # Update Password
  65. ############################
  66. @router.post("/update/password", response_model=bool)
  67. async def update_password(
  68. form_data: UpdatePasswordForm, session_user=Depends(get_current_user)
  69. ):
  70. if WEBUI_AUTH_TRUSTED_EMAIL_HEADER:
  71. raise HTTPException(400, detail=ERROR_MESSAGES.ACTION_PROHIBITED)
  72. if session_user:
  73. user = Auths.authenticate_user(session_user.email, form_data.password)
  74. if user:
  75. hashed = get_password_hash(form_data.new_password)
  76. return Auths.update_user_password_by_id(user.id, hashed)
  77. else:
  78. raise HTTPException(400, detail=ERROR_MESSAGES.INVALID_PASSWORD)
  79. else:
  80. raise HTTPException(400, detail=ERROR_MESSAGES.INVALID_CRED)
  81. ############################
  82. # SignIn
  83. ############################
  84. @router.post("/signin", response_model=SigninResponse)
  85. async def signin(request: Request, form_data: SigninForm):
  86. if WEBUI_AUTH_TRUSTED_EMAIL_HEADER:
  87. if WEBUI_AUTH_TRUSTED_EMAIL_HEADER not in request.headers:
  88. raise HTTPException(400, detail=ERROR_MESSAGES.INVALID_TRUSTED_HEADER)
  89. trusted_email = request.headers[WEBUI_AUTH_TRUSTED_EMAIL_HEADER].lower()
  90. if not Users.get_user_by_email(trusted_email.lower()):
  91. await signup(
  92. request,
  93. SignupForm(
  94. email=trusted_email, password=str(uuid.uuid4()), name=trusted_email
  95. ),
  96. )
  97. user = Auths.authenticate_user_by_trusted_header(trusted_email)
  98. elif WEBUI_AUTH == False:
  99. admin_email = "admin@localhost"
  100. admin_password = "admin"
  101. if Users.get_user_by_email(admin_email.lower()):
  102. user = Auths.authenticate_user(admin_email.lower(), admin_password)
  103. else:
  104. if Users.get_num_users() != 0:
  105. raise HTTPException(400, detail=ERROR_MESSAGES.EXISTING_USERS)
  106. await signup(
  107. request,
  108. SignupForm(email=admin_email, password=admin_password, name="User"),
  109. )
  110. user = Auths.authenticate_user(admin_email.lower(), admin_password)
  111. else:
  112. user = Auths.authenticate_user(form_data.email.lower(), form_data.password)
  113. if user:
  114. token = create_token(
  115. data={"id": user.id},
  116. expires_delta=parse_duration(config_get(request.app.state.JWT_EXPIRES_IN)),
  117. )
  118. return {
  119. "token": token,
  120. "token_type": "Bearer",
  121. "id": user.id,
  122. "email": user.email,
  123. "name": user.name,
  124. "role": user.role,
  125. "profile_image_url": user.profile_image_url,
  126. }
  127. else:
  128. raise HTTPException(400, detail=ERROR_MESSAGES.INVALID_CRED)
  129. ############################
  130. # SignUp
  131. ############################
  132. @router.post("/signup", response_model=SigninResponse)
  133. async def signup(request: Request, form_data: SignupForm):
  134. if not config_get(request.app.state.ENABLE_SIGNUP) and WEBUI_AUTH:
  135. raise HTTPException(
  136. status.HTTP_403_FORBIDDEN, detail=ERROR_MESSAGES.ACCESS_PROHIBITED
  137. )
  138. if not validate_email_format(form_data.email.lower()):
  139. raise HTTPException(
  140. status.HTTP_400_BAD_REQUEST, detail=ERROR_MESSAGES.INVALID_EMAIL_FORMAT
  141. )
  142. if Users.get_user_by_email(form_data.email.lower()):
  143. raise HTTPException(400, detail=ERROR_MESSAGES.EMAIL_TAKEN)
  144. try:
  145. role = (
  146. "admin"
  147. if Users.get_num_users() == 0
  148. else config_get(request.app.state.DEFAULT_USER_ROLE)
  149. )
  150. hashed = get_password_hash(form_data.password)
  151. user = Auths.insert_new_auth(
  152. form_data.email.lower(),
  153. hashed,
  154. form_data.name,
  155. form_data.profile_image_url,
  156. role,
  157. )
  158. if user:
  159. token = create_token(
  160. data={"id": user.id},
  161. expires_delta=parse_duration(
  162. config_get(request.app.state.JWT_EXPIRES_IN)
  163. ),
  164. )
  165. # response.set_cookie(key='token', value=token, httponly=True)
  166. if config_get(request.app.state.WEBHOOK_URL):
  167. post_webhook(
  168. config_get(request.app.state.WEBHOOK_URL),
  169. WEBHOOK_MESSAGES.USER_SIGNUP(user.name),
  170. {
  171. "action": "signup",
  172. "message": WEBHOOK_MESSAGES.USER_SIGNUP(user.name),
  173. "user": user.model_dump_json(exclude_none=True),
  174. },
  175. )
  176. return {
  177. "token": token,
  178. "token_type": "Bearer",
  179. "id": user.id,
  180. "email": user.email,
  181. "name": user.name,
  182. "role": user.role,
  183. "profile_image_url": user.profile_image_url,
  184. }
  185. else:
  186. raise HTTPException(500, detail=ERROR_MESSAGES.CREATE_USER_ERROR)
  187. except Exception as err:
  188. raise HTTPException(500, detail=ERROR_MESSAGES.DEFAULT(err))
  189. ############################
  190. # AddUser
  191. ############################
  192. @router.post("/add", response_model=SigninResponse)
  193. async def add_user(form_data: AddUserForm, user=Depends(get_admin_user)):
  194. if not validate_email_format(form_data.email.lower()):
  195. raise HTTPException(
  196. status.HTTP_400_BAD_REQUEST, detail=ERROR_MESSAGES.INVALID_EMAIL_FORMAT
  197. )
  198. if Users.get_user_by_email(form_data.email.lower()):
  199. raise HTTPException(400, detail=ERROR_MESSAGES.EMAIL_TAKEN)
  200. try:
  201. print(form_data)
  202. hashed = get_password_hash(form_data.password)
  203. user = Auths.insert_new_auth(
  204. form_data.email.lower(),
  205. hashed,
  206. form_data.name,
  207. form_data.profile_image_url,
  208. form_data.role,
  209. )
  210. if user:
  211. token = create_token(data={"id": user.id})
  212. return {
  213. "token": token,
  214. "token_type": "Bearer",
  215. "id": user.id,
  216. "email": user.email,
  217. "name": user.name,
  218. "role": user.role,
  219. "profile_image_url": user.profile_image_url,
  220. }
  221. else:
  222. raise HTTPException(500, detail=ERROR_MESSAGES.CREATE_USER_ERROR)
  223. except Exception as err:
  224. raise HTTPException(500, detail=ERROR_MESSAGES.DEFAULT(err))
  225. ############################
  226. # ToggleSignUp
  227. ############################
  228. @router.get("/signup/enabled", response_model=bool)
  229. async def get_sign_up_status(request: Request, user=Depends(get_admin_user)):
  230. return config_get(request.app.state.ENABLE_SIGNUP)
  231. @router.get("/signup/enabled/toggle", response_model=bool)
  232. async def toggle_sign_up(request: Request, user=Depends(get_admin_user)):
  233. config_set(
  234. request.app.state.ENABLE_SIGNUP, not config_get(request.app.state.ENABLE_SIGNUP)
  235. )
  236. return config_get(request.app.state.ENABLE_SIGNUP)
  237. ############################
  238. # Default User Role
  239. ############################
  240. @router.get("/signup/user/role")
  241. async def get_default_user_role(request: Request, user=Depends(get_admin_user)):
  242. return config_get(request.app.state.DEFAULT_USER_ROLE)
  243. class UpdateRoleForm(BaseModel):
  244. role: str
  245. @router.post("/signup/user/role")
  246. async def update_default_user_role(
  247. request: Request, form_data: UpdateRoleForm, user=Depends(get_admin_user)
  248. ):
  249. if form_data.role in ["pending", "user", "admin"]:
  250. config_set(request.app.state.DEFAULT_USER_ROLE, form_data.role)
  251. return config_get(request.app.state.DEFAULT_USER_ROLE)
  252. ############################
  253. # JWT Expiration
  254. ############################
  255. @router.get("/token/expires")
  256. async def get_token_expires_duration(request: Request, user=Depends(get_admin_user)):
  257. return config_get(request.app.state.JWT_EXPIRES_IN)
  258. class UpdateJWTExpiresDurationForm(BaseModel):
  259. duration: str
  260. @router.post("/token/expires/update")
  261. async def update_token_expires_duration(
  262. request: Request,
  263. form_data: UpdateJWTExpiresDurationForm,
  264. user=Depends(get_admin_user),
  265. ):
  266. pattern = r"^(-1|0|(-?\d+(\.\d+)?)(ms|s|m|h|d|w))$"
  267. # Check if the input string matches the pattern
  268. if re.match(pattern, form_data.duration):
  269. config_set(request.app.state.JWT_EXPIRES_IN, form_data.duration)
  270. return config_get(request.app.state.JWT_EXPIRES_IN)
  271. else:
  272. return config_get(request.app.state.JWT_EXPIRES_IN)
  273. ############################
  274. # API Key
  275. ############################
  276. # create api key
  277. @router.post("/api_key", response_model=ApiKey)
  278. async def create_api_key_(user=Depends(get_current_user)):
  279. api_key = create_api_key()
  280. success = Users.update_user_api_key_by_id(user.id, api_key)
  281. if success:
  282. return {
  283. "api_key": api_key,
  284. }
  285. else:
  286. raise HTTPException(500, detail=ERROR_MESSAGES.CREATE_API_KEY_ERROR)
  287. # delete api key
  288. @router.delete("/api_key", response_model=bool)
  289. async def delete_api_key(user=Depends(get_current_user)):
  290. success = Users.update_user_api_key_by_id(user.id, None)
  291. return success
  292. # get api key
  293. @router.get("/api_key", response_model=ApiKey)
  294. async def get_api_key(user=Depends(get_current_user)):
  295. api_key = Users.get_user_api_key_by_id(user.id)
  296. if api_key:
  297. return {
  298. "api_key": api_key,
  299. }
  300. else:
  301. raise HTTPException(404, detail=ERROR_MESSAGES.API_KEY_NOT_FOUND)