chats.py 14 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471
  1. import json
  2. import logging
  3. from typing import Optional
  4. from apps.webui.models.chats import ChatForm, ChatResponse, Chats, ChatTitleIdResponse
  5. from apps.webui.models.tags import ChatIdTagForm, ChatIdTagModel, TagModel, Tags
  6. from config import ENABLE_ADMIN_CHAT_ACCESS, ENABLE_ADMIN_EXPORT
  7. from constants import ERROR_MESSAGES
  8. from env import SRC_LOG_LEVELS
  9. from fastapi import APIRouter, Depends, HTTPException, Request, status
  10. from pydantic import BaseModel
  11. from utils.utils import get_admin_user, get_verified_user
  12. log = logging.getLogger(__name__)
  13. log.setLevel(SRC_LOG_LEVELS["MODELS"])
  14. router = APIRouter()
  15. ############################
  16. # GetChatList
  17. ############################
  18. @router.get("/", response_model=list[ChatTitleIdResponse])
  19. @router.get("/list", response_model=list[ChatTitleIdResponse])
  20. async def get_session_user_chat_list(
  21. user=Depends(get_verified_user), page: Optional[int] = None
  22. ):
  23. if page is not None:
  24. limit = 60
  25. skip = (page - 1) * limit
  26. return Chats.get_chat_title_id_list_by_user_id(user.id, skip=skip, limit=limit)
  27. else:
  28. return Chats.get_chat_title_id_list_by_user_id(user.id)
  29. ############################
  30. # DeleteAllChats
  31. ############################
  32. @router.delete("/", response_model=bool)
  33. async def delete_all_user_chats(request: Request, user=Depends(get_verified_user)):
  34. if (
  35. user.role == "user"
  36. and not request.app.state.config.USER_PERMISSIONS["chat"]["deletion"]
  37. ):
  38. raise HTTPException(
  39. status_code=status.HTTP_401_UNAUTHORIZED,
  40. detail=ERROR_MESSAGES.ACCESS_PROHIBITED,
  41. )
  42. result = Chats.delete_chats_by_user_id(user.id)
  43. return result
  44. ############################
  45. # GetUserChatList
  46. ############################
  47. @router.get("/list/user/{user_id}", response_model=list[ChatTitleIdResponse])
  48. async def get_user_chat_list_by_user_id(
  49. user_id: str,
  50. user=Depends(get_admin_user),
  51. skip: int = 0,
  52. limit: int = 50,
  53. ):
  54. if not ENABLE_ADMIN_CHAT_ACCESS:
  55. raise HTTPException(
  56. status_code=status.HTTP_401_UNAUTHORIZED,
  57. detail=ERROR_MESSAGES.ACCESS_PROHIBITED,
  58. )
  59. return Chats.get_chat_list_by_user_id(
  60. user_id, include_archived=True, skip=skip, limit=limit
  61. )
  62. ############################
  63. # CreateNewChat
  64. ############################
  65. @router.post("/new", response_model=Optional[ChatResponse])
  66. async def create_new_chat(form_data: ChatForm, user=Depends(get_verified_user)):
  67. try:
  68. chat = Chats.insert_new_chat(user.id, form_data)
  69. return ChatResponse(**{**chat.model_dump(), "chat": json.loads(chat.chat)})
  70. except Exception as e:
  71. log.exception(e)
  72. raise HTTPException(
  73. status_code=status.HTTP_400_BAD_REQUEST, detail=ERROR_MESSAGES.DEFAULT()
  74. )
  75. ############################
  76. # GetChats
  77. ############################
  78. @router.get("/all", response_model=list[ChatResponse])
  79. async def get_user_chats(user=Depends(get_verified_user)):
  80. return [
  81. ChatResponse(**{**chat.model_dump(), "chat": json.loads(chat.chat)})
  82. for chat in Chats.get_chats_by_user_id(user.id)
  83. ]
  84. ############################
  85. # GetArchivedChats
  86. ############################
  87. @router.get("/all/archived", response_model=list[ChatResponse])
  88. async def get_user_archived_chats(user=Depends(get_verified_user)):
  89. return [
  90. ChatResponse(**{**chat.model_dump(), "chat": json.loads(chat.chat)})
  91. for chat in Chats.get_archived_chats_by_user_id(user.id)
  92. ]
  93. ############################
  94. # GetAllChatsInDB
  95. ############################
  96. @router.get("/all/db", response_model=list[ChatResponse])
  97. async def get_all_user_chats_in_db(user=Depends(get_admin_user)):
  98. if not ENABLE_ADMIN_EXPORT:
  99. raise HTTPException(
  100. status_code=status.HTTP_401_UNAUTHORIZED,
  101. detail=ERROR_MESSAGES.ACCESS_PROHIBITED,
  102. )
  103. return [
  104. ChatResponse(**{**chat.model_dump(), "chat": json.loads(chat.chat)})
  105. for chat in Chats.get_chats()
  106. ]
  107. ############################
  108. # GetArchivedChats
  109. ############################
  110. @router.get("/archived", response_model=list[ChatTitleIdResponse])
  111. async def get_archived_session_user_chat_list(
  112. user=Depends(get_verified_user), skip: int = 0, limit: int = 50
  113. ):
  114. return Chats.get_archived_chat_list_by_user_id(user.id, skip, limit)
  115. ############################
  116. # ArchiveAllChats
  117. ############################
  118. @router.post("/archive/all", response_model=bool)
  119. async def archive_all_chats(user=Depends(get_verified_user)):
  120. return Chats.archive_all_chats_by_user_id(user.id)
  121. ############################
  122. # GetSharedChatById
  123. ############################
  124. @router.get("/share/{share_id}", response_model=Optional[ChatResponse])
  125. async def get_shared_chat_by_id(share_id: str, user=Depends(get_verified_user)):
  126. if user.role == "pending":
  127. raise HTTPException(
  128. status_code=status.HTTP_401_UNAUTHORIZED, detail=ERROR_MESSAGES.NOT_FOUND
  129. )
  130. if user.role == "user" or (user.role == "admin" and not ENABLE_ADMIN_CHAT_ACCESS):
  131. chat = Chats.get_chat_by_share_id(share_id)
  132. elif user.role == "admin" and ENABLE_ADMIN_CHAT_ACCESS:
  133. chat = Chats.get_chat_by_id(share_id)
  134. if chat:
  135. return ChatResponse(**{**chat.model_dump(), "chat": json.loads(chat.chat)})
  136. else:
  137. raise HTTPException(
  138. status_code=status.HTTP_401_UNAUTHORIZED, detail=ERROR_MESSAGES.NOT_FOUND
  139. )
  140. ############################
  141. # GetChatsByTags
  142. ############################
  143. class TagNameForm(BaseModel):
  144. name: str
  145. skip: Optional[int] = 0
  146. limit: Optional[int] = 50
  147. @router.post("/tags", response_model=list[ChatTitleIdResponse])
  148. async def get_user_chat_list_by_tag_name(
  149. form_data: TagNameForm, user=Depends(get_verified_user)
  150. ):
  151. chat_ids = [
  152. chat_id_tag.chat_id
  153. for chat_id_tag in Tags.get_chat_ids_by_tag_name_and_user_id(
  154. form_data.name, user.id
  155. )
  156. ]
  157. chats = Chats.get_chat_list_by_chat_ids(chat_ids, form_data.skip, form_data.limit)
  158. if len(chats) == 0:
  159. Tags.delete_tag_by_tag_name_and_user_id(form_data.name, user.id)
  160. return chats
  161. ############################
  162. # GetAllTags
  163. ############################
  164. @router.get("/tags/all", response_model=list[TagModel])
  165. async def get_all_tags(user=Depends(get_verified_user)):
  166. try:
  167. tags = Tags.get_tags_by_user_id(user.id)
  168. return tags
  169. except Exception as e:
  170. log.exception(e)
  171. raise HTTPException(
  172. status_code=status.HTTP_400_BAD_REQUEST, detail=ERROR_MESSAGES.DEFAULT()
  173. )
  174. ############################
  175. # GetChatById
  176. ############################
  177. @router.get("/{id}", response_model=Optional[ChatResponse])
  178. async def get_chat_by_id(id: str, user=Depends(get_verified_user)):
  179. chat = Chats.get_chat_by_id_and_user_id(id, user.id)
  180. if chat:
  181. return ChatResponse(**{**chat.model_dump(), "chat": json.loads(chat.chat)})
  182. else:
  183. raise HTTPException(
  184. status_code=status.HTTP_401_UNAUTHORIZED, detail=ERROR_MESSAGES.NOT_FOUND
  185. )
  186. ############################
  187. # UpdateChatById
  188. ############################
  189. @router.post("/{id}", response_model=Optional[ChatResponse])
  190. async def update_chat_by_id(
  191. id: str, form_data: ChatForm, user=Depends(get_verified_user)
  192. ):
  193. chat = Chats.get_chat_by_id_and_user_id(id, user.id)
  194. if chat:
  195. updated_chat = {**json.loads(chat.chat), **form_data.chat}
  196. chat = Chats.update_chat_by_id(id, updated_chat)
  197. return ChatResponse(**{**chat.model_dump(), "chat": json.loads(chat.chat)})
  198. else:
  199. raise HTTPException(
  200. status_code=status.HTTP_401_UNAUTHORIZED,
  201. detail=ERROR_MESSAGES.ACCESS_PROHIBITED,
  202. )
  203. ############################
  204. # DeleteChatById
  205. ############################
  206. @router.delete("/{id}", response_model=bool)
  207. async def delete_chat_by_id(request: Request, id: str, user=Depends(get_verified_user)):
  208. if user.role == "admin":
  209. result = Chats.delete_chat_by_id(id)
  210. return result
  211. else:
  212. if not request.app.state.config.USER_PERMISSIONS["chat"]["deletion"]:
  213. raise HTTPException(
  214. status_code=status.HTTP_401_UNAUTHORIZED,
  215. detail=ERROR_MESSAGES.ACCESS_PROHIBITED,
  216. )
  217. result = Chats.delete_chat_by_id_and_user_id(id, user.id)
  218. return result
  219. ############################
  220. # CloneChat
  221. ############################
  222. @router.get("/{id}/clone", response_model=Optional[ChatResponse])
  223. async def clone_chat_by_id(id: str, user=Depends(get_verified_user)):
  224. chat = Chats.get_chat_by_id_and_user_id(id, user.id)
  225. if chat:
  226. chat_body = json.loads(chat.chat)
  227. updated_chat = {
  228. **chat_body,
  229. "originalChatId": chat.id,
  230. "branchPointMessageId": chat_body["history"]["currentId"],
  231. "title": f"Clone of {chat.title}",
  232. }
  233. chat = Chats.insert_new_chat(user.id, ChatForm(**{"chat": updated_chat}))
  234. return ChatResponse(**{**chat.model_dump(), "chat": json.loads(chat.chat)})
  235. else:
  236. raise HTTPException(
  237. status_code=status.HTTP_401_UNAUTHORIZED, detail=ERROR_MESSAGES.DEFAULT()
  238. )
  239. ############################
  240. # ArchiveChat
  241. ############################
  242. @router.get("/{id}/archive", response_model=Optional[ChatResponse])
  243. async def archive_chat_by_id(id: str, user=Depends(get_verified_user)):
  244. chat = Chats.get_chat_by_id_and_user_id(id, user.id)
  245. if chat:
  246. chat = Chats.toggle_chat_archive_by_id(id)
  247. return ChatResponse(**{**chat.model_dump(), "chat": json.loads(chat.chat)})
  248. else:
  249. raise HTTPException(
  250. status_code=status.HTTP_401_UNAUTHORIZED, detail=ERROR_MESSAGES.DEFAULT()
  251. )
  252. ############################
  253. # ShareChatById
  254. ############################
  255. @router.post("/{id}/share", response_model=Optional[ChatResponse])
  256. async def share_chat_by_id(id: str, user=Depends(get_verified_user)):
  257. chat = Chats.get_chat_by_id_and_user_id(id, user.id)
  258. if chat:
  259. if chat.share_id:
  260. shared_chat = Chats.update_shared_chat_by_chat_id(chat.id)
  261. return ChatResponse(
  262. **{**shared_chat.model_dump(), "chat": json.loads(shared_chat.chat)}
  263. )
  264. shared_chat = Chats.insert_shared_chat_by_chat_id(chat.id)
  265. if not shared_chat:
  266. raise HTTPException(
  267. status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
  268. detail=ERROR_MESSAGES.DEFAULT(),
  269. )
  270. return ChatResponse(
  271. **{**shared_chat.model_dump(), "chat": json.loads(shared_chat.chat)}
  272. )
  273. else:
  274. raise HTTPException(
  275. status_code=status.HTTP_401_UNAUTHORIZED,
  276. detail=ERROR_MESSAGES.ACCESS_PROHIBITED,
  277. )
  278. ############################
  279. # DeletedSharedChatById
  280. ############################
  281. @router.delete("/{id}/share", response_model=Optional[bool])
  282. async def delete_shared_chat_by_id(id: str, user=Depends(get_verified_user)):
  283. chat = Chats.get_chat_by_id_and_user_id(id, user.id)
  284. if chat:
  285. if not chat.share_id:
  286. return False
  287. result = Chats.delete_shared_chat_by_chat_id(id)
  288. update_result = Chats.update_chat_share_id_by_id(id, None)
  289. return result and update_result != None
  290. else:
  291. raise HTTPException(
  292. status_code=status.HTTP_401_UNAUTHORIZED,
  293. detail=ERROR_MESSAGES.ACCESS_PROHIBITED,
  294. )
  295. ############################
  296. # GetChatTagsById
  297. ############################
  298. @router.get("/{id}/tags", response_model=list[TagModel])
  299. async def get_chat_tags_by_id(id: str, user=Depends(get_verified_user)):
  300. tags = Tags.get_tags_by_chat_id_and_user_id(id, user.id)
  301. if tags != None:
  302. return tags
  303. else:
  304. raise HTTPException(
  305. status_code=status.HTTP_401_UNAUTHORIZED, detail=ERROR_MESSAGES.NOT_FOUND
  306. )
  307. ############################
  308. # AddChatTagById
  309. ############################
  310. @router.post("/{id}/tags", response_model=Optional[ChatIdTagModel])
  311. async def add_chat_tag_by_id(
  312. id: str, form_data: ChatIdTagForm, user=Depends(get_verified_user)
  313. ):
  314. tags = Tags.get_tags_by_chat_id_and_user_id(id, user.id)
  315. if form_data.tag_name not in tags:
  316. tag = Tags.add_tag_to_chat(user.id, form_data)
  317. if tag:
  318. return tag
  319. else:
  320. raise HTTPException(
  321. status_code=status.HTTP_401_UNAUTHORIZED,
  322. detail=ERROR_MESSAGES.NOT_FOUND,
  323. )
  324. else:
  325. raise HTTPException(
  326. status_code=status.HTTP_401_UNAUTHORIZED, detail=ERROR_MESSAGES.DEFAULT()
  327. )
  328. ############################
  329. # DeleteChatTagById
  330. ############################
  331. @router.delete("/{id}/tags", response_model=Optional[bool])
  332. async def delete_chat_tag_by_id(
  333. id: str, form_data: ChatIdTagForm, user=Depends(get_verified_user)
  334. ):
  335. result = Tags.delete_tag_by_tag_name_and_chat_id_and_user_id(
  336. form_data.tag_name, id, user.id
  337. )
  338. if result:
  339. return result
  340. else:
  341. raise HTTPException(
  342. status_code=status.HTTP_401_UNAUTHORIZED, detail=ERROR_MESSAGES.NOT_FOUND
  343. )
  344. ############################
  345. # DeleteAllChatTagsById
  346. ############################
  347. @router.delete("/{id}/tags/all", response_model=Optional[bool])
  348. async def delete_all_chat_tags_by_id(id: str, user=Depends(get_verified_user)):
  349. result = Tags.delete_tags_by_chat_id_and_user_id(id, user.id)
  350. if result:
  351. return result
  352. else:
  353. raise HTTPException(
  354. status_code=status.HTTP_401_UNAUTHORIZED, detail=ERROR_MESSAGES.NOT_FOUND
  355. )