chats.py 12 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406
  1. from pydantic import BaseModel, ConfigDict
  2. from typing import List, Union, Optional
  3. import json
  4. import uuid
  5. import time
  6. from sqlalchemy import Column, String, BigInteger, Boolean, Text
  7. from apps.webui.internal.db import Base, get_db
  8. ####################
  9. # Chat DB Schema
  10. ####################
  11. class Chat(Base):
  12. __tablename__ = "chat"
  13. id = Column(String, primary_key=True)
  14. user_id = Column(String)
  15. title = Column(Text)
  16. chat = Column(Text) # Save Chat JSON as Text
  17. created_at = Column(BigInteger)
  18. updated_at = Column(BigInteger)
  19. share_id = Column(Text, unique=True, nullable=True)
  20. archived = Column(Boolean, default=False)
  21. class ChatModel(BaseModel):
  22. model_config = ConfigDict(from_attributes=True)
  23. id: str
  24. user_id: str
  25. title: str
  26. chat: str
  27. created_at: int # timestamp in epoch
  28. updated_at: int # timestamp in epoch
  29. share_id: Optional[str] = None
  30. archived: bool = False
  31. ####################
  32. # Forms
  33. ####################
  34. class ChatForm(BaseModel):
  35. chat: dict
  36. class ChatTitleForm(BaseModel):
  37. title: str
  38. class ChatResponse(BaseModel):
  39. id: str
  40. user_id: str
  41. title: str
  42. chat: dict
  43. updated_at: int # timestamp in epoch
  44. created_at: int # timestamp in epoch
  45. share_id: Optional[str] = None # id of the chat to be shared
  46. archived: bool
  47. class ChatTitleIdResponse(BaseModel):
  48. id: str
  49. title: str
  50. updated_at: int
  51. created_at: int
  52. class ChatTable:
  53. def insert_new_chat(self, user_id: str, form_data: ChatForm) -> Optional[ChatModel]:
  54. with get_db() as db:
  55. id = str(uuid.uuid4())
  56. chat = ChatModel(
  57. **{
  58. "id": id,
  59. "user_id": user_id,
  60. "title": (
  61. form_data.chat["title"]
  62. if "title" in form_data.chat
  63. else "New Chat"
  64. ),
  65. "chat": json.dumps(form_data.chat),
  66. "created_at": int(time.time()),
  67. "updated_at": int(time.time()),
  68. }
  69. )
  70. result = Chat(**chat.model_dump())
  71. db.add(result)
  72. db.commit()
  73. db.refresh(result)
  74. return ChatModel.model_validate(result) if result else None
  75. def update_chat_by_id(self, id: str, chat: dict) -> Optional[ChatModel]:
  76. try:
  77. with get_db() as db:
  78. chat_obj = db.get(Chat, id)
  79. chat_obj.chat = json.dumps(chat)
  80. chat_obj.title = chat["title"] if "title" in chat else "New Chat"
  81. chat_obj.updated_at = int(time.time())
  82. db.commit()
  83. db.refresh(chat_obj)
  84. return ChatModel.model_validate(chat_obj)
  85. except Exception as e:
  86. return None
  87. def insert_shared_chat_by_chat_id(self, chat_id: str) -> Optional[ChatModel]:
  88. with get_db() as db:
  89. # Get the existing chat to share
  90. chat = db.get(Chat, chat_id)
  91. # Check if the chat is already shared
  92. if chat.share_id:
  93. return self.get_chat_by_id_and_user_id(chat.share_id, "shared")
  94. # Create a new chat with the same data, but with a new ID
  95. shared_chat = ChatModel(
  96. **{
  97. "id": str(uuid.uuid4()),
  98. "user_id": f"shared-{chat_id}",
  99. "title": chat.title,
  100. "chat": chat.chat,
  101. "created_at": chat.created_at,
  102. "updated_at": int(time.time()),
  103. }
  104. )
  105. shared_result = Chat(**shared_chat.model_dump())
  106. db.add(shared_result)
  107. db.commit()
  108. db.refresh(shared_result)
  109. # Update the original chat with the share_id
  110. result = (
  111. db.query(Chat)
  112. .filter_by(id=chat_id)
  113. .update({"share_id": shared_chat.id})
  114. )
  115. db.commit()
  116. return shared_chat if (shared_result and result) else None
  117. def update_shared_chat_by_chat_id(self, chat_id: str) -> Optional[ChatModel]:
  118. try:
  119. with get_db() as db:
  120. print("update_shared_chat_by_id")
  121. chat = db.get(Chat, chat_id)
  122. print(chat)
  123. chat.title = chat.title
  124. chat.chat = chat.chat
  125. db.commit()
  126. db.refresh(chat)
  127. return self.get_chat_by_id(chat.share_id)
  128. except:
  129. return None
  130. def delete_shared_chat_by_chat_id(self, chat_id: str) -> bool:
  131. try:
  132. with get_db() as db:
  133. db.query(Chat).filter_by(user_id=f"shared-{chat_id}").delete()
  134. db.commit()
  135. return True
  136. except:
  137. return False
  138. def update_chat_share_id_by_id(
  139. self, id: str, share_id: Optional[str]
  140. ) -> Optional[ChatModel]:
  141. try:
  142. with get_db() as db:
  143. chat = db.get(Chat, id)
  144. chat.share_id = share_id
  145. db.commit()
  146. db.refresh(chat)
  147. return ChatModel.model_validate(chat)
  148. except:
  149. return None
  150. def toggle_chat_archive_by_id(self, id: str) -> Optional[ChatModel]:
  151. try:
  152. with get_db() as db:
  153. chat = db.get(Chat, id)
  154. chat.archived = not chat.archived
  155. db.commit()
  156. db.refresh(chat)
  157. return ChatModel.model_validate(chat)
  158. except:
  159. return None
  160. def archive_all_chats_by_user_id(self, user_id: str) -> bool:
  161. try:
  162. with get_db() as db:
  163. db.query(Chat).filter_by(user_id=user_id).update({"archived": True})
  164. db.commit()
  165. return True
  166. except:
  167. return False
  168. def get_archived_chat_list_by_user_id(
  169. self, user_id: str, skip: int = 0, limit: int = 50
  170. ) -> List[ChatModel]:
  171. with get_db() as db:
  172. all_chats = (
  173. db.query(Chat)
  174. .filter_by(user_id=user_id, archived=True)
  175. .order_by(Chat.updated_at.desc())
  176. # .limit(limit).offset(skip)
  177. .all()
  178. )
  179. return [ChatModel.model_validate(chat) for chat in all_chats]
  180. def get_chat_list_by_user_id(
  181. self,
  182. user_id: str,
  183. include_archived: bool = False,
  184. skip: int = 0,
  185. limit: int = 50,
  186. ) -> List[ChatModel]:
  187. with get_db() as db:
  188. query = db.query(Chat).filter_by(user_id=user_id)
  189. if not include_archived:
  190. query = query.filter_by(archived=False)
  191. all_chats = (
  192. query.order_by(Chat.updated_at.desc())
  193. # .limit(limit).offset(skip)
  194. .all()
  195. )
  196. return [ChatModel.model_validate(chat) for chat in all_chats]
  197. def get_chat_title_id_list_by_user_id(
  198. self,
  199. user_id: str,
  200. include_archived: bool = False,
  201. skip: int = 0,
  202. limit: int = 50,
  203. ) -> List[ChatTitleIdResponse]:
  204. with get_db() as db:
  205. query = db.query(Chat).filter_by(user_id=user_id)
  206. if not include_archived:
  207. query = query.filter_by(archived=False)
  208. all_chats = (
  209. query.order_by(Chat.updated_at.desc())
  210. # limit cols
  211. .with_entities(
  212. Chat.id, Chat.title, Chat.updated_at, Chat.created_at
  213. ).all()
  214. )
  215. # result has to be destrctured from sqlalchemy `row` and mapped to a dict since the `ChatModel`is not the returned dataclass.
  216. return [
  217. ChatTitleIdResponse.model_validate(
  218. {
  219. "id": chat[0],
  220. "title": chat[1],
  221. "updated_at": chat[2],
  222. "created_at": chat[3],
  223. }
  224. )
  225. for chat in all_chats
  226. ]
  227. def get_chat_list_by_chat_ids(
  228. self, chat_ids: List[str], skip: int = 0, limit: int = 50
  229. ) -> List[ChatModel]:
  230. with get_db() as db:
  231. all_chats = (
  232. db.query(Chat)
  233. .filter(Chat.id.in_(chat_ids))
  234. .filter_by(archived=False)
  235. .order_by(Chat.updated_at.desc())
  236. .all()
  237. )
  238. return [ChatModel.model_validate(chat) for chat in all_chats]
  239. def get_chat_by_id(self, id: str) -> Optional[ChatModel]:
  240. try:
  241. with get_db() as db:
  242. chat = db.get(Chat, id)
  243. return ChatModel.model_validate(chat)
  244. except:
  245. return None
  246. def get_chat_by_share_id(self, id: str) -> Optional[ChatModel]:
  247. try:
  248. with get_db() as db:
  249. chat = db.query(Chat).filter_by(share_id=id).first()
  250. if chat:
  251. return self.get_chat_by_id(id)
  252. else:
  253. return None
  254. except Exception as e:
  255. return None
  256. def get_chat_by_id_and_user_id(self, id: str, user_id: str) -> Optional[ChatModel]:
  257. try:
  258. with get_db() as db:
  259. chat = db.query(Chat).filter_by(id=id, user_id=user_id).first()
  260. return ChatModel.model_validate(chat)
  261. except:
  262. return None
  263. def get_chats(self, skip: int = 0, limit: int = 50) -> List[ChatModel]:
  264. with get_db() as db:
  265. all_chats = (
  266. db.query(Chat)
  267. # .limit(limit).offset(skip)
  268. .order_by(Chat.updated_at.desc())
  269. )
  270. return [ChatModel.model_validate(chat) for chat in all_chats]
  271. def get_chats_by_user_id(self, user_id: str) -> List[ChatModel]:
  272. with get_db() as db:
  273. all_chats = (
  274. db.query(Chat)
  275. .filter_by(user_id=user_id)
  276. .order_by(Chat.updated_at.desc())
  277. )
  278. return [ChatModel.model_validate(chat) for chat in all_chats]
  279. def get_archived_chats_by_user_id(self, user_id: str) -> List[ChatModel]:
  280. with get_db() as db:
  281. all_chats = (
  282. db.query(Chat)
  283. .filter_by(user_id=user_id, archived=True)
  284. .order_by(Chat.updated_at.desc())
  285. )
  286. return [ChatModel.model_validate(chat) for chat in all_chats]
  287. def delete_chat_by_id(self, id: str) -> bool:
  288. try:
  289. with get_db() as db:
  290. db.query(Chat).filter_by(id=id).delete()
  291. db.commit()
  292. return True and self.delete_shared_chat_by_chat_id(id)
  293. except:
  294. return False
  295. def delete_chat_by_id_and_user_id(self, id: str, user_id: str) -> bool:
  296. try:
  297. with get_db() as db:
  298. db.query(Chat).filter_by(id=id, user_id=user_id).delete()
  299. db.commit()
  300. return True and self.delete_shared_chat_by_chat_id(id)
  301. except:
  302. return False
  303. def delete_chats_by_user_id(self, user_id: str) -> bool:
  304. try:
  305. with get_db() as db:
  306. self.delete_shared_chats_by_user_id(user_id)
  307. db.query(Chat).filter_by(user_id=user_id).delete()
  308. db.commit()
  309. return True
  310. except:
  311. return False
  312. def delete_shared_chats_by_user_id(self, user_id: str) -> bool:
  313. try:
  314. with get_db() as db:
  315. chats_by_user = db.query(Chat).filter_by(user_id=user_id).all()
  316. shared_chat_ids = [f"shared-{chat.id}" for chat in chats_by_user]
  317. db.query(Chat).filter(Chat.user_id.in_(shared_chat_ids)).delete()
  318. db.commit()
  319. return True
  320. except:
  321. return False
  322. Chats = ChatTable()