chats.py 8.9 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314
  1. from pydantic import BaseModel
  2. from typing import List, Union, Optional
  3. from peewee import *
  4. from playhouse.shortcuts import model_to_dict
  5. import json
  6. import uuid
  7. import time
  8. from apps.web.internal.db import DB
  9. ####################
  10. # Chat DB Schema
  11. ####################
  12. class Chat(Model):
  13. id = CharField(unique=True)
  14. user_id = CharField()
  15. title = CharField()
  16. chat = TextField() # Save Chat JSON as Text
  17. created_at = DateTimeField()
  18. updated_at = DateTimeField()
  19. share_id = CharField(null=True, unique=True)
  20. archived = BooleanField(default=False)
  21. class Meta:
  22. database = DB
  23. class ChatModel(BaseModel):
  24. id: str
  25. user_id: str
  26. title: str
  27. chat: str
  28. created_at: int # timestamp in epoch
  29. updated_at: int # timestamp in epoch
  30. share_id: Optional[str] = None
  31. archived: bool = False
  32. ####################
  33. # Forms
  34. ####################
  35. class ChatForm(BaseModel):
  36. chat: dict
  37. class ChatTitleForm(BaseModel):
  38. title: str
  39. class ChatResponse(BaseModel):
  40. id: str
  41. user_id: str
  42. title: str
  43. chat: dict
  44. updated_at: int # timestamp in epoch
  45. created_at: int # timestamp in epoch
  46. share_id: Optional[str] = None # id of the chat to be shared
  47. class ChatTitleIdResponse(BaseModel):
  48. id: str
  49. title: str
  50. updated_at: int
  51. created_at: int
  52. class ChatTable:
  53. def __init__(self, db):
  54. self.db = db
  55. db.create_tables([Chat])
  56. def insert_new_chat(self, user_id: str, form_data: ChatForm) -> Optional[ChatModel]:
  57. id = str(uuid.uuid4())
  58. chat = ChatModel(
  59. **{
  60. "id": id,
  61. "user_id": user_id,
  62. "title": (
  63. form_data.chat["title"] if "title" in form_data.chat else "New Chat"
  64. ),
  65. "chat": json.dumps(form_data.chat),
  66. "created_at": int(time.time()),
  67. "updated_at": int(time.time()),
  68. }
  69. )
  70. result = Chat.create(**chat.model_dump())
  71. return chat if result else None
  72. def update_chat_by_id(self, id: str, chat: dict) -> Optional[ChatModel]:
  73. try:
  74. query = Chat.update(
  75. chat=json.dumps(chat),
  76. title=chat["title"] if "title" in chat else "New Chat",
  77. updated_at=int(time.time()),
  78. ).where(Chat.id == id)
  79. query.execute()
  80. chat = Chat.get(Chat.id == id)
  81. return ChatModel(**model_to_dict(chat))
  82. except:
  83. return None
  84. def insert_shared_chat_by_chat_id(self, chat_id: str) -> Optional[ChatModel]:
  85. # Get the existing chat to share
  86. chat = Chat.get(Chat.id == chat_id)
  87. # Check if the chat is already shared
  88. if chat.share_id:
  89. return self.get_chat_by_id_and_user_id(chat.share_id, "shared")
  90. # Create a new chat with the same data, but with a new ID
  91. shared_chat = ChatModel(
  92. **{
  93. "id": str(uuid.uuid4()),
  94. "user_id": f"shared-{chat_id}",
  95. "title": chat.title,
  96. "chat": chat.chat,
  97. "created_at": int(time.time()),
  98. }
  99. )
  100. shared_result = Chat.create(**shared_chat.model_dump())
  101. # Update the original chat with the share_id
  102. result = (
  103. Chat.update(share_id=shared_chat.id).where(Chat.id == chat_id).execute()
  104. )
  105. return shared_chat if (shared_result and result) else None
  106. def update_shared_chat_by_chat_id(self, chat_id: str) -> Optional[ChatModel]:
  107. try:
  108. print("update_shared_chat_by_id")
  109. chat = Chat.get(Chat.id == chat_id)
  110. print(chat)
  111. query = Chat.update(
  112. title=chat.title,
  113. chat=chat.chat,
  114. ).where(Chat.id == chat.share_id)
  115. query.execute()
  116. chat = Chat.get(Chat.id == chat.share_id)
  117. return ChatModel(**model_to_dict(chat))
  118. except:
  119. return None
  120. def delete_shared_chat_by_chat_id(self, chat_id: str) -> bool:
  121. try:
  122. query = Chat.delete().where(Chat.user_id == f"shared-{chat_id}")
  123. query.execute() # Remove the rows, return number of rows removed.
  124. return True
  125. except:
  126. return False
  127. def update_chat_share_id_by_id(
  128. self, id: str, share_id: Optional[str]
  129. ) -> Optional[ChatModel]:
  130. try:
  131. query = Chat.update(
  132. share_id=share_id,
  133. ).where(Chat.id == id)
  134. query.execute()
  135. chat = Chat.get(Chat.id == id)
  136. return ChatModel(**model_to_dict(chat))
  137. except:
  138. return None
  139. def toggle_chat_archive_by_id(self, id: str) -> Optional[ChatModel]:
  140. try:
  141. chat = self.get_chat_by_id(id)
  142. query = Chat.update(
  143. archived=(not chat.archived),
  144. ).where(Chat.id == id)
  145. query.execute()
  146. chat = Chat.get(Chat.id == id)
  147. return ChatModel(**model_to_dict(chat))
  148. except:
  149. return None
  150. def get_archived_chat_lists_by_user_id(
  151. self, user_id: str, skip: int = 0, limit: int = 50
  152. ) -> List[ChatModel]:
  153. return [
  154. ChatModel(**model_to_dict(chat))
  155. for chat in Chat.select()
  156. .where(Chat.archived == True)
  157. .where(Chat.user_id == user_id)
  158. .order_by(Chat.updated_at.desc())
  159. # .limit(limit)
  160. # .offset(skip)
  161. ]
  162. def get_chat_lists_by_user_id(
  163. self, user_id: str, skip: int = 0, limit: int = 50
  164. ) -> List[ChatModel]:
  165. return [
  166. ChatModel(**model_to_dict(chat))
  167. for chat in Chat.select()
  168. .where(Chat.archived == False)
  169. .where(Chat.user_id == user_id)
  170. .order_by(Chat.updated_at.desc())
  171. # .limit(limit)
  172. # .offset(skip)
  173. ]
  174. def get_chat_lists_by_chat_ids(
  175. self, chat_ids: List[str], skip: int = 0, limit: int = 50
  176. ) -> List[ChatModel]:
  177. return [
  178. ChatModel(**model_to_dict(chat))
  179. for chat in Chat.select()
  180. .where(Chat.archived == False)
  181. .where(Chat.id.in_(chat_ids))
  182. .order_by(Chat.updated_at.desc())
  183. ]
  184. def get_all_chats(self) -> List[ChatModel]:
  185. return [
  186. ChatModel(**model_to_dict(chat))
  187. for chat in Chat.select()
  188. .where(Chat.archived == False)
  189. .order_by(Chat.updated_at.desc())
  190. ]
  191. def get_all_chats_by_user_id(self, user_id: str) -> List[ChatModel]:
  192. return [
  193. ChatModel(**model_to_dict(chat))
  194. for chat in Chat.select()
  195. .where(Chat.archived == False)
  196. .where(Chat.user_id == user_id)
  197. .order_by(Chat.updated_at.desc())
  198. ]
  199. def get_chat_by_id(self, id: str) -> Optional[ChatModel]:
  200. try:
  201. chat = Chat.get(Chat.id == id)
  202. return ChatModel(**model_to_dict(chat))
  203. except:
  204. return None
  205. def get_chat_by_share_id(self, id: str) -> Optional[ChatModel]:
  206. try:
  207. chat = Chat.get(Chat.share_id == id)
  208. if chat:
  209. chat = Chat.get(Chat.id == id)
  210. return ChatModel(**model_to_dict(chat))
  211. else:
  212. return None
  213. except:
  214. return None
  215. def get_chat_by_id_and_user_id(self, id: str, user_id: str) -> Optional[ChatModel]:
  216. try:
  217. chat = Chat.get(Chat.id == id, Chat.user_id == user_id)
  218. return ChatModel(**model_to_dict(chat))
  219. except:
  220. return None
  221. def get_chats(self, skip: int = 0, limit: int = 50) -> List[ChatModel]:
  222. return [
  223. ChatModel(**model_to_dict(chat))
  224. for chat in Chat.select().limit(limit).offset(skip)
  225. ]
  226. def delete_chat_by_id_and_user_id(self, id: str, user_id: str) -> bool:
  227. try:
  228. query = Chat.delete().where((Chat.id == id) & (Chat.user_id == user_id))
  229. query.execute() # Remove the rows, return number of rows removed.
  230. return True and self.delete_shared_chat_by_chat_id(id)
  231. except:
  232. return False
  233. def delete_chats_by_user_id(self, user_id: str) -> bool:
  234. try:
  235. self.delete_shared_chats_by_user_id(user_id)
  236. query = Chat.delete().where(Chat.user_id == user_id)
  237. query.execute() # Remove the rows, return number of rows removed.
  238. return True
  239. except:
  240. return False
  241. def delete_shared_chats_by_user_id(self, user_id: str) -> bool:
  242. try:
  243. shared_chat_ids = [
  244. f"shared-{chat.id}"
  245. for chat in Chat.select().where(Chat.user_id == user_id)
  246. ]
  247. query = Chat.delete().where(Chat.user_id << shared_chat_ids)
  248. query.execute() # Remove the rows, return number of rows removed.
  249. return True
  250. except:
  251. return False
  252. Chats = ChatTable(DB)