chats.py 9.6 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332
  1. from pydantic import BaseModel, ConfigDict
  2. from typing import List, Union, Optional
  3. import json
  4. import uuid
  5. import time
  6. from sqlalchemy import Column, String, BigInteger, Boolean
  7. from sqlalchemy.orm import Session
  8. from apps.webui.internal.db import Base
  9. ####################
  10. # Chat DB Schema
  11. ####################
  12. class Chat(Base):
  13. __tablename__ = "chat"
  14. id = Column(String, primary_key=True)
  15. user_id = Column(String)
  16. title = Column(String)
  17. chat = Column(String) # Save Chat JSON as Text
  18. created_at = Column(BigInteger)
  19. updated_at = Column(BigInteger)
  20. share_id = Column(String, unique=True, nullable=True)
  21. archived = Column(Boolean, default=False)
  22. class ChatModel(BaseModel):
  23. model_config = ConfigDict(from_attributes=True)
  24. id: str
  25. user_id: str
  26. title: str
  27. chat: str
  28. created_at: int # timestamp in epoch
  29. updated_at: int # timestamp in epoch
  30. share_id: Optional[str] = None
  31. archived: bool = False
  32. ####################
  33. # Forms
  34. ####################
  35. class ChatForm(BaseModel):
  36. chat: dict
  37. class ChatTitleForm(BaseModel):
  38. title: str
  39. class ChatResponse(BaseModel):
  40. id: str
  41. user_id: str
  42. title: str
  43. chat: dict
  44. updated_at: int # timestamp in epoch
  45. created_at: int # timestamp in epoch
  46. share_id: Optional[str] = None # id of the chat to be shared
  47. archived: bool
  48. class ChatTitleIdResponse(BaseModel):
  49. id: str
  50. title: str
  51. updated_at: int
  52. created_at: int
  53. class ChatTable:
  54. def insert_new_chat(
  55. self, db: Session, user_id: str, form_data: ChatForm
  56. ) -> Optional[ChatModel]:
  57. id = str(uuid.uuid4())
  58. chat = ChatModel(
  59. **{
  60. "id": id,
  61. "user_id": user_id,
  62. "title": (
  63. form_data.chat["title"] if "title" in form_data.chat else "New Chat"
  64. ),
  65. "chat": json.dumps(form_data.chat),
  66. "created_at": int(time.time()),
  67. "updated_at": int(time.time()),
  68. }
  69. )
  70. result = Chat(**chat.model_dump())
  71. db.add(result)
  72. db.commit()
  73. db.refresh(result)
  74. return ChatModel.model_validate(result) if result else None
  75. def update_chat_by_id(
  76. self, db: Session, id: str, chat: dict
  77. ) -> Optional[ChatModel]:
  78. try:
  79. db.query(Chat).filter_by(id=id).update(
  80. {
  81. "chat": json.dumps(chat),
  82. "title": chat["title"] if "title" in chat else "New Chat",
  83. "updated_at": int(time.time()),
  84. }
  85. )
  86. return self.get_chat_by_id(db, id)
  87. except:
  88. return None
  89. def insert_shared_chat_by_chat_id(
  90. self, db: Session, chat_id: str
  91. ) -> Optional[ChatModel]:
  92. # Get the existing chat to share
  93. chat = db.get(Chat, chat_id)
  94. # Check if the chat is already shared
  95. if chat.share_id:
  96. return self.get_chat_by_id_and_user_id(db, chat.share_id, "shared")
  97. # Create a new chat with the same data, but with a new ID
  98. shared_chat = ChatModel(
  99. **{
  100. "id": str(uuid.uuid4()),
  101. "user_id": f"shared-{chat_id}",
  102. "title": chat.title,
  103. "chat": chat.chat,
  104. "created_at": chat.created_at,
  105. "updated_at": int(time.time()),
  106. }
  107. )
  108. shared_result = Chat(**shared_chat.model_dump())
  109. db.add(shared_result)
  110. db.commit()
  111. db.refresh(shared_result)
  112. # Update the original chat with the share_id
  113. result = (
  114. db.query(Chat).filter_by(id=chat_id).update({"share_id": shared_chat.id})
  115. )
  116. return shared_chat if (shared_result and result) else None
  117. def update_shared_chat_by_chat_id(
  118. self, db: Session, chat_id: str
  119. ) -> Optional[ChatModel]:
  120. try:
  121. print("update_shared_chat_by_id")
  122. chat = db.get(Chat, chat_id)
  123. print(chat)
  124. db.query(Chat).filter_by(id=chat.share_id).update(
  125. {"title": chat.title, "chat": chat.chat}
  126. )
  127. return self.get_chat_by_id(db, chat.share_id)
  128. except:
  129. return None
  130. def delete_shared_chat_by_chat_id(self, db: Session, chat_id: str) -> bool:
  131. try:
  132. db.query(Chat).filter_by(user_id=f"shared-{chat_id}").delete()
  133. return True
  134. except:
  135. return False
  136. def update_chat_share_id_by_id(
  137. self, db: Session, id: str, share_id: Optional[str]
  138. ) -> Optional[ChatModel]:
  139. try:
  140. db.query(Chat).filter_by(id=id).update({"share_id": share_id})
  141. return self.get_chat_by_id(db, id)
  142. except:
  143. return None
  144. def toggle_chat_archive_by_id(self, db: Session, id: str) -> Optional[ChatModel]:
  145. try:
  146. chat = self.get_chat_by_id(db, id)
  147. db.query(Chat).filter_by(id=id).update({"archived": not chat.archived})
  148. return self.get_chat_by_id(db, id)
  149. except:
  150. return None
  151. def archive_all_chats_by_user_id(self, db: Session, user_id: str) -> bool:
  152. try:
  153. db.query(Chat).filter_by(user_id=user_id).update({"archived": True})
  154. return True
  155. except:
  156. return False
  157. def get_archived_chat_list_by_user_id(
  158. self, db: Session, user_id: str, skip: int = 0, limit: int = 50
  159. ) -> List[ChatModel]:
  160. all_chats = (
  161. db.query(Chat)
  162. .filter_by(user_id=user_id, archived=True)
  163. .order_by(Chat.updated_at.desc())
  164. # .limit(limit).offset(skip)
  165. .all()
  166. )
  167. return [ChatModel.model_validate(chat) for chat in all_chats]
  168. def get_chat_list_by_user_id(
  169. self,
  170. db: Session,
  171. user_id: str,
  172. include_archived: bool = False,
  173. skip: int = 0,
  174. limit: int = 50,
  175. ) -> List[ChatModel]:
  176. query = db.query(Chat).filter_by(user_id=user_id)
  177. if not include_archived:
  178. query = query.filter_by(archived=False)
  179. all_chats = (
  180. query.order_by(Chat.updated_at.desc())
  181. # .limit(limit).offset(skip)
  182. .all()
  183. )
  184. return [ChatModel.model_validate(chat) for chat in all_chats]
  185. def get_chat_list_by_chat_ids(
  186. self, db: Session, chat_ids: List[str], skip: int = 0, limit: int = 50
  187. ) -> List[ChatModel]:
  188. all_chats = (
  189. db.query(Chat)
  190. .filter(Chat.id.in_(chat_ids))
  191. .filter_by(archived=False)
  192. .order_by(Chat.updated_at.desc())
  193. .all()
  194. )
  195. return [ChatModel.model_validate(chat) for chat in all_chats]
  196. def get_chat_by_id(self, db: Session, id: str) -> Optional[ChatModel]:
  197. try:
  198. chat = db.get(Chat, id)
  199. return ChatModel.model_validate(chat)
  200. except:
  201. return None
  202. def get_chat_by_share_id(self, db: Session, id: str) -> Optional[ChatModel]:
  203. try:
  204. chat = db.query(Chat).filter_by(share_id=id).first()
  205. if chat:
  206. return self.get_chat_by_id(db, id)
  207. else:
  208. return None
  209. except Exception as e:
  210. return None
  211. def get_chat_by_id_and_user_id(
  212. self, db: Session, id: str, user_id: str
  213. ) -> Optional[ChatModel]:
  214. try:
  215. chat = db.query(Chat).filter_by(id=id, user_id=user_id).first()
  216. return ChatModel.model_validate(chat)
  217. except:
  218. return None
  219. def get_chats(self, db: Session, skip: int = 0, limit: int = 50) -> List[ChatModel]:
  220. all_chats = (
  221. db.query(Chat)
  222. # .limit(limit).offset(skip)
  223. .order_by(Chat.updated_at.desc())
  224. )
  225. return [ChatModel.model_validate(chat) for chat in all_chats]
  226. def get_chats_by_user_id(self, db: Session, user_id: str) -> List[ChatModel]:
  227. all_chats = (
  228. db.query(Chat).filter_by(user_id=user_id).order_by(Chat.updated_at.desc())
  229. )
  230. return [ChatModel.model_validate(chat) for chat in all_chats]
  231. def get_archived_chats_by_user_id(
  232. self, db: Session, user_id: str
  233. ) -> List[ChatModel]:
  234. all_chats = (
  235. db.query(Chat)
  236. .filter_by(user_id=user_id, archived=True)
  237. .order_by(Chat.updated_at.desc())
  238. )
  239. return [ChatModel.model_validate(chat) for chat in all_chats]
  240. def delete_chat_by_id(self, db: Session, id: str) -> bool:
  241. try:
  242. db.query(Chat).filter_by(id=id).delete()
  243. return True and self.delete_shared_chat_by_chat_id(db, id)
  244. except:
  245. return False
  246. def delete_chat_by_id_and_user_id(self, db: Session, id: str, user_id: str) -> bool:
  247. try:
  248. db.query(Chat).filter_by(id=id, user_id=user_id).delete()
  249. return True and self.delete_shared_chat_by_chat_id(db, id)
  250. except:
  251. return False
  252. def delete_chats_by_user_id(self, db: Session, user_id: str) -> bool:
  253. try:
  254. self.delete_shared_chats_by_user_id(db, user_id)
  255. db.query(Chat).filter_by(user_id=user_id).delete()
  256. return True
  257. except:
  258. return False
  259. def delete_shared_chats_by_user_id(self, db: Session, user_id: str) -> bool:
  260. try:
  261. chats_by_user = db.query(Chat).filter_by(user_id=user_id).all()
  262. shared_chat_ids = [f"shared-{chat.id}" for chat in chats_by_user]
  263. db.query(Chat).filter(Chat.user_id.in_(shared_chat_ids)).delete()
  264. return True
  265. except:
  266. return False
  267. Chats = ChatTable()