Browse Source

feat: access archived chats as admin

Timothy J. Baek 11 months ago
parent
commit
e20bb23409

+ 38 - 10
backend/apps/webui/models/chats.py

@@ -191,6 +191,20 @@ class ChatTable:
         except:
             return None
 
+    def archive_all_chats_by_user_id(self, user_id: str) -> bool:
+        try:
+            chats = self.get_chats_by_user_id(user_id)
+            for chat in chats:
+                query = Chat.update(
+                    archived=True,
+                ).where(Chat.id == chat.id)
+
+                query.execute()
+
+            return True
+        except:
+            return False
+
     def get_archived_chat_list_by_user_id(
         self, user_id: str, skip: int = 0, limit: int = 50
     ) -> List[ChatModel]:
@@ -205,17 +219,31 @@ class ChatTable:
         ]
 
     def get_chat_list_by_user_id(
-        self, user_id: str, skip: int = 0, limit: int = 50
+        self,
+        user_id: str,
+        include_archived: bool = False,
+        skip: int = 0,
+        limit: int = 50,
     ) -> List[ChatModel]:
-        return [
-            ChatModel(**model_to_dict(chat))
-            for chat in Chat.select()
-            .where(Chat.archived == False)
-            .where(Chat.user_id == user_id)
-            .order_by(Chat.updated_at.desc())
-            # .limit(limit)
-            # .offset(skip)
-        ]
+        if include_archived:
+            return [
+                ChatModel(**model_to_dict(chat))
+                for chat in Chat.select()
+                .where(Chat.user_id == user_id)
+                .order_by(Chat.updated_at.desc())
+                # .limit(limit)
+                # .offset(skip)
+            ]
+        else:
+            return [
+                ChatModel(**model_to_dict(chat))
+                for chat in Chat.select()
+                .where(Chat.archived == False)
+                .where(Chat.user_id == user_id)
+                .order_by(Chat.updated_at.desc())
+                # .limit(limit)
+                # .offset(skip)
+            ]
 
     def get_chat_list_by_chat_ids(
         self, chat_ids: List[str], skip: int = 0, limit: int = 50

+ 49 - 37
backend/apps/webui/routers/chats.py

@@ -78,43 +78,25 @@ async def delete_all_user_chats(request: Request, user=Depends(get_current_user)
 async def get_user_chat_list_by_user_id(
     user_id: str, user=Depends(get_admin_user), skip: int = 0, limit: int = 50
 ):
-    return Chats.get_chat_list_by_user_id(user_id, skip, limit)
-
-
-############################
-# GetArchivedChats
-############################
-
-
-@router.get("/archived", response_model=List[ChatTitleIdResponse])
-async def get_archived_session_user_chat_list(
-    user=Depends(get_current_user), skip: int = 0, limit: int = 50
-):
-    return Chats.get_archived_chat_list_by_user_id(user.id, skip, limit)
+    return Chats.get_chat_list_by_user_id(
+        user_id, include_archived=True, skip=skip, limit=limit
+    )
 
 
 ############################
-# GetSharedChatById
+# CreateNewChat
 ############################
 
 
-@router.get("/share/{share_id}", response_model=Optional[ChatResponse])
-async def get_shared_chat_by_id(share_id: str, user=Depends(get_current_user)):
-    if user.role == "pending":
-        raise HTTPException(
-            status_code=status.HTTP_401_UNAUTHORIZED, detail=ERROR_MESSAGES.NOT_FOUND
-        )
-
-    if user.role == "user":
-        chat = Chats.get_chat_by_share_id(share_id)
-    elif user.role == "admin":
-        chat = Chats.get_chat_by_id(share_id)
-
-    if chat:
+@router.post("/new", response_model=Optional[ChatResponse])
+async def create_new_chat(form_data: ChatForm, user=Depends(get_current_user)):
+    try:
+        chat = Chats.insert_new_chat(user.id, form_data)
         return ChatResponse(**{**chat.model_dump(), "chat": json.loads(chat.chat)})
-    else:
+    except Exception as e:
+        log.exception(e)
         raise HTTPException(
-            status_code=status.HTTP_401_UNAUTHORIZED, detail=ERROR_MESSAGES.NOT_FOUND
+            status_code=status.HTTP_400_BAD_REQUEST, detail=ERROR_MESSAGES.DEFAULT()
         )
 
 
@@ -150,19 +132,49 @@ async def get_all_user_chats_in_db(user=Depends(get_admin_user)):
 
 
 ############################
-# CreateNewChat
+# GetArchivedChats
 ############################
 
 
-@router.post("/new", response_model=Optional[ChatResponse])
-async def create_new_chat(form_data: ChatForm, user=Depends(get_current_user)):
-    try:
-        chat = Chats.insert_new_chat(user.id, form_data)
+@router.get("/archived", response_model=List[ChatTitleIdResponse])
+async def get_archived_session_user_chat_list(
+    user=Depends(get_current_user), skip: int = 0, limit: int = 50
+):
+    return Chats.get_archived_chat_list_by_user_id(user.id, skip, limit)
+
+
+############################
+# ArchiveAllChats
+############################
+
+
+@router.get("/archive/all", response_model=List[ChatTitleIdResponse])
+async def archive_all_chats(user=Depends(get_current_user)):
+    return Chats.archive_all_chats_by_user_id(user.id)
+
+
+############################
+# GetSharedChatById
+############################
+
+
+@router.get("/share/{share_id}", response_model=Optional[ChatResponse])
+async def get_shared_chat_by_id(share_id: str, user=Depends(get_current_user)):
+    if user.role == "pending":
+        raise HTTPException(
+            status_code=status.HTTP_401_UNAUTHORIZED, detail=ERROR_MESSAGES.NOT_FOUND
+        )
+
+    if user.role == "user":
+        chat = Chats.get_chat_by_share_id(share_id)
+    elif user.role == "admin":
+        chat = Chats.get_chat_by_id(share_id)
+
+    if chat:
         return ChatResponse(**{**chat.model_dump(), "chat": json.loads(chat.chat)})
-    except Exception as e:
-        log.exception(e)
+    else:
         raise HTTPException(
-            status_code=status.HTTP_400_BAD_REQUEST, detail=ERROR_MESSAGES.DEFAULT()
+            status_code=status.HTTP_401_UNAUTHORIZED, detail=ERROR_MESSAGES.NOT_FOUND
         )
 
 

+ 1 - 1
src/lib/components/chat/Messages/ResponseMessage.svelte

@@ -82,7 +82,7 @@
 
 	// Open all links in a new tab/window (from https://github.com/markedjs/marked/issues/655#issuecomment-383226346)
 	const origLinkRenderer = renderer.link;
-	renderer.link =	(href, title, text) => {
+	renderer.link = (href, title, text) => {
 		const html = origLinkRenderer.call(renderer, href, title, text);
 		return html.replace(/^<a /, '<a target="_blank" rel="nofollow" ');
 	};

+ 1 - 3
src/lib/utils/index.ts

@@ -17,9 +17,7 @@ export const sanitizeResponseContent = (content: string) => {
 };
 
 export const revertSanitizedResponseContent = (content: string) => {
-	return content
-		.replaceAll('&lt;', '<')
-		.replaceAll('&gt;', '>');
+	return content.replaceAll('&lt;', '<').replaceAll('&gt;', '>');
 };
 
 export const capitalizeFirstLetter = (string) => {