Quellcode durchsuchen

feat: convo tagging full integration

Timothy J. Baek vor 1 Jahr
Ursprung
Commit
987685dbf9

+ 35 - 20
backend/apps/web/models/tags.py

@@ -15,7 +15,8 @@ from apps.web.internal.db import DB
 
 
 class Tag(Model):
-    name = CharField(unique=True)
+    id = CharField(unique=True)
+    name = CharField()
     user_id = CharField()
     data = TextField(null=True)
 
@@ -24,7 +25,8 @@ class Tag(Model):
 
 
 class ChatIdTag(Model):
-    tag_name = ForeignKeyField(Tag, backref="chat_id_tags")
+    id = CharField(unique=True)
+    tag_name = CharField()
     chat_id = CharField()
     user_id = CharField()
     timestamp = DateField()
@@ -34,12 +36,14 @@ class ChatIdTag(Model):
 
 
 class TagModel(BaseModel):
+    id: str
     name: str
     user_id: str
     data: Optional[str] = None
 
 
 class ChatIdTagModel(BaseModel):
+    id: str
     tag_name: str
     chat_id: str
     user_id: str
@@ -70,14 +74,15 @@ class TagTable:
         db.create_tables([Tag, ChatIdTag])
 
     def insert_new_tag(self, name: str, user_id: str) -> Optional[TagModel]:
-        tag = TagModel(**{"user_id": user_id, "name": name})
+        id = str(uuid.uuid4())
+        tag = TagModel(**{"id": id, "user_id": user_id, "name": name})
         try:
             result = Tag.create(**tag.model_dump())
             if result:
                 return tag
             else:
                 return None
-        except:
+        except Exception as e:
             return None
 
     def get_tag_by_name_and_user_id(
@@ -86,17 +91,27 @@ class TagTable:
         try:
             tag = Tag.get(Tag.name == name, Tag.user_id == user_id)
             return TagModel(**model_to_dict(tag))
-        except:
+        except Exception as e:
             return None
 
     def add_tag_to_chat(
         self, user_id: str, form_data: ChatIdTagForm
-    ) -> Optional[ChatTagsResponse]:
+    ) -> Optional[ChatIdTagModel]:
         tag = self.get_tag_by_name_and_user_id(form_data.tag_name, user_id)
         if tag == None:
             tag = self.insert_new_tag(form_data.tag_name, user_id)
 
-        chatIdTag = ChatIdTagModel(**{"user_id": user_id, "tag_name": tag.name})
+        print(tag)
+        id = str(uuid.uuid4())
+        chatIdTag = ChatIdTagModel(
+            **{
+                "id": id,
+                "user_id": user_id,
+                "chat_id": form_data.chat_id,
+                "tag_name": tag.name,
+                "timestamp": int(time.time()),
+            }
+        )
         try:
             result = ChatIdTag.create(**chatIdTag.model_dump())
             if result:
@@ -109,19 +124,17 @@ class TagTable:
     def get_tags_by_chat_id_and_user_id(
         self, chat_id: str, user_id: str
     ) -> List[TagModel]:
+        tag_names = [
+            ChatIdTagModel(**model_to_dict(chat_id_tag)).tag_name
+            for chat_id_tag in ChatIdTag.select()
+            .where((ChatIdTag.user_id == user_id) & (ChatIdTag.chat_id == chat_id))
+            .order_by(ChatIdTag.timestamp.desc())
+        ]
+
+        print(tag_names)
         return [
             TagModel(**model_to_dict(tag))
-            for tag in Tag.select().where(
-                Tag.name
-                in [
-                    ChatIdTagModel(**model_to_dict(chat_id_tag)).tag_name
-                    for chat_id_tag in ChatIdTag.select()
-                    .where(
-                        (ChatIdTag.user_id == user_id) & (ChatIdTag.chat_id == chat_id)
-                    )
-                    .order_by(ChatIdTag.timestamp.desc())
-                ]
-            )
+            for tag in Tag.select().where(Tag.name.in_(tag_names))
         ]
 
     def get_chat_ids_by_tag_name_and_user_id(
@@ -152,7 +165,8 @@ class TagTable:
                 & (ChatIdTag.chat_id == chat_id)
                 & (ChatIdTag.user_id == user_id)
             )
-            query.execute()  # Remove the rows, return number of rows removed.
+            res = query.execute()  # Remove the rows, return number of rows removed.
+            print(res)
 
             tag_count = self.count_chat_ids_by_tag_name_and_user_id(tag_name, user_id)
             if tag_count == 0:
@@ -163,7 +177,8 @@ class TagTable:
                 query.execute()  # Remove the rows, return number of rows removed.
 
             return True
-        except:
+        except Exception as e:
+            print("delete_tag", e)
             return False
 
     def delete_tags_by_chat_id_and_user_id(self, chat_id: str, user_id: str) -> bool:

+ 16 - 6
backend/apps/web/routers/chats.py

@@ -19,6 +19,7 @@ from apps.web.models.chats import (
 
 from apps.web.models.tags import (
     TagModel,
+    ChatIdTagModel,
     ChatIdTagForm,
     ChatTagsResponse,
     Tags,
@@ -132,7 +133,8 @@ async def delete_chat_by_id(id: str, user=Depends(get_current_user)):
 async def get_chat_tags_by_id(id: str, user=Depends(get_current_user)):
     tags = Tags.get_tags_by_chat_id_and_user_id(id, user.id)
 
-    if tags:
+    if tags != None:
+        print(tags)
         return tags
     else:
         raise HTTPException(
@@ -145,17 +147,25 @@ async def get_chat_tags_by_id(id: str, user=Depends(get_current_user)):
 ############################
 
 
-@router.post("/{id}/tags", response_model=Optional[ChatTagsResponse])
+@router.post("/{id}/tags", response_model=Optional[ChatIdTagModel])
 async def add_chat_tag_by_id(
     id: str, form_data: ChatIdTagForm, user=Depends(get_current_user)
 ):
-    tag = Tags.add_tag_to_chat(user.id, {"tag_name": form_data.tag_name, "chat_id": id})
+    tags = Tags.get_tags_by_chat_id_and_user_id(id, user.id)
+
+    if form_data.tag_name not in tags:
+        tag = Tags.add_tag_to_chat(user.id, form_data)
 
-    if tag:
-        return tag
+        if tag:
+            return tag
+        else:
+            raise HTTPException(
+                status_code=status.HTTP_401_UNAUTHORIZED,
+                detail=ERROR_MESSAGES.NOT_FOUND,
+            )
     else:
         raise HTTPException(
-            status_code=status.HTTP_401_UNAUTHORIZED, detail=ERROR_MESSAGES.NOT_FOUND
+            status_code=status.HTTP_401_UNAUTHORIZED, detail=ERROR_MESSAGES.DEFAULT()
         )
 
 

+ 77 - 89
src/lib/components/layout/Navbar.svelte

@@ -12,22 +12,11 @@
 	export let title: string = 'Ollama Web UI';
 	export let shareEnabled: boolean = false;
 
-	let showShareChatModal = false;
+	export let tags = [];
+	export let addTag: Function;
+	export let deleteTag: Function;
 
-	let tags = [
-		// {
-		// 	name: 'general'
-		// },
-		// {
-		// 	name: 'medicine'
-		// },
-		// {
-		// 	name: 'cooking'
-		// },
-		// {
-		// 	name: 'education'
-		// }
-	];
+	let showShareChatModal = false;
 
 	let tagName = '';
 	let showTagInput = false;
@@ -74,16 +63,17 @@
 		saveAs(blob, `chat-${chat.title}.txt`);
 	};
 
-	const addTag = () => {
-		if (!tags.find((e) => e.name === tagName)) {
-			tags = [
-				...tags,
-				{
-					name: JSON.parse(JSON.stringify(tagName))
-				}
-			];
-		}
+	const addTagHandler = () => {
+		// if (!tags.find((e) => e.name === tagName)) {
+		// 	tags = [
+		// 		...tags,
+		// 		{
+		// 			name: JSON.parse(JSON.stringify(tagName))
+		// 		}
+		// 	];
+		// }
 
+		addTag(tagName);
 		tagName = '';
 		showTagInput = false;
 	};
@@ -126,48 +116,19 @@
 			</div>
 
 			<div class="pl-2 self-center flex items-center space-x-2">
-				<div class="flex flex-row space-x-0.5 line-clamp-1">
-					{#each tags as tag}
-						<div
-							class="px-2 py-0.5 space-x-1 flex h-fit items-center rounded-full transition border dark:border-gray-600 dark:text-white"
-						>
-							<div class=" text-[0.65rem] font-medium self-center line-clamp-1">
-								{tag.name}
-							</div>
-							<button
-								class=" m-auto self-center cursor-pointer"
-								on:click={() => {
-									console.log(tag.name);
-
-									tags = tags.filter((t) => t.name !== tag.name);
-								}}
+				{#if shareEnabled}
+					<div class="flex flex-row space-x-0.5 line-clamp-1">
+						{#each tags as tag}
+							<div
+								class="px-2 py-0.5 space-x-1 flex h-fit items-center rounded-full transition border dark:border-gray-600 dark:text-white"
 							>
-								<svg
-									xmlns="http://www.w3.org/2000/svg"
-									viewBox="0 0 16 16"
-									fill="currentColor"
-									class="w-3 h-3"
-								>
-									<path
-										d="M5.28 4.22a.75.75 0 0 0-1.06 1.06L6.94 8l-2.72 2.72a.75.75 0 1 0 1.06 1.06L8 9.06l2.72 2.72a.75.75 0 1 0 1.06-1.06L9.06 8l2.72-2.72a.75.75 0 0 0-1.06-1.06L8 6.94 5.28 4.22Z"
-									/>
-								</svg>
-							</button>
-						</div>
-					{/each}
-
-					<div class="flex space-x-1 pl-1.5">
-						{#if showTagInput}
-							<div class="flex items-center">
-								<input
-									bind:value={tagName}
-									class=" cursor-pointer self-center text-xs h-fit bg-transparent outline-none line-clamp-1 w-[4rem]"
-									placeholder="Add a tag"
-								/>
-
+								<div class=" text-[0.65rem] font-medium self-center line-clamp-1">
+									{tag.name}
+								</div>
 								<button
+									class=" m-auto self-center cursor-pointer"
 									on:click={() => {
-										addTag();
+										deleteTag(tag.name);
 									}}
 								>
 									<svg
@@ -177,40 +138,67 @@
 										class="w-3 h-3"
 									>
 										<path
-											fill-rule="evenodd"
-											d="M12.416 3.376a.75.75 0 0 1 .208 1.04l-5 7.5a.75.75 0 0 1-1.154.114l-3-3a.75.75 0 0 1 1.06-1.06l2.353 2.353 4.493-6.74a.75.75 0 0 1 1.04-.207Z"
-											clip-rule="evenodd"
+											d="M5.28 4.22a.75.75 0 0 0-1.06 1.06L6.94 8l-2.72 2.72a.75.75 0 1 0 1.06 1.06L8 9.06l2.72 2.72a.75.75 0 1 0 1.06-1.06L9.06 8l2.72-2.72a.75.75 0 0 0-1.06-1.06L8 6.94 5.28 4.22Z"
 										/>
 									</svg>
 								</button>
 							</div>
+						{/each}
+
+						<div class="flex space-x-1 pl-1.5">
+							{#if showTagInput}
+								<div class="flex items-center">
+									<input
+										bind:value={tagName}
+										class=" cursor-pointer self-center text-xs h-fit bg-transparent outline-none line-clamp-1 w-[4rem]"
+										placeholder="Add a tag"
+									/>
 
-							<!-- TODO: Tag Suggestions -->
-						{/if}
+									<button
+										on:click={() => {
+											addTagHandler();
+										}}
+									>
+										<svg
+											xmlns="http://www.w3.org/2000/svg"
+											viewBox="0 0 16 16"
+											fill="currentColor"
+											class="w-3 h-3"
+										>
+											<path
+												fill-rule="evenodd"
+												d="M12.416 3.376a.75.75 0 0 1 .208 1.04l-5 7.5a.75.75 0 0 1-1.154.114l-3-3a.75.75 0 0 1 1.06-1.06l2.353 2.353 4.493-6.74a.75.75 0 0 1 1.04-.207Z"
+												clip-rule="evenodd"
+											/>
+										</svg>
+									</button>
+								</div>
+
+								<!-- TODO: Tag Suggestions -->
+							{/if}
 
-						<button
-							class=" cursor-pointer self-center p-0.5 space-x-1 flex h-fit items-center dark:hover:bg-gray-700 rounded-full transition border dark:border-gray-600 border-dashed"
-							on:click={() => {
-								showTagInput = !showTagInput;
-							}}
-						>
-							<div class=" m-auto self-center">
-								<svg
-									xmlns="http://www.w3.org/2000/svg"
-									viewBox="0 0 16 16"
-									fill="currentColor"
-									class="w-3 h-3 {showTagInput ? 'rotate-45' : ''} transition-all transform"
-								>
-									<path
-										d="M8.75 3.75a.75.75 0 0 0-1.5 0v3.5h-3.5a.75.75 0 0 0 0 1.5h3.5v3.5a.75.75 0 0 0 1.5 0v-3.5h3.5a.75.75 0 0 0 0-1.5h-3.5v-3.5Z"
-									/>
-								</svg>
-							</div>
-						</button>
+							<button
+								class=" cursor-pointer self-center p-0.5 space-x-1 flex h-fit items-center dark:hover:bg-gray-700 rounded-full transition border dark:border-gray-600 border-dashed"
+								on:click={() => {
+									showTagInput = !showTagInput;
+								}}
+							>
+								<div class=" m-auto self-center">
+									<svg
+										xmlns="http://www.w3.org/2000/svg"
+										viewBox="0 0 16 16"
+										fill="currentColor"
+										class="w-3 h-3 {showTagInput ? 'rotate-45' : ''} transition-all transform"
+									>
+										<path
+											d="M8.75 3.75a.75.75 0 0 0-1.5 0v3.5h-3.5a.75.75 0 0 0 0 1.5h3.5v3.5a.75.75 0 0 0 1.5 0v-3.5h3.5a.75.75 0 0 0 0-1.5h-3.5v-3.5Z"
+										/>
+									</svg>
+								</div>
+							</button>
+						</div>
 					</div>
-				</div>
 
-				{#if shareEnabled}
 					<button
 						class=" cursor-pointer p-1.5 flex dark:hover:bg-gray-700 rounded-lg transition border dark:border-gray-600"
 						on:click={async () => {

+ 26 - 2
src/routes/(app)/+page.svelte

@@ -10,7 +10,14 @@
 	import { copyToClipboard, splitStream } from '$lib/utils';
 
 	import { generateChatCompletion, cancelChatCompletion, generateTitle } from '$lib/apis/ollama';
-	import { createNewChat, getChatList, updateChatById } from '$lib/apis/chats';
+	import {
+		addTagById,
+		createNewChat,
+		deleteTagById,
+		getChatList,
+		getTagsById,
+		updateChatById
+	} from '$lib/apis/chats';
 	import { queryVectorDB } from '$lib/apis/rag';
 	import { generateOpenAIChatCompletion } from '$lib/apis/openai';
 
@@ -47,6 +54,7 @@
 	}, {});
 
 	let chat = null;
+	let tags = [];
 
 	let title = '';
 	let prompt = '';
@@ -673,6 +681,22 @@
 		}
 	};
 
+	const getTags = async () => {
+		return await getTagsById(localStorage.token, $chatId).catch(async (error) => {
+			return [];
+		});
+	};
+
+	const addTag = async (tagName) => {
+		const res = await addTagById(localStorage.token, $chatId, tagName);
+		tags = await getTags();
+	};
+
+	const deleteTag = async (tagName) => {
+		const res = await deleteTagById(localStorage.token, $chatId, tagName);
+		tags = await getTags();
+	};
+
 	const setChatTitle = async (_chatId, _title) => {
 		if (_chatId === $chatId) {
 			title = _title;
@@ -691,7 +715,7 @@
 	}}
 />
 
-<Navbar {title} shareEnabled={messages.length > 0} {initNewChat} />
+<Navbar {title} shareEnabled={messages.length > 0} {initNewChat} {tags} {addTag} {deleteTag} />
 <div class="min-h-screen w-full flex justify-center">
 	<div class=" py-2.5 flex flex-col justify-between w-full">
 		<div class="max-w-2xl mx-auto w-full px-3 md:px-0 mt-10">

+ 30 - 1
src/routes/(app)/c/[id]/+page.svelte

@@ -10,7 +10,15 @@
 	import { copyToClipboard, splitStream, convertMessagesToHistory } from '$lib/utils';
 
 	import { generateChatCompletion, generateTitle } from '$lib/apis/ollama';
-	import { createNewChat, getChatById, getChatList, updateChatById } from '$lib/apis/chats';
+	import {
+		addTagById,
+		createNewChat,
+		deleteTagById,
+		getChatById,
+		getChatList,
+		getTagsById,
+		updateChatById
+	} from '$lib/apis/chats';
 	import { queryVectorDB } from '$lib/apis/rag';
 	import { generateOpenAIChatCompletion } from '$lib/apis/openai';
 
@@ -49,6 +57,7 @@
 	}, {});
 
 	let chat = null;
+	let tags = [];
 
 	let title = '';
 	let prompt = '';
@@ -97,6 +106,7 @@
 		});
 
 		if (chat) {
+			tags = await getTags();
 			const chatContent = chat.chat;
 
 			if (chatContent) {
@@ -688,6 +698,22 @@
 		await chats.set(await getChatList(localStorage.token));
 	};
 
+	const getTags = async () => {
+		return await getTagsById(localStorage.token, $chatId).catch(async (error) => {
+			return [];
+		});
+	};
+
+	const addTag = async (tagName) => {
+		const res = await addTagById(localStorage.token, $chatId, tagName);
+		tags = await getTags();
+	};
+
+	const deleteTag = async (tagName) => {
+		const res = await deleteTagById(localStorage.token, $chatId, tagName);
+		tags = await getTags();
+	};
+
 	onMount(async () => {
 		if (!($settings.saveChatHistory ?? true)) {
 			await goto('/');
@@ -713,6 +739,9 @@
 
 			goto('/');
 		}}
+		{tags}
+		{addTag}
+		{deleteTag}
 	/>
 	<div class="min-h-screen w-full flex justify-center">
 		<div class=" py-2.5 flex flex-col justify-between w-full">