users.py 7.6 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290
  1. from fastapi import Response, Request
  2. from fastapi import Depends, FastAPI, HTTPException, status
  3. from datetime import datetime, timedelta
  4. from typing import List, Union, Optional
  5. from fastapi import APIRouter
  6. from pydantic import BaseModel
  7. import time
  8. import uuid
  9. import logging
  10. from apps.webui.internal.db import get_db
  11. from apps.webui.models.users import (
  12. UserModel,
  13. UserUpdateForm,
  14. UserRoleUpdateForm,
  15. UserSettings,
  16. Users,
  17. )
  18. from apps.webui.models.auths import Auths
  19. from apps.webui.models.chats import Chats
  20. from utils.utils import (
  21. get_verified_user,
  22. get_password_hash,
  23. get_current_user,
  24. get_admin_user,
  25. )
  26. from constants import ERROR_MESSAGES
  27. from config import SRC_LOG_LEVELS
  28. log = logging.getLogger(__name__)
  29. log.setLevel(SRC_LOG_LEVELS["MODELS"])
  30. router = APIRouter()
  31. ############################
  32. # GetUsers
  33. ############################
  34. @router.get("/", response_model=List[UserModel])
  35. async def get_users(
  36. skip: int = 0, limit: int = 50, user=Depends(get_admin_user), db=Depends(get_db)
  37. ):
  38. return Users.get_users(db, skip, limit)
  39. ############################
  40. # User Permissions
  41. ############################
  42. @router.get("/permissions/user")
  43. async def get_user_permissions(request: Request, user=Depends(get_admin_user)):
  44. return request.app.state.config.USER_PERMISSIONS
  45. @router.post("/permissions/user")
  46. async def update_user_permissions(
  47. request: Request, form_data: dict, user=Depends(get_admin_user)
  48. ):
  49. request.app.state.config.USER_PERMISSIONS = form_data
  50. return request.app.state.config.USER_PERMISSIONS
  51. ############################
  52. # UpdateUserRole
  53. ############################
  54. @router.post("/update/role", response_model=Optional[UserModel])
  55. async def update_user_role(
  56. form_data: UserRoleUpdateForm, user=Depends(get_admin_user), db=Depends(get_db)
  57. ):
  58. if user.id != form_data.id and form_data.id != Users.get_first_user(db).id:
  59. return Users.update_user_role_by_id(db, form_data.id, form_data.role)
  60. raise HTTPException(
  61. status_code=status.HTTP_403_FORBIDDEN,
  62. detail=ERROR_MESSAGES.ACTION_PROHIBITED,
  63. )
  64. ############################
  65. # GetUserSettingsBySessionUser
  66. ############################
  67. @router.get("/user/settings", response_model=Optional[UserSettings])
  68. async def get_user_settings_by_session_user(
  69. user=Depends(get_verified_user), db=Depends(get_db)
  70. ):
  71. user = Users.get_user_by_id(db, user.id)
  72. if user:
  73. return user.settings
  74. else:
  75. raise HTTPException(
  76. status_code=status.HTTP_400_BAD_REQUEST,
  77. detail=ERROR_MESSAGES.USER_NOT_FOUND,
  78. )
  79. ############################
  80. # UpdateUserSettingsBySessionUser
  81. ############################
  82. @router.post("/user/settings/update", response_model=UserSettings)
  83. async def update_user_settings_by_session_user(
  84. form_data: UserSettings, user=Depends(get_verified_user), db=Depends(get_db)
  85. ):
  86. user = Users.update_user_by_id(db, user.id, {"settings": form_data.model_dump()})
  87. if user:
  88. return user.settings
  89. else:
  90. raise HTTPException(
  91. status_code=status.HTTP_400_BAD_REQUEST,
  92. detail=ERROR_MESSAGES.USER_NOT_FOUND,
  93. )
  94. ############################
  95. # GetUserInfoBySessionUser
  96. ############################
  97. @router.get("/user/info", response_model=Optional[dict])
  98. async def get_user_info_by_session_user(
  99. user=Depends(get_verified_user), db=Depends(get_db)
  100. ):
  101. user = Users.get_user_by_id(db, user.id)
  102. if user:
  103. return user.info
  104. else:
  105. raise HTTPException(
  106. status_code=status.HTTP_400_BAD_REQUEST,
  107. detail=ERROR_MESSAGES.USER_NOT_FOUND,
  108. )
  109. ############################
  110. # UpdateUserInfoBySessionUser
  111. ############################
  112. @router.post("/user/info/update", response_model=Optional[dict])
  113. async def update_user_info_by_session_user(
  114. form_data: dict, user=Depends(get_verified_user), db=Depends(get_db)
  115. ):
  116. user = Users.get_user_by_id(db, user.id)
  117. if user:
  118. if user.info is None:
  119. user.info = {}
  120. user = Users.update_user_by_id(
  121. db, user.id, {"info": {**user.info, **form_data}}
  122. )
  123. if user:
  124. return user.info
  125. else:
  126. raise HTTPException(
  127. status_code=status.HTTP_400_BAD_REQUEST,
  128. detail=ERROR_MESSAGES.USER_NOT_FOUND,
  129. )
  130. else:
  131. raise HTTPException(
  132. status_code=status.HTTP_400_BAD_REQUEST,
  133. detail=ERROR_MESSAGES.USER_NOT_FOUND,
  134. )
  135. ############################
  136. # GetUserById
  137. ############################
  138. class UserResponse(BaseModel):
  139. name: str
  140. profile_image_url: str
  141. @router.get("/{user_id}", response_model=UserResponse)
  142. async def get_user_by_id(
  143. user_id: str, user=Depends(get_verified_user), db=Depends(get_db)
  144. ):
  145. # Check if user_id is a shared chat
  146. # If it is, get the user_id from the chat
  147. if user_id.startswith("shared-"):
  148. chat_id = user_id.replace("shared-", "")
  149. chat = Chats.get_chat_by_id(db, chat_id)
  150. if chat:
  151. user_id = chat.user_id
  152. else:
  153. raise HTTPException(
  154. status_code=status.HTTP_400_BAD_REQUEST,
  155. detail=ERROR_MESSAGES.USER_NOT_FOUND,
  156. )
  157. user = Users.get_user_by_id(db, user_id)
  158. if user:
  159. return UserResponse(name=user.name, profile_image_url=user.profile_image_url)
  160. else:
  161. raise HTTPException(
  162. status_code=status.HTTP_400_BAD_REQUEST,
  163. detail=ERROR_MESSAGES.USER_NOT_FOUND,
  164. )
  165. ############################
  166. # UpdateUserById
  167. ############################
  168. @router.post("/{user_id}/update", response_model=Optional[UserModel])
  169. async def update_user_by_id(
  170. user_id: str,
  171. form_data: UserUpdateForm,
  172. session_user=Depends(get_admin_user),
  173. db=Depends(get_db),
  174. ):
  175. user = Users.get_user_by_id(db, user_id)
  176. if user:
  177. if form_data.email.lower() != user.email:
  178. email_user = Users.get_user_by_email(db, form_data.email.lower())
  179. if email_user:
  180. raise HTTPException(
  181. status_code=status.HTTP_400_BAD_REQUEST,
  182. detail=ERROR_MESSAGES.EMAIL_TAKEN,
  183. )
  184. if form_data.password:
  185. hashed = get_password_hash(form_data.password)
  186. log.debug(f"hashed: {hashed}")
  187. Auths.update_user_password_by_id(db, user_id, hashed)
  188. Auths.update_email_by_id(db, user_id, form_data.email.lower())
  189. updated_user = Users.update_user_by_id(
  190. db,
  191. user_id,
  192. {
  193. "name": form_data.name,
  194. "email": form_data.email.lower(),
  195. "profile_image_url": form_data.profile_image_url,
  196. },
  197. )
  198. if updated_user:
  199. return updated_user
  200. raise HTTPException(
  201. status_code=status.HTTP_400_BAD_REQUEST,
  202. detail=ERROR_MESSAGES.DEFAULT(),
  203. )
  204. raise HTTPException(
  205. status_code=status.HTTP_400_BAD_REQUEST,
  206. detail=ERROR_MESSAGES.USER_NOT_FOUND,
  207. )
  208. ############################
  209. # DeleteUserById
  210. ############################
  211. @router.delete("/{user_id}", response_model=bool)
  212. async def delete_user_by_id(
  213. user_id: str, user=Depends(get_admin_user), db=Depends(get_db)
  214. ):
  215. if user.id != user_id:
  216. result = Auths.delete_auth_by_id(db, user_id)
  217. if result:
  218. return True
  219. raise HTTPException(
  220. status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
  221. detail=ERROR_MESSAGES.DELETE_USER_ERROR,
  222. )
  223. raise HTTPException(
  224. status_code=status.HTTP_403_FORBIDDEN,
  225. detail=ERROR_MESSAGES.ACTION_PROHIBITED,
  226. )