chats.py 9.5 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324
  1. from pydantic import BaseModel, ConfigDict
  2. from typing import List, Union, Optional
  3. import json
  4. import uuid
  5. import time
  6. from sqlalchemy import Column, String, BigInteger, Boolean
  7. from apps.webui.internal.db import Base, Session
  8. ####################
  9. # Chat DB Schema
  10. ####################
  11. class Chat(Base):
  12. __tablename__ = "chat"
  13. id = Column(String, primary_key=True)
  14. user_id = Column(String)
  15. title = Column(String)
  16. chat = Column(String) # Save Chat JSON as Text
  17. created_at = Column(BigInteger)
  18. updated_at = Column(BigInteger)
  19. share_id = Column(String, unique=True, nullable=True)
  20. archived = Column(Boolean, default=False)
  21. class ChatModel(BaseModel):
  22. model_config = ConfigDict(from_attributes=True)
  23. id: str
  24. user_id: str
  25. title: str
  26. chat: str
  27. created_at: int # timestamp in epoch
  28. updated_at: int # timestamp in epoch
  29. share_id: Optional[str] = None
  30. archived: bool = False
  31. ####################
  32. # Forms
  33. ####################
  34. class ChatForm(BaseModel):
  35. chat: dict
  36. class ChatTitleForm(BaseModel):
  37. title: str
  38. class ChatResponse(BaseModel):
  39. id: str
  40. user_id: str
  41. title: str
  42. chat: dict
  43. updated_at: int # timestamp in epoch
  44. created_at: int # timestamp in epoch
  45. share_id: Optional[str] = None # id of the chat to be shared
  46. archived: bool
  47. class ChatTitleIdResponse(BaseModel):
  48. id: str
  49. title: str
  50. updated_at: int
  51. created_at: int
  52. class ChatTable:
  53. def insert_new_chat(self, user_id: str, form_data: ChatForm) -> Optional[ChatModel]:
  54. id = str(uuid.uuid4())
  55. chat = ChatModel(
  56. **{
  57. "id": id,
  58. "user_id": user_id,
  59. "title": (
  60. form_data.chat["title"]
  61. if "title" in form_data.chat
  62. else "New Chat"
  63. ),
  64. "chat": json.dumps(form_data.chat),
  65. "created_at": int(time.time()),
  66. "updated_at": int(time.time()),
  67. }
  68. )
  69. result = Chat(**chat.model_dump())
  70. Session.add(result)
  71. Session.commit()
  72. Session.refresh(result)
  73. return ChatModel.model_validate(result) if result else None
  74. def update_chat_by_id(self, id: str, chat: dict) -> Optional[ChatModel]:
  75. try:
  76. chat_obj = Session.get(Chat, id)
  77. chat_obj.chat = json.dumps(chat)
  78. chat_obj.title = chat["title"] if "title" in chat else "New Chat"
  79. chat_obj.updated_at = int(time.time())
  80. Session.commit()
  81. Session.refresh(chat_obj)
  82. return ChatModel.model_validate(chat_obj)
  83. except Exception as e:
  84. return None
  85. def insert_shared_chat_by_chat_id(self, chat_id: str) -> Optional[ChatModel]:
  86. # Get the existing chat to share
  87. chat = Session.get(Chat, chat_id)
  88. # Check if the chat is already shared
  89. if chat.share_id:
  90. return self.get_chat_by_id_and_user_id(chat.share_id, "shared")
  91. # Create a new chat with the same data, but with a new ID
  92. shared_chat = ChatModel(
  93. **{
  94. "id": str(uuid.uuid4()),
  95. "user_id": f"shared-{chat_id}",
  96. "title": chat.title,
  97. "chat": chat.chat,
  98. "created_at": chat.created_at,
  99. "updated_at": int(time.time()),
  100. }
  101. )
  102. shared_result = Chat(**shared_chat.model_dump())
  103. Session.add(shared_result)
  104. Session.commit()
  105. Session.refresh(shared_result)
  106. # Update the original chat with the share_id
  107. result = (
  108. Session.query(Chat)
  109. .filter_by(id=chat_id)
  110. .update({"share_id": shared_chat.id})
  111. )
  112. return shared_chat if (shared_result and result) else None
  113. def update_shared_chat_by_chat_id(self, chat_id: str) -> Optional[ChatModel]:
  114. try:
  115. print("update_shared_chat_by_id")
  116. chat = Session.get(Chat, chat_id)
  117. print(chat)
  118. chat.title = chat.title
  119. chat.chat = chat.chat
  120. Session.commit()
  121. Session.refresh(chat)
  122. return self.get_chat_by_id(chat.share_id)
  123. except:
  124. return None
  125. def delete_shared_chat_by_chat_id(self, chat_id: str) -> bool:
  126. try:
  127. Session.query(Chat).filter_by(user_id=f"shared-{chat_id}").delete()
  128. return True
  129. except:
  130. return False
  131. def update_chat_share_id_by_id(
  132. self, id: str, share_id: Optional[str]
  133. ) -> Optional[ChatModel]:
  134. try:
  135. chat = Session.get(Chat, id)
  136. chat.share_id = share_id
  137. Session.commit()
  138. Session.refresh(chat)
  139. return ChatModel.model_validate(chat)
  140. except:
  141. return None
  142. def toggle_chat_archive_by_id(self, id: str) -> Optional[ChatModel]:
  143. try:
  144. chat = Session.get(Chat, id)
  145. chat.archived = not chat.archived
  146. Session.commit()
  147. Session.refresh(chat)
  148. return ChatModel.model_validate(chat)
  149. except:
  150. return None
  151. def archive_all_chats_by_user_id(self, user_id: str) -> bool:
  152. try:
  153. Session.query(Chat).filter_by(user_id=user_id).update({"archived": True})
  154. return True
  155. except:
  156. return False
  157. def get_archived_chat_list_by_user_id(
  158. self, user_id: str, skip: int = 0, limit: int = 50
  159. ) -> List[ChatModel]:
  160. all_chats = (
  161. Session.query(Chat)
  162. .filter_by(user_id=user_id, archived=True)
  163. .order_by(Chat.updated_at.desc())
  164. # .limit(limit).offset(skip)
  165. .all()
  166. )
  167. return [ChatModel.model_validate(chat) for chat in all_chats]
  168. def get_chat_list_by_user_id(
  169. self,
  170. user_id: str,
  171. include_archived: bool = False,
  172. skip: int = 0,
  173. limit: int = 50,
  174. ) -> List[ChatModel]:
  175. query = Session.query(Chat).filter_by(user_id=user_id)
  176. if not include_archived:
  177. query = query.filter_by(archived=False)
  178. all_chats = (
  179. query.order_by(Chat.updated_at.desc())
  180. # .limit(limit).offset(skip)
  181. .all()
  182. )
  183. return [ChatModel.model_validate(chat) for chat in all_chats]
  184. def get_chat_list_by_chat_ids(
  185. self, chat_ids: List[str], skip: int = 0, limit: int = 50
  186. ) -> List[ChatModel]:
  187. all_chats = (
  188. Session.query(Chat)
  189. .filter(Chat.id.in_(chat_ids))
  190. .filter_by(archived=False)
  191. .order_by(Chat.updated_at.desc())
  192. .all()
  193. )
  194. return [ChatModel.model_validate(chat) for chat in all_chats]
  195. def get_chat_by_id(self, id: str) -> Optional[ChatModel]:
  196. try:
  197. chat = Session.get(Chat, id)
  198. return ChatModel.model_validate(chat)
  199. except:
  200. return None
  201. def get_chat_by_share_id(self, id: str) -> Optional[ChatModel]:
  202. try:
  203. chat = Session.query(Chat).filter_by(share_id=id).first()
  204. if chat:
  205. return self.get_chat_by_id(id)
  206. else:
  207. return None
  208. except Exception as e:
  209. return None
  210. def get_chat_by_id_and_user_id(self, id: str, user_id: str) -> Optional[ChatModel]:
  211. try:
  212. chat = Session.query(Chat).filter_by(id=id, user_id=user_id).first()
  213. return ChatModel.model_validate(chat)
  214. except:
  215. return None
  216. def get_chats(self, skip: int = 0, limit: int = 50) -> List[ChatModel]:
  217. all_chats = (
  218. Session.query(Chat)
  219. # .limit(limit).offset(skip)
  220. .order_by(Chat.updated_at.desc())
  221. )
  222. return [ChatModel.model_validate(chat) for chat in all_chats]
  223. def get_chats_by_user_id(self, user_id: str) -> List[ChatModel]:
  224. all_chats = (
  225. Session.query(Chat)
  226. .filter_by(user_id=user_id)
  227. .order_by(Chat.updated_at.desc())
  228. )
  229. return [ChatModel.model_validate(chat) for chat in all_chats]
  230. def get_archived_chats_by_user_id(self, user_id: str) -> List[ChatModel]:
  231. all_chats = (
  232. Session.query(Chat)
  233. .filter_by(user_id=user_id, archived=True)
  234. .order_by(Chat.updated_at.desc())
  235. )
  236. return [ChatModel.model_validate(chat) for chat in all_chats]
  237. def delete_chat_by_id(self, id: str) -> bool:
  238. try:
  239. Session.query(Chat).filter_by(id=id).delete()
  240. return True and self.delete_shared_chat_by_chat_id(id)
  241. except:
  242. return False
  243. def delete_chat_by_id_and_user_id(self, id: str, user_id: str) -> bool:
  244. try:
  245. Session.query(Chat).filter_by(id=id, user_id=user_id).delete()
  246. return True and self.delete_shared_chat_by_chat_id(id)
  247. except:
  248. return False
  249. def delete_chats_by_user_id(self, user_id: str) -> bool:
  250. try:
  251. self.delete_shared_chats_by_user_id(user_id)
  252. Session.query(Chat).filter_by(user_id=user_id).delete()
  253. return True
  254. except:
  255. return False
  256. def delete_shared_chats_by_user_id(self, user_id: str) -> bool:
  257. try:
  258. chats_by_user = Session.query(Chat).filter_by(user_id=user_id).all()
  259. shared_chat_ids = [f"shared-{chat.id}" for chat in chats_by_user]
  260. Session.query(Chat).filter(Chat.user_id.in_(shared_chat_ids)).delete()
  261. return True
  262. except:
  263. return False
  264. Chats = ChatTable()