Browse Source

feat: add backend functions for sharing chats

Jun Siang Cheah 1 year ago
parent
commit
94976e5ed3
2 changed files with 112 additions and 0 deletions
  1. 41 0
      backend/apps/web/models/chats.py
  2. 71 0
      backend/apps/web/routers/chats.py

+ 41 - 0
backend/apps/web/models/chats.py

@@ -20,6 +20,7 @@ class Chat(Model):
     title = CharField()
     chat = TextField()  # Save Chat JSON as Text
     timestamp = DateField()
+    share_id = CharField(null=True, unique=True)
 
     class Meta:
         database = DB
@@ -31,6 +32,7 @@ class ChatModel(BaseModel):
     title: str
     chat: str
     timestamp: int  # timestamp in epoch
+    share_id: Optional[str] = None
 
 
 ####################
@@ -52,6 +54,7 @@ class ChatResponse(BaseModel):
     title: str
     chat: dict
     timestamp: int  # timestamp in epoch
+    share_id: Optional[str] = None  # id of the chat to be shared
 
 
 class ChatTitleIdResponse(BaseModel):
@@ -95,6 +98,44 @@ class ChatTable:
         except:
             return None
 
+    def insert_shared_chat(self, chat_id: str) -> Optional[ChatModel]:
+        # Get the existing chat to share
+        chat = Chat.get(Chat.id == chat_id)
+        # Check if the chat is already shared
+        if chat.share_id:
+            return self.get_chat_by_id_and_user_id(chat.share_id, "shared")
+        # Create a new chat with the same data, but with a new ID
+        shared_chat = ChatModel(
+            **{
+                "id": str(uuid.uuid4()),
+                "user_id": "shared",
+                "title": chat.title,
+                "chat": chat.chat,
+                "timestamp": int(time.time()),
+            }
+        )
+        shared_result = Chat.create(**shared_chat.model_dump())
+        # Update the original chat with the share_id
+        result = (
+            Chat.update(share_id=shared_chat.id).where(Chat.id == chat_id).execute()
+        )
+
+        return shared_chat if (shared_result and result) else None
+
+    def update_chat_share_id_by_id(
+        self, od: str, share_id: Optional[str]
+    ) -> Optional[ChatModel]:
+        try:
+            query = Chat.update(
+                share_id=share_id,
+            ).where(Chat.id == id)
+            query.execute()
+
+            chat = Chat.get(Chat.id == id)
+            return ChatModel(**model_to_dict(chat))
+        except:
+            return None
+
     def get_chat_lists_by_user_id(
         self, user_id: str, skip: int = 0, limit: int = 50
     ) -> List[ChatModel]:

+ 71 - 0
backend/apps/web/routers/chats.py

@@ -189,6 +189,77 @@ async def delete_chat_by_id(request: Request, id: str, user=Depends(get_current_
     return result
 
 
+############################
+# ShareChatById
+############################
+
+
+@router.post("/{id}/share", response_model=Optional[ChatResponse])
+async def share_chat_by_id(id: str, user=Depends(get_current_user)):
+    chat = Chats.get_chat_by_id_and_user_id(id, user.id)
+    if chat:
+        if chat.share_id:
+            shared_chat = Chats.get_chat_by_id_and_user_id(chat.share_id, "shared")
+            return ChatResponse(
+                **{**shared_chat.model_dump(), "chat": json.loads(shared_chat.chat)}
+            )
+
+        shared_chat = Chats.insert_shared_chat(chat.id)
+        if not shared_chat:
+            raise HTTPException(
+                status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
+                detail=ERROR_MESSAGES.DEFAULT(),
+            )
+
+        return ChatResponse(
+            **{**shared_chat.model_dump(), "chat": json.loads(shared_chat.chat)}
+        )
+    else:
+        raise HTTPException(
+            status_code=status.HTTP_401_UNAUTHORIZED,
+            detail=ERROR_MESSAGES.ACCESS_PROHIBITED,
+        )
+
+
+############################
+# DeletedSharedChatById
+############################
+
+
+@router.delete("/{id}/share", response_model=Optional[bool])
+async def delete_shared_chat_by_id(id: str, user=Depends(get_current_user)):
+    chat = Chats.get_chat_by_id_and_user_id(id, user.id)
+    if chat:
+        if not chat.share_id:
+            return False
+        result = Chats.delete_chat_by_id_and_user_id(chat.share_id, "shared")
+        update_result = Chats.update_chat_share_id_by_id(chat.id, None)
+
+        return result and update_result
+    else:
+        raise HTTPException(
+            status_code=status.HTTP_401_UNAUTHORIZED,
+            detail=ERROR_MESSAGES.ACCESS_PROHIBITED,
+        )
+
+
+############################
+# GetSharedChatById
+############################
+
+
+@router.get("/share/{id}", response_model=Optional[ChatResponse])
+async def get_shared_chat_by_id(id: str, user=Depends(get_current_user)):
+    chat = Chats.get_chat_by_id_and_user_id(id, "shared")
+
+    if chat:
+        return ChatResponse(**{**chat.model_dump(), "chat": json.loads(chat.chat)})
+    else:
+        raise HTTPException(
+            status_code=status.HTTP_401_UNAUTHORIZED, detail=ERROR_MESSAGES.NOT_FOUND
+        )
+
+
 ############################
 # GetChatTagsById
 ############################