chats.py 11 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368
  1. from pydantic import BaseModel, ConfigDict
  2. from typing import List, Union, Optional
  3. import json
  4. import uuid
  5. import time
  6. from sqlalchemy import Column, String, BigInteger, Boolean, Text
  7. from apps.webui.internal.db import Base, get_db
  8. ####################
  9. # Chat DB Schema
  10. ####################
  11. class Chat(Base):
  12. __tablename__ = "chat"
  13. id = Column(String, primary_key=True)
  14. user_id = Column(String)
  15. title = Column(Text)
  16. chat = Column(Text) # Save Chat JSON as Text
  17. created_at = Column(BigInteger)
  18. updated_at = Column(BigInteger)
  19. share_id = Column(Text, unique=True, nullable=True)
  20. archived = Column(Boolean, default=False)
  21. class ChatModel(BaseModel):
  22. model_config = ConfigDict(from_attributes=True)
  23. id: str
  24. user_id: str
  25. title: str
  26. chat: str
  27. created_at: int # timestamp in epoch
  28. updated_at: int # timestamp in epoch
  29. share_id: Optional[str] = None
  30. archived: bool = False
  31. ####################
  32. # Forms
  33. ####################
  34. class ChatForm(BaseModel):
  35. chat: dict
  36. class ChatTitleForm(BaseModel):
  37. title: str
  38. class ChatResponse(BaseModel):
  39. id: str
  40. user_id: str
  41. title: str
  42. chat: dict
  43. updated_at: int # timestamp in epoch
  44. created_at: int # timestamp in epoch
  45. share_id: Optional[str] = None # id of the chat to be shared
  46. archived: bool
  47. class ChatTitleIdResponse(BaseModel):
  48. id: str
  49. title: str
  50. updated_at: int
  51. created_at: int
  52. class ChatTable:
  53. def insert_new_chat(self, user_id: str, form_data: ChatForm) -> Optional[ChatModel]:
  54. with get_db() as db:
  55. id = str(uuid.uuid4())
  56. chat = ChatModel(
  57. **{
  58. "id": id,
  59. "user_id": user_id,
  60. "title": (
  61. form_data.chat["title"]
  62. if "title" in form_data.chat
  63. else "New Chat"
  64. ),
  65. "chat": json.dumps(form_data.chat),
  66. "created_at": int(time.time()),
  67. "updated_at": int(time.time()),
  68. }
  69. )
  70. result = Chat(**chat.model_dump())
  71. db.add(result)
  72. db.commit()
  73. db.refresh(result)
  74. return ChatModel.model_validate(result) if result else None
  75. def update_chat_by_id(self, id: str, chat: dict) -> Optional[ChatModel]:
  76. try:
  77. with get_db() as db:
  78. chat_obj = db.get(Chat, id)
  79. chat_obj.chat = json.dumps(chat)
  80. chat_obj.title = chat["title"] if "title" in chat else "New Chat"
  81. chat_obj.updated_at = int(time.time())
  82. db.commit()
  83. db.refresh(chat_obj)
  84. return ChatModel.model_validate(chat_obj)
  85. except Exception as e:
  86. return None
  87. def insert_shared_chat_by_chat_id(self, chat_id: str) -> Optional[ChatModel]:
  88. with get_db() as db:
  89. # Get the existing chat to share
  90. chat = db.get(Chat, chat_id)
  91. # Check if the chat is already shared
  92. if chat.share_id:
  93. return self.get_chat_by_id_and_user_id(chat.share_id, "shared")
  94. # Create a new chat with the same data, but with a new ID
  95. shared_chat = ChatModel(
  96. **{
  97. "id": str(uuid.uuid4()),
  98. "user_id": f"shared-{chat_id}",
  99. "title": chat.title,
  100. "chat": chat.chat,
  101. "created_at": chat.created_at,
  102. "updated_at": int(time.time()),
  103. }
  104. )
  105. shared_result = Chat(**shared_chat.model_dump())
  106. db.add(shared_result)
  107. db.commit()
  108. db.refresh(shared_result)
  109. # Update the original chat with the share_id
  110. result = (
  111. db.query(Chat)
  112. .filter_by(id=chat_id)
  113. .update({"share_id": shared_chat.id})
  114. )
  115. return shared_chat if (shared_result and result) else None
  116. def update_shared_chat_by_chat_id(self, chat_id: str) -> Optional[ChatModel]:
  117. try:
  118. with get_db() as db:
  119. print("update_shared_chat_by_id")
  120. chat = db.get(Chat, chat_id)
  121. print(chat)
  122. chat.title = chat.title
  123. chat.chat = chat.chat
  124. db.commit()
  125. db.refresh(chat)
  126. return self.get_chat_by_id(chat.share_id)
  127. except:
  128. return None
  129. def delete_shared_chat_by_chat_id(self, chat_id: str) -> bool:
  130. try:
  131. with get_db() as db:
  132. db.query(Chat).filter_by(user_id=f"shared-{chat_id}").delete()
  133. return True
  134. except:
  135. return False
  136. def update_chat_share_id_by_id(
  137. self, id: str, share_id: Optional[str]
  138. ) -> Optional[ChatModel]:
  139. try:
  140. with get_db() as db:
  141. chat = db.get(Chat, id)
  142. chat.share_id = share_id
  143. db.commit()
  144. db.refresh(chat)
  145. return ChatModel.model_validate(chat)
  146. except:
  147. return None
  148. def toggle_chat_archive_by_id(self, id: str) -> Optional[ChatModel]:
  149. try:
  150. with get_db() as db:
  151. chat = db.get(Chat, id)
  152. chat.archived = not chat.archived
  153. db.commit()
  154. db.refresh(chat)
  155. return ChatModel.model_validate(chat)
  156. except:
  157. return None
  158. def archive_all_chats_by_user_id(self, user_id: str) -> bool:
  159. try:
  160. with get_db() as db:
  161. db.query(Chat).filter_by(user_id=user_id).update({"archived": True})
  162. return True
  163. except:
  164. return False
  165. def get_archived_chat_list_by_user_id(
  166. self, user_id: str, skip: int = 0, limit: int = 50
  167. ) -> List[ChatModel]:
  168. with get_db() as db:
  169. all_chats = (
  170. db.query(Chat)
  171. .filter_by(user_id=user_id, archived=True)
  172. .order_by(Chat.updated_at.desc())
  173. # .limit(limit).offset(skip)
  174. .all()
  175. )
  176. return [ChatModel.model_validate(chat) for chat in all_chats]
  177. def get_chat_list_by_user_id(
  178. self,
  179. user_id: str,
  180. include_archived: bool = False,
  181. skip: int = 0,
  182. limit: int = 50,
  183. ) -> List[ChatModel]:
  184. with get_db() as db:
  185. query = db.query(Chat).filter_by(user_id=user_id)
  186. if not include_archived:
  187. query = query.filter_by(archived=False)
  188. all_chats = (
  189. query.order_by(Chat.updated_at.desc())
  190. # .limit(limit).offset(skip)
  191. .all()
  192. )
  193. return [ChatModel.model_validate(chat) for chat in all_chats]
  194. def get_chat_list_by_chat_ids(
  195. self, chat_ids: List[str], skip: int = 0, limit: int = 50
  196. ) -> List[ChatModel]:
  197. with get_db() as db:
  198. all_chats = (
  199. db.query(Chat)
  200. .filter(Chat.id.in_(chat_ids))
  201. .filter_by(archived=False)
  202. .order_by(Chat.updated_at.desc())
  203. .all()
  204. )
  205. return [ChatModel.model_validate(chat) for chat in all_chats]
  206. def get_chat_by_id(self, id: str) -> Optional[ChatModel]:
  207. try:
  208. with get_db() as db:
  209. chat = db.get(Chat, id)
  210. return ChatModel.model_validate(chat)
  211. except:
  212. return None
  213. def get_chat_by_share_id(self, id: str) -> Optional[ChatModel]:
  214. try:
  215. with get_db() as db:
  216. chat = db.query(Chat).filter_by(share_id=id).first()
  217. if chat:
  218. return self.get_chat_by_id(id)
  219. else:
  220. return None
  221. except Exception as e:
  222. return None
  223. def get_chat_by_id_and_user_id(self, id: str, user_id: str) -> Optional[ChatModel]:
  224. try:
  225. with get_db() as db:
  226. chat = db.query(Chat).filter_by(id=id, user_id=user_id).first()
  227. return ChatModel.model_validate(chat)
  228. except:
  229. return None
  230. def get_chats(self, skip: int = 0, limit: int = 50) -> List[ChatModel]:
  231. with get_db() as db:
  232. all_chats = (
  233. db.query(Chat)
  234. # .limit(limit).offset(skip)
  235. .order_by(Chat.updated_at.desc())
  236. )
  237. return [ChatModel.model_validate(chat) for chat in all_chats]
  238. def get_chats_by_user_id(self, user_id: str) -> List[ChatModel]:
  239. with get_db() as db:
  240. all_chats = (
  241. db.query(Chat)
  242. .filter_by(user_id=user_id)
  243. .order_by(Chat.updated_at.desc())
  244. )
  245. return [ChatModel.model_validate(chat) for chat in all_chats]
  246. def get_archived_chats_by_user_id(self, user_id: str) -> List[ChatModel]:
  247. with get_db() as db:
  248. all_chats = (
  249. db.query(Chat)
  250. .filter_by(user_id=user_id, archived=True)
  251. .order_by(Chat.updated_at.desc())
  252. )
  253. return [ChatModel.model_validate(chat) for chat in all_chats]
  254. def delete_chat_by_id(self, id: str) -> bool:
  255. try:
  256. with get_db() as db:
  257. db.query(Chat).filter_by(id=id).delete()
  258. return True and self.delete_shared_chat_by_chat_id(id)
  259. except:
  260. return False
  261. def delete_chat_by_id_and_user_id(self, id: str, user_id: str) -> bool:
  262. try:
  263. with get_db() as db:
  264. db.query(Chat).filter_by(id=id, user_id=user_id).delete()
  265. return True and self.delete_shared_chat_by_chat_id(id)
  266. except:
  267. return False
  268. def delete_chats_by_user_id(self, user_id: str) -> bool:
  269. try:
  270. with get_db() as db:
  271. self.delete_shared_chats_by_user_id(user_id)
  272. db.query(Chat).filter_by(user_id=user_id).delete()
  273. return True
  274. except:
  275. return False
  276. def delete_shared_chats_by_user_id(self, user_id: str) -> bool:
  277. try:
  278. with get_db() as db:
  279. chats_by_user = db.query(Chat).filter_by(user_id=user_id).all()
  280. shared_chat_ids = [f"shared-{chat.id}" for chat in chats_by_user]
  281. db.query(Chat).filter(Chat.user_id.in_(shared_chat_ids)).delete()
  282. return True
  283. except:
  284. return False
  285. Chats = ChatTable()