Ver código fonte

feat: archive chat

Timothy J. Baek 1 ano atrás
pai
commit
fbd520bf07

+ 46 - 0
backend/apps/web/internal/migrations/004_add_archived.py

@@ -0,0 +1,46 @@
+"""Peewee migrations -- 002_add_local_sharing.py.
+
+Some examples (model - class or model name)::
+
+    > Model = migrator.orm['table_name']            # Return model in current state by name
+    > Model = migrator.ModelClass                   # Return model in current state by name
+
+    > migrator.sql(sql)                             # Run custom SQL
+    > migrator.run(func, *args, **kwargs)           # Run python function with the given args
+    > migrator.create_model(Model)                  # Create a model (could be used as decorator)
+    > migrator.remove_model(model, cascade=True)    # Remove a model
+    > migrator.add_fields(model, **fields)          # Add fields to a model
+    > migrator.change_fields(model, **fields)       # Change fields
+    > migrator.remove_fields(model, *field_names, cascade=True)
+    > migrator.rename_field(model, old_field_name, new_field_name)
+    > migrator.rename_table(model, new_table_name)
+    > migrator.add_index(model, *col_names, unique=False)
+    > migrator.add_not_null(model, *field_names)
+    > migrator.add_default(model, field_name, default)
+    > migrator.add_constraint(model, name, sql)
+    > migrator.drop_index(model, *col_names)
+    > migrator.drop_not_null(model, *field_names)
+    > migrator.drop_constraints(model, *constraints)
+
+"""
+
+from contextlib import suppress
+
+import peewee as pw
+from peewee_migrate import Migrator
+
+
+with suppress(ImportError):
+    import playhouse.postgres_ext as pw_pext
+
+
+def migrate(migrator: Migrator, database: pw.Database, *, fake=False):
+    """Write your migrations here."""
+
+    migrator.add_fields("chat", archived=pw.BooleanField(default=False))
+
+
+def rollback(migrator: Migrator, database: pw.Database, *, fake=False):
+    """Write your rollback migrations here."""
+
+    migrator.remove_fields("chat", "archived")

+ 22 - 1
backend/apps/web/models/chats.py

@@ -21,6 +21,7 @@ class Chat(Model):
     chat = TextField()  # Save Chat JSON as Text
     timestamp = DateField()
     share_id = CharField(null=True, unique=True)
+    archived = BooleanField(default=False)
 
     class Meta:
         database = DB
@@ -33,6 +34,7 @@ class ChatModel(BaseModel):
     chat: str
     timestamp: int  # timestamp in epoch
     share_id: Optional[str] = None
+    archived: bool = False
 
 
 ####################
@@ -163,12 +165,27 @@ class ChatTable:
         except:
             return None
 
+    def toggle_chat_archive_by_id(self, id: str) -> Optional[ChatModel]:
+        try:
+            chat = self.get_chat_by_id(id)
+            query = Chat.update(
+                archived=(not chat.archived),
+            ).where(Chat.id == id)
+
+            query.execute()
+
+            chat = Chat.get(Chat.id == id)
+            return ChatModel(**model_to_dict(chat))
+        except:
+            return None
+
     def get_chat_lists_by_user_id(
         self, user_id: str, skip: int = 0, limit: int = 50
     ) -> List[ChatModel]:
         return [
             ChatModel(**model_to_dict(chat))
             for chat in Chat.select()
+            .where(Chat.archived == False)
             .where(Chat.user_id == user_id)
             .order_by(Chat.timestamp.desc())
             # .limit(limit)
@@ -181,6 +198,7 @@ class ChatTable:
         return [
             ChatModel(**model_to_dict(chat))
             for chat in Chat.select()
+            .where(Chat.archived == False)
             .where(Chat.id.in_(chat_ids))
             .order_by(Chat.timestamp.desc())
         ]
@@ -188,13 +206,16 @@ class ChatTable:
     def get_all_chats(self) -> List[ChatModel]:
         return [
             ChatModel(**model_to_dict(chat))
-            for chat in Chat.select().order_by(Chat.timestamp.desc())
+            for chat in Chat.select()
+            .where(Chat.archived == False)
+            .order_by(Chat.timestamp.desc())
         ]
 
     def get_all_chats_by_user_id(self, user_id: str) -> List[ChatModel]:
         return [
             ChatModel(**model_to_dict(chat))
             for chat in Chat.select()
+            .where(Chat.archived == False)
             .where(Chat.user_id == user_id)
             .order_by(Chat.timestamp.desc())
         ]

+ 17 - 0
backend/apps/web/routers/chats.py

@@ -189,6 +189,23 @@ async def delete_chat_by_id(request: Request, id: str, user=Depends(get_current_
     return result
 
 
+############################
+# ArchiveChat
+############################
+
+
+@router.get("/{id}/archive", response_model=Optional[ChatResponse])
+async def archive_chat_by_id(id: str, user=Depends(get_current_user)):
+    chat = Chats.get_chat_by_id_and_user_id(id, user.id)
+    if chat:
+        chat = Chats.toggle_chat_archive_by_id(id)
+        return ChatResponse(**{**chat.model_dump(), "chat": json.loads(chat.chat)})
+    else:
+        raise HTTPException(
+            status_code=status.HTTP_401_UNAUTHORIZED, detail=ERROR_MESSAGES.DEFAULT()
+        )
+
+
 ############################
 # ShareChatById
 ############################

+ 32 - 0
src/lib/apis/chats/index.ts

@@ -282,6 +282,38 @@ export const shareChatById = async (token: string, id: string) => {
 	return res;
 };
 
+export const archiveChatById = async (token: string, id: string) => {
+	let error = null;
+
+	const res = await fetch(`${WEBUI_API_BASE_URL}/chats/${id}/archive`, {
+		method: 'GET',
+		headers: {
+			Accept: 'application/json',
+			'Content-Type': 'application/json',
+			...(token && { authorization: `Bearer ${token}` })
+		}
+	})
+		.then(async (res) => {
+			if (!res.ok) throw await res.json();
+			return res.json();
+		})
+		.then((json) => {
+			return json;
+		})
+		.catch((err) => {
+			error = err;
+
+			console.log(err);
+			return null;
+		});
+
+	if (error) {
+		throw error;
+	}
+
+	return res;
+};
+
 export const deleteSharedChatById = async (token: string, id: string) => {
 	let error = null;
 

+ 8 - 2
src/lib/components/layout/Sidebar.svelte

@@ -17,7 +17,8 @@
 		getChatById,
 		getChatListByTagName,
 		updateChatById,
-		getAllChatTags
+		getAllChatTags,
+		archiveChatById
 	} from '$lib/apis/chats';
 	import { toast } from 'svelte-sonner';
 	import { fade, slide } from 'svelte/transition';
@@ -139,6 +140,11 @@
 		localStorage.setItem('settings', JSON.stringify($settings));
 		location.href = '/';
 	};
+
+	const archiveChatHandler = async (id) => {
+		await archiveChatById(localStorage.token, id);
+		await chats.set(await getChatList(localStorage.token));
+	};
 </script>
 
 <ShareChatModal bind:show={showShareChatModal} chatId={shareChatId} />
@@ -594,7 +600,7 @@
 											aria-label="Archive"
 											class=" self-center dark:hover:text-white transition"
 											on:click={() => {
-												selectedChatId = chat.id;
+												archiveChatHandler(chat.id);
 											}}
 										>
 											<ArchiveBox />