chats.py 10 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355
  1. from pydantic import BaseModel
  2. from typing import List, Union, Optional
  3. from peewee import *
  4. from playhouse.shortcuts import model_to_dict
  5. import json
  6. import uuid
  7. import time
  8. from apps.webui.internal.db import DB
  9. ####################
  10. # Chat DB Schema
  11. ####################
  12. class Chat(Model):
  13. id = CharField(unique=True)
  14. user_id = CharField()
  15. title = TextField()
  16. chat = TextField() # Save Chat JSON as Text
  17. created_at = BigIntegerField()
  18. updated_at = BigIntegerField()
  19. share_id = CharField(null=True, unique=True)
  20. archived = BooleanField(default=False)
  21. class Meta:
  22. database = DB
  23. class ChatModel(BaseModel):
  24. id: str
  25. user_id: str
  26. title: str
  27. chat: str
  28. created_at: int # timestamp in epoch
  29. updated_at: int # timestamp in epoch
  30. share_id: Optional[str] = None
  31. archived: bool = False
  32. ####################
  33. # Forms
  34. ####################
  35. class ChatForm(BaseModel):
  36. chat: dict
  37. class ChatTitleForm(BaseModel):
  38. title: str
  39. class ChatResponse(BaseModel):
  40. id: str
  41. user_id: str
  42. title: str
  43. chat: dict
  44. updated_at: int # timestamp in epoch
  45. created_at: int # timestamp in epoch
  46. share_id: Optional[str] = None # id of the chat to be shared
  47. archived: bool
  48. class ChatTitleIdResponse(BaseModel):
  49. id: str
  50. title: str
  51. updated_at: int
  52. created_at: int
  53. class ChatTable:
  54. def __init__(self, db):
  55. self.db = db
  56. db.create_tables([Chat])
  57. def insert_new_chat(self, user_id: str, form_data: ChatForm) -> Optional[ChatModel]:
  58. id = str(uuid.uuid4())
  59. chat = ChatModel(
  60. **{
  61. "id": id,
  62. "user_id": user_id,
  63. "title": (
  64. form_data.chat["title"] if "title" in form_data.chat else "New Chat"
  65. ),
  66. "chat": json.dumps(form_data.chat),
  67. "created_at": int(time.time()),
  68. "updated_at": int(time.time()),
  69. }
  70. )
  71. result = Chat.create(**chat.model_dump())
  72. return chat if result else None
  73. def update_chat_by_id(self, id: str, chat: dict) -> Optional[ChatModel]:
  74. try:
  75. query = Chat.update(
  76. chat=json.dumps(chat),
  77. title=chat["title"] if "title" in chat else "New Chat",
  78. updated_at=int(time.time()),
  79. ).where(Chat.id == id)
  80. query.execute()
  81. chat = Chat.get(Chat.id == id)
  82. return ChatModel(**model_to_dict(chat))
  83. except:
  84. return None
  85. def insert_shared_chat_by_chat_id(self, chat_id: str) -> Optional[ChatModel]:
  86. # Get the existing chat to share
  87. chat = Chat.get(Chat.id == chat_id)
  88. # Check if the chat is already shared
  89. if chat.share_id:
  90. return self.get_chat_by_id_and_user_id(chat.share_id, "shared")
  91. # Create a new chat with the same data, but with a new ID
  92. shared_chat = ChatModel(
  93. **{
  94. "id": str(uuid.uuid4()),
  95. "user_id": f"shared-{chat_id}",
  96. "title": chat.title,
  97. "chat": chat.chat,
  98. "created_at": chat.created_at,
  99. "updated_at": int(time.time()),
  100. }
  101. )
  102. shared_result = Chat.create(**shared_chat.model_dump())
  103. # Update the original chat with the share_id
  104. result = (
  105. Chat.update(share_id=shared_chat.id).where(Chat.id == chat_id).execute()
  106. )
  107. return shared_chat if (shared_result and result) else None
  108. def update_shared_chat_by_chat_id(self, chat_id: str) -> Optional[ChatModel]:
  109. try:
  110. print("update_shared_chat_by_id")
  111. chat = Chat.get(Chat.id == chat_id)
  112. print(chat)
  113. query = Chat.update(
  114. title=chat.title,
  115. chat=chat.chat,
  116. ).where(Chat.id == chat.share_id)
  117. query.execute()
  118. chat = Chat.get(Chat.id == chat.share_id)
  119. return ChatModel(**model_to_dict(chat))
  120. except:
  121. return None
  122. def delete_shared_chat_by_chat_id(self, chat_id: str) -> bool:
  123. try:
  124. query = Chat.delete().where(Chat.user_id == f"shared-{chat_id}")
  125. query.execute() # Remove the rows, return number of rows removed.
  126. return True
  127. except:
  128. return False
  129. def update_chat_share_id_by_id(
  130. self, id: str, share_id: Optional[str]
  131. ) -> Optional[ChatModel]:
  132. try:
  133. query = Chat.update(
  134. share_id=share_id,
  135. ).where(Chat.id == id)
  136. query.execute()
  137. chat = Chat.get(Chat.id == id)
  138. return ChatModel(**model_to_dict(chat))
  139. except:
  140. return None
  141. def toggle_chat_archive_by_id(self, id: str) -> Optional[ChatModel]:
  142. try:
  143. chat = self.get_chat_by_id(id)
  144. query = Chat.update(
  145. archived=(not chat.archived),
  146. ).where(Chat.id == id)
  147. query.execute()
  148. chat = Chat.get(Chat.id == id)
  149. return ChatModel(**model_to_dict(chat))
  150. except:
  151. return None
  152. def archive_all_chats_by_user_id(self, user_id: str) -> bool:
  153. try:
  154. chats = self.get_chats_by_user_id(user_id)
  155. for chat in chats:
  156. query = Chat.update(
  157. archived=True,
  158. ).where(Chat.id == chat.id)
  159. query.execute()
  160. return True
  161. except:
  162. return False
  163. def get_archived_chat_list_by_user_id(
  164. self, user_id: str, skip: int = 0, limit: int = 50
  165. ) -> List[ChatModel]:
  166. return [
  167. ChatModel(**model_to_dict(chat))
  168. for chat in Chat.select()
  169. .where(Chat.archived == True)
  170. .where(Chat.user_id == user_id)
  171. .order_by(Chat.updated_at.desc())
  172. # .limit(limit)
  173. # .offset(skip)
  174. ]
  175. def get_chat_list_by_user_id(
  176. self,
  177. user_id: str,
  178. include_archived: bool = False,
  179. skip: int = 0,
  180. limit: int = 50,
  181. ) -> List[ChatModel]:
  182. if include_archived:
  183. return [
  184. ChatModel(**model_to_dict(chat))
  185. for chat in Chat.select()
  186. .where(Chat.user_id == user_id)
  187. .order_by(Chat.updated_at.desc())
  188. # .limit(limit)
  189. # .offset(skip)
  190. ]
  191. else:
  192. return [
  193. ChatModel(**model_to_dict(chat))
  194. for chat in Chat.select()
  195. .where(Chat.archived == False)
  196. .where(Chat.user_id == user_id)
  197. .order_by(Chat.updated_at.desc())
  198. # .limit(limit)
  199. # .offset(skip)
  200. ]
  201. def get_chat_list_by_chat_ids(
  202. self, chat_ids: List[str], skip: int = 0, limit: int = 50
  203. ) -> List[ChatModel]:
  204. return [
  205. ChatModel(**model_to_dict(chat))
  206. for chat in Chat.select()
  207. .where(Chat.archived == False)
  208. .where(Chat.id.in_(chat_ids))
  209. .order_by(Chat.updated_at.desc())
  210. ]
  211. def get_chat_by_id(self, id: str) -> Optional[ChatModel]:
  212. try:
  213. chat = Chat.get(Chat.id == id)
  214. return ChatModel(**model_to_dict(chat))
  215. except:
  216. return None
  217. def get_chat_by_share_id(self, id: str) -> Optional[ChatModel]:
  218. try:
  219. chat = Chat.get(Chat.share_id == id)
  220. if chat:
  221. chat = Chat.get(Chat.id == id)
  222. return ChatModel(**model_to_dict(chat))
  223. else:
  224. return None
  225. except:
  226. return None
  227. def get_chat_by_id_and_user_id(self, id: str, user_id: str) -> Optional[ChatModel]:
  228. try:
  229. chat = Chat.get(Chat.id == id, Chat.user_id == user_id)
  230. return ChatModel(**model_to_dict(chat))
  231. except:
  232. return None
  233. def get_chats(self, skip: int = 0, limit: int = 50) -> List[ChatModel]:
  234. return [
  235. ChatModel(**model_to_dict(chat))
  236. for chat in Chat.select().order_by(Chat.updated_at.desc())
  237. # .limit(limit).offset(skip)
  238. ]
  239. def get_chats_by_user_id(self, user_id: str) -> List[ChatModel]:
  240. return [
  241. ChatModel(**model_to_dict(chat))
  242. for chat in Chat.select()
  243. .where(Chat.user_id == user_id)
  244. .order_by(Chat.updated_at.desc())
  245. # .limit(limit).offset(skip)
  246. ]
  247. def get_archived_chats_by_user_id(self, user_id: str) -> List[ChatModel]:
  248. return [
  249. ChatModel(**model_to_dict(chat))
  250. for chat in Chat.select()
  251. .where(Chat.archived == True)
  252. .where(Chat.user_id == user_id)
  253. .order_by(Chat.updated_at.desc())
  254. ]
  255. def delete_chat_by_id(self, id: str) -> bool:
  256. try:
  257. query = Chat.delete().where((Chat.id == id))
  258. query.execute() # Remove the rows, return number of rows removed.
  259. return True and self.delete_shared_chat_by_chat_id(id)
  260. except:
  261. return False
  262. def delete_chat_by_id_and_user_id(self, id: str, user_id: str) -> bool:
  263. try:
  264. query = Chat.delete().where((Chat.id == id) & (Chat.user_id == user_id))
  265. query.execute() # Remove the rows, return number of rows removed.
  266. return True and self.delete_shared_chat_by_chat_id(id)
  267. except:
  268. return False
  269. def delete_chats_by_user_id(self, user_id: str) -> bool:
  270. try:
  271. self.delete_shared_chats_by_user_id(user_id)
  272. query = Chat.delete().where(Chat.user_id == user_id)
  273. query.execute() # Remove the rows, return number of rows removed.
  274. return True
  275. except:
  276. return False
  277. def delete_shared_chats_by_user_id(self, user_id: str) -> bool:
  278. try:
  279. shared_chat_ids = [
  280. f"shared-{chat.id}"
  281. for chat in Chat.select().where(Chat.user_id == user_id)
  282. ]
  283. query = Chat.delete().where(Chat.user_id << shared_chat_ids)
  284. query.execute() # Remove the rows, return number of rows removed.
  285. return True
  286. except:
  287. return False
  288. Chats = ChatTable(DB)