Procházet zdrojové kódy

Merge pull request #778 from open-webui/rag-folder

feat: rag folder scan support
Timothy Jaeryang Baek před 1 rokem
rodič
revize
a32ab5afbd

+ 90 - 11
backend/apps/rag/main.py

@@ -10,6 +10,8 @@ from fastapi import (
 )
 from fastapi.middleware.cors import CORSMiddleware
 import os, shutil
+
+from pathlib import Path
 from typing import List
 
 # from chromadb.utils import embedding_functions
@@ -28,19 +30,39 @@ from langchain_community.document_loaders import (
     UnstructuredExcelLoader,
 )
 from langchain.text_splitter import RecursiveCharacterTextSplitter
-from langchain_community.vectorstores import Chroma
 from langchain.chains import RetrievalQA
+from langchain_community.vectorstores import Chroma
 
 
 from pydantic import BaseModel
 from typing import Optional
-
+import mimetypes
 import uuid
+import json
 import time
 
-from utils.misc import calculate_sha256, calculate_sha256_string
+
+from apps.web.models.documents import (
+    Documents,
+    DocumentForm,
+    DocumentResponse,
+)
+
+from utils.misc import (
+    calculate_sha256,
+    calculate_sha256_string,
+    sanitize_filename,
+    extract_folders_after_data_docs,
+)
 from utils.utils import get_current_user, get_admin_user
-from config import UPLOAD_DIR, EMBED_MODEL, CHROMA_CLIENT, CHUNK_SIZE, CHUNK_OVERLAP
+from config import (
+    UPLOAD_DIR,
+    DOCS_DIR,
+    EMBED_MODEL,
+    CHROMA_CLIENT,
+    CHUNK_SIZE,
+    CHUNK_OVERLAP,
+)
 from constants import ERROR_MESSAGES
 
 # EMBEDDING_FUNC = embedding_functions.SentenceTransformerEmbeddingFunction(
@@ -220,8 +242,8 @@ def store_web(form_data: StoreWebForm, user=Depends(get_current_user)):
         )
 
 
-def get_loader(file, file_path):
-    file_ext = file.filename.split(".")[-1].lower()
+def get_loader(filename: str, file_content_type: str, file_path: str):
+    file_ext = filename.split(".")[-1].lower()
     known_type = True
 
     known_source_ext = [
@@ -279,20 +301,20 @@ def get_loader(file, file_path):
         loader = UnstructuredXMLLoader(file_path)
     elif file_ext == "md":
         loader = UnstructuredMarkdownLoader(file_path)
-    elif file.content_type == "application/epub+zip":
+    elif file_content_type == "application/epub+zip":
         loader = UnstructuredEPubLoader(file_path)
     elif (
-        file.content_type
+        file_content_type
         == "application/vnd.openxmlformats-officedocument.wordprocessingml.document"
         or file_ext in ["doc", "docx"]
     ):
         loader = Docx2txtLoader(file_path)
-    elif file.content_type in [
+    elif file_content_type in [
         "application/vnd.ms-excel",
         "application/vnd.openxmlformats-officedocument.spreadsheetml.sheet",
     ] or file_ext in ["xls", "xlsx"]:
         loader = UnstructuredExcelLoader(file_path)
-    elif file_ext in known_source_ext or file.content_type.find("text/") >= 0:
+    elif file_ext in known_source_ext or file_content_type.find("text/") >= 0:
         loader = TextLoader(file_path)
     else:
         loader = TextLoader(file_path)
@@ -323,7 +345,7 @@ def store_doc(
             collection_name = calculate_sha256(f)[:63]
         f.close()
 
-        loader, known_type = get_loader(file, file_path)
+        loader, known_type = get_loader(file.filename, file.content_type, file_path)
         data = loader.load()
         result = store_data_in_vector_db(data, collection_name)
 
@@ -353,6 +375,63 @@ def store_doc(
             )
 
 
+@app.get("/scan")
+def scan_docs_dir(user=Depends(get_admin_user)):
+    try:
+        for path in Path(DOCS_DIR).rglob("./**/*"):
+            if path.is_file() and not path.name.startswith("."):
+                tags = extract_folders_after_data_docs(path)
+                filename = path.name
+                file_content_type = mimetypes.guess_type(path)
+
+                f = open(path, "rb")
+                collection_name = calculate_sha256(f)[:63]
+                f.close()
+
+                loader, known_type = get_loader(
+                    filename, file_content_type[0], str(path)
+                )
+                data = loader.load()
+
+                result = store_data_in_vector_db(data, collection_name)
+
+                if result:
+                    sanitized_filename = sanitize_filename(filename)
+                    doc = Documents.get_doc_by_name(sanitized_filename)
+
+                    if doc == None:
+                        doc = Documents.insert_new_doc(
+                            user.id,
+                            DocumentForm(
+                                **{
+                                    "name": sanitized_filename,
+                                    "title": filename,
+                                    "collection_name": collection_name,
+                                    "filename": filename,
+                                    "content": (
+                                        json.dumps(
+                                            {
+                                                "tags": list(
+                                                    map(
+                                                        lambda name: {"name": name},
+                                                        tags,
+                                                    )
+                                                )
+                                            }
+                                        )
+                                        if len(tags)
+                                        else "{}"
+                                    ),
+                                }
+                            ),
+                        )
+
+    except Exception as e:
+        print(e)
+
+    return True
+
+
 @app.get("/reset/db")
 def reset_vector_db(user=Depends(get_admin_user)):
     CHROMA_CLIENT.reset()

+ 4 - 0
backend/apps/web/routers/documents.py

@@ -96,6 +96,10 @@ async def get_doc_by_name(name: str, user=Depends(get_current_user)):
 ############################
 
 
+class TagItem(BaseModel):
+    name: str
+
+
 class TagDocumentForm(BaseModel):
     name: str
     tags: List[dict]

+ 8 - 0
backend/config.py

@@ -43,6 +43,14 @@ Path(UPLOAD_DIR).mkdir(parents=True, exist_ok=True)
 CACHE_DIR = f"{DATA_DIR}/cache"
 Path(CACHE_DIR).mkdir(parents=True, exist_ok=True)
 
+
+####################################
+# Docs DIR
+####################################
+
+DOCS_DIR = f"{DATA_DIR}/docs"
+Path(DOCS_DIR).mkdir(parents=True, exist_ok=True)
+
 ####################################
 # OLLAMA_API_BASE_URL
 ####################################

+ 38 - 0
backend/utils/misc.py

@@ -1,3 +1,4 @@
+from pathlib import Path
 import hashlib
 import re
 
@@ -38,3 +39,40 @@ def validate_email_format(email: str) -> bool:
     if not re.match(r"[^@]+@[^@]+\.[^@]+", email):
         return False
     return True
+
+
+def sanitize_filename(file_name):
+    # Convert to lowercase
+    lower_case_file_name = file_name.lower()
+
+    # Remove special characters using regular expression
+    sanitized_file_name = re.sub(r"[^\w\s]", "", lower_case_file_name)
+
+    # Replace spaces with dashes
+    final_file_name = re.sub(r"\s+", "-", sanitized_file_name)
+
+    return final_file_name
+
+
+def extract_folders_after_data_docs(path):
+    # Convert the path to a Path object if it's not already
+    path = Path(path)
+
+    # Extract parts of the path
+    parts = path.parts
+
+    # Find the index of '/data/docs' in the path
+    try:
+        index_data_docs = parts.index("data") + 1
+        index_docs = parts.index("docs", index_data_docs) + 1
+    except ValueError:
+        return []
+
+    # Exclude the filename and accumulate folder names
+    tags = []
+
+    folders = parts[index_docs:-1]
+    for idx, part in enumerate(folders):
+        tags.append("/".join(folders[: idx + 1]))
+
+    return tags

+ 26 - 0
src/lib/apis/rag/index.ts

@@ -138,6 +138,32 @@ export const queryCollection = async (
 	return res;
 };
 
+export const scanDocs = async (token: string) => {
+	let error = null;
+
+	const res = await fetch(`${RAG_API_BASE_URL}/scan`, {
+		method: 'GET',
+		headers: {
+			Accept: 'application/json',
+			authorization: `Bearer ${token}`
+		}
+	})
+		.then(async (res) => {
+			if (!res.ok) throw await res.json();
+			return res.json();
+		})
+		.catch((err) => {
+			error = err.detail;
+			return null;
+		});
+
+	if (error) {
+		throw error;
+	}
+
+	return res;
+};
+
 export const resetVectorDB = async (token: string) => {
 	let error = null;
 

+ 1 - 1
src/lib/components/chat/Messages/ResponseMessage.svelte

@@ -366,7 +366,7 @@
 
 								{#if message.done}
 									<div
-										class=" flex justify-start space-x-1 -mt-1 overflow-x-auto buttons text-gray-700 dark:text-gray-500"
+										class=" flex justify-start space-x-1 overflow-x-auto buttons text-gray-700 dark:text-gray-500"
 									>
 										{#if siblings.length > 1}
 											<div class="flex self-center min-w-fit">

+ 106 - 0
src/lib/components/documents/Settings/General.svelte

@@ -0,0 +1,106 @@
+<script lang="ts">
+	import { getDocs } from '$lib/apis/documents';
+	import { scanDocs } from '$lib/apis/rag';
+	import { documents } from '$lib/stores';
+	import { onMount } from 'svelte';
+	import toast from 'svelte-french-toast';
+
+	export let saveHandler: Function;
+
+	let loading = false;
+
+	const scanHandler = async () => {
+		loading = true;
+		const res = await scanDocs(localStorage.token);
+		loading = false;
+
+		if (res) {
+			await documents.set(await getDocs(localStorage.token));
+			toast.success('Scan complete!');
+		}
+	};
+
+	onMount(async () => {});
+</script>
+
+<form
+	class="flex flex-col h-full justify-between space-y-3 text-sm"
+	on:submit|preventDefault={() => {
+		// console.log('submit');
+		saveHandler();
+	}}
+>
+	<div class=" space-y-3 pr-1.5 overflow-y-scroll max-h-80">
+		<div>
+			<div class=" mb-2 text-sm font-medium">General Settings</div>
+
+			<div class="  flex w-full justify-between">
+				<div class=" self-center text-xs font-medium">Scan for documents from '/data/docs'</div>
+
+				<button
+					class=" self-center text-xs p-1 px-3 bg-gray-100 dark:bg-gray-800 dark:hover:bg-gray-700 rounded flex flex-row space-x-1 items-center {loading
+						? ' cursor-not-allowed'
+						: ''}"
+					on:click={() => {
+						scanHandler();
+						console.log('check');
+					}}
+					type="button"
+					disabled={loading}
+				>
+					<div class="self-center font-medium">Scan</div>
+
+					<!-- <svg
+						xmlns="http://www.w3.org/2000/svg"
+						viewBox="0 0 16 16"
+						fill="currentColor"
+						class="w-3 h-3"
+					>
+						<path
+							fill-rule="evenodd"
+							d="M13.836 2.477a.75.75 0 0 1 .75.75v3.182a.75.75 0 0 1-.75.75h-3.182a.75.75 0 0 1 0-1.5h1.37l-.84-.841a4.5 4.5 0 0 0-7.08.932.75.75 0 0 1-1.3-.75 6 6 0 0 1 9.44-1.242l.842.84V3.227a.75.75 0 0 1 .75-.75Zm-.911 7.5A.75.75 0 0 1 13.199 11a6 6 0 0 1-9.44 1.241l-.84-.84v1.371a.75.75 0 0 1-1.5 0V9.591a.75.75 0 0 1 .75-.75H5.35a.75.75 0 0 1 0 1.5H3.98l.841.841a4.5 4.5 0 0 0 7.08-.932.75.75 0 0 1 1.025-.273Z"
+							clip-rule="evenodd"
+						/>
+					</svg> -->
+
+					{#if loading}
+						<div class="ml-3 self-center">
+							<svg
+								class=" w-3 h-3"
+								viewBox="0 0 24 24"
+								fill="currentColor"
+								xmlns="http://www.w3.org/2000/svg"
+								><style>
+									.spinner_ajPY {
+										transform-origin: center;
+										animation: spinner_AtaB 0.75s infinite linear;
+									}
+									@keyframes spinner_AtaB {
+										100% {
+											transform: rotate(360deg);
+										}
+									}
+								</style><path
+									d="M12,1A11,11,0,1,0,23,12,11,11,0,0,0,12,1Zm0,19a8,8,0,1,1,8-8A8,8,0,0,1,12,20Z"
+									opacity=".25"
+								/><path
+									d="M10.14,1.16a11,11,0,0,0-9,8.92A1.59,1.59,0,0,0,2.46,12,1.52,1.52,0,0,0,4.11,10.7a8,8,0,0,1,6.66-6.61A1.42,1.42,0,0,0,12,2.69h0A1.57,1.57,0,0,0,10.14,1.16Z"
+									class="spinner_ajPY"
+								/></svg
+							>
+						</div>
+					{/if}
+				</button>
+			</div>
+		</div>
+	</div>
+
+	<!-- <div class="flex justify-end pt-3 text-sm font-medium">
+		<button
+			class=" px-4 py-2 bg-emerald-600 hover:bg-emerald-700 text-gray-100 transition rounded"
+			type="submit"
+		>
+			Save
+		</button>
+	</div> -->
+</form>

+ 86 - 0
src/lib/components/documents/SettingsModal.svelte

@@ -0,0 +1,86 @@
+<script>
+	import Modal from '../common/Modal.svelte';
+	import General from './Settings/General.svelte';
+
+	export let show = false;
+
+	let selectedTab = 'general';
+</script>
+
+<Modal bind:show>
+	<div>
+		<div class=" flex justify-between dark:text-gray-300 px-5 py-4">
+			<div class=" text-lg font-medium self-center">Document Settings</div>
+			<button
+				class="self-center"
+				on:click={() => {
+					show = false;
+				}}
+			>
+				<svg
+					xmlns="http://www.w3.org/2000/svg"
+					viewBox="0 0 20 20"
+					fill="currentColor"
+					class="w-5 h-5"
+				>
+					<path
+						d="M6.28 5.22a.75.75 0 00-1.06 1.06L8.94 10l-3.72 3.72a.75.75 0 101.06 1.06L10 11.06l3.72 3.72a.75.75 0 101.06-1.06L11.06 10l3.72-3.72a.75.75 0 00-1.06-1.06L10 8.94 6.28 5.22z"
+					/>
+				</svg>
+			</button>
+		</div>
+		<hr class=" dark:border-gray-800" />
+
+		<div class="flex flex-col md:flex-row w-full p-4 md:space-x-4">
+			<div
+				class="tabs flex flex-row overflow-x-auto space-x-1 md:space-x-0 md:space-y-1 md:flex-col flex-1 md:flex-none md:w-40 dark:text-gray-200 text-xs text-left mb-3 md:mb-0"
+			>
+				<button
+					class="px-2.5 py-2.5 min-w-fit rounded-lg flex-1 md:flex-none flex text-right transition {selectedTab ===
+					'general'
+						? 'bg-gray-200 dark:bg-gray-700'
+						: ' hover:bg-gray-300 dark:hover:bg-gray-800'}"
+					on:click={() => {
+						selectedTab = 'general';
+					}}
+				>
+					<div class=" self-center mr-2">
+						<svg
+							xmlns="http://www.w3.org/2000/svg"
+							viewBox="0 0 16 16"
+							fill="currentColor"
+							class="w-4 h-4"
+						>
+							<path
+								fill-rule="evenodd"
+								d="M6.955 1.45A.5.5 0 0 1 7.452 1h1.096a.5.5 0 0 1 .497.45l.17 1.699c.484.12.94.312 1.356.562l1.321-1.081a.5.5 0 0 1 .67.033l.774.775a.5.5 0 0 1 .034.67l-1.08 1.32c.25.417.44.873.561 1.357l1.699.17a.5.5 0 0 1 .45.497v1.096a.5.5 0 0 1-.45.497l-1.699.17c-.12.484-.312.94-.562 1.356l1.082 1.322a.5.5 0 0 1-.034.67l-.774.774a.5.5 0 0 1-.67.033l-1.322-1.08c-.416.25-.872.44-1.356.561l-.17 1.699a.5.5 0 0 1-.497.45H7.452a.5.5 0 0 1-.497-.45l-.17-1.699a4.973 4.973 0 0 1-1.356-.562L4.108 13.37a.5.5 0 0 1-.67-.033l-.774-.775a.5.5 0 0 1-.034-.67l1.08-1.32a4.971 4.971 0 0 1-.561-1.357l-1.699-.17A.5.5 0 0 1 1 8.548V7.452a.5.5 0 0 1 .45-.497l1.699-.17c.12-.484.312-.94.562-1.356L2.629 4.107a.5.5 0 0 1 .034-.67l.774-.774a.5.5 0 0 1 .67-.033L5.43 3.71a4.97 4.97 0 0 1 1.356-.561l.17-1.699ZM6 8c0 .538.212 1.026.558 1.385l.057.057a2 2 0 0 0 2.828-2.828l-.058-.056A2 2 0 0 0 6 8Z"
+								clip-rule="evenodd"
+							/>
+						</svg>
+					</div>
+					<div class=" self-center">General</div>
+				</button>
+			</div>
+			<div class="flex-1 md:min-h-[380px]">
+				{#if selectedTab === 'general'}
+					<General
+						saveHandler={() => {
+							show = false;
+						}}
+					/>
+					<!-- <General
+						saveHandler={() => {
+							show = false;
+						}}
+					/> -->
+					<!-- {:else if selectedTab === 'users'}
+					<Users
+						saveHandler={() => {
+							show = false;
+						}}
+					/> -->
+				{/if}
+			</div>
+		</div>
+	</div>
+</Modal>

+ 31 - 0
src/routes/(app)/documents/+page.svelte

@@ -13,6 +13,7 @@
 
 	import EditDocModal from '$lib/components/documents/EditDocModal.svelte';
 	import AddFilesPlaceholder from '$lib/components/AddFilesPlaceholder.svelte';
+	import SettingsModal from '$lib/components/documents/SettingsModal.svelte';
 	let importFiles = '';
 
 	let inputFiles = '';
@@ -20,6 +21,7 @@
 
 	let tags = [];
 
+	let showSettingsModal = false;
 	let showEditDocModal = false;
 	let selectedDoc;
 	let selectedTag = '';
@@ -179,11 +181,38 @@
 	}}
 />
 
+<SettingsModal bind:show={showSettingsModal} />
+
 <div class="min-h-screen max-h-[100dvh] w-full flex justify-center dark:text-white">
 	<div class=" py-2.5 flex flex-col justify-between w-full overflow-y-auto">
 		<div class="max-w-2xl mx-auto w-full px-3 md:px-0 my-10">
 			<div class="mb-6 flex justify-between items-center">
 				<div class=" text-2xl font-semibold self-center">My Documents</div>
+
+				<div>
+					<button
+						class="flex items-center space-x-1 border border-gray-200 dark:border-gray-600 px-3 py-1 rounded-lg"
+						type="button"
+						on:click={() => {
+							showSettingsModal = !showSettingsModal;
+						}}
+					>
+						<svg
+							xmlns="http://www.w3.org/2000/svg"
+							viewBox="0 0 16 16"
+							fill="currentColor"
+							class="w-4 h-4"
+						>
+							<path
+								fill-rule="evenodd"
+								d="M6.955 1.45A.5.5 0 0 1 7.452 1h1.096a.5.5 0 0 1 .497.45l.17 1.699c.484.12.94.312 1.356.562l1.321-1.081a.5.5 0 0 1 .67.033l.774.775a.5.5 0 0 1 .034.67l-1.08 1.32c.25.417.44.873.561 1.357l1.699.17a.5.5 0 0 1 .45.497v1.096a.5.5 0 0 1-.45.497l-1.699.17c-.12.484-.312.94-.562 1.356l1.082 1.322a.5.5 0 0 1-.034.67l-.774.774a.5.5 0 0 1-.67.033l-1.322-1.08c-.416.25-.872.44-1.356.561l-.17 1.699a.5.5 0 0 1-.497.45H7.452a.5.5 0 0 1-.497-.45l-.17-1.699a4.973 4.973 0 0 1-1.356-.562L4.108 13.37a.5.5 0 0 1-.67-.033l-.774-.775a.5.5 0 0 1-.034-.67l1.08-1.32a4.971 4.971 0 0 1-.561-1.357l-1.699-.17A.5.5 0 0 1 1 8.548V7.452a.5.5 0 0 1 .45-.497l1.699-.17c.12-.484.312-.94.562-1.356L2.629 4.107a.5.5 0 0 1 .034-.67l.774-.774a.5.5 0 0 1 .67-.033L5.43 3.71a4.97 4.97 0 0 1 1.356-.561l.17-1.699ZM6 8c0 .538.212 1.026.558 1.385l.057.057a2 2 0 0 0 2.828-2.828l-.058-.056A2 2 0 0 0 6 8Z"
+								clip-rule="evenodd"
+							/>
+						</svg>
+
+						<div class=" text-xs">Document Settings</div>
+					</button>
+				</div>
 			</div>
 
 			<div class=" flex w-full space-x-2">
@@ -419,6 +448,8 @@
 						</button>
 					</div>
 				</div>
+
+				<div class=" my-2.5" />
 			{/each}
 
 			{#if $documents.length > 0}