auths.py 11 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377
  1. import logging
  2. from fastapi import Request, UploadFile, File
  3. from fastapi import Depends, HTTPException, status
  4. from fastapi import APIRouter
  5. from pydantic import BaseModel
  6. import re
  7. import uuid
  8. import csv
  9. from apps.webui.models.auths import (
  10. SigninForm,
  11. SignupForm,
  12. AddUserForm,
  13. UpdateProfileForm,
  14. UpdatePasswordForm,
  15. UserResponse,
  16. SigninResponse,
  17. Auths,
  18. ApiKey,
  19. )
  20. from apps.webui.models.users import Users
  21. from utils.utils import (
  22. get_password_hash,
  23. get_current_user,
  24. get_admin_user,
  25. create_token,
  26. create_api_key,
  27. )
  28. from utils.misc import parse_duration, validate_email_format
  29. from utils.webhook import post_webhook
  30. from constants import ERROR_MESSAGES, WEBHOOK_MESSAGES
  31. from config import (
  32. WEBUI_AUTH,
  33. WEBUI_AUTH_TRUSTED_EMAIL_HEADER,
  34. )
  35. router = APIRouter()
  36. ############################
  37. # GetSessionUser
  38. ############################
  39. @router.get("/", response_model=UserResponse)
  40. async def get_session_user(user=Depends(get_current_user)):
  41. return {
  42. "id": user.id,
  43. "email": user.email,
  44. "name": user.name,
  45. "role": user.role,
  46. "profile_image_url": user.profile_image_url,
  47. }
  48. ############################
  49. # Update Profile
  50. ############################
  51. @router.post("/update/profile", response_model=UserResponse)
  52. async def update_profile(
  53. form_data: UpdateProfileForm, session_user=Depends(get_current_user)
  54. ):
  55. if session_user:
  56. user = Users.update_user_by_id(
  57. session_user.id,
  58. {"profile_image_url": form_data.profile_image_url, "name": form_data.name},
  59. )
  60. if user:
  61. return user
  62. else:
  63. raise HTTPException(400, detail=ERROR_MESSAGES.DEFAULT())
  64. else:
  65. raise HTTPException(400, detail=ERROR_MESSAGES.INVALID_CRED)
  66. ############################
  67. # Update Password
  68. ############################
  69. @router.post("/update/password", response_model=bool)
  70. async def update_password(
  71. form_data: UpdatePasswordForm, session_user=Depends(get_current_user)
  72. ):
  73. if WEBUI_AUTH_TRUSTED_EMAIL_HEADER:
  74. raise HTTPException(400, detail=ERROR_MESSAGES.ACTION_PROHIBITED)
  75. if session_user:
  76. user = Auths.authenticate_user(session_user.email, form_data.password)
  77. if user:
  78. hashed = get_password_hash(form_data.new_password)
  79. return Auths.update_user_password_by_id(user.id, hashed)
  80. else:
  81. raise HTTPException(400, detail=ERROR_MESSAGES.INVALID_PASSWORD)
  82. else:
  83. raise HTTPException(400, detail=ERROR_MESSAGES.INVALID_CRED)
  84. ############################
  85. # SignIn
  86. ############################
  87. @router.post("/signin", response_model=SigninResponse)
  88. async def signin(request: Request, form_data: SigninForm):
  89. if WEBUI_AUTH_TRUSTED_EMAIL_HEADER:
  90. if WEBUI_AUTH_TRUSTED_EMAIL_HEADER not in request.headers:
  91. raise HTTPException(400, detail=ERROR_MESSAGES.INVALID_TRUSTED_HEADER)
  92. trusted_email = request.headers[WEBUI_AUTH_TRUSTED_EMAIL_HEADER].lower()
  93. if not Users.get_user_by_email(trusted_email.lower()):
  94. await signup(
  95. request,
  96. SignupForm(
  97. email=trusted_email, password=str(uuid.uuid4()), name=trusted_email
  98. ),
  99. )
  100. user = Auths.authenticate_user_by_trusted_header(trusted_email)
  101. elif WEBUI_AUTH == False:
  102. admin_email = "admin@localhost"
  103. admin_password = "admin"
  104. if Users.get_user_by_email(admin_email.lower()):
  105. user = Auths.authenticate_user(admin_email.lower(), admin_password)
  106. else:
  107. if Users.get_num_users() != 0:
  108. raise HTTPException(400, detail=ERROR_MESSAGES.EXISTING_USERS)
  109. await signup(
  110. request,
  111. SignupForm(email=admin_email, password=admin_password, name="User"),
  112. )
  113. user = Auths.authenticate_user(admin_email.lower(), admin_password)
  114. else:
  115. user = Auths.authenticate_user(form_data.email.lower(), form_data.password)
  116. if user:
  117. token = create_token(
  118. data={"id": user.id},
  119. expires_delta=parse_duration(request.app.state.config.JWT_EXPIRES_IN),
  120. )
  121. return {
  122. "token": token,
  123. "token_type": "Bearer",
  124. "id": user.id,
  125. "email": user.email,
  126. "name": user.name,
  127. "role": user.role,
  128. "profile_image_url": user.profile_image_url,
  129. }
  130. else:
  131. raise HTTPException(400, detail=ERROR_MESSAGES.INVALID_CRED)
  132. ############################
  133. # SignUp
  134. ############################
  135. @router.post("/signup", response_model=SigninResponse)
  136. async def signup(request: Request, form_data: SignupForm):
  137. if not request.app.state.config.ENABLE_SIGNUP and WEBUI_AUTH:
  138. raise HTTPException(
  139. status.HTTP_403_FORBIDDEN, detail=ERROR_MESSAGES.ACCESS_PROHIBITED
  140. )
  141. if not validate_email_format(form_data.email.lower()):
  142. raise HTTPException(
  143. status.HTTP_400_BAD_REQUEST, detail=ERROR_MESSAGES.INVALID_EMAIL_FORMAT
  144. )
  145. if Users.get_user_by_email(form_data.email.lower()):
  146. raise HTTPException(400, detail=ERROR_MESSAGES.EMAIL_TAKEN)
  147. try:
  148. role = (
  149. "admin"
  150. if Users.get_num_users() == 0
  151. else request.app.state.config.DEFAULT_USER_ROLE
  152. )
  153. hashed = get_password_hash(form_data.password)
  154. user = Auths.insert_new_auth(
  155. form_data.email.lower(),
  156. hashed,
  157. form_data.name,
  158. form_data.profile_image_url,
  159. role,
  160. )
  161. if user:
  162. token = create_token(
  163. data={"id": user.id},
  164. expires_delta=parse_duration(request.app.state.config.JWT_EXPIRES_IN),
  165. )
  166. # response.set_cookie(key='token', value=token, httponly=True)
  167. if request.app.state.config.WEBHOOK_URL:
  168. post_webhook(
  169. request.app.state.config.WEBHOOK_URL,
  170. WEBHOOK_MESSAGES.USER_SIGNUP(user.name),
  171. {
  172. "action": "signup",
  173. "message": WEBHOOK_MESSAGES.USER_SIGNUP(user.name),
  174. "user": user.model_dump_json(exclude_none=True),
  175. },
  176. )
  177. return {
  178. "token": token,
  179. "token_type": "Bearer",
  180. "id": user.id,
  181. "email": user.email,
  182. "name": user.name,
  183. "role": user.role,
  184. "profile_image_url": user.profile_image_url,
  185. }
  186. else:
  187. raise HTTPException(500, detail=ERROR_MESSAGES.CREATE_USER_ERROR)
  188. except Exception as err:
  189. raise HTTPException(500, detail=ERROR_MESSAGES.DEFAULT(err))
  190. ############################
  191. # AddUser
  192. ############################
  193. @router.post("/add", response_model=SigninResponse)
  194. async def add_user(form_data: AddUserForm, user=Depends(get_admin_user)):
  195. if not validate_email_format(form_data.email.lower()):
  196. raise HTTPException(
  197. status.HTTP_400_BAD_REQUEST, detail=ERROR_MESSAGES.INVALID_EMAIL_FORMAT
  198. )
  199. if Users.get_user_by_email(form_data.email.lower()):
  200. raise HTTPException(400, detail=ERROR_MESSAGES.EMAIL_TAKEN)
  201. try:
  202. print(form_data)
  203. hashed = get_password_hash(form_data.password)
  204. user = Auths.insert_new_auth(
  205. form_data.email.lower(),
  206. hashed,
  207. form_data.name,
  208. form_data.profile_image_url,
  209. form_data.role,
  210. )
  211. if user:
  212. token = create_token(data={"id": user.id})
  213. return {
  214. "token": token,
  215. "token_type": "Bearer",
  216. "id": user.id,
  217. "email": user.email,
  218. "name": user.name,
  219. "role": user.role,
  220. "profile_image_url": user.profile_image_url,
  221. }
  222. else:
  223. raise HTTPException(500, detail=ERROR_MESSAGES.CREATE_USER_ERROR)
  224. except Exception as err:
  225. raise HTTPException(500, detail=ERROR_MESSAGES.DEFAULT(err))
  226. ############################
  227. # ToggleSignUp
  228. ############################
  229. @router.get("/signup/enabled", response_model=bool)
  230. async def get_sign_up_status(request: Request, user=Depends(get_admin_user)):
  231. return request.app.state.config.ENABLE_SIGNUP
  232. @router.get("/signup/enabled/toggle", response_model=bool)
  233. async def toggle_sign_up(request: Request, user=Depends(get_admin_user)):
  234. request.app.state.config.ENABLE_SIGNUP = not request.app.state.config.ENABLE_SIGNUP
  235. return request.app.state.config.ENABLE_SIGNUP
  236. ############################
  237. # Default User Role
  238. ############################
  239. @router.get("/signup/user/role")
  240. async def get_default_user_role(request: Request, user=Depends(get_admin_user)):
  241. return request.app.state.config.DEFAULT_USER_ROLE
  242. class UpdateRoleForm(BaseModel):
  243. role: str
  244. @router.post("/signup/user/role")
  245. async def update_default_user_role(
  246. request: Request, form_data: UpdateRoleForm, user=Depends(get_admin_user)
  247. ):
  248. if form_data.role in ["pending", "user", "admin"]:
  249. request.app.state.config.DEFAULT_USER_ROLE = form_data.role
  250. return request.app.state.config.DEFAULT_USER_ROLE
  251. ############################
  252. # JWT Expiration
  253. ############################
  254. @router.get("/token/expires")
  255. async def get_token_expires_duration(request: Request, user=Depends(get_admin_user)):
  256. return request.app.state.config.JWT_EXPIRES_IN
  257. class UpdateJWTExpiresDurationForm(BaseModel):
  258. duration: str
  259. @router.post("/token/expires/update")
  260. async def update_token_expires_duration(
  261. request: Request,
  262. form_data: UpdateJWTExpiresDurationForm,
  263. user=Depends(get_admin_user),
  264. ):
  265. pattern = r"^(-1|0|(-?\d+(\.\d+)?)(ms|s|m|h|d|w))$"
  266. # Check if the input string matches the pattern
  267. if re.match(pattern, form_data.duration):
  268. request.app.state.config.JWT_EXPIRES_IN = form_data.duration
  269. return request.app.state.config.JWT_EXPIRES_IN
  270. else:
  271. return request.app.state.config.JWT_EXPIRES_IN
  272. ############################
  273. # API Key
  274. ############################
  275. # create api key
  276. @router.post("/api_key", response_model=ApiKey)
  277. async def create_api_key_(user=Depends(get_current_user)):
  278. api_key = create_api_key()
  279. success = Users.update_user_api_key_by_id(user.id, api_key)
  280. if success:
  281. return {
  282. "api_key": api_key,
  283. }
  284. else:
  285. raise HTTPException(500, detail=ERROR_MESSAGES.CREATE_API_KEY_ERROR)
  286. # delete api key
  287. @router.delete("/api_key", response_model=bool)
  288. async def delete_api_key(user=Depends(get_current_user)):
  289. success = Users.update_user_api_key_by_id(user.id, None)
  290. return success
  291. # get api key
  292. @router.get("/api_key", response_model=ApiKey)
  293. async def get_api_key(user=Depends(get_current_user)):
  294. api_key = Users.get_user_api_key_by_id(user.id)
  295. if api_key:
  296. return {
  297. "api_key": api_key,
  298. }
  299. else:
  300. raise HTTPException(404, detail=ERROR_MESSAGES.API_KEY_NOT_FOUND)