auths.py 13 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441
  1. import logging
  2. from fastapi import Request, UploadFile, File
  3. from fastapi import Depends, HTTPException, status
  4. from fastapi.responses import Response
  5. from fastapi import APIRouter
  6. from pydantic import BaseModel
  7. import re
  8. import uuid
  9. import csv
  10. from apps.webui.internal.db import get_db
  11. from apps.webui.models.auths import (
  12. SigninForm,
  13. SignupForm,
  14. AddUserForm,
  15. UpdateProfileForm,
  16. UpdatePasswordForm,
  17. UserResponse,
  18. SigninResponse,
  19. Auths,
  20. ApiKey,
  21. )
  22. from apps.webui.models.users import Users
  23. from utils.utils import (
  24. get_password_hash,
  25. get_current_user,
  26. get_admin_user,
  27. create_token,
  28. create_api_key,
  29. )
  30. from utils.misc import parse_duration, validate_email_format
  31. from utils.webhook import post_webhook
  32. from constants import ERROR_MESSAGES, WEBHOOK_MESSAGES
  33. from config import (
  34. WEBUI_AUTH,
  35. WEBUI_AUTH_TRUSTED_EMAIL_HEADER,
  36. WEBUI_AUTH_TRUSTED_NAME_HEADER,
  37. )
  38. router = APIRouter()
  39. ############################
  40. # GetSessionUser
  41. ############################
  42. @router.get("/", response_model=UserResponse)
  43. async def get_session_user(
  44. request: Request, response: Response, user=Depends(get_current_user)
  45. ):
  46. token = create_token(
  47. data={"id": user.id},
  48. expires_delta=parse_duration(request.app.state.config.JWT_EXPIRES_IN),
  49. )
  50. # Set the cookie token
  51. response.set_cookie(
  52. key="token",
  53. value=token,
  54. httponly=True, # Ensures the cookie is not accessible via JavaScript
  55. )
  56. return {
  57. "id": user.id,
  58. "email": user.email,
  59. "name": user.name,
  60. "role": user.role,
  61. "profile_image_url": user.profile_image_url,
  62. }
  63. ############################
  64. # Update Profile
  65. ############################
  66. @router.post("/update/profile", response_model=UserResponse)
  67. async def update_profile(
  68. form_data: UpdateProfileForm,
  69. session_user=Depends(get_current_user),
  70. db=Depends(get_db),
  71. ):
  72. if session_user:
  73. user = Users.update_user_by_id(
  74. db,
  75. session_user.id,
  76. {"profile_image_url": form_data.profile_image_url, "name": form_data.name},
  77. )
  78. if user:
  79. return user
  80. else:
  81. raise HTTPException(400, detail=ERROR_MESSAGES.DEFAULT())
  82. else:
  83. raise HTTPException(400, detail=ERROR_MESSAGES.INVALID_CRED)
  84. ############################
  85. # Update Password
  86. ############################
  87. @router.post("/update/password", response_model=bool)
  88. async def update_password(
  89. form_data: UpdatePasswordForm,
  90. session_user=Depends(get_current_user),
  91. db=Depends(get_db),
  92. ):
  93. if WEBUI_AUTH_TRUSTED_EMAIL_HEADER:
  94. raise HTTPException(400, detail=ERROR_MESSAGES.ACTION_PROHIBITED)
  95. if session_user:
  96. user = Auths.authenticate_user(db, session_user.email, form_data.password)
  97. if user:
  98. hashed = get_password_hash(form_data.new_password)
  99. return Auths.update_user_password_by_id(db, user.id, hashed)
  100. else:
  101. raise HTTPException(400, detail=ERROR_MESSAGES.INVALID_PASSWORD)
  102. else:
  103. raise HTTPException(400, detail=ERROR_MESSAGES.INVALID_CRED)
  104. ############################
  105. # SignIn
  106. ############################
  107. @router.post("/signin", response_model=SigninResponse)
  108. async def signin(request: Request, response: Response, form_data: SigninForm, db=Depends(get_db)):
  109. if WEBUI_AUTH_TRUSTED_EMAIL_HEADER:
  110. if WEBUI_AUTH_TRUSTED_EMAIL_HEADER not in request.headers:
  111. raise HTTPException(400, detail=ERROR_MESSAGES.INVALID_TRUSTED_HEADER)
  112. trusted_email = request.headers[WEBUI_AUTH_TRUSTED_EMAIL_HEADER].lower()
  113. trusted_name = trusted_email
  114. if WEBUI_AUTH_TRUSTED_NAME_HEADER:
  115. trusted_name = request.headers.get(
  116. WEBUI_AUTH_TRUSTED_NAME_HEADER, trusted_email
  117. )
  118. if not Users.get_user_by_email(db, trusted_email.lower()):
  119. await signup(
  120. request,
  121. SignupForm(
  122. email=trusted_email, password=str(uuid.uuid4()), name=trusted_name
  123. ),
  124. db,
  125. )
  126. user = Auths.authenticate_user_by_trusted_header(db, trusted_email)
  127. elif WEBUI_AUTH == False:
  128. admin_email = "admin@localhost"
  129. admin_password = "admin"
  130. if Users.get_user_by_email(db, admin_email.lower()):
  131. user = Auths.authenticate_user(db, admin_email.lower(), admin_password)
  132. else:
  133. if Users.get_num_users(db) != 0:
  134. raise HTTPException(400, detail=ERROR_MESSAGES.EXISTING_USERS)
  135. await signup(
  136. request,
  137. SignupForm(email=admin_email, password=admin_password, name="User"),
  138. db,
  139. )
  140. user = Auths.authenticate_user(db, admin_email.lower(), admin_password)
  141. else:
  142. user = Auths.authenticate_user(db, form_data.email.lower(), form_data.password)
  143. if user:
  144. token = create_token(
  145. data={"id": user.id},
  146. expires_delta=parse_duration(request.app.state.config.JWT_EXPIRES_IN),
  147. )
  148. # Set the cookie token
  149. response.set_cookie(
  150. key="token",
  151. value=token,
  152. httponly=True, # Ensures the cookie is not accessible via JavaScript
  153. )
  154. return {
  155. "token": token,
  156. "token_type": "Bearer",
  157. "id": user.id,
  158. "email": user.email,
  159. "name": user.name,
  160. "role": user.role,
  161. "profile_image_url": user.profile_image_url,
  162. }
  163. else:
  164. raise HTTPException(400, detail=ERROR_MESSAGES.INVALID_CRED)
  165. ############################
  166. # SignUp
  167. ############################
  168. @router.post("/signup", response_model=SigninResponse)
  169. async def signup(request: Request, response: Response, form_data: SignupForm, db=Depends(get_db)):
  170. if not request.app.state.config.ENABLE_SIGNUP and WEBUI_AUTH:
  171. raise HTTPException(
  172. status.HTTP_403_FORBIDDEN, detail=ERROR_MESSAGES.ACCESS_PROHIBITED
  173. )
  174. if not validate_email_format(form_data.email.lower()):
  175. raise HTTPException(
  176. status.HTTP_400_BAD_REQUEST, detail=ERROR_MESSAGES.INVALID_EMAIL_FORMAT
  177. )
  178. if Users.get_user_by_email(db, form_data.email.lower()):
  179. raise HTTPException(400, detail=ERROR_MESSAGES.EMAIL_TAKEN)
  180. try:
  181. role = (
  182. "admin"
  183. if Users.get_num_users(db) == 0
  184. else request.app.state.config.DEFAULT_USER_ROLE
  185. )
  186. hashed = get_password_hash(form_data.password)
  187. user = Auths.insert_new_auth(
  188. db,
  189. form_data.email.lower(),
  190. hashed,
  191. form_data.name,
  192. form_data.profile_image_url,
  193. role,
  194. )
  195. if user:
  196. token = create_token(
  197. data={"id": user.id},
  198. expires_delta=parse_duration(request.app.state.config.JWT_EXPIRES_IN),
  199. )
  200. # response.set_cookie(key='token', value=token, httponly=True)
  201. # Set the cookie token
  202. response.set_cookie(
  203. key="token",
  204. value=token,
  205. httponly=True, # Ensures the cookie is not accessible via JavaScript
  206. )
  207. if request.app.state.config.WEBHOOK_URL:
  208. post_webhook(
  209. request.app.state.config.WEBHOOK_URL,
  210. WEBHOOK_MESSAGES.USER_SIGNUP(user.name),
  211. {
  212. "action": "signup",
  213. "message": WEBHOOK_MESSAGES.USER_SIGNUP(user.name),
  214. "user": user.model_dump_json(exclude_none=True),
  215. },
  216. )
  217. return {
  218. "token": token,
  219. "token_type": "Bearer",
  220. "id": user.id,
  221. "email": user.email,
  222. "name": user.name,
  223. "role": user.role,
  224. "profile_image_url": user.profile_image_url,
  225. }
  226. else:
  227. raise HTTPException(500, detail=ERROR_MESSAGES.CREATE_USER_ERROR)
  228. except Exception as err:
  229. raise HTTPException(500, detail=ERROR_MESSAGES.DEFAULT(err))
  230. ############################
  231. # AddUser
  232. ############################
  233. @router.post("/add", response_model=SigninResponse)
  234. async def add_user(
  235. form_data: AddUserForm, user=Depends(get_admin_user), db=Depends(get_db)
  236. ):
  237. if not validate_email_format(form_data.email.lower()):
  238. raise HTTPException(
  239. status.HTTP_400_BAD_REQUEST, detail=ERROR_MESSAGES.INVALID_EMAIL_FORMAT
  240. )
  241. if Users.get_user_by_email(db, form_data.email.lower()):
  242. raise HTTPException(400, detail=ERROR_MESSAGES.EMAIL_TAKEN)
  243. try:
  244. print(form_data)
  245. hashed = get_password_hash(form_data.password)
  246. user = Auths.insert_new_auth(
  247. db,
  248. form_data.email.lower(),
  249. hashed,
  250. form_data.name,
  251. form_data.profile_image_url,
  252. form_data.role,
  253. )
  254. if user:
  255. token = create_token(data={"id": user.id})
  256. return {
  257. "token": token,
  258. "token_type": "Bearer",
  259. "id": user.id,
  260. "email": user.email,
  261. "name": user.name,
  262. "role": user.role,
  263. "profile_image_url": user.profile_image_url,
  264. }
  265. else:
  266. raise HTTPException(500, detail=ERROR_MESSAGES.CREATE_USER_ERROR)
  267. except Exception as err:
  268. raise HTTPException(500, detail=ERROR_MESSAGES.DEFAULT(err))
  269. ############################
  270. # GetAdminDetails
  271. ############################
  272. @router.get("/admin/details")
  273. async def get_admin_details(
  274. request: Request, user=Depends(get_current_user), db=Depends(get_db)
  275. ):
  276. if request.app.state.config.SHOW_ADMIN_DETAILS:
  277. admin_email = request.app.state.config.ADMIN_EMAIL
  278. admin_name = None
  279. print(admin_email, admin_name)
  280. if admin_email:
  281. admin = Users.get_user_by_email(db, admin_email)
  282. if admin:
  283. admin_name = admin.name
  284. else:
  285. admin = Users.get_first_user(db)
  286. if admin:
  287. admin_email = admin.email
  288. admin_name = admin.name
  289. return {
  290. "name": admin_name,
  291. "email": admin_email,
  292. }
  293. else:
  294. raise HTTPException(400, detail=ERROR_MESSAGES.ACTION_PROHIBITED)
  295. ############################
  296. # ToggleSignUp
  297. ############################
  298. @router.get("/admin/config")
  299. async def get_admin_config(request: Request, user=Depends(get_admin_user)):
  300. return {
  301. "SHOW_ADMIN_DETAILS": request.app.state.config.SHOW_ADMIN_DETAILS,
  302. "ENABLE_SIGNUP": request.app.state.config.ENABLE_SIGNUP,
  303. "DEFAULT_USER_ROLE": request.app.state.config.DEFAULT_USER_ROLE,
  304. "JWT_EXPIRES_IN": request.app.state.config.JWT_EXPIRES_IN,
  305. "ENABLE_COMMUNITY_SHARING": request.app.state.config.ENABLE_COMMUNITY_SHARING,
  306. }
  307. class AdminConfig(BaseModel):
  308. SHOW_ADMIN_DETAILS: bool
  309. ENABLE_SIGNUP: bool
  310. DEFAULT_USER_ROLE: str
  311. JWT_EXPIRES_IN: str
  312. ENABLE_COMMUNITY_SHARING: bool
  313. @router.post("/admin/config")
  314. async def update_admin_config(
  315. request: Request, form_data: AdminConfig, user=Depends(get_admin_user)
  316. ):
  317. request.app.state.config.SHOW_ADMIN_DETAILS = form_data.SHOW_ADMIN_DETAILS
  318. request.app.state.config.ENABLE_SIGNUP = form_data.ENABLE_SIGNUP
  319. if form_data.DEFAULT_USER_ROLE in ["pending", "user", "admin"]:
  320. request.app.state.config.DEFAULT_USER_ROLE = form_data.DEFAULT_USER_ROLE
  321. pattern = r"^(-1|0|(-?\d+(\.\d+)?)(ms|s|m|h|d|w))$"
  322. # Check if the input string matches the pattern
  323. if re.match(pattern, form_data.JWT_EXPIRES_IN):
  324. request.app.state.config.JWT_EXPIRES_IN = form_data.JWT_EXPIRES_IN
  325. request.app.state.config.ENABLE_COMMUNITY_SHARING = (
  326. form_data.ENABLE_COMMUNITY_SHARING
  327. )
  328. return {
  329. "SHOW_ADMIN_DETAILS": request.app.state.config.SHOW_ADMIN_DETAILS,
  330. "ENABLE_SIGNUP": request.app.state.config.ENABLE_SIGNUP,
  331. "DEFAULT_USER_ROLE": request.app.state.config.DEFAULT_USER_ROLE,
  332. "JWT_EXPIRES_IN": request.app.state.config.JWT_EXPIRES_IN,
  333. "ENABLE_COMMUNITY_SHARING": request.app.state.config.ENABLE_COMMUNITY_SHARING,
  334. }
  335. ############################
  336. # API Key
  337. ############################
  338. # create api key
  339. @router.post("/api_key", response_model=ApiKey)
  340. async def create_api_key_(user=Depends(get_current_user), db=Depends(get_db)):
  341. api_key = create_api_key()
  342. success = Users.update_user_api_key_by_id(db, user.id, api_key)
  343. if success:
  344. return {
  345. "api_key": api_key,
  346. }
  347. else:
  348. raise HTTPException(500, detail=ERROR_MESSAGES.CREATE_API_KEY_ERROR)
  349. # delete api key
  350. @router.delete("/api_key", response_model=bool)
  351. async def delete_api_key(user=Depends(get_current_user), db=Depends(get_db)):
  352. success = Users.update_user_api_key_by_id(db, user.id, None)
  353. return success
  354. # get api key
  355. @router.get("/api_key", response_model=ApiKey)
  356. async def get_api_key(user=Depends(get_current_user), db=Depends(get_db)):
  357. api_key = Users.get_user_api_key_by_id(db, user.id)
  358. if api_key:
  359. return {
  360. "api_key": api_key,
  361. }
  362. else:
  363. raise HTTPException(404, detail=ERROR_MESSAGES.API_KEY_NOT_FOUND)