chats.py 11 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346
  1. from pydantic import BaseModel, ConfigDict
  2. from typing import List, Union, Optional
  3. import json
  4. import uuid
  5. import time
  6. from sqlalchemy import Column, String, BigInteger, Boolean
  7. from sqlalchemy.orm import Session
  8. from apps.webui.internal.db import Base, get_session
  9. ####################
  10. # Chat DB Schema
  11. ####################
  12. class Chat(Base):
  13. __tablename__ = "chat"
  14. id = Column(String, primary_key=True)
  15. user_id = Column(String)
  16. title = Column(String)
  17. chat = Column(String) # Save Chat JSON as Text
  18. created_at = Column(BigInteger)
  19. updated_at = Column(BigInteger)
  20. share_id = Column(String, unique=True, nullable=True)
  21. archived = Column(Boolean, default=False)
  22. class ChatModel(BaseModel):
  23. model_config = ConfigDict(from_attributes=True)
  24. id: str
  25. user_id: str
  26. title: str
  27. chat: str
  28. created_at: int # timestamp in epoch
  29. updated_at: int # timestamp in epoch
  30. share_id: Optional[str] = None
  31. archived: bool = False
  32. ####################
  33. # Forms
  34. ####################
  35. class ChatForm(BaseModel):
  36. chat: dict
  37. class ChatTitleForm(BaseModel):
  38. title: str
  39. class ChatResponse(BaseModel):
  40. id: str
  41. user_id: str
  42. title: str
  43. chat: dict
  44. updated_at: int # timestamp in epoch
  45. created_at: int # timestamp in epoch
  46. share_id: Optional[str] = None # id of the chat to be shared
  47. archived: bool
  48. class ChatTitleIdResponse(BaseModel):
  49. id: str
  50. title: str
  51. updated_at: int
  52. created_at: int
  53. class ChatTable:
  54. def insert_new_chat(self, user_id: str, form_data: ChatForm) -> Optional[ChatModel]:
  55. with get_session() as db:
  56. id = str(uuid.uuid4())
  57. chat = ChatModel(
  58. **{
  59. "id": id,
  60. "user_id": user_id,
  61. "title": (
  62. form_data.chat["title"]
  63. if "title" in form_data.chat
  64. else "New Chat"
  65. ),
  66. "chat": json.dumps(form_data.chat),
  67. "created_at": int(time.time()),
  68. "updated_at": int(time.time()),
  69. }
  70. )
  71. result = Chat(**chat.model_dump())
  72. db.add(result)
  73. db.commit()
  74. db.refresh(result)
  75. return ChatModel.model_validate(result) if result else None
  76. def update_chat_by_id(self, id: str, chat: dict) -> Optional[ChatModel]:
  77. with get_session() as db:
  78. try:
  79. chat_obj = db.get(Chat, id)
  80. chat_obj.chat = json.dumps(chat)
  81. chat_obj.title = chat["title"] if "title" in chat else "New Chat"
  82. chat_obj.updated_at = int(time.time())
  83. db.commit()
  84. db.refresh(chat_obj)
  85. return ChatModel.model_validate(chat_obj)
  86. except Exception as e:
  87. return None
  88. def insert_shared_chat_by_chat_id(self, chat_id: str) -> Optional[ChatModel]:
  89. with get_session() as db:
  90. # Get the existing chat to share
  91. chat = db.get(Chat, chat_id)
  92. # Check if the chat is already shared
  93. if chat.share_id:
  94. return self.get_chat_by_id_and_user_id(chat.share_id, "shared")
  95. # Create a new chat with the same data, but with a new ID
  96. shared_chat = ChatModel(
  97. **{
  98. "id": str(uuid.uuid4()),
  99. "user_id": f"shared-{chat_id}",
  100. "title": chat.title,
  101. "chat": chat.chat,
  102. "created_at": chat.created_at,
  103. "updated_at": int(time.time()),
  104. }
  105. )
  106. shared_result = Chat(**shared_chat.model_dump())
  107. db.add(shared_result)
  108. db.commit()
  109. db.refresh(shared_result)
  110. # Update the original chat with the share_id
  111. result = (
  112. db.query(Chat)
  113. .filter_by(id=chat_id)
  114. .update({"share_id": shared_chat.id})
  115. )
  116. return shared_chat if (shared_result and result) else None
  117. def update_shared_chat_by_chat_id(self, chat_id: str) -> Optional[ChatModel]:
  118. with get_session() as db:
  119. try:
  120. print("update_shared_chat_by_id")
  121. chat = db.get(Chat, chat_id)
  122. print(chat)
  123. chat.title = chat.title
  124. chat.chat = chat.chat
  125. db.commit()
  126. db.refresh(chat)
  127. return self.get_chat_by_id(chat.share_id)
  128. except:
  129. return None
  130. def delete_shared_chat_by_chat_id(self, chat_id: str) -> bool:
  131. try:
  132. with get_session() as db:
  133. db.query(Chat).filter_by(user_id=f"shared-{chat_id}").delete()
  134. return True
  135. except:
  136. return False
  137. def update_chat_share_id_by_id(
  138. self, id: str, share_id: Optional[str]
  139. ) -> Optional[ChatModel]:
  140. try:
  141. with get_session() as db:
  142. chat = db.get(Chat, id)
  143. chat.share_id = share_id
  144. db.commit()
  145. db.refresh(chat)
  146. return chat
  147. except:
  148. return None
  149. def toggle_chat_archive_by_id(self, id: str) -> Optional[ChatModel]:
  150. try:
  151. with get_session() as db:
  152. chat = self.get_chat_by_id(id)
  153. db.query(Chat).filter_by(id=id).update({"archived": not chat.archived})
  154. return self.get_chat_by_id(id)
  155. except:
  156. return None
  157. def archive_all_chats_by_user_id(self, user_id: str) -> bool:
  158. try:
  159. with get_session() as db:
  160. db.query(Chat).filter_by(user_id=user_id).update({"archived": True})
  161. return True
  162. except:
  163. return False
  164. def get_archived_chat_list_by_user_id(
  165. self, user_id: str, skip: int = 0, limit: int = 50
  166. ) -> List[ChatModel]:
  167. with get_session() as db:
  168. all_chats = (
  169. db.query(Chat)
  170. .filter_by(user_id=user_id, archived=True)
  171. .order_by(Chat.updated_at.desc())
  172. # .limit(limit).offset(skip)
  173. .all()
  174. )
  175. return [ChatModel.model_validate(chat) for chat in all_chats]
  176. def get_chat_list_by_user_id(
  177. self,
  178. user_id: str,
  179. include_archived: bool = False,
  180. skip: int = 0,
  181. limit: int = 50,
  182. ) -> List[ChatModel]:
  183. with get_session() as db:
  184. query = db.query(Chat).filter_by(user_id=user_id)
  185. if not include_archived:
  186. query = query.filter_by(archived=False)
  187. all_chats = (
  188. query.order_by(Chat.updated_at.desc())
  189. # .limit(limit).offset(skip)
  190. .all()
  191. )
  192. return [ChatModel.model_validate(chat) for chat in all_chats]
  193. def get_chat_list_by_chat_ids(
  194. self, chat_ids: List[str], skip: int = 0, limit: int = 50
  195. ) -> List[ChatModel]:
  196. with get_session() as db:
  197. all_chats = (
  198. db.query(Chat)
  199. .filter(Chat.id.in_(chat_ids))
  200. .filter_by(archived=False)
  201. .order_by(Chat.updated_at.desc())
  202. .all()
  203. )
  204. return [ChatModel.model_validate(chat) for chat in all_chats]
  205. def get_chat_by_id(self, id: str) -> Optional[ChatModel]:
  206. try:
  207. with get_session() as db:
  208. chat = db.get(Chat, id)
  209. return ChatModel.model_validate(chat)
  210. except:
  211. return None
  212. def get_chat_by_share_id(self, id: str) -> Optional[ChatModel]:
  213. try:
  214. with get_session() as db:
  215. chat = db.query(Chat).filter_by(share_id=id).first()
  216. if chat:
  217. return self.get_chat_by_id(id)
  218. else:
  219. return None
  220. except Exception as e:
  221. return None
  222. def get_chat_by_id_and_user_id(self, id: str, user_id: str) -> Optional[ChatModel]:
  223. try:
  224. with get_session() as db:
  225. chat = db.query(Chat).filter_by(id=id, user_id=user_id).first()
  226. return ChatModel.model_validate(chat)
  227. except:
  228. return None
  229. def get_chats(self, skip: int = 0, limit: int = 50) -> List[ChatModel]:
  230. with get_session() as db:
  231. all_chats = (
  232. db.query(Chat)
  233. # .limit(limit).offset(skip)
  234. .order_by(Chat.updated_at.desc())
  235. )
  236. return [ChatModel.model_validate(chat) for chat in all_chats]
  237. def get_chats_by_user_id(self, user_id: str) -> List[ChatModel]:
  238. with get_session() as db:
  239. all_chats = (
  240. db.query(Chat)
  241. .filter_by(user_id=user_id)
  242. .order_by(Chat.updated_at.desc())
  243. )
  244. return [ChatModel.model_validate(chat) for chat in all_chats]
  245. def get_archived_chats_by_user_id(self, user_id: str) -> List[ChatModel]:
  246. with get_session() as db:
  247. all_chats = (
  248. db.query(Chat)
  249. .filter_by(user_id=user_id, archived=True)
  250. .order_by(Chat.updated_at.desc())
  251. )
  252. return [ChatModel.model_validate(chat) for chat in all_chats]
  253. def delete_chat_by_id(self, id: str) -> bool:
  254. try:
  255. with get_session() as db:
  256. db.query(Chat).filter_by(id=id).delete()
  257. return True and self.delete_shared_chat_by_chat_id(id)
  258. except:
  259. return False
  260. def delete_chat_by_id_and_user_id(self, id: str, user_id: str) -> bool:
  261. try:
  262. with get_session() as db:
  263. db.query(Chat).filter_by(id=id, user_id=user_id).delete()
  264. return True and self.delete_shared_chat_by_chat_id(id)
  265. except:
  266. return False
  267. def delete_chats_by_user_id(self, user_id: str) -> bool:
  268. try:
  269. with get_session() as db:
  270. self.delete_shared_chats_by_user_id(user_id)
  271. db.query(Chat).filter_by(user_id=user_id).delete()
  272. return True
  273. except:
  274. return False
  275. def delete_shared_chats_by_user_id(self, user_id: str) -> bool:
  276. try:
  277. with get_session() as db:
  278. chats_by_user = db.query(Chat).filter_by(user_id=user_id).all()
  279. shared_chat_ids = [f"shared-{chat.id}" for chat in chats_by_user]
  280. db.query(Chat).filter(Chat.user_id.in_(shared_chat_ids)).delete()
  281. return True
  282. except:
  283. return False
  284. Chats = ChatTable()