Bläddra i källkod

feat: full integration

Timothy J. Baek 1 år sedan
förälder
incheckning
9634e2da3e

+ 49 - 11
backend/apps/rag/main.py

@@ -9,6 +9,7 @@ from fastapi import (
     Form,
 )
 from fastapi.middleware.cors import CORSMiddleware
+import os, shutil
 
 from chromadb.utils import embedding_functions
 
@@ -23,7 +24,7 @@ from typing import Optional
 
 import uuid
 
-from config import EMBED_MODEL, CHROMA_CLIENT, CHUNK_SIZE, CHUNK_OVERLAP
+from config import UPLOAD_DIR, EMBED_MODEL, CHROMA_CLIENT, CHUNK_SIZE, CHUNK_OVERLAP
 from constants import ERROR_MESSAGES
 
 EMBEDDING_FUNC = embedding_functions.SentenceTransformerEmbeddingFunction(
@@ -51,7 +52,7 @@ class StoreWebForm(CollectionNameForm):
     url: str
 
 
-def store_data_in_vector_db(data, collection_name):
+def store_data_in_vector_db(data, collection_name) -> bool:
     text_splitter = RecursiveCharacterTextSplitter(
         chunk_size=CHUNK_SIZE, chunk_overlap=CHUNK_OVERLAP
     )
@@ -60,13 +61,22 @@ def store_data_in_vector_db(data, collection_name):
     texts = [doc.page_content for doc in docs]
     metadatas = [doc.metadata for doc in docs]
 
-    collection = CHROMA_CLIENT.create_collection(
-        name=collection_name, embedding_function=EMBEDDING_FUNC
-    )
+    try:
+        collection = CHROMA_CLIENT.create_collection(
+            name=collection_name, embedding_function=EMBEDDING_FUNC
+        )
 
-    collection.add(
-        documents=texts, metadatas=metadatas, ids=[str(uuid.uuid1()) for _ in texts]
-    )
+        collection.add(
+            documents=texts, metadatas=metadatas, ids=[str(uuid.uuid1()) for _ in texts]
+        )
+        return True
+    except Exception as e:
+        print(e)
+        print(e.__class__.__name__)
+        if e.__class__.__name__ == "UniqueConstraintError":
+            return True
+
+        return False
 
 
 @app.get("/")
@@ -116,7 +126,7 @@ def store_doc(collection_name: str = Form(...), file: UploadFile = File(...)):
 
     try:
         filename = file.filename
-        file_path = f"./data/{filename}"
+        file_path = f"{UPLOAD_DIR}/{filename}"
         contents = file.file.read()
         with open(file_path, "wb") as f:
             f.write(contents)
@@ -128,8 +138,15 @@ def store_doc(collection_name: str = Form(...), file: UploadFile = File(...)):
             loader = TextLoader(file_path)
 
         data = loader.load()
-        store_data_in_vector_db(data, collection_name)
-        return {"status": True, "collection_name": collection_name}
+        result = store_data_in_vector_db(data, collection_name)
+
+        if result:
+            return {"status": True, "collection_name": collection_name}
+        else:
+            raise HTTPException(
+                status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
+                detail=ERROR_MESSAGES.DEFAULT(),
+            )
     except Exception as e:
         print(e)
         raise HTTPException(
@@ -138,6 +155,27 @@ def store_doc(collection_name: str = Form(...), file: UploadFile = File(...)):
         )
 
 
+@app.get("/reset/db")
 def reset_vector_db():
     CHROMA_CLIENT.reset()
+
+
+@app.get("/reset")
+def reset():
+    folder = f"{UPLOAD_DIR}"
+    for filename in os.listdir(folder):
+        file_path = os.path.join(folder, filename)
+        try:
+            if os.path.isfile(file_path) or os.path.islink(file_path):
+                os.unlink(file_path)
+            elif os.path.isdir(file_path):
+                shutil.rmtree(file_path)
+        except Exception as e:
+            print("Failed to delete %s. Reason: %s" % (file_path, e))
+
+    try:
+        CHROMA_CLIENT.reset()
+    except Exception as e:
+        print(e)
+
     return {"status": True}

+ 20 - 1
backend/config.py

@@ -1,14 +1,31 @@
 from dotenv import load_dotenv, find_dotenv
 import os
+
+
 import chromadb
+from chromadb import Settings
+
 
 from secrets import token_bytes
 from base64 import b64encode
 
 from constants import ERROR_MESSAGES
 
+
+from pathlib import Path
+
 load_dotenv(find_dotenv("../.env"))
 
+
+####################################
+# File Upload
+####################################
+
+
+UPLOAD_DIR = "./data/uploads"
+Path(UPLOAD_DIR).mkdir(parents=True, exist_ok=True)
+
+
 ####################################
 # ENV (dev,test,prod)
 ####################################
@@ -64,6 +81,8 @@ if WEBUI_AUTH and WEBUI_JWT_SECRET_KEY == "":
 
 CHROMA_DATA_PATH = "./data/vector_db"
 EMBED_MODEL = "all-MiniLM-L6-v2"
-CHROMA_CLIENT = chromadb.PersistentClient(path=CHROMA_DATA_PATH)
+CHROMA_CLIENT = chromadb.PersistentClient(
+    path=CHROMA_DATA_PATH, settings=Settings(allow_reset=True)
+)
 CHUNK_SIZE = 1500
 CHUNK_OVERLAP = 100

+ 9 - 9
src/lib/components/chat/MessageInput.svelte

@@ -124,16 +124,16 @@
 						reader.readAsDataURL(file);
 					} else if (['application/pdf', 'text/plain'].includes(file['type'])) {
 						console.log(file);
-						const hash = await calculateSHA256(file);
-						// const res = uploadDocToVectorDB(localStorage.token, hash,file);
+						const hash = (await calculateSHA256(file)).substring(0, 63);
+						const res = await uploadDocToVectorDB(localStorage.token, hash, file);
 
-						if (true) {
+						if (res) {
 							files = [
 								...files,
 								{
 									type: 'doc',
 									name: file.name,
-									collection_name: hash
+									collection_name: res.collection_name
 								}
 							];
 						}
@@ -243,16 +243,16 @@
 								reader.readAsDataURL(file);
 							} else if (['application/pdf', 'text/plain'].includes(file['type'])) {
 								console.log(file);
-								const hash = await calculateSHA256(file);
-								// const res = uploadDocToVectorDB(localStorage.token,hash,file);
+								const hash = (await calculateSHA256(file)).substring(0, 63);
+								const res = await uploadDocToVectorDB(localStorage.token, hash, file);
 
-								if (true) {
+								if (res) {
 									files = [
 										...files,
 										{
 											type: 'doc',
 											name: file.name,
-											collection_name: hash
+											collection_name: res.collection_name
 										}
 									];
 									filesInputElement.value = '';
@@ -280,7 +280,7 @@
 										<img src={file.url} alt="input" class=" h-16 w-16 rounded-xl object-cover" />
 									{:else if file.type === 'doc'}
 										<div
-											class="h-16 w-[15rem] flex items-center space-x-3 px-2 bg-gray-600 rounded-xl"
+											class="h-16 w-[15rem] flex items-center space-x-3 px-2.5 bg-gray-600 rounded-xl"
 										>
 											<div class="p-2.5 bg-red-400 rounded-lg">
 												<svg

+ 31 - 1
src/lib/components/chat/Messages/UserMessage.svelte

@@ -53,11 +53,41 @@
 			class="prose chat-{message.role} w-full max-w-full dark:prose-invert prose-headings:my-0 prose-p:my-0 prose-p:-mb-4 prose-pre:my-0 prose-table:my-0 prose-blockquote:my-0 prose-img:my-0 prose-ul:-my-4 prose-ol:-my-4 prose-li:-my-3 prose-ul:-mb-6 prose-ol:-mb-6 prose-li:-mb-4 whitespace-pre-line"
 		>
 			{#if message.files}
-				<div class="my-3 w-full flex overflow-x-auto space-x-2">
+				<div class="my-2.5 w-full flex overflow-x-auto space-x-2 flex-wrap">
 					{#each message.files as file}
 						<div>
 							{#if file.type === 'image'}
 								<img src={file.url} alt="input" class=" max-h-96 rounded-lg" draggable="false" />
+							{:else if file.type === 'doc'}
+								<div
+									class="h-16 w-[15rem] flex items-center space-x-3 px-2.5 bg-gray-600 rounded-xl"
+								>
+									<div class="p-2.5 bg-red-400 rounded-lg">
+										<svg
+											xmlns="http://www.w3.org/2000/svg"
+											viewBox="0 0 24 24"
+											fill="currentColor"
+											class="w-6 h-6"
+										>
+											<path
+												fill-rule="evenodd"
+												d="M5.625 1.5c-1.036 0-1.875.84-1.875 1.875v17.25c0 1.035.84 1.875 1.875 1.875h12.75c1.035 0 1.875-.84 1.875-1.875V12.75A3.75 3.75 0 0 0 16.5 9h-1.875a1.875 1.875 0 0 1-1.875-1.875V5.25A3.75 3.75 0 0 0 9 1.5H5.625ZM7.5 15a.75.75 0 0 1 .75-.75h7.5a.75.75 0 0 1 0 1.5h-7.5A.75.75 0 0 1 7.5 15Zm.75 2.25a.75.75 0 0 0 0 1.5H12a.75.75 0 0 0 0-1.5H8.25Z"
+												clip-rule="evenodd"
+											/>
+											<path
+												d="M12.971 1.816A5.23 5.23 0 0 1 14.25 5.25v1.875c0 .207.168.375.375.375H16.5a5.23 5.23 0 0 1 3.434 1.279 9.768 9.768 0 0 0-6.963-6.963Z"
+											/>
+										</svg>
+									</div>
+
+									<div class="flex flex-col justify-center -space-y-0.5">
+										<div class=" text-gray-100 text-sm line-clamp-1">
+											{file.name}
+										</div>
+
+										<div class=" text-gray-500 text-sm">Document</div>
+									</div>
+								</div>
 							{/if}
 						</div>
 					{/each}

+ 1 - 2
src/lib/utils/index.ts

@@ -129,7 +129,6 @@ export const findWordIndices = (text) => {
 };
 
 export const calculateSHA256 = async (file) => {
-	console.log(file);
 	// Create a FileReader to read the file asynchronously
 	const reader = new FileReader();
 
@@ -156,7 +155,7 @@ export const calculateSHA256 = async (file) => {
 		const hashArray = Array.from(new Uint8Array(hashBuffer));
 		const hashHex = hashArray.map((byte) => byte.toString(16).padStart(2, '0')).join('');
 
-		return `sha256:${hashHex}`;
+		return `${hashHex}`;
 	} catch (error) {
 		console.error('Error calculating SHA-256 hash:', error);
 		throw error;

+ 6 - 1
src/routes/(app)/+page.svelte

@@ -186,8 +186,11 @@
 		const _chatId = JSON.parse(JSON.stringify($chatId));
 
 		// TODO: update below to include all ancestral files
-		const docs = history.messages[parentId].files.filter((item) => item.type === 'file');
 
+		console.log(history.messages[parentId]);
+		const docs = history.messages[parentId]?.files?.filter((item) => item.type === 'doc') ?? [];
+
+		console.log(docs);
 		if (docs.length > 0) {
 			const query = history.messages[parentId].content;
 
@@ -207,6 +210,8 @@
 				return `${a}${context.documents.join(' ')}\n`;
 			}, '');
 
+			console.log(contextString);
+
 			history.messages[parentId].raContent = RAGTemplate(contextString, query);
 			history.messages[parentId].contexts = relevantContexts;
 			await tick();