chats.py 11 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352
  1. from pydantic import BaseModel, ConfigDict
  2. from typing import List, Union, Optional
  3. import json
  4. import uuid
  5. import time
  6. from sqlalchemy import Column, String, BigInteger, Boolean
  7. from sqlalchemy.orm import Session
  8. from apps.webui.internal.db import Base, get_session
  9. ####################
  10. # Chat DB Schema
  11. ####################
  12. class Chat(Base):
  13. __tablename__ = "chat"
  14. id = Column(String, primary_key=True)
  15. user_id = Column(String)
  16. title = Column(String)
  17. chat = Column(String) # Save Chat JSON as Text
  18. created_at = Column(BigInteger)
  19. updated_at = Column(BigInteger)
  20. share_id = Column(String, unique=True, nullable=True)
  21. archived = Column(Boolean, default=False)
  22. class ChatModel(BaseModel):
  23. model_config = ConfigDict(from_attributes=True)
  24. id: str
  25. user_id: str
  26. title: str
  27. chat: str
  28. created_at: int # timestamp in epoch
  29. updated_at: int # timestamp in epoch
  30. share_id: Optional[str] = None
  31. archived: bool = False
  32. ####################
  33. # Forms
  34. ####################
  35. class ChatForm(BaseModel):
  36. chat: dict
  37. class ChatTitleForm(BaseModel):
  38. title: str
  39. class ChatResponse(BaseModel):
  40. id: str
  41. user_id: str
  42. title: str
  43. chat: dict
  44. updated_at: int # timestamp in epoch
  45. created_at: int # timestamp in epoch
  46. share_id: Optional[str] = None # id of the chat to be shared
  47. archived: bool
  48. class ChatTitleIdResponse(BaseModel):
  49. id: str
  50. title: str
  51. updated_at: int
  52. created_at: int
  53. class ChatTable:
  54. def insert_new_chat(
  55. self, user_id: str, form_data: ChatForm
  56. ) -> Optional[ChatModel]:
  57. with get_session() as db:
  58. id = str(uuid.uuid4())
  59. chat = ChatModel(
  60. **{
  61. "id": id,
  62. "user_id": user_id,
  63. "title": (
  64. form_data.chat["title"] if "title" in form_data.chat else "New Chat"
  65. ),
  66. "chat": json.dumps(form_data.chat),
  67. "created_at": int(time.time()),
  68. "updated_at": int(time.time()),
  69. }
  70. )
  71. result = Chat(**chat.model_dump())
  72. db.add(result)
  73. db.commit()
  74. db.refresh(result)
  75. return ChatModel.model_validate(result) if result else None
  76. def update_chat_by_id(
  77. self, id: str, chat: dict
  78. ) -> Optional[ChatModel]:
  79. with get_session() as db:
  80. try:
  81. chat_obj = db.get(Chat, id)
  82. chat_obj.chat = json.dumps(chat)
  83. chat_obj.title = chat["title"] if "title" in chat else "New Chat"
  84. chat_obj.updated_at = int(time.time())
  85. db.commit()
  86. db.refresh(chat_obj)
  87. return ChatModel.model_validate(chat_obj)
  88. except Exception as e:
  89. return None
  90. def insert_shared_chat_by_chat_id(
  91. self, chat_id: str
  92. ) -> Optional[ChatModel]:
  93. with get_session() as db:
  94. # Get the existing chat to share
  95. chat = db.get(Chat, chat_id)
  96. # Check if the chat is already shared
  97. if chat.share_id:
  98. return self.get_chat_by_id_and_user_id(chat.share_id, "shared")
  99. # Create a new chat with the same data, but with a new ID
  100. shared_chat = ChatModel(
  101. **{
  102. "id": str(uuid.uuid4()),
  103. "user_id": f"shared-{chat_id}",
  104. "title": chat.title,
  105. "chat": chat.chat,
  106. "created_at": chat.created_at,
  107. "updated_at": int(time.time()),
  108. }
  109. )
  110. shared_result = Chat(**shared_chat.model_dump())
  111. db.add(shared_result)
  112. db.commit()
  113. db.refresh(shared_result)
  114. # Update the original chat with the share_id
  115. result = (
  116. db.query(Chat).filter_by(id=chat_id).update({"share_id": shared_chat.id})
  117. )
  118. return shared_chat if (shared_result and result) else None
  119. def update_shared_chat_by_chat_id(
  120. self, chat_id: str
  121. ) -> Optional[ChatModel]:
  122. with get_session() as db:
  123. try:
  124. print("update_shared_chat_by_id")
  125. chat = db.get(Chat, chat_id)
  126. print(chat)
  127. chat.title = chat.title
  128. chat.chat = chat.chat
  129. db.commit()
  130. db.refresh(chat)
  131. return self.get_chat_by_id(chat.share_id)
  132. except:
  133. return None
  134. def delete_shared_chat_by_chat_id(self, chat_id: str) -> bool:
  135. try:
  136. with get_session() as db:
  137. db.query(Chat).filter_by(user_id=f"shared-{chat_id}").delete()
  138. return True
  139. except:
  140. return False
  141. def update_chat_share_id_by_id(
  142. self, id: str, share_id: Optional[str]
  143. ) -> Optional[ChatModel]:
  144. try:
  145. with get_session() as db:
  146. chat = db.get(Chat, id)
  147. chat.share_id = share_id
  148. db.commit()
  149. db.refresh(chat)
  150. return chat
  151. except:
  152. return None
  153. def toggle_chat_archive_by_id(self, id: str) -> Optional[ChatModel]:
  154. try:
  155. with get_session() as db:
  156. chat = self.get_chat_by_id(id)
  157. db.query(Chat).filter_by(id=id).update({"archived": not chat.archived})
  158. return self.get_chat_by_id(id)
  159. except:
  160. return None
  161. def archive_all_chats_by_user_id(self, user_id: str) -> bool:
  162. try:
  163. with get_session() as db:
  164. db.query(Chat).filter_by(user_id=user_id).update({"archived": True})
  165. return True
  166. except:
  167. return False
  168. def get_archived_chat_list_by_user_id(
  169. self, user_id: str, skip: int = 0, limit: int = 50
  170. ) -> List[ChatModel]:
  171. with get_session() as db:
  172. all_chats = (
  173. db.query(Chat)
  174. .filter_by(user_id=user_id, archived=True)
  175. .order_by(Chat.updated_at.desc())
  176. # .limit(limit).offset(skip)
  177. .all()
  178. )
  179. return [ChatModel.model_validate(chat) for chat in all_chats]
  180. def get_chat_list_by_user_id(
  181. self,
  182. user_id: str,
  183. include_archived: bool = False,
  184. skip: int = 0,
  185. limit: int = 50,
  186. ) -> List[ChatModel]:
  187. with get_session() as db:
  188. query = db.query(Chat).filter_by(user_id=user_id)
  189. if not include_archived:
  190. query = query.filter_by(archived=False)
  191. all_chats = (
  192. query.order_by(Chat.updated_at.desc())
  193. # .limit(limit).offset(skip)
  194. .all()
  195. )
  196. return [ChatModel.model_validate(chat) for chat in all_chats]
  197. def get_chat_list_by_chat_ids(
  198. self, chat_ids: List[str], skip: int = 0, limit: int = 50
  199. ) -> List[ChatModel]:
  200. with get_session() as db:
  201. all_chats = (
  202. db.query(Chat)
  203. .filter(Chat.id.in_(chat_ids))
  204. .filter_by(archived=False)
  205. .order_by(Chat.updated_at.desc())
  206. .all()
  207. )
  208. return [ChatModel.model_validate(chat) for chat in all_chats]
  209. def get_chat_by_id(self, id: str) -> Optional[ChatModel]:
  210. try:
  211. with get_session() as db:
  212. chat = db.get(Chat, id)
  213. return ChatModel.model_validate(chat)
  214. except:
  215. return None
  216. def get_chat_by_share_id(self, id: str) -> Optional[ChatModel]:
  217. try:
  218. with get_session() as db:
  219. chat = db.query(Chat).filter_by(share_id=id).first()
  220. if chat:
  221. return self.get_chat_by_id(id)
  222. else:
  223. return None
  224. except Exception as e:
  225. return None
  226. def get_chat_by_id_and_user_id(
  227. self, id: str, user_id: str
  228. ) -> Optional[ChatModel]:
  229. try:
  230. with get_session() as db:
  231. chat = db.query(Chat).filter_by(id=id, user_id=user_id).first()
  232. return ChatModel.model_validate(chat)
  233. except:
  234. return None
  235. def get_chats(self, skip: int = 0, limit: int = 50) -> List[ChatModel]:
  236. with get_session() as db:
  237. all_chats = (
  238. db.query(Chat)
  239. # .limit(limit).offset(skip)
  240. .order_by(Chat.updated_at.desc())
  241. )
  242. return [ChatModel.model_validate(chat) for chat in all_chats]
  243. def get_chats_by_user_id(self, user_id: str) -> List[ChatModel]:
  244. with get_session() as db:
  245. all_chats = (
  246. db.query(Chat).filter_by(user_id=user_id).order_by(Chat.updated_at.desc())
  247. )
  248. return [ChatModel.model_validate(chat) for chat in all_chats]
  249. def get_archived_chats_by_user_id(
  250. self, user_id: str
  251. ) -> List[ChatModel]:
  252. with get_session() as db:
  253. all_chats = (
  254. db.query(Chat)
  255. .filter_by(user_id=user_id, archived=True)
  256. .order_by(Chat.updated_at.desc())
  257. )
  258. return [ChatModel.model_validate(chat) for chat in all_chats]
  259. def delete_chat_by_id(self, id: str) -> bool:
  260. try:
  261. with get_session() as db:
  262. db.query(Chat).filter_by(id=id).delete()
  263. return True and self.delete_shared_chat_by_chat_id(id)
  264. except:
  265. return False
  266. def delete_chat_by_id_and_user_id(self, id: str, user_id: str) -> bool:
  267. try:
  268. with get_session() as db:
  269. db.query(Chat).filter_by(id=id, user_id=user_id).delete()
  270. return True and self.delete_shared_chat_by_chat_id(id)
  271. except:
  272. return False
  273. def delete_chats_by_user_id(self, user_id: str) -> bool:
  274. try:
  275. with get_session() as db:
  276. self.delete_shared_chats_by_user_id(user_id)
  277. db.query(Chat).filter_by(user_id=user_id).delete()
  278. return True
  279. except:
  280. return False
  281. def delete_shared_chats_by_user_id(self, user_id: str) -> bool:
  282. try:
  283. with get_session() as db:
  284. chats_by_user = db.query(Chat).filter_by(user_id=user_id).all()
  285. shared_chat_ids = [f"shared-{chat.id}" for chat in chats_by_user]
  286. db.query(Chat).filter(Chat.user_id.in_(shared_chat_ids)).delete()
  287. return True
  288. except:
  289. return False
  290. Chats = ChatTable()