chats.py 9.5 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322
  1. from pydantic import BaseModel, ConfigDict
  2. from typing import List, Union, Optional
  3. import json
  4. import uuid
  5. import time
  6. from sqlalchemy import Column, String, BigInteger, Boolean, Text
  7. from apps.webui.internal.db import Base, Session
  8. ####################
  9. # Chat DB Schema
  10. ####################
  11. class Chat(Base):
  12. __tablename__ = "chat"
  13. id = Column(String, primary_key=True)
  14. user_id = Column(String)
  15. title = Column(Text)
  16. chat = Column(Text) # Save Chat JSON as Text
  17. created_at = Column(BigInteger)
  18. updated_at = Column(BigInteger)
  19. share_id = Column(Text, unique=True, nullable=True)
  20. archived = Column(Boolean, default=False)
  21. class ChatModel(BaseModel):
  22. model_config = ConfigDict(from_attributes=True)
  23. id: str
  24. user_id: str
  25. title: str
  26. chat: str
  27. created_at: int # timestamp in epoch
  28. updated_at: int # timestamp in epoch
  29. share_id: Optional[str] = None
  30. archived: bool = False
  31. ####################
  32. # Forms
  33. ####################
  34. class ChatForm(BaseModel):
  35. chat: dict
  36. class ChatTitleForm(BaseModel):
  37. title: str
  38. class ChatResponse(BaseModel):
  39. id: str
  40. user_id: str
  41. title: str
  42. chat: dict
  43. updated_at: int # timestamp in epoch
  44. created_at: int # timestamp in epoch
  45. share_id: Optional[str] = None # id of the chat to be shared
  46. archived: bool
  47. class ChatTitleIdResponse(BaseModel):
  48. id: str
  49. title: str
  50. updated_at: int
  51. created_at: int
  52. class ChatTable:
  53. def insert_new_chat(self, user_id: str, form_data: ChatForm) -> Optional[ChatModel]:
  54. id = str(uuid.uuid4())
  55. chat = ChatModel(
  56. **{
  57. "id": id,
  58. "user_id": user_id,
  59. "title": (
  60. form_data.chat["title"] if "title" in form_data.chat else "New Chat"
  61. ),
  62. "chat": json.dumps(form_data.chat),
  63. "created_at": int(time.time()),
  64. "updated_at": int(time.time()),
  65. }
  66. )
  67. result = Chat(**chat.model_dump())
  68. Session.add(result)
  69. Session.commit()
  70. Session.refresh(result)
  71. return ChatModel.model_validate(result) if result else None
  72. def update_chat_by_id(self, id: str, chat: dict) -> Optional[ChatModel]:
  73. try:
  74. chat_obj = Session.get(Chat, id)
  75. chat_obj.chat = json.dumps(chat)
  76. chat_obj.title = chat["title"] if "title" in chat else "New Chat"
  77. chat_obj.updated_at = int(time.time())
  78. Session.commit()
  79. Session.refresh(chat_obj)
  80. return ChatModel.model_validate(chat_obj)
  81. except Exception as e:
  82. return None
  83. def insert_shared_chat_by_chat_id(self, chat_id: str) -> Optional[ChatModel]:
  84. # Get the existing chat to share
  85. chat = Session.get(Chat, chat_id)
  86. # Check if the chat is already shared
  87. if chat.share_id:
  88. return self.get_chat_by_id_and_user_id(chat.share_id, "shared")
  89. # Create a new chat with the same data, but with a new ID
  90. shared_chat = ChatModel(
  91. **{
  92. "id": str(uuid.uuid4()),
  93. "user_id": f"shared-{chat_id}",
  94. "title": chat.title,
  95. "chat": chat.chat,
  96. "created_at": chat.created_at,
  97. "updated_at": int(time.time()),
  98. }
  99. )
  100. shared_result = Chat(**shared_chat.model_dump())
  101. Session.add(shared_result)
  102. Session.commit()
  103. Session.refresh(shared_result)
  104. # Update the original chat with the share_id
  105. result = (
  106. Session.query(Chat)
  107. .filter_by(id=chat_id)
  108. .update({"share_id": shared_chat.id})
  109. )
  110. return shared_chat if (shared_result and result) else None
  111. def update_shared_chat_by_chat_id(self, chat_id: str) -> Optional[ChatModel]:
  112. try:
  113. print("update_shared_chat_by_id")
  114. chat = Session.get(Chat, chat_id)
  115. print(chat)
  116. chat.title = chat.title
  117. chat.chat = chat.chat
  118. Session.commit()
  119. Session.refresh(chat)
  120. return self.get_chat_by_id(chat.share_id)
  121. except:
  122. return None
  123. def delete_shared_chat_by_chat_id(self, chat_id: str) -> bool:
  124. try:
  125. Session.query(Chat).filter_by(user_id=f"shared-{chat_id}").delete()
  126. return True
  127. except:
  128. return False
  129. def update_chat_share_id_by_id(
  130. self, id: str, share_id: Optional[str]
  131. ) -> Optional[ChatModel]:
  132. try:
  133. chat = Session.get(Chat, id)
  134. chat.share_id = share_id
  135. Session.commit()
  136. Session.refresh(chat)
  137. return ChatModel.model_validate(chat)
  138. except:
  139. return None
  140. def toggle_chat_archive_by_id(self, id: str) -> Optional[ChatModel]:
  141. try:
  142. chat = Session.get(Chat, id)
  143. chat.archived = not chat.archived
  144. Session.commit()
  145. Session.refresh(chat)
  146. return ChatModel.model_validate(chat)
  147. except:
  148. return None
  149. def archive_all_chats_by_user_id(self, user_id: str) -> bool:
  150. try:
  151. Session.query(Chat).filter_by(user_id=user_id).update({"archived": True})
  152. return True
  153. except:
  154. return False
  155. def get_archived_chat_list_by_user_id(
  156. self, user_id: str, skip: int = 0, limit: int = 50
  157. ) -> List[ChatModel]:
  158. all_chats = (
  159. Session.query(Chat)
  160. .filter_by(user_id=user_id, archived=True)
  161. .order_by(Chat.updated_at.desc())
  162. # .limit(limit).offset(skip)
  163. .all()
  164. )
  165. return [ChatModel.model_validate(chat) for chat in all_chats]
  166. def get_chat_list_by_user_id(
  167. self,
  168. user_id: str,
  169. include_archived: bool = False,
  170. skip: int = 0,
  171. limit: int = 50,
  172. ) -> List[ChatModel]:
  173. query = Session.query(Chat).filter_by(user_id=user_id)
  174. if not include_archived:
  175. query = query.filter_by(archived=False)
  176. all_chats = (
  177. query.order_by(Chat.updated_at.desc())
  178. # .limit(limit).offset(skip)
  179. .all()
  180. )
  181. return [ChatModel.model_validate(chat) for chat in all_chats]
  182. def get_chat_list_by_chat_ids(
  183. self, chat_ids: List[str], skip: int = 0, limit: int = 50
  184. ) -> List[ChatModel]:
  185. all_chats = (
  186. Session.query(Chat)
  187. .filter(Chat.id.in_(chat_ids))
  188. .filter_by(archived=False)
  189. .order_by(Chat.updated_at.desc())
  190. .all()
  191. )
  192. return [ChatModel.model_validate(chat) for chat in all_chats]
  193. def get_chat_by_id(self, id: str) -> Optional[ChatModel]:
  194. try:
  195. chat = Session.get(Chat, id)
  196. return ChatModel.model_validate(chat)
  197. except:
  198. return None
  199. def get_chat_by_share_id(self, id: str) -> Optional[ChatModel]:
  200. try:
  201. chat = Session.query(Chat).filter_by(share_id=id).first()
  202. if chat:
  203. return self.get_chat_by_id(id)
  204. else:
  205. return None
  206. except Exception as e:
  207. return None
  208. def get_chat_by_id_and_user_id(self, id: str, user_id: str) -> Optional[ChatModel]:
  209. try:
  210. chat = Session.query(Chat).filter_by(id=id, user_id=user_id).first()
  211. return ChatModel.model_validate(chat)
  212. except:
  213. return None
  214. def get_chats(self, skip: int = 0, limit: int = 50) -> List[ChatModel]:
  215. all_chats = (
  216. Session.query(Chat)
  217. # .limit(limit).offset(skip)
  218. .order_by(Chat.updated_at.desc())
  219. )
  220. return [ChatModel.model_validate(chat) for chat in all_chats]
  221. def get_chats_by_user_id(self, user_id: str) -> List[ChatModel]:
  222. all_chats = (
  223. Session.query(Chat)
  224. .filter_by(user_id=user_id)
  225. .order_by(Chat.updated_at.desc())
  226. )
  227. return [ChatModel.model_validate(chat) for chat in all_chats]
  228. def get_archived_chats_by_user_id(self, user_id: str) -> List[ChatModel]:
  229. all_chats = (
  230. Session.query(Chat)
  231. .filter_by(user_id=user_id, archived=True)
  232. .order_by(Chat.updated_at.desc())
  233. )
  234. return [ChatModel.model_validate(chat) for chat in all_chats]
  235. def delete_chat_by_id(self, id: str) -> bool:
  236. try:
  237. Session.query(Chat).filter_by(id=id).delete()
  238. return True and self.delete_shared_chat_by_chat_id(id)
  239. except:
  240. return False
  241. def delete_chat_by_id_and_user_id(self, id: str, user_id: str) -> bool:
  242. try:
  243. Session.query(Chat).filter_by(id=id, user_id=user_id).delete()
  244. return True and self.delete_shared_chat_by_chat_id(id)
  245. except:
  246. return False
  247. def delete_chats_by_user_id(self, user_id: str) -> bool:
  248. try:
  249. self.delete_shared_chats_by_user_id(user_id)
  250. Session.query(Chat).filter_by(user_id=user_id).delete()
  251. return True
  252. except:
  253. return False
  254. def delete_shared_chats_by_user_id(self, user_id: str) -> bool:
  255. try:
  256. chats_by_user = Session.query(Chat).filter_by(user_id=user_id).all()
  257. shared_chat_ids = [f"shared-{chat.id}" for chat in chats_by_user]
  258. Session.query(Chat).filter(Chat.user_id.in_(shared_chat_ids)).delete()
  259. return True
  260. except:
  261. return False
  262. Chats = ChatTable()