auths.py 13 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433
  1. import logging
  2. from fastapi import Request, UploadFile, File
  3. from fastapi import Depends, HTTPException, status
  4. from fastapi.responses import Response
  5. from fastapi import APIRouter
  6. from pydantic import BaseModel
  7. import re
  8. import uuid
  9. import csv
  10. from apps.webui.models.auths import (
  11. SigninForm,
  12. SignupForm,
  13. AddUserForm,
  14. UpdateProfileForm,
  15. UpdatePasswordForm,
  16. UserResponse,
  17. SigninResponse,
  18. Auths,
  19. ApiKey,
  20. )
  21. from apps.webui.models.users import Users
  22. from utils.utils import (
  23. get_password_hash,
  24. get_current_user,
  25. get_admin_user,
  26. create_token,
  27. create_api_key,
  28. )
  29. from utils.misc import parse_duration, validate_email_format
  30. from utils.webhook import post_webhook
  31. from constants import ERROR_MESSAGES, WEBHOOK_MESSAGES
  32. from config import (
  33. WEBUI_AUTH,
  34. WEBUI_AUTH_TRUSTED_EMAIL_HEADER,
  35. WEBUI_AUTH_TRUSTED_NAME_HEADER,
  36. )
  37. router = APIRouter()
  38. ############################
  39. # GetSessionUser
  40. ############################
  41. @router.get("/", response_model=UserResponse)
  42. async def get_session_user(
  43. request: Request, response: Response, user=Depends(get_current_user)
  44. ):
  45. token = create_token(
  46. data={"id": user.id},
  47. expires_delta=parse_duration(request.app.state.config.JWT_EXPIRES_IN),
  48. )
  49. # Set the cookie token
  50. response.set_cookie(
  51. key="token",
  52. value=token,
  53. httponly=True, # Ensures the cookie is not accessible via JavaScript
  54. )
  55. return {
  56. "id": user.id,
  57. "email": user.email,
  58. "name": user.name,
  59. "role": user.role,
  60. "profile_image_url": user.profile_image_url,
  61. }
  62. ############################
  63. # Update Profile
  64. ############################
  65. @router.post("/update/profile", response_model=UserResponse)
  66. async def update_profile(
  67. form_data: UpdateProfileForm, session_user=Depends(get_current_user)
  68. ):
  69. if session_user:
  70. user = Users.update_user_by_id(
  71. session_user.id,
  72. {"profile_image_url": form_data.profile_image_url, "name": form_data.name},
  73. )
  74. if user:
  75. return user
  76. else:
  77. raise HTTPException(400, detail=ERROR_MESSAGES.DEFAULT())
  78. else:
  79. raise HTTPException(400, detail=ERROR_MESSAGES.INVALID_CRED)
  80. ############################
  81. # Update Password
  82. ############################
  83. @router.post("/update/password", response_model=bool)
  84. async def update_password(
  85. form_data: UpdatePasswordForm, session_user=Depends(get_current_user)
  86. ):
  87. if WEBUI_AUTH_TRUSTED_EMAIL_HEADER:
  88. raise HTTPException(400, detail=ERROR_MESSAGES.ACTION_PROHIBITED)
  89. if session_user:
  90. user = Auths.authenticate_user(session_user.email, form_data.password)
  91. if user:
  92. hashed = get_password_hash(form_data.new_password)
  93. return Auths.update_user_password_by_id(user.id, hashed)
  94. else:
  95. raise HTTPException(400, detail=ERROR_MESSAGES.INVALID_PASSWORD)
  96. else:
  97. raise HTTPException(400, detail=ERROR_MESSAGES.INVALID_CRED)
  98. ############################
  99. # SignIn
  100. ############################
  101. @router.post("/signin", response_model=SigninResponse)
  102. async def signin(request: Request, response: Response, form_data: SigninForm):
  103. if WEBUI_AUTH_TRUSTED_EMAIL_HEADER:
  104. if WEBUI_AUTH_TRUSTED_EMAIL_HEADER not in request.headers:
  105. raise HTTPException(400, detail=ERROR_MESSAGES.INVALID_TRUSTED_HEADER)
  106. trusted_email = request.headers[WEBUI_AUTH_TRUSTED_EMAIL_HEADER].lower()
  107. trusted_name = trusted_email
  108. if WEBUI_AUTH_TRUSTED_NAME_HEADER:
  109. trusted_name = request.headers.get(
  110. WEBUI_AUTH_TRUSTED_NAME_HEADER, trusted_email
  111. )
  112. if not Users.get_user_by_email(trusted_email.lower()):
  113. await signup(
  114. request,
  115. response,
  116. SignupForm(
  117. email=trusted_email, password=str(uuid.uuid4()), name=trusted_name
  118. ),
  119. )
  120. user = Auths.authenticate_user_by_trusted_header(trusted_email)
  121. elif WEBUI_AUTH == False:
  122. admin_email = "admin@localhost"
  123. admin_password = "admin"
  124. if Users.get_user_by_email(admin_email.lower()):
  125. user = Auths.authenticate_user(admin_email.lower(), admin_password)
  126. else:
  127. if Users.get_num_users() != 0:
  128. raise HTTPException(400, detail=ERROR_MESSAGES.EXISTING_USERS)
  129. await signup(
  130. request,
  131. response,
  132. SignupForm(email=admin_email, password=admin_password, name="User"),
  133. )
  134. user = Auths.authenticate_user(admin_email.lower(), admin_password)
  135. else:
  136. user = Auths.authenticate_user(form_data.email.lower(), form_data.password)
  137. if user:
  138. token = create_token(
  139. data={"id": user.id},
  140. expires_delta=parse_duration(request.app.state.config.JWT_EXPIRES_IN),
  141. )
  142. # Set the cookie token
  143. response.set_cookie(
  144. key="token",
  145. value=token,
  146. httponly=True, # Ensures the cookie is not accessible via JavaScript
  147. )
  148. return {
  149. "token": token,
  150. "token_type": "Bearer",
  151. "id": user.id,
  152. "email": user.email,
  153. "name": user.name,
  154. "role": user.role,
  155. "profile_image_url": user.profile_image_url,
  156. }
  157. else:
  158. raise HTTPException(400, detail=ERROR_MESSAGES.INVALID_CRED)
  159. ############################
  160. # SignUp
  161. ############################
  162. @router.post("/signup", response_model=SigninResponse)
  163. async def signup(request: Request, response: Response, form_data: SignupForm):
  164. if not request.app.state.config.ENABLE_SIGNUP and WEBUI_AUTH:
  165. raise HTTPException(
  166. status.HTTP_403_FORBIDDEN, detail=ERROR_MESSAGES.ACCESS_PROHIBITED
  167. )
  168. if not validate_email_format(form_data.email.lower()):
  169. raise HTTPException(
  170. status.HTTP_400_BAD_REQUEST, detail=ERROR_MESSAGES.INVALID_EMAIL_FORMAT
  171. )
  172. if Users.get_user_by_email(form_data.email.lower()):
  173. raise HTTPException(400, detail=ERROR_MESSAGES.EMAIL_TAKEN)
  174. try:
  175. role = (
  176. "admin"
  177. if Users.get_num_users() == 0
  178. else request.app.state.config.DEFAULT_USER_ROLE
  179. )
  180. hashed = get_password_hash(form_data.password)
  181. user = Auths.insert_new_auth(
  182. form_data.email.lower(),
  183. hashed,
  184. form_data.name,
  185. form_data.profile_image_url,
  186. role,
  187. )
  188. if user:
  189. token = create_token(
  190. data={"id": user.id},
  191. expires_delta=parse_duration(request.app.state.config.JWT_EXPIRES_IN),
  192. )
  193. # response.set_cookie(key='token', value=token, httponly=True)
  194. # Set the cookie token
  195. response.set_cookie(
  196. key="token",
  197. value=token,
  198. httponly=True, # Ensures the cookie is not accessible via JavaScript
  199. )
  200. if request.app.state.config.WEBHOOK_URL:
  201. post_webhook(
  202. request.app.state.config.WEBHOOK_URL,
  203. WEBHOOK_MESSAGES.USER_SIGNUP(user.name),
  204. {
  205. "action": "signup",
  206. "message": WEBHOOK_MESSAGES.USER_SIGNUP(user.name),
  207. "user": user.model_dump_json(exclude_none=True),
  208. },
  209. )
  210. return {
  211. "token": token,
  212. "token_type": "Bearer",
  213. "id": user.id,
  214. "email": user.email,
  215. "name": user.name,
  216. "role": user.role,
  217. "profile_image_url": user.profile_image_url,
  218. }
  219. else:
  220. raise HTTPException(500, detail=ERROR_MESSAGES.CREATE_USER_ERROR)
  221. except Exception as err:
  222. raise HTTPException(500, detail=ERROR_MESSAGES.DEFAULT(err))
  223. ############################
  224. # AddUser
  225. ############################
  226. @router.post("/add", response_model=SigninResponse)
  227. async def add_user(form_data: AddUserForm, user=Depends(get_admin_user)):
  228. if not validate_email_format(form_data.email.lower()):
  229. raise HTTPException(
  230. status.HTTP_400_BAD_REQUEST, detail=ERROR_MESSAGES.INVALID_EMAIL_FORMAT
  231. )
  232. if Users.get_user_by_email(form_data.email.lower()):
  233. raise HTTPException(400, detail=ERROR_MESSAGES.EMAIL_TAKEN)
  234. try:
  235. print(form_data)
  236. hashed = get_password_hash(form_data.password)
  237. user = Auths.insert_new_auth(
  238. form_data.email.lower(),
  239. hashed,
  240. form_data.name,
  241. form_data.profile_image_url,
  242. form_data.role,
  243. )
  244. if user:
  245. token = create_token(data={"id": user.id})
  246. return {
  247. "token": token,
  248. "token_type": "Bearer",
  249. "id": user.id,
  250. "email": user.email,
  251. "name": user.name,
  252. "role": user.role,
  253. "profile_image_url": user.profile_image_url,
  254. }
  255. else:
  256. raise HTTPException(500, detail=ERROR_MESSAGES.CREATE_USER_ERROR)
  257. except Exception as err:
  258. raise HTTPException(500, detail=ERROR_MESSAGES.DEFAULT(err))
  259. ############################
  260. # GetAdminDetails
  261. ############################
  262. @router.get("/admin/details")
  263. async def get_admin_details(request: Request, user=Depends(get_current_user)):
  264. if request.app.state.config.SHOW_ADMIN_DETAILS:
  265. admin_email = request.app.state.config.ADMIN_EMAIL
  266. admin_name = None
  267. print(admin_email, admin_name)
  268. if admin_email:
  269. admin = Users.get_user_by_email(admin_email)
  270. if admin:
  271. admin_name = admin.name
  272. else:
  273. admin = Users.get_first_user()
  274. if admin:
  275. admin_email = admin.email
  276. admin_name = admin.name
  277. return {
  278. "name": admin_name,
  279. "email": admin_email,
  280. }
  281. else:
  282. raise HTTPException(400, detail=ERROR_MESSAGES.ACTION_PROHIBITED)
  283. ############################
  284. # ToggleSignUp
  285. ############################
  286. @router.get("/admin/config")
  287. async def get_admin_config(request: Request, user=Depends(get_admin_user)):
  288. return {
  289. "SHOW_ADMIN_DETAILS": request.app.state.config.SHOW_ADMIN_DETAILS,
  290. "ENABLE_SIGNUP": request.app.state.config.ENABLE_SIGNUP,
  291. "DEFAULT_USER_ROLE": request.app.state.config.DEFAULT_USER_ROLE,
  292. "JWT_EXPIRES_IN": request.app.state.config.JWT_EXPIRES_IN,
  293. "ENABLE_COMMUNITY_SHARING": request.app.state.config.ENABLE_COMMUNITY_SHARING,
  294. "ENABLE_MESSAGE_RATING": request.app.state.config.ENABLE_MESSAGE_RATING,
  295. }
  296. class AdminConfig(BaseModel):
  297. SHOW_ADMIN_DETAILS: bool
  298. ENABLE_SIGNUP: bool
  299. DEFAULT_USER_ROLE: str
  300. JWT_EXPIRES_IN: str
  301. ENABLE_COMMUNITY_SHARING: bool
  302. ENABLE_MESSAGE_RATING: bool
  303. @router.post("/admin/config")
  304. async def update_admin_config(
  305. request: Request, form_data: AdminConfig, user=Depends(get_admin_user)
  306. ):
  307. request.app.state.config.SHOW_ADMIN_DETAILS = form_data.SHOW_ADMIN_DETAILS
  308. request.app.state.config.ENABLE_SIGNUP = form_data.ENABLE_SIGNUP
  309. if form_data.DEFAULT_USER_ROLE in ["pending", "user", "admin"]:
  310. request.app.state.config.DEFAULT_USER_ROLE = form_data.DEFAULT_USER_ROLE
  311. pattern = r"^(-1|0|(-?\d+(\.\d+)?)(ms|s|m|h|d|w))$"
  312. # Check if the input string matches the pattern
  313. if re.match(pattern, form_data.JWT_EXPIRES_IN):
  314. request.app.state.config.JWT_EXPIRES_IN = form_data.JWT_EXPIRES_IN
  315. request.app.state.config.ENABLE_COMMUNITY_SHARING = (
  316. form_data.ENABLE_COMMUNITY_SHARING
  317. )
  318. request.app.state.config.ENABLE_MESSAGE_RATING = form_data.ENABLE_MESSAGE_RATING
  319. return {
  320. "SHOW_ADMIN_DETAILS": request.app.state.config.SHOW_ADMIN_DETAILS,
  321. "ENABLE_SIGNUP": request.app.state.config.ENABLE_SIGNUP,
  322. "DEFAULT_USER_ROLE": request.app.state.config.DEFAULT_USER_ROLE,
  323. "JWT_EXPIRES_IN": request.app.state.config.JWT_EXPIRES_IN,
  324. "ENABLE_COMMUNITY_SHARING": request.app.state.config.ENABLE_COMMUNITY_SHARING,
  325. "ENABLE_MESSAGE_RATING": request.app.state.config.ENABLE_MESSAGE_RATING,
  326. }
  327. ############################
  328. # API Key
  329. ############################
  330. # create api key
  331. @router.post("/api_key", response_model=ApiKey)
  332. async def create_api_key_(user=Depends(get_current_user)):
  333. api_key = create_api_key()
  334. success = Users.update_user_api_key_by_id(user.id, api_key)
  335. if success:
  336. return {
  337. "api_key": api_key,
  338. }
  339. else:
  340. raise HTTPException(500, detail=ERROR_MESSAGES.CREATE_API_KEY_ERROR)
  341. # delete api key
  342. @router.delete("/api_key", response_model=bool)
  343. async def delete_api_key(user=Depends(get_current_user)):
  344. success = Users.update_user_api_key_by_id(user.id, None)
  345. return success
  346. # get api key
  347. @router.get("/api_key", response_model=ApiKey)
  348. async def get_api_key(user=Depends(get_current_user)):
  349. api_key = Users.get_user_api_key_by_id(user.id)
  350. if api_key:
  351. return {
  352. "api_key": api_key,
  353. }
  354. else:
  355. raise HTTPException(404, detail=ERROR_MESSAGES.API_KEY_NOT_FOUND)