auths.py 12 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399
  1. import logging
  2. from fastapi import Request, UploadFile, File
  3. from fastapi import Depends, HTTPException, status
  4. from fastapi import APIRouter
  5. from pydantic import BaseModel
  6. import re
  7. import uuid
  8. import csv
  9. from apps.webui.models.auths import (
  10. SigninForm,
  11. SignupForm,
  12. AddUserForm,
  13. UpdateProfileForm,
  14. UpdatePasswordForm,
  15. UserResponse,
  16. SigninResponse,
  17. Auths,
  18. ApiKey,
  19. )
  20. from apps.webui.models.users import Users
  21. from utils.utils import (
  22. get_password_hash,
  23. get_current_user,
  24. get_admin_user,
  25. create_token,
  26. create_api_key,
  27. )
  28. from utils.misc import parse_duration, validate_email_format
  29. from utils.webhook import post_webhook
  30. from constants import ERROR_MESSAGES, WEBHOOK_MESSAGES
  31. from config import (
  32. WEBUI_AUTH,
  33. WEBUI_AUTH_TRUSTED_EMAIL_HEADER,
  34. WEBUI_AUTH_TRUSTED_NAME_HEADER,
  35. )
  36. router = APIRouter()
  37. ############################
  38. # GetSessionUser
  39. ############################
  40. @router.get("/", response_model=UserResponse)
  41. async def get_session_user(user=Depends(get_current_user)):
  42. return {
  43. "id": user.id,
  44. "email": user.email,
  45. "name": user.name,
  46. "role": user.role,
  47. "profile_image_url": user.profile_image_url,
  48. }
  49. ############################
  50. # Update Profile
  51. ############################
  52. @router.post("/update/profile", response_model=UserResponse)
  53. async def update_profile(
  54. form_data: UpdateProfileForm, session_user=Depends(get_current_user)
  55. ):
  56. if session_user:
  57. user = Users.update_user_by_id(
  58. session_user.id,
  59. {"profile_image_url": form_data.profile_image_url, "name": form_data.name},
  60. )
  61. if user:
  62. return user
  63. else:
  64. raise HTTPException(400, detail=ERROR_MESSAGES.DEFAULT())
  65. else:
  66. raise HTTPException(400, detail=ERROR_MESSAGES.INVALID_CRED)
  67. ############################
  68. # Update Password
  69. ############################
  70. @router.post("/update/password", response_model=bool)
  71. async def update_password(
  72. form_data: UpdatePasswordForm, session_user=Depends(get_current_user)
  73. ):
  74. if WEBUI_AUTH_TRUSTED_EMAIL_HEADER:
  75. raise HTTPException(400, detail=ERROR_MESSAGES.ACTION_PROHIBITED)
  76. if session_user:
  77. user = Auths.authenticate_user(session_user.email, form_data.password)
  78. if user:
  79. hashed = get_password_hash(form_data.new_password)
  80. return Auths.update_user_password_by_id(user.id, hashed)
  81. else:
  82. raise HTTPException(400, detail=ERROR_MESSAGES.INVALID_PASSWORD)
  83. else:
  84. raise HTTPException(400, detail=ERROR_MESSAGES.INVALID_CRED)
  85. ############################
  86. # SignIn
  87. ############################
  88. @router.post("/signin", response_model=SigninResponse)
  89. async def signin(request: Request, form_data: SigninForm):
  90. if WEBUI_AUTH_TRUSTED_EMAIL_HEADER:
  91. if WEBUI_AUTH_TRUSTED_EMAIL_HEADER not in request.headers:
  92. raise HTTPException(400, detail=ERROR_MESSAGES.INVALID_TRUSTED_HEADER)
  93. trusted_email = request.headers[WEBUI_AUTH_TRUSTED_EMAIL_HEADER].lower()
  94. trusted_name = trusted_email
  95. if WEBUI_AUTH_TRUSTED_NAME_HEADER:
  96. trusted_name = request.headers.get(
  97. WEBUI_AUTH_TRUSTED_NAME_HEADER, trusted_email
  98. )
  99. if not Users.get_user_by_email(trusted_email.lower()):
  100. await signup(
  101. request,
  102. SignupForm(
  103. email=trusted_email, password=str(uuid.uuid4()), name=trusted_name
  104. ),
  105. )
  106. user = Auths.authenticate_user_by_trusted_header(trusted_email)
  107. elif WEBUI_AUTH == False:
  108. admin_email = "admin@localhost"
  109. admin_password = "admin"
  110. if Users.get_user_by_email(admin_email.lower()):
  111. user = Auths.authenticate_user(admin_email.lower(), admin_password)
  112. else:
  113. if Users.get_num_users() != 0:
  114. raise HTTPException(400, detail=ERROR_MESSAGES.EXISTING_USERS)
  115. await signup(
  116. request,
  117. SignupForm(email=admin_email, password=admin_password, name="User"),
  118. )
  119. user = Auths.authenticate_user(admin_email.lower(), admin_password)
  120. else:
  121. user = Auths.authenticate_user(form_data.email.lower(), form_data.password)
  122. if user:
  123. token = create_token(
  124. data={"id": user.id},
  125. expires_delta=parse_duration(request.app.state.config.JWT_EXPIRES_IN),
  126. )
  127. return {
  128. "token": token,
  129. "token_type": "Bearer",
  130. "id": user.id,
  131. "email": user.email,
  132. "name": user.name,
  133. "role": user.role,
  134. "profile_image_url": user.profile_image_url,
  135. }
  136. else:
  137. raise HTTPException(400, detail=ERROR_MESSAGES.INVALID_CRED)
  138. ############################
  139. # SignUp
  140. ############################
  141. @router.post("/signup", response_model=SigninResponse)
  142. async def signup(request: Request, form_data: SignupForm):
  143. if not request.app.state.config.ENABLE_SIGNUP and WEBUI_AUTH:
  144. raise HTTPException(
  145. status.HTTP_403_FORBIDDEN, detail=ERROR_MESSAGES.ACCESS_PROHIBITED
  146. )
  147. if not validate_email_format(form_data.email.lower()):
  148. raise HTTPException(
  149. status.HTTP_400_BAD_REQUEST, detail=ERROR_MESSAGES.INVALID_EMAIL_FORMAT
  150. )
  151. if Users.get_user_by_email(form_data.email.lower()):
  152. raise HTTPException(400, detail=ERROR_MESSAGES.EMAIL_TAKEN)
  153. try:
  154. role = (
  155. "admin"
  156. if Users.get_num_users() == 0
  157. else request.app.state.config.DEFAULT_USER_ROLE
  158. )
  159. hashed = get_password_hash(form_data.password)
  160. user = Auths.insert_new_auth(
  161. form_data.email.lower(),
  162. hashed,
  163. form_data.name,
  164. form_data.profile_image_url,
  165. role,
  166. )
  167. if user:
  168. token = create_token(
  169. data={"id": user.id},
  170. expires_delta=parse_duration(request.app.state.config.JWT_EXPIRES_IN),
  171. )
  172. # response.set_cookie(key='token', value=token, httponly=True)
  173. if request.app.state.config.WEBHOOK_URL:
  174. post_webhook(
  175. request.app.state.config.WEBHOOK_URL,
  176. WEBHOOK_MESSAGES.USER_SIGNUP(user.name),
  177. {
  178. "action": "signup",
  179. "message": WEBHOOK_MESSAGES.USER_SIGNUP(user.name),
  180. "user": user.model_dump_json(exclude_none=True),
  181. },
  182. )
  183. return {
  184. "token": token,
  185. "token_type": "Bearer",
  186. "id": user.id,
  187. "email": user.email,
  188. "name": user.name,
  189. "role": user.role,
  190. "profile_image_url": user.profile_image_url,
  191. }
  192. else:
  193. raise HTTPException(500, detail=ERROR_MESSAGES.CREATE_USER_ERROR)
  194. except Exception as err:
  195. raise HTTPException(500, detail=ERROR_MESSAGES.DEFAULT(err))
  196. ############################
  197. # AddUser
  198. ############################
  199. @router.post("/add", response_model=SigninResponse)
  200. async def add_user(form_data: AddUserForm, user=Depends(get_admin_user)):
  201. if not validate_email_format(form_data.email.lower()):
  202. raise HTTPException(
  203. status.HTTP_400_BAD_REQUEST, detail=ERROR_MESSAGES.INVALID_EMAIL_FORMAT
  204. )
  205. if Users.get_user_by_email(form_data.email.lower()):
  206. raise HTTPException(400, detail=ERROR_MESSAGES.EMAIL_TAKEN)
  207. try:
  208. print(form_data)
  209. hashed = get_password_hash(form_data.password)
  210. user = Auths.insert_new_auth(
  211. form_data.email.lower(),
  212. hashed,
  213. form_data.name,
  214. form_data.profile_image_url,
  215. form_data.role,
  216. )
  217. if user:
  218. token = create_token(data={"id": user.id})
  219. return {
  220. "token": token,
  221. "token_type": "Bearer",
  222. "id": user.id,
  223. "email": user.email,
  224. "name": user.name,
  225. "role": user.role,
  226. "profile_image_url": user.profile_image_url,
  227. }
  228. else:
  229. raise HTTPException(500, detail=ERROR_MESSAGES.CREATE_USER_ERROR)
  230. except Exception as err:
  231. raise HTTPException(500, detail=ERROR_MESSAGES.DEFAULT(err))
  232. ############################
  233. # GetAdminDetails
  234. ############################
  235. @router.get("/admin/details")
  236. async def get_admin_details(request: Request, user=Depends(get_current_user)):
  237. if request.app.state.config.SHOW_ADMIN_DETAILS:
  238. admin_email = request.app.state.config.ADMIN_EMAIL
  239. admin_name = None
  240. print(admin_email, admin_name)
  241. if admin_email:
  242. admin = Users.get_user_by_email(admin_email)
  243. if admin:
  244. admin_name = admin.name
  245. else:
  246. admin = Users.get_first_user()
  247. if admin:
  248. admin_email = admin.email
  249. admin_name = admin.name
  250. return {
  251. "name": admin_name,
  252. "email": admin_email,
  253. }
  254. else:
  255. raise HTTPException(400, detail=ERROR_MESSAGES.ACTION_PROHIBITED)
  256. ############################
  257. # ToggleSignUp
  258. ############################
  259. @router.get("/admin/config")
  260. async def get_admin_config(request: Request, user=Depends(get_admin_user)):
  261. return {
  262. "SHOW_ADMIN_DETAILS": request.app.state.config.SHOW_ADMIN_DETAILS,
  263. "ENABLE_SIGNUP": request.app.state.config.ENABLE_SIGNUP,
  264. "DEFAULT_USER_ROLE": request.app.state.config.DEFAULT_USER_ROLE,
  265. "JWT_EXPIRES_IN": request.app.state.config.JWT_EXPIRES_IN,
  266. "ENABLE_COMMUNITY_SHARING": request.app.state.config.ENABLE_COMMUNITY_SHARING,
  267. }
  268. class AdminConfig(BaseModel):
  269. SHOW_ADMIN_DETAILS: bool
  270. ENABLE_SIGNUP: bool
  271. DEFAULT_USER_ROLE: str
  272. JWT_EXPIRES_IN: str
  273. ENABLE_COMMUNITY_SHARING: bool
  274. @router.post("/admin/config")
  275. async def update_admin_config(
  276. request: Request, form_data: AdminConfig, user=Depends(get_admin_user)
  277. ):
  278. request.app.state.config.SHOW_ADMIN_DETAILS = form_data.SHOW_ADMIN_DETAILS
  279. request.app.state.config.ENABLE_SIGNUP = form_data.ENABLE_SIGNUP
  280. if form_data.DEFAULT_USER_ROLE in ["pending", "user", "admin"]:
  281. request.app.state.config.DEFAULT_USER_ROLE = form_data.DEFAULT_USER_ROLE
  282. pattern = r"^(-1|0|(-?\d+(\.\d+)?)(ms|s|m|h|d|w))$"
  283. # Check if the input string matches the pattern
  284. if re.match(pattern, form_data.JWT_EXPIRES_IN):
  285. request.app.state.config.JWT_EXPIRES_IN = form_data.JWT_EXPIRES_IN
  286. request.app.state.config.ENABLE_COMMUNITY_SHARING = (
  287. form_data.ENABLE_COMMUNITY_SHARING
  288. )
  289. return {
  290. "SHOW_ADMIN_DETAILS": request.app.state.config.SHOW_ADMIN_DETAILS,
  291. "ENABLE_SIGNUP": request.app.state.config.ENABLE_SIGNUP,
  292. "DEFAULT_USER_ROLE": request.app.state.config.DEFAULT_USER_ROLE,
  293. "JWT_EXPIRES_IN": request.app.state.config.JWT_EXPIRES_IN,
  294. "ENABLE_COMMUNITY_SHARING": request.app.state.config.ENABLE_COMMUNITY_SHARING,
  295. }
  296. ############################
  297. # API Key
  298. ############################
  299. # create api key
  300. @router.post("/api_key", response_model=ApiKey)
  301. async def create_api_key_(user=Depends(get_current_user)):
  302. api_key = create_api_key()
  303. success = Users.update_user_api_key_by_id(user.id, api_key)
  304. if success:
  305. return {
  306. "api_key": api_key,
  307. }
  308. else:
  309. raise HTTPException(500, detail=ERROR_MESSAGES.CREATE_API_KEY_ERROR)
  310. # delete api key
  311. @router.delete("/api_key", response_model=bool)
  312. async def delete_api_key(user=Depends(get_current_user)):
  313. success = Users.update_user_api_key_by_id(user.id, None)
  314. return success
  315. # get api key
  316. @router.get("/api_key", response_model=ApiKey)
  317. async def get_api_key(user=Depends(get_current_user)):
  318. api_key = Users.get_user_api_key_by_id(user.id)
  319. if api_key:
  320. return {
  321. "api_key": api_key,
  322. }
  323. else:
  324. raise HTTPException(404, detail=ERROR_MESSAGES.API_KEY_NOT_FOUND)