ソースを参照

feat: doc tagging

Timothy J. Baek 1 年間 前
コミット
00803c92f2

+ 46 - 1
backend/apps/rag/main.py

@@ -128,6 +128,51 @@ class QueryCollectionsForm(BaseModel):
     k: Optional[int] = 4
 
 
+def merge_and_sort_query_results(query_results, k):
+    # Initialize lists to store combined data
+    combined_ids = []
+    combined_distances = []
+    combined_metadatas = []
+    combined_documents = []
+
+    # Combine data from each dictionary
+    for data in query_results:
+        combined_ids.extend(data["ids"][0])
+        combined_distances.extend(data["distances"][0])
+        combined_metadatas.extend(data["metadatas"][0])
+        combined_documents.extend(data["documents"][0])
+
+    # Create a list of tuples (distance, id, metadata, document)
+    combined = list(
+        zip(combined_distances, combined_ids, combined_metadatas, combined_documents)
+    )
+
+    # Sort the list based on distances
+    combined.sort(key=lambda x: x[0])
+
+    # Unzip the sorted list
+    sorted_distances, sorted_ids, sorted_metadatas, sorted_documents = zip(*combined)
+
+    # Slicing the lists to include only k elements
+    sorted_distances = list(sorted_distances)[:k]
+    sorted_ids = list(sorted_ids)[:k]
+    sorted_metadatas = list(sorted_metadatas)[:k]
+    sorted_documents = list(sorted_documents)[:k]
+
+    # Create the output dictionary
+    merged_query_results = {
+        "ids": [sorted_ids],
+        "distances": [sorted_distances],
+        "metadatas": [sorted_metadatas],
+        "documents": [sorted_documents],
+        "embeddings": None,
+        "uris": None,
+        "data": None,
+    }
+
+    return merged_query_results
+
+
 @app.post("/query/collections")
 def query_collections(
     form_data: QueryCollectionsForm,
@@ -147,7 +192,7 @@ def query_collections(
         except:
             pass
 
-    return results
+    return merge_and_sort_query_results(results, form_data.k)
 
 
 @app.post("/web")

+ 30 - 0
backend/apps/web/models/documents.py

@@ -44,6 +44,16 @@ class DocumentModel(BaseModel):
 ####################
 
 
+class DocumentResponse(BaseModel):
+    collection_name: str
+    name: str
+    title: str
+    filename: str
+    content: Optional[dict] = None
+    user_id: str
+    timestamp: int  # timestamp in epoch
+
+
 class DocumentUpdateForm(BaseModel):
     name: str
     title: str
@@ -111,6 +121,26 @@ class DocumentsTable:
             print(e)
             return None
 
+    def update_doc_content_by_name(
+        self, name: str, updated: dict
+    ) -> Optional[DocumentModel]:
+        try:
+            doc = self.get_doc_by_name(name)
+            doc_content = json.loads(doc.content if doc.content else "{}")
+            doc_content = {**doc_content, **updated}
+
+            query = Document.update(
+                content=json.dumps(doc_content),
+                timestamp=int(time.time()),
+            ).where(Document.name == name)
+            query.execute()
+
+            doc = Document.get(Document.name == name)
+            return DocumentModel(**model_to_dict(doc))
+        except Exception as e:
+            print(e)
+            return None
+
     def delete_doc_by_name(self, name: str) -> bool:
         try:
             query = Document.delete().where((Document.name == name))

+ 61 - 8
backend/apps/web/routers/documents.py

@@ -11,6 +11,7 @@ from apps.web.models.documents import (
     DocumentForm,
     DocumentUpdateForm,
     DocumentModel,
+    DocumentResponse,
 )
 
 from utils.utils import get_current_user
@@ -23,9 +24,18 @@ router = APIRouter()
 ############################
 
 
-@router.get("/", response_model=List[DocumentModel])
+@router.get("/", response_model=List[DocumentResponse])
 async def get_documents(user=Depends(get_current_user)):
-    return Documents.get_docs()
+    docs = [
+        DocumentResponse(
+            **{
+                **doc.model_dump(),
+                "content": json.loads(doc.content if doc.content else "{}"),
+            }
+        )
+        for doc in Documents.get_docs()
+    ]
+    return docs
 
 
 ############################
@@ -33,7 +43,7 @@ async def get_documents(user=Depends(get_current_user)):
 ############################
 
 
-@router.post("/create", response_model=Optional[DocumentModel])
+@router.post("/create", response_model=Optional[DocumentResponse])
 async def create_new_doc(form_data: DocumentForm, user=Depends(get_current_user)):
     if user.role != "admin":
         raise HTTPException(
@@ -46,7 +56,12 @@ async def create_new_doc(form_data: DocumentForm, user=Depends(get_current_user)
         doc = Documents.insert_new_doc(user.id, form_data)
 
         if doc:
-            return doc
+            return DocumentResponse(
+                **{
+                    **doc.model_dump(),
+                    "content": json.loads(doc.content if doc.content else "{}"),
+                }
+            )
         else:
             raise HTTPException(
                 status_code=status.HTTP_400_BAD_REQUEST,
@@ -64,12 +79,45 @@ async def create_new_doc(form_data: DocumentForm, user=Depends(get_current_user)
 ############################
 
 
-@router.get("/name/{name}", response_model=Optional[DocumentModel])
+@router.get("/name/{name}", response_model=Optional[DocumentResponse])
 async def get_doc_by_name(name: str, user=Depends(get_current_user)):
     doc = Documents.get_doc_by_name(name)
 
     if doc:
-        return doc
+        return DocumentResponse(
+            **{
+                **doc.model_dump(),
+                "content": json.loads(doc.content if doc.content else "{}"),
+            }
+        )
+    else:
+        raise HTTPException(
+            status_code=status.HTTP_401_UNAUTHORIZED,
+            detail=ERROR_MESSAGES.NOT_FOUND,
+        )
+
+
+############################
+# TagDocByName
+############################
+
+
+class TagDocumentForm(BaseModel):
+    name: str
+    tags: List[dict]
+
+
+@router.post("/name/{name}/tags", response_model=Optional[DocumentResponse])
+async def tag_doc_by_name(form_data: TagDocumentForm, user=Depends(get_current_user)):
+    doc = Documents.update_doc_content_by_name(form_data.name, {"tags": form_data.tags})
+
+    if doc:
+        return DocumentResponse(
+            **{
+                **doc.model_dump(),
+                "content": json.loads(doc.content if doc.content else "{}"),
+            }
+        )
     else:
         raise HTTPException(
             status_code=status.HTTP_401_UNAUTHORIZED,
@@ -82,7 +130,7 @@ async def get_doc_by_name(name: str, user=Depends(get_current_user)):
 ############################
 
 
-@router.post("/name/{name}/update", response_model=Optional[DocumentModel])
+@router.post("/name/{name}/update", response_model=Optional[DocumentResponse])
 async def update_doc_by_name(
     name: str, form_data: DocumentUpdateForm, user=Depends(get_current_user)
 ):
@@ -94,7 +142,12 @@ async def update_doc_by_name(
 
     doc = Documents.update_doc_by_name(name, form_data)
     if doc:
-        return doc
+        return DocumentResponse(
+            **{
+                **doc.model_dump(),
+                "content": json.loads(doc.content if doc.content else "{}"),
+            }
+        )
     else:
         raise HTTPException(
             status_code=status.HTTP_400_BAD_REQUEST,

+ 41 - 0
src/lib/apis/documents/index.ts

@@ -144,6 +144,47 @@ export const updateDocByName = async (token: string, name: string, form: DocUpda
 	return res;
 };
 
+type TagDocForm = {
+	name: string;
+	tags: string[];
+};
+
+export const tagDocByName = async (token: string, name: string, form: TagDocForm) => {
+	let error = null;
+
+	const res = await fetch(`${WEBUI_API_BASE_URL}/documents/name/${name}/tags`, {
+		method: 'POST',
+		headers: {
+			Accept: 'application/json',
+			'Content-Type': 'application/json',
+			authorization: `Bearer ${token}`
+		},
+		body: JSON.stringify({
+			name: form.name,
+			tags: form.tags
+		})
+	})
+		.then(async (res) => {
+			if (!res.ok) throw await res.json();
+			return res.json();
+		})
+		.then((json) => {
+			return json;
+		})
+		.catch((err) => {
+			error = err.detail;
+
+			console.log(err);
+			return null;
+		});
+
+	if (error) {
+		throw error;
+	}
+
+	return res;
+};
+
 export const deleteDocByName = async (token: string, name: string) => {
 	let error = null;
 

+ 24 - 0
src/lib/components/common/Tags.svelte

@@ -0,0 +1,24 @@
+<script lang="ts">
+	import TagInput from './Tags/TagInput.svelte';
+	import TagList from './Tags/TagList.svelte';
+
+	export let tags = [];
+
+	export let deleteTag: Function;
+	export let addTag: Function;
+</script>
+
+<div class="flex flex-row space-x-0.5 line-clamp-1">
+	<TagList
+		{tags}
+		on:delete={(e) => {
+			deleteTag(e.detail);
+		}}
+	/>
+
+	<TagInput
+		on:add={(e) => {
+			addTag(e.detail);
+		}}
+	/>
+</div>

+ 64 - 0
src/lib/components/common/Tags/TagInput.svelte

@@ -0,0 +1,64 @@
+<script lang="ts">
+	import { createEventDispatcher } from 'svelte';
+	const dispatch = createEventDispatcher();
+
+	let showTagInput = false;
+	let tagName = '';
+</script>
+
+<div class="flex space-x-1 pl-1.5">
+	{#if showTagInput}
+		<div class="flex items-center">
+			<input
+				bind:value={tagName}
+				class=" cursor-pointer self-center text-xs h-fit bg-transparent outline-none line-clamp-1 w-[4rem]"
+				placeholder="Add a tag"
+			/>
+
+			<button
+				type="button"
+				on:click={() => {
+					dispatch('add', tagName);
+					tagName = '';
+					showTagInput = false;
+				}}
+			>
+				<svg
+					xmlns="http://www.w3.org/2000/svg"
+					viewBox="0 0 16 16"
+					fill="currentColor"
+					class="w-3 h-3"
+				>
+					<path
+						fill-rule="evenodd"
+						d="M12.416 3.376a.75.75 0 0 1 .208 1.04l-5 7.5a.75.75 0 0 1-1.154.114l-3-3a.75.75 0 0 1 1.06-1.06l2.353 2.353 4.493-6.74a.75.75 0 0 1 1.04-.207Z"
+						clip-rule="evenodd"
+					/>
+				</svg>
+			</button>
+		</div>
+
+		<!-- TODO: Tag Suggestions -->
+	{/if}
+
+	<button
+		class=" cursor-pointer self-center p-0.5 space-x-1 flex h-fit items-center dark:hover:bg-gray-700 rounded-full transition border dark:border-gray-600 border-dashed"
+		type="button"
+		on:click={() => {
+			showTagInput = !showTagInput;
+		}}
+	>
+		<div class=" m-auto self-center">
+			<svg
+				xmlns="http://www.w3.org/2000/svg"
+				viewBox="0 0 16 16"
+				fill="currentColor"
+				class="w-3 h-3 {showTagInput ? 'rotate-45' : ''} transition-all transform"
+			>
+				<path
+					d="M8.75 3.75a.75.75 0 0 0-1.5 0v3.5h-3.5a.75.75 0 0 0 0 1.5h3.5v3.5a.75.75 0 0 0 1.5 0v-3.5h3.5a.75.75 0 0 0 0-1.5h-3.5v-3.5Z"
+				/>
+			</svg>
+		</div>
+	</button>
+</div>

+ 33 - 0
src/lib/components/common/Tags/TagList.svelte

@@ -0,0 +1,33 @@
+<script lang="ts">
+	import { createEventDispatcher } from 'svelte';
+	const dispatch = createEventDispatcher();
+
+	export let tags = [];
+</script>
+
+{#each tags as tag}
+	<div
+		class="px-2 py-0.5 space-x-1 flex h-fit items-center rounded-full transition border dark:border-gray-600 dark:text-white"
+	>
+		<div class=" text-[0.65rem] font-medium self-center line-clamp-1">
+			{tag.name}
+		</div>
+		<button
+			class=" m-auto self-center cursor-pointer"
+			on:click={() => {
+				dispatch('delete', tag.name);
+			}}
+		>
+			<svg
+				xmlns="http://www.w3.org/2000/svg"
+				viewBox="0 0 16 16"
+				fill="currentColor"
+				class="w-3 h-3"
+			>
+				<path
+					d="M5.28 4.22a.75.75 0 0 0-1.06 1.06L6.94 8l-2.72 2.72a.75.75 0 1 0 1.06 1.06L8 9.06l2.72 2.72a.75.75 0 1 0 1.06-1.06L9.06 8l2.72-2.72a.75.75 0 0 0-1.06-1.06L8 6.94 5.28 4.22Z"
+				/>
+			</svg>
+		</button>
+	</div>
+{/each}

+ 42 - 2
src/lib/components/documents/EditDocModal.svelte

@@ -3,16 +3,22 @@
 	import dayjs from 'dayjs';
 	import { onMount } from 'svelte';
 
-	import { getDocs, updateDocByName } from '$lib/apis/documents';
+	import { getDocs, tagDocByName, updateDocByName } from '$lib/apis/documents';
 	import Modal from '../common/Modal.svelte';
 	import { documents } from '$lib/stores';
+	import TagInput from '../common/Tags/TagInput.svelte';
+	import Tags from '../common/Tags.svelte';
+	import { addTagById } from '$lib/apis/chats';
 
 	export let show = false;
 	export let selectedDoc;
 
+	let tags = [];
+
 	let doc = {
 		name: '',
-		title: ''
+		title: '',
+		content: null
 	};
 
 	const submitHandler = async () => {
@@ -30,9 +36,37 @@
 		}
 	};
 
+	const addTagHandler = async (tagName) => {
+		if (!tags.find((tag) => tag.name === tagName)) {
+			tags = [...tags, { name: tagName }];
+
+			await tagDocByName(localStorage.token, doc.name, {
+				name: doc.name,
+				tags: tags
+			});
+
+			documents.set(await getDocs(localStorage.token));
+		} else {
+			console.log('tag already exists');
+		}
+	};
+
+	const deleteTagHandler = async (tagName) => {
+		tags = tags.filter((tag) => tag.name !== tagName);
+
+		await tagDocByName(localStorage.token, doc.name, {
+			name: doc.name,
+			tags: tags
+		});
+
+		documents.set(await getDocs(localStorage.token));
+	};
+
 	onMount(() => {
 		if (selectedDoc) {
 			doc = JSON.parse(JSON.stringify(selectedDoc));
+
+			tags = doc?.content?.tags ?? [];
 		}
 	});
 </script>
@@ -112,6 +146,12 @@
 								/>
 							</div>
 						</div>
+
+						<div class="flex flex-col w-full">
+							<div class=" mb-1.5 text-xs text-gray-500">Tags</div>
+
+							<Tags {tags} addTag={addTagHandler} deleteTag={deleteTagHandler} />
+						</div>
 					</div>
 
 					<div class="flex justify-end pt-5 text-sm font-medium">

+ 3 - 96
src/lib/components/layout/Navbar.svelte

@@ -6,6 +6,8 @@
 	import { getChatById } from '$lib/apis/chats';
 	import { chatId, modelfiles } from '$lib/stores';
 	import ShareChatModal from '../chat/ShareChatModal.svelte';
+	import TagInput from '../common/Tags/TagInput.svelte';
+	import Tags from '../common/Tags.svelte';
 
 	export let initNewChat: Function;
 	export let title: string = 'Ollama Web UI';
@@ -61,21 +63,6 @@
 
 		saveAs(blob, `chat-${chat.title}.txt`);
 	};
-
-	const addTagHandler = () => {
-		// if (!tags.find((e) => e.name === tagName)) {
-		// 	tags = [
-		// 		...tags,
-		// 		{
-		// 			name: JSON.parse(JSON.stringify(tagName))
-		// 		}
-		// 	];
-		// }
-
-		addTag(tagName);
-		tagName = '';
-		showTagInput = false;
-	};
 </script>
 
 <ShareChatModal bind:show={showShareChatModal} {downloadChat} {shareChat} />
@@ -116,87 +103,7 @@
 
 			<div class="pl-2 self-center flex items-center space-x-2">
 				{#if shareEnabled}
-					<div class="flex flex-row space-x-0.5 line-clamp-1">
-						{#each tags as tag}
-							<div
-								class="px-2 py-0.5 space-x-1 flex h-fit items-center rounded-full transition border dark:border-gray-600 dark:text-white"
-							>
-								<div class=" text-[0.65rem] font-medium self-center line-clamp-1">
-									{tag.name}
-								</div>
-								<button
-									class=" m-auto self-center cursor-pointer"
-									on:click={() => {
-										deleteTag(tag.name);
-									}}
-								>
-									<svg
-										xmlns="http://www.w3.org/2000/svg"
-										viewBox="0 0 16 16"
-										fill="currentColor"
-										class="w-3 h-3"
-									>
-										<path
-											d="M5.28 4.22a.75.75 0 0 0-1.06 1.06L6.94 8l-2.72 2.72a.75.75 0 1 0 1.06 1.06L8 9.06l2.72 2.72a.75.75 0 1 0 1.06-1.06L9.06 8l2.72-2.72a.75.75 0 0 0-1.06-1.06L8 6.94 5.28 4.22Z"
-										/>
-									</svg>
-								</button>
-							</div>
-						{/each}
-
-						<div class="flex space-x-1 pl-1.5">
-							{#if showTagInput}
-								<div class="flex items-center">
-									<input
-										bind:value={tagName}
-										class=" cursor-pointer self-center text-xs h-fit bg-transparent outline-none line-clamp-1 w-[4rem]"
-										placeholder="Add a tag"
-									/>
-
-									<button
-										on:click={() => {
-											addTagHandler();
-										}}
-									>
-										<svg
-											xmlns="http://www.w3.org/2000/svg"
-											viewBox="0 0 16 16"
-											fill="currentColor"
-											class="w-3 h-3"
-										>
-											<path
-												fill-rule="evenodd"
-												d="M12.416 3.376a.75.75 0 0 1 .208 1.04l-5 7.5a.75.75 0 0 1-1.154.114l-3-3a.75.75 0 0 1 1.06-1.06l2.353 2.353 4.493-6.74a.75.75 0 0 1 1.04-.207Z"
-												clip-rule="evenodd"
-											/>
-										</svg>
-									</button>
-								</div>
-
-								<!-- TODO: Tag Suggestions -->
-							{/if}
-
-							<button
-								class=" cursor-pointer self-center p-0.5 space-x-1 flex h-fit items-center dark:hover:bg-gray-700 rounded-full transition border dark:border-gray-600 border-dashed"
-								on:click={() => {
-									showTagInput = !showTagInput;
-								}}
-							>
-								<div class=" m-auto self-center">
-									<svg
-										xmlns="http://www.w3.org/2000/svg"
-										viewBox="0 0 16 16"
-										fill="currentColor"
-										class="w-3 h-3 {showTagInput ? 'rotate-45' : ''} transition-all transform"
-									>
-										<path
-											d="M8.75 3.75a.75.75 0 0 0-1.5 0v3.5h-3.5a.75.75 0 0 0 0 1.5h3.5v3.5a.75.75 0 0 0 1.5 0v-3.5h3.5a.75.75 0 0 0 0-1.5h-3.5v-3.5Z"
-										/>
-									</svg>
-								</div>
-							</button>
-						</div>
-					</div>
+					<Tags {tags} {deleteTag} {addTag} />
 
 					<button
 						class=" cursor-pointer p-1.5 flex dark:hover:bg-gray-700 rounded-lg transition border dark:border-gray-600"

+ 0 - 1
src/routes/(app)/documents/+page.svelte

@@ -13,7 +13,6 @@
 
 	import EditDocModal from '$lib/components/documents/EditDocModal.svelte';
 	import AddFilesPlaceholder from '$lib/components/AddFilesPlaceholder.svelte';
-
 	let importFiles = '';
 
 	let inputFiles = '';