chats.py 8.3 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291
  1. from pydantic import BaseModel
  2. from typing import List, Union, Optional
  3. from peewee import *
  4. from playhouse.shortcuts import model_to_dict
  5. import json
  6. import uuid
  7. import time
  8. from apps.web.internal.db import DB
  9. ####################
  10. # Chat DB Schema
  11. ####################
  12. class Chat(Model):
  13. id = CharField(unique=True)
  14. user_id = CharField()
  15. title = CharField()
  16. chat = TextField() # Save Chat JSON as Text
  17. timestamp = DateField()
  18. share_id = CharField(null=True, unique=True)
  19. archived = BooleanField(default=False)
  20. class Meta:
  21. database = DB
  22. class ChatModel(BaseModel):
  23. id: str
  24. user_id: str
  25. title: str
  26. chat: str
  27. timestamp: int # timestamp in epoch
  28. share_id: Optional[str] = None
  29. archived: bool = False
  30. ####################
  31. # Forms
  32. ####################
  33. class ChatForm(BaseModel):
  34. chat: dict
  35. class ChatTitleForm(BaseModel):
  36. title: str
  37. class ChatResponse(BaseModel):
  38. id: str
  39. user_id: str
  40. title: str
  41. chat: dict
  42. timestamp: int # timestamp in epoch
  43. share_id: Optional[str] = None # id of the chat to be shared
  44. class ChatTitleIdResponse(BaseModel):
  45. id: str
  46. title: str
  47. class ChatTable:
  48. def __init__(self, db):
  49. self.db = db
  50. db.create_tables([Chat])
  51. def insert_new_chat(self, user_id: str, form_data: ChatForm) -> Optional[ChatModel]:
  52. id = str(uuid.uuid4())
  53. chat = ChatModel(
  54. **{
  55. "id": id,
  56. "user_id": user_id,
  57. "title": (
  58. form_data.chat["title"] if "title" in form_data.chat else "New Chat"
  59. ),
  60. "chat": json.dumps(form_data.chat),
  61. "timestamp": int(time.time()),
  62. }
  63. )
  64. result = Chat.create(**chat.model_dump())
  65. return chat if result else None
  66. def update_chat_by_id(self, id: str, chat: dict) -> Optional[ChatModel]:
  67. try:
  68. query = Chat.update(
  69. chat=json.dumps(chat),
  70. title=chat["title"] if "title" in chat else "New Chat",
  71. timestamp=int(time.time()),
  72. ).where(Chat.id == id)
  73. query.execute()
  74. chat = Chat.get(Chat.id == id)
  75. return ChatModel(**model_to_dict(chat))
  76. except:
  77. return None
  78. def insert_shared_chat_by_chat_id(self, chat_id: str) -> Optional[ChatModel]:
  79. # Get the existing chat to share
  80. chat = Chat.get(Chat.id == chat_id)
  81. # Check if the chat is already shared
  82. if chat.share_id:
  83. return self.get_chat_by_id_and_user_id(chat.share_id, "shared")
  84. # Create a new chat with the same data, but with a new ID
  85. shared_chat = ChatModel(
  86. **{
  87. "id": str(uuid.uuid4()),
  88. "user_id": f"shared-{chat_id}",
  89. "title": chat.title,
  90. "chat": chat.chat,
  91. "timestamp": int(time.time()),
  92. }
  93. )
  94. shared_result = Chat.create(**shared_chat.model_dump())
  95. # Update the original chat with the share_id
  96. result = (
  97. Chat.update(share_id=shared_chat.id).where(Chat.id == chat_id).execute()
  98. )
  99. return shared_chat if (shared_result and result) else None
  100. def update_shared_chat_by_chat_id(self, chat_id: str) -> Optional[ChatModel]:
  101. try:
  102. print("update_shared_chat_by_id")
  103. chat = Chat.get(Chat.id == chat_id)
  104. print(chat)
  105. query = Chat.update(
  106. title=chat.title,
  107. chat=chat.chat,
  108. ).where(Chat.id == chat.share_id)
  109. query.execute()
  110. chat = Chat.get(Chat.id == chat.share_id)
  111. return ChatModel(**model_to_dict(chat))
  112. except:
  113. return None
  114. def delete_shared_chat_by_chat_id(self, chat_id: str) -> bool:
  115. try:
  116. query = Chat.delete().where(Chat.user_id == f"shared-{chat_id}")
  117. query.execute() # Remove the rows, return number of rows removed.
  118. return True
  119. except:
  120. return False
  121. def update_chat_share_id_by_id(
  122. self, id: str, share_id: Optional[str]
  123. ) -> Optional[ChatModel]:
  124. try:
  125. query = Chat.update(
  126. share_id=share_id,
  127. ).where(Chat.id == id)
  128. query.execute()
  129. chat = Chat.get(Chat.id == id)
  130. return ChatModel(**model_to_dict(chat))
  131. except:
  132. return None
  133. def toggle_chat_archive_by_id(self, id: str) -> Optional[ChatModel]:
  134. try:
  135. chat = self.get_chat_by_id(id)
  136. query = Chat.update(
  137. archived=(not chat.archived),
  138. ).where(Chat.id == id)
  139. query.execute()
  140. chat = Chat.get(Chat.id == id)
  141. return ChatModel(**model_to_dict(chat))
  142. except:
  143. return None
  144. def get_chat_lists_by_user_id(
  145. self, user_id: str, skip: int = 0, limit: int = 50
  146. ) -> List[ChatModel]:
  147. return [
  148. ChatModel(**model_to_dict(chat))
  149. for chat in Chat.select()
  150. .where(Chat.archived == False)
  151. .where(Chat.user_id == user_id)
  152. .order_by(Chat.timestamp.desc())
  153. # .limit(limit)
  154. # .offset(skip)
  155. ]
  156. def get_chat_lists_by_chat_ids(
  157. self, chat_ids: List[str], skip: int = 0, limit: int = 50
  158. ) -> List[ChatModel]:
  159. return [
  160. ChatModel(**model_to_dict(chat))
  161. for chat in Chat.select()
  162. .where(Chat.archived == False)
  163. .where(Chat.id.in_(chat_ids))
  164. .order_by(Chat.timestamp.desc())
  165. ]
  166. def get_all_chats(self) -> List[ChatModel]:
  167. return [
  168. ChatModel(**model_to_dict(chat))
  169. for chat in Chat.select()
  170. .where(Chat.archived == False)
  171. .order_by(Chat.timestamp.desc())
  172. ]
  173. def get_all_chats_by_user_id(self, user_id: str) -> List[ChatModel]:
  174. return [
  175. ChatModel(**model_to_dict(chat))
  176. for chat in Chat.select()
  177. .where(Chat.archived == False)
  178. .where(Chat.user_id == user_id)
  179. .order_by(Chat.timestamp.desc())
  180. ]
  181. def get_chat_by_id(self, id: str) -> Optional[ChatModel]:
  182. try:
  183. chat = Chat.get(Chat.id == id)
  184. return ChatModel(**model_to_dict(chat))
  185. except:
  186. return None
  187. def get_chat_by_share_id(self, id: str) -> Optional[ChatModel]:
  188. try:
  189. chat = Chat.get(Chat.share_id == id)
  190. if chat:
  191. chat = Chat.get(Chat.id == id)
  192. return ChatModel(**model_to_dict(chat))
  193. else:
  194. return None
  195. except:
  196. return None
  197. def get_chat_by_id_and_user_id(self, id: str, user_id: str) -> Optional[ChatModel]:
  198. try:
  199. chat = Chat.get(Chat.id == id, Chat.user_id == user_id)
  200. return ChatModel(**model_to_dict(chat))
  201. except:
  202. return None
  203. def get_chats(self, skip: int = 0, limit: int = 50) -> List[ChatModel]:
  204. return [
  205. ChatModel(**model_to_dict(chat))
  206. for chat in Chat.select().limit(limit).offset(skip)
  207. ]
  208. def delete_chat_by_id_and_user_id(self, id: str, user_id: str) -> bool:
  209. try:
  210. query = Chat.delete().where((Chat.id == id) & (Chat.user_id == user_id))
  211. query.execute() # Remove the rows, return number of rows removed.
  212. return True and self.delete_shared_chat_by_chat_id(id)
  213. except:
  214. return False
  215. def delete_chats_by_user_id(self, user_id: str) -> bool:
  216. try:
  217. self.delete_shared_chats_by_user_id(user_id)
  218. query = Chat.delete().where(Chat.user_id == user_id)
  219. query.execute() # Remove the rows, return number of rows removed.
  220. return True
  221. except:
  222. return False
  223. def delete_shared_chats_by_user_id(self, user_id: str) -> bool:
  224. try:
  225. shared_chat_ids = [
  226. f"shared-{chat.id}"
  227. for chat in Chat.select().where(Chat.user_id == user_id)
  228. ]
  229. query = Chat.delete().where(Chat.user_id << shared_chat_ids)
  230. query.execute() # Remove the rows, return number of rows removed.
  231. return True
  232. except:
  233. return False
  234. Chats = ChatTable(DB)