Timothy J. Baek 7 月之前
父节点
当前提交
af57a2c153
共有 28 个文件被更改,包括 929 次插入869 次删除
  1. 183 0
      backend/open_webui/apps/retrieval/loader/main.py
  2. 431 691
      backend/open_webui/apps/retrieval/main.py
  3. 81 0
      backend/open_webui/apps/retrieval/model/colbert.py
  4. 1 1
      backend/open_webui/apps/retrieval/web/brave.py
  5. 1 1
      backend/open_webui/apps/retrieval/web/duckduckgo.py
  6. 1 1
      backend/open_webui/apps/retrieval/web/google_pse.py
  7. 1 1
      backend/open_webui/apps/retrieval/web/jina_search.py
  8. 0 0
      backend/open_webui/apps/retrieval/web/main.py
  9. 1 1
      backend/open_webui/apps/retrieval/web/searchapi.py
  10. 1 1
      backend/open_webui/apps/retrieval/web/searxng.py
  11. 1 1
      backend/open_webui/apps/retrieval/web/serper.py
  12. 1 1
      backend/open_webui/apps/retrieval/web/serply.py
  13. 1 1
      backend/open_webui/apps/retrieval/web/serpstack.py
  14. 1 1
      backend/open_webui/apps/retrieval/web/tavily.py
  15. 0 0
      backend/open_webui/apps/retrieval/web/testdata/brave.json
  16. 0 0
      backend/open_webui/apps/retrieval/web/testdata/google_pse.json
  17. 0 0
      backend/open_webui/apps/retrieval/web/testdata/searchapi.json
  18. 0 0
      backend/open_webui/apps/retrieval/web/testdata/searxng.json
  19. 0 0
      backend/open_webui/apps/retrieval/web/testdata/serper.json
  20. 0 0
      backend/open_webui/apps/retrieval/web/testdata/serply.json
  21. 0 0
      backend/open_webui/apps/retrieval/web/testdata/serpstack.json
  22. 97 0
      backend/open_webui/apps/retrieval/web/utils.py
  23. 114 146
      src/lib/apis/retrieval/index.ts
  24. 2 2
      src/lib/components/admin/Settings/Documents.svelte
  25. 2 2
      src/lib/components/chat/Chat.svelte
  26. 3 0
      src/lib/components/chat/Controls/Controls.svelte
  27. 3 3
      src/lib/components/chat/MessageInput/Commands.svelte
  28. 3 15
      src/lib/components/common/FileItem.svelte

+ 183 - 0
backend/open_webui/apps/retrieval/loader/main.py

@@ -0,0 +1,183 @@
+import requests
+import logging
+
+from langchain_community.document_loaders import (
+    BSHTMLLoader,
+    CSVLoader,
+    Docx2txtLoader,
+    OutlookMessageLoader,
+    PyPDFLoader,
+    TextLoader,
+    UnstructuredEPubLoader,
+    UnstructuredExcelLoader,
+    UnstructuredMarkdownLoader,
+    UnstructuredPowerPointLoader,
+    UnstructuredRSTLoader,
+    UnstructuredXMLLoader,
+    YoutubeLoader,
+)
+from langchain_core.documents import Document
+from open_webui.env import SRC_LOG_LEVELS
+
+log = logging.getLogger(__name__)
+log.setLevel(SRC_LOG_LEVELS["RAG"])
+
+
+known_source_ext = [
+    "go",
+    "py",
+    "java",
+    "sh",
+    "bat",
+    "ps1",
+    "cmd",
+    "js",
+    "ts",
+    "css",
+    "cpp",
+    "hpp",
+    "h",
+    "c",
+    "cs",
+    "sql",
+    "log",
+    "ini",
+    "pl",
+    "pm",
+    "r",
+    "dart",
+    "dockerfile",
+    "env",
+    "php",
+    "hs",
+    "hsc",
+    "lua",
+    "nginxconf",
+    "conf",
+    "m",
+    "mm",
+    "plsql",
+    "perl",
+    "rb",
+    "rs",
+    "db2",
+    "scala",
+    "bash",
+    "swift",
+    "vue",
+    "svelte",
+    "msg",
+    "ex",
+    "exs",
+    "erl",
+    "tsx",
+    "jsx",
+    "hs",
+    "lhs",
+]
+
+
+class TikaLoader:
+    def __init__(self, url, file_path, mime_type=None):
+        self.url = url
+        self.file_path = file_path
+        self.mime_type = mime_type
+
+    def load(self) -> list[Document]:
+        with open(self.file_path, "rb") as f:
+            data = f.read()
+
+        if self.mime_type is not None:
+            headers = {"Content-Type": self.mime_type}
+        else:
+            headers = {}
+
+        endpoint = self.url
+        if not endpoint.endswith("/"):
+            endpoint += "/"
+        endpoint += "tika/text"
+
+        r = requests.put(endpoint, data=data, headers=headers)
+
+        if r.ok:
+            raw_metadata = r.json()
+            text = raw_metadata.get("X-TIKA:content", "<No text content found>")
+
+            if "Content-Type" in raw_metadata:
+                headers["Content-Type"] = raw_metadata["Content-Type"]
+
+            log.info("Tika extracted text: %s", text)
+
+            return [Document(page_content=text, metadata=headers)]
+        else:
+            raise Exception(f"Error calling Tika: {r.reason}")
+
+
+class Loader:
+    def __init__(self, engine: str = "", **kwargs):
+        self.engine = engine
+        self.kwargs = kwargs
+
+    def load(
+        self, filename: str, file_content_type: str, file_path: str
+    ) -> list[Document]:
+        loader = self._get_loader(filename, file_content_type, file_path)
+        return loader.load()
+
+    def _get_loader(self, filename: str, file_content_type: str, file_path: str):
+        file_ext = filename.split(".")[-1].lower()
+
+        if self.engine == "tika" and self.kwargs.get("TIKA_SERVER_URL"):
+            if file_ext in known_source_ext or (
+                file_content_type and file_content_type.find("text/") >= 0
+            ):
+                loader = TextLoader(file_path, autodetect_encoding=True)
+            else:
+                loader = TikaLoader(
+                    url=self.kwargs.get("TIKA_SERVER_URL"),
+                    file_path=file_path,
+                    mime_type=file_content_type,
+                )
+        else:
+            if file_ext == "pdf":
+                loader = PyPDFLoader(
+                    file_path, extract_images=self.kwargs.get("PDF_EXTRACT_IMAGES")
+                )
+            elif file_ext == "csv":
+                loader = CSVLoader(file_path)
+            elif file_ext == "rst":
+                loader = UnstructuredRSTLoader(file_path, mode="elements")
+            elif file_ext == "xml":
+                loader = UnstructuredXMLLoader(file_path)
+            elif file_ext in ["htm", "html"]:
+                loader = BSHTMLLoader(file_path, open_encoding="unicode_escape")
+            elif file_ext == "md":
+                loader = UnstructuredMarkdownLoader(file_path)
+            elif file_content_type == "application/epub+zip":
+                loader = UnstructuredEPubLoader(file_path)
+            elif (
+                file_content_type
+                == "application/vnd.openxmlformats-officedocument.wordprocessingml.document"
+                or file_ext == "docx"
+            ):
+                loader = Docx2txtLoader(file_path)
+            elif file_content_type in [
+                "application/vnd.ms-excel",
+                "application/vnd.openxmlformats-officedocument.spreadsheetml.sheet",
+            ] or file_ext in ["xls", "xlsx"]:
+                loader = UnstructuredExcelLoader(file_path)
+            elif file_content_type in [
+                "application/vnd.ms-powerpoint",
+                "application/vnd.openxmlformats-officedocument.presentationml.presentation",
+            ] or file_ext in ["ppt", "pptx"]:
+                loader = UnstructuredPowerPointLoader(file_path)
+            elif file_ext == "msg":
+                loader = OutlookMessageLoader(file_path)
+            elif file_ext in known_source_ext or (
+                file_content_type and file_content_type.find("text/") >= 0
+            ):
+                loader = TextLoader(file_path, autodetect_encoding=True)
+            else:
+                loader = TextLoader(file_path, autodetect_encoding=True)
+
+        return loader

文件差异内容过多而无法显示
+ 431 - 691
backend/open_webui/apps/retrieval/main.py


+ 81 - 0
backend/open_webui/apps/retrieval/model/colbert.py

@@ -0,0 +1,81 @@
+import os
+import torch
+import numpy as np
+from colbert.infra import ColBERTConfig
+from colbert.modeling.checkpoint import Checkpoint
+
+
+class ColBERT:
+    def __init__(self, name, **kwargs) -> None:
+        print("ColBERT: Loading model", name)
+        self.device = "cuda" if torch.cuda.is_available() else "cpu"
+
+        DOCKER = kwargs.get("env") == "docker"
+        if DOCKER:
+            # This is a workaround for the issue with the docker container
+            # where the torch extension is not loaded properly
+            # and the following error is thrown:
+            # /root/.cache/torch_extensions/py311_cpu/segmented_maxsim_cpp/segmented_maxsim_cpp.so: cannot open shared object file: No such file or directory
+
+            lock_file = (
+                "/root/.cache/torch_extensions/py311_cpu/segmented_maxsim_cpp/lock"
+            )
+            if os.path.exists(lock_file):
+                os.remove(lock_file)
+
+        self.ckpt = Checkpoint(
+            name,
+            colbert_config=ColBERTConfig(model_name=name),
+        ).to(self.device)
+        pass
+
+    def calculate_similarity_scores(self, query_embeddings, document_embeddings):
+
+        query_embeddings = query_embeddings.to(self.device)
+        document_embeddings = document_embeddings.to(self.device)
+
+        # Validate dimensions to ensure compatibility
+        if query_embeddings.dim() != 3:
+            raise ValueError(
+                f"Expected query embeddings to have 3 dimensions, but got {query_embeddings.dim()}."
+            )
+        if document_embeddings.dim() != 3:
+            raise ValueError(
+                f"Expected document embeddings to have 3 dimensions, but got {document_embeddings.dim()}."
+            )
+        if query_embeddings.size(0) not in [1, document_embeddings.size(0)]:
+            raise ValueError(
+                "There should be either one query or queries equal to the number of documents."
+            )
+
+        # Transpose the query embeddings to align for matrix multiplication
+        transposed_query_embeddings = query_embeddings.permute(0, 2, 1)
+        # Compute similarity scores using batch matrix multiplication
+        computed_scores = torch.matmul(document_embeddings, transposed_query_embeddings)
+        # Apply max pooling to extract the highest semantic similarity across each document's sequence
+        maximum_scores = torch.max(computed_scores, dim=1).values
+
+        # Sum up the maximum scores across features to get the overall document relevance scores
+        final_scores = maximum_scores.sum(dim=1)
+
+        normalized_scores = torch.softmax(final_scores, dim=0)
+
+        return normalized_scores.detach().cpu().numpy().astype(np.float32)
+
+    def predict(self, sentences):
+
+        query = sentences[0][0]
+        docs = [i[1] for i in sentences]
+
+        # Embedding the documents
+        embedded_docs = self.ckpt.docFromText(docs, bsize=32)[0]
+        # Embedding the queries
+        embedded_queries = self.ckpt.queryFromText([query], bsize=32)
+        embedded_query = embedded_queries[0]
+
+        # Calculate retrieval scores for the query against all documents
+        scores = self.calculate_similarity_scores(
+            embedded_query.unsqueeze(0), embedded_docs
+        )
+
+        return scores

+ 1 - 1
backend/open_webui/apps/retrieval/search/brave.py → backend/open_webui/apps/retrieval/web/brave.py

@@ -2,7 +2,7 @@ import logging
 from typing import Optional
 
 import requests
-from open_webui.apps.retrieval.search.main import SearchResult, get_filtered_results
+from open_webui.apps.retrieval.web.main import SearchResult, get_filtered_results
 from open_webui.env import SRC_LOG_LEVELS
 
 log = logging.getLogger(__name__)

+ 1 - 1
backend/open_webui/apps/retrieval/search/duckduckgo.py → backend/open_webui/apps/retrieval/web/duckduckgo.py

@@ -1,7 +1,7 @@
 import logging
 from typing import Optional
 
-from open_webui.apps.retrieval.search.main import SearchResult, get_filtered_results
+from open_webui.apps.retrieval.web.main import SearchResult, get_filtered_results
 from duckduckgo_search import DDGS
 from open_webui.env import SRC_LOG_LEVELS
 

+ 1 - 1
backend/open_webui/apps/retrieval/search/google_pse.py → backend/open_webui/apps/retrieval/web/google_pse.py

@@ -2,7 +2,7 @@ import logging
 from typing import Optional
 
 import requests
-from open_webui.apps.retrieval.search.main import SearchResult, get_filtered_results
+from open_webui.apps.retrieval.web.main import SearchResult, get_filtered_results
 from open_webui.env import SRC_LOG_LEVELS
 
 log = logging.getLogger(__name__)

+ 1 - 1
backend/open_webui/apps/retrieval/search/jina_search.py → backend/open_webui/apps/retrieval/web/jina_search.py

@@ -1,7 +1,7 @@
 import logging
 
 import requests
-from open_webui.apps.retrieval.search.main import SearchResult
+from open_webui.apps.retrieval.web.main import SearchResult
 from open_webui.env import SRC_LOG_LEVELS
 from yarl import URL
 

+ 0 - 0
backend/open_webui/apps/retrieval/search/main.py → backend/open_webui/apps/retrieval/web/main.py


+ 1 - 1
backend/open_webui/apps/retrieval/search/searchapi.py → backend/open_webui/apps/retrieval/web/searchapi.py

@@ -3,7 +3,7 @@ from typing import Optional
 from urllib.parse import urlencode
 
 import requests
-from open_webui.apps.retrieval.search.main import SearchResult, get_filtered_results
+from open_webui.apps.retrieval.web.main import SearchResult, get_filtered_results
 from open_webui.env import SRC_LOG_LEVELS
 
 log = logging.getLogger(__name__)

+ 1 - 1
backend/open_webui/apps/retrieval/search/searxng.py → backend/open_webui/apps/retrieval/web/searxng.py

@@ -2,7 +2,7 @@ import logging
 from typing import Optional
 
 import requests
-from open_webui.apps.retrieval.search.main import SearchResult, get_filtered_results
+from open_webui.apps.retrieval.web.main import SearchResult, get_filtered_results
 from open_webui.env import SRC_LOG_LEVELS
 
 log = logging.getLogger(__name__)

+ 1 - 1
backend/open_webui/apps/retrieval/search/serper.py → backend/open_webui/apps/retrieval/web/serper.py

@@ -3,7 +3,7 @@ import logging
 from typing import Optional
 
 import requests
-from open_webui.apps.retrieval.search.main import SearchResult, get_filtered_results
+from open_webui.apps.retrieval.web.main import SearchResult, get_filtered_results
 from open_webui.env import SRC_LOG_LEVELS
 
 log = logging.getLogger(__name__)

+ 1 - 1
backend/open_webui/apps/retrieval/search/serply.py → backend/open_webui/apps/retrieval/web/serply.py

@@ -3,7 +3,7 @@ from typing import Optional
 from urllib.parse import urlencode
 
 import requests
-from open_webui.apps.retrieval.search.main import SearchResult, get_filtered_results
+from open_webui.apps.retrieval.web.main import SearchResult, get_filtered_results
 from open_webui.env import SRC_LOG_LEVELS
 
 log = logging.getLogger(__name__)

+ 1 - 1
backend/open_webui/apps/retrieval/search/serpstack.py → backend/open_webui/apps/retrieval/web/serpstack.py

@@ -2,7 +2,7 @@ import logging
 from typing import Optional
 
 import requests
-from open_webui.apps.retrieval.search.main import SearchResult, get_filtered_results
+from open_webui.apps.retrieval.web.main import SearchResult, get_filtered_results
 from open_webui.env import SRC_LOG_LEVELS
 
 log = logging.getLogger(__name__)

+ 1 - 1
backend/open_webui/apps/retrieval/search/tavily.py → backend/open_webui/apps/retrieval/web/tavily.py

@@ -1,7 +1,7 @@
 import logging
 
 import requests
-from open_webui.apps.retrieval.search.main import SearchResult
+from open_webui.apps.retrieval.web.main import SearchResult
 from open_webui.env import SRC_LOG_LEVELS
 
 log = logging.getLogger(__name__)

+ 0 - 0
backend/open_webui/apps/retrieval/search/testdata/brave.json → backend/open_webui/apps/retrieval/web/testdata/brave.json


+ 0 - 0
backend/open_webui/apps/retrieval/search/testdata/google_pse.json → backend/open_webui/apps/retrieval/web/testdata/google_pse.json


+ 0 - 0
backend/open_webui/apps/retrieval/search/testdata/searchapi.json → backend/open_webui/apps/retrieval/web/testdata/searchapi.json


+ 0 - 0
backend/open_webui/apps/retrieval/search/testdata/searxng.json → backend/open_webui/apps/retrieval/web/testdata/searxng.json


+ 0 - 0
backend/open_webui/apps/retrieval/search/testdata/serper.json → backend/open_webui/apps/retrieval/web/testdata/serper.json


+ 0 - 0
backend/open_webui/apps/retrieval/search/testdata/serply.json → backend/open_webui/apps/retrieval/web/testdata/serply.json


+ 0 - 0
backend/open_webui/apps/retrieval/search/testdata/serpstack.json → backend/open_webui/apps/retrieval/web/testdata/serpstack.json


+ 97 - 0
backend/open_webui/apps/retrieval/web/utils.py

@@ -0,0 +1,97 @@
+import socket
+import urllib.parse
+import validators
+from typing import Union, Sequence, Iterator
+
+from langchain_community.document_loaders import (
+    WebBaseLoader,
+)
+from langchain_core.documents import Document
+
+
+from open_webui.constants import ERROR_MESSAGES
+from open_webui.config import ENABLE_RAG_LOCAL_WEB_FETCH
+from open_webui.env import SRC_LOG_LEVELS
+
+import logging
+
+log = logging.getLogger(__name__)
+log.setLevel(SRC_LOG_LEVELS["RAG"])
+
+
+def validate_url(url: Union[str, Sequence[str]]):
+    if isinstance(url, str):
+        if isinstance(validators.url(url), validators.ValidationError):
+            raise ValueError(ERROR_MESSAGES.INVALID_URL)
+        if not ENABLE_RAG_LOCAL_WEB_FETCH:
+            # Local web fetch is disabled, filter out any URLs that resolve to private IP addresses
+            parsed_url = urllib.parse.urlparse(url)
+            # Get IPv4 and IPv6 addresses
+            ipv4_addresses, ipv6_addresses = resolve_hostname(parsed_url.hostname)
+            # Check if any of the resolved addresses are private
+            # This is technically still vulnerable to DNS rebinding attacks, as we don't control WebBaseLoader
+            for ip in ipv4_addresses:
+                if validators.ipv4(ip, private=True):
+                    raise ValueError(ERROR_MESSAGES.INVALID_URL)
+            for ip in ipv6_addresses:
+                if validators.ipv6(ip, private=True):
+                    raise ValueError(ERROR_MESSAGES.INVALID_URL)
+        return True
+    elif isinstance(url, Sequence):
+        return all(validate_url(u) for u in url)
+    else:
+        return False
+
+
+def resolve_hostname(hostname):
+    # Get address information
+    addr_info = socket.getaddrinfo(hostname, None)
+
+    # Extract IP addresses from address information
+    ipv4_addresses = [info[4][0] for info in addr_info if info[0] == socket.AF_INET]
+    ipv6_addresses = [info[4][0] for info in addr_info if info[0] == socket.AF_INET6]
+
+    return ipv4_addresses, ipv6_addresses
+
+
+class SafeWebBaseLoader(WebBaseLoader):
+    """WebBaseLoader with enhanced error handling for URLs."""
+
+    def lazy_load(self) -> Iterator[Document]:
+        """Lazy load text from the url(s) in web_path with error handling."""
+        for path in self.web_paths:
+            try:
+                soup = self._scrape(path, bs_kwargs=self.bs_kwargs)
+                text = soup.get_text(**self.bs_get_text_kwargs)
+
+                # Build metadata
+                metadata = {"source": path}
+                if title := soup.find("title"):
+                    metadata["title"] = title.get_text()
+                if description := soup.find("meta", attrs={"name": "description"}):
+                    metadata["description"] = description.get(
+                        "content", "No description found."
+                    )
+                if html := soup.find("html"):
+                    metadata["language"] = html.get("lang", "No language found.")
+
+                yield Document(page_content=text, metadata=metadata)
+            except Exception as e:
+                # Log the error and continue with the next URL
+                log.error(f"Error loading {path}: {e}")
+
+
+def get_web_loader(
+    url: Union[str, Sequence[str]],
+    verify_ssl: bool = True,
+    requests_per_second: int = 2,
+):
+    # Check if the URL is valid
+    if not validate_url(url):
+        raise ValueError(ERROR_MESSAGES.INVALID_URL)
+    return SafeWebBaseLoader(
+        url,
+        verify_ssl=verify_ssl,
+        requests_per_second=requests_per_second,
+        continue_on_failure=True,
+    )

+ 114 - 146
src/lib/apis/retrieval/index.ts

@@ -170,27 +170,23 @@ export const updateQuerySettings = async (token: string, settings: QuerySettings
 	return res;
 };
 
-export const processFile = async (token: string, file_id: string) => {
+export const getEmbeddingConfig = async (token: string) => {
 	let error = null;
 
-	const res = await fetch(`${RAG_API_BASE_URL}/process/file`, {
-		method: 'POST',
+	const res = await fetch(`${RAG_API_BASE_URL}/embedding`, {
+		method: 'GET',
 		headers: {
-			Accept: 'application/json',
 			'Content-Type': 'application/json',
-			authorization: `Bearer ${token}`
-		},
-		body: JSON.stringify({
-			file_id: file_id
-		})
+			Authorization: `Bearer ${token}`
+		}
 	})
 		.then(async (res) => {
 			if (!res.ok) throw await res.json();
 			return res.json();
 		})
 		.catch((err) => {
-			error = err.detail;
 			console.log(err);
+			error = err.detail;
 			return null;
 		});
 
@@ -201,51 +197,29 @@ export const processFile = async (token: string, file_id: string) => {
 	return res;
 };
 
-export const uploadDocToVectorDB = async (token: string, collection_name: string, file: File) => {
-	const data = new FormData();
-	data.append('file', file);
-	data.append('collection_name', collection_name);
-
-	let error = null;
-
-	const res = await fetch(`${RAG_API_BASE_URL}/doc`, {
-		method: 'POST',
-		headers: {
-			Accept: 'application/json',
-			authorization: `Bearer ${token}`
-		},
-		body: data
-	})
-		.then(async (res) => {
-			if (!res.ok) throw await res.json();
-			return res.json();
-		})
-		.catch((err) => {
-			error = err.detail;
-			console.log(err);
-			return null;
-		});
-
-	if (error) {
-		throw error;
-	}
+type OpenAIConfigForm = {
+	key: string;
+	url: string;
+	batch_size: number;
+};
 
-	return res;
+type EmbeddingModelUpdateForm = {
+	openai_config?: OpenAIConfigForm;
+	embedding_engine: string;
+	embedding_model: string;
 };
 
-export const uploadWebToVectorDB = async (token: string, collection_name: string, url: string) => {
+export const updateEmbeddingConfig = async (token: string, payload: EmbeddingModelUpdateForm) => {
 	let error = null;
 
-	const res = await fetch(`${RAG_API_BASE_URL}/web`, {
+	const res = await fetch(`${RAG_API_BASE_URL}/embedding/update`, {
 		method: 'POST',
 		headers: {
-			Accept: 'application/json',
 			'Content-Type': 'application/json',
-			authorization: `Bearer ${token}`
+			Authorization: `Bearer ${token}`
 		},
 		body: JSON.stringify({
-			url: url,
-			collection_name: collection_name
+			...payload
 		})
 	})
 		.then(async (res) => {
@@ -253,8 +227,8 @@ export const uploadWebToVectorDB = async (token: string, collection_name: string
 			return res.json();
 		})
 		.catch((err) => {
-			error = err.detail;
 			console.log(err);
+			error = err.detail;
 			return null;
 		});
 
@@ -265,27 +239,23 @@ export const uploadWebToVectorDB = async (token: string, collection_name: string
 	return res;
 };
 
-export const uploadYoutubeTranscriptionToVectorDB = async (token: string, url: string) => {
+export const getRerankingConfig = async (token: string) => {
 	let error = null;
 
-	const res = await fetch(`${RAG_API_BASE_URL}/youtube`, {
-		method: 'POST',
+	const res = await fetch(`${RAG_API_BASE_URL}/reranking`, {
+		method: 'GET',
 		headers: {
-			Accept: 'application/json',
 			'Content-Type': 'application/json',
-			authorization: `Bearer ${token}`
-		},
-		body: JSON.stringify({
-			url: url
-		})
+			Authorization: `Bearer ${token}`
+		}
 	})
 		.then(async (res) => {
 			if (!res.ok) throw await res.json();
 			return res.json();
 		})
 		.catch((err) => {
-			error = err.detail;
 			console.log(err);
+			error = err.detail;
 			return null;
 		});
 
@@ -296,25 +266,21 @@ export const uploadYoutubeTranscriptionToVectorDB = async (token: string, url: s
 	return res;
 };
 
-export const queryDoc = async (
-	token: string,
-	collection_name: string,
-	query: string,
-	k: number | null = null
-) => {
+type RerankingModelUpdateForm = {
+	reranking_model: string;
+};
+
+export const updateRerankingConfig = async (token: string, payload: RerankingModelUpdateForm) => {
 	let error = null;
 
-	const res = await fetch(`${RAG_API_BASE_URL}/query/doc`, {
+	const res = await fetch(`${RAG_API_BASE_URL}/reranking/update`, {
 		method: 'POST',
 		headers: {
-			Accept: 'application/json',
 			'Content-Type': 'application/json',
-			authorization: `Bearer ${token}`
+			Authorization: `Bearer ${token}`
 		},
 		body: JSON.stringify({
-			collection_name: collection_name,
-			query: query,
-			k: k
+			...payload
 		})
 	})
 		.then(async (res) => {
@@ -322,6 +288,7 @@ export const queryDoc = async (
 			return res.json();
 		})
 		.catch((err) => {
+			console.log(err);
 			error = err.detail;
 			return null;
 		});
@@ -333,15 +300,16 @@ export const queryDoc = async (
 	return res;
 };
 
-export const queryCollection = async (
-	token: string,
-	collection_names: string,
-	query: string,
-	k: number | null = null
-) => {
+export interface SearchDocument {
+	status: boolean;
+	collection_name: string;
+	filenames: string[];
+}
+
+export const processFile = async (token: string, file_id: string) => {
 	let error = null;
 
-	const res = await fetch(`${RAG_API_BASE_URL}/query/collection`, {
+	const res = await fetch(`${RAG_API_BASE_URL}/process/file`, {
 		method: 'POST',
 		headers: {
 			Accept: 'application/json',
@@ -349,9 +317,7 @@ export const queryCollection = async (
 			authorization: `Bearer ${token}`
 		},
 		body: JSON.stringify({
-			collection_names: collection_names,
-			query: query,
-			k: k
+			file_id: file_id
 		})
 	})
 		.then(async (res) => {
@@ -360,6 +326,7 @@ export const queryCollection = async (
 		})
 		.catch((err) => {
 			error = err.detail;
+			console.log(err);
 			return null;
 		});
 
@@ -370,10 +337,10 @@ export const queryCollection = async (
 	return res;
 };
 
-export const scanDocs = async (token: string) => {
+export const processDocsDir = async (token: string) => {
 	let error = null;
 
-	const res = await fetch(`${RAG_API_BASE_URL}/scan`, {
+	const res = await fetch(`${RAG_API_BASE_URL}/process/dir`, {
 		method: 'GET',
 		headers: {
 			Accept: 'application/json',
@@ -396,15 +363,19 @@ export const scanDocs = async (token: string) => {
 	return res;
 };
 
-export const resetUploadDir = async (token: string) => {
+export const processYoutubeVideo = async (token: string, url: string) => {
 	let error = null;
 
-	const res = await fetch(`${RAG_API_BASE_URL}/reset/uploads`, {
+	const res = await fetch(`${RAG_API_BASE_URL}/process/youtube`, {
 		method: 'POST',
 		headers: {
 			Accept: 'application/json',
+			'Content-Type': 'application/json',
 			authorization: `Bearer ${token}`
-		}
+		},
+		body: JSON.stringify({
+			url: url
+		})
 	})
 		.then(async (res) => {
 			if (!res.ok) throw await res.json();
@@ -412,6 +383,7 @@ export const resetUploadDir = async (token: string) => {
 		})
 		.catch((err) => {
 			error = err.detail;
+			console.log(err);
 			return null;
 		});
 
@@ -422,15 +394,20 @@ export const resetUploadDir = async (token: string) => {
 	return res;
 };
 
-export const resetVectorDB = async (token: string) => {
+export const processWeb = async (token: string, collection_name: string, url: string) => {
 	let error = null;
 
-	const res = await fetch(`${RAG_API_BASE_URL}/reset/db`, {
+	const res = await fetch(`${RAG_API_BASE_URL}/process/web`, {
 		method: 'POST',
 		headers: {
 			Accept: 'application/json',
+			'Content-Type': 'application/json',
 			authorization: `Bearer ${token}`
-		}
+		},
+		body: JSON.stringify({
+			url: url,
+			collection_name: collection_name
+		})
 	})
 		.then(async (res) => {
 			if (!res.ok) throw await res.json();
@@ -438,6 +415,7 @@ export const resetVectorDB = async (token: string) => {
 		})
 		.catch((err) => {
 			error = err.detail;
+			console.log(err);
 			return null;
 		});
 
@@ -448,15 +426,23 @@ export const resetVectorDB = async (token: string) => {
 	return res;
 };
 
-export const getEmbeddingConfig = async (token: string) => {
+export const processWebSearch = async (
+	token: string,
+	query: string,
+	collection_name?: string
+): Promise<SearchDocument | null> => {
 	let error = null;
 
-	const res = await fetch(`${RAG_API_BASE_URL}/embedding`, {
-		method: 'GET',
+	const res = await fetch(`${RAG_API_BASE_URL}/process/web/search`, {
+		method: 'POST',
 		headers: {
 			'Content-Type': 'application/json',
 			Authorization: `Bearer ${token}`
-		}
+		},
+		body: JSON.stringify({
+			query,
+			collection_name: collection_name ?? ''
+		})
 	})
 		.then(async (res) => {
 			if (!res.ok) throw await res.json();
@@ -475,29 +461,25 @@ export const getEmbeddingConfig = async (token: string) => {
 	return res;
 };
 
-type OpenAIConfigForm = {
-	key: string;
-	url: string;
-	batch_size: number;
-};
-
-type EmbeddingModelUpdateForm = {
-	openai_config?: OpenAIConfigForm;
-	embedding_engine: string;
-	embedding_model: string;
-};
-
-export const updateEmbeddingConfig = async (token: string, payload: EmbeddingModelUpdateForm) => {
+export const queryDoc = async (
+	token: string,
+	collection_name: string,
+	query: string,
+	k: number | null = null
+) => {
 	let error = null;
 
-	const res = await fetch(`${RAG_API_BASE_URL}/embedding/update`, {
+	const res = await fetch(`${RAG_API_BASE_URL}/query/doc`, {
 		method: 'POST',
 		headers: {
+			Accept: 'application/json',
 			'Content-Type': 'application/json',
-			Authorization: `Bearer ${token}`
+			authorization: `Bearer ${token}`
 		},
 		body: JSON.stringify({
-			...payload
+			collection_name: collection_name,
+			query: query,
+			k: k
 		})
 	})
 		.then(async (res) => {
@@ -505,7 +487,6 @@ export const updateEmbeddingConfig = async (token: string, payload: EmbeddingMod
 			return res.json();
 		})
 		.catch((err) => {
-			console.log(err);
 			error = err.detail;
 			return null;
 		});
@@ -517,22 +498,32 @@ export const updateEmbeddingConfig = async (token: string, payload: EmbeddingMod
 	return res;
 };
 
-export const getRerankingConfig = async (token: string) => {
+export const queryCollection = async (
+	token: string,
+	collection_names: string,
+	query: string,
+	k: number | null = null
+) => {
 	let error = null;
 
-	const res = await fetch(`${RAG_API_BASE_URL}/reranking`, {
-		method: 'GET',
+	const res = await fetch(`${RAG_API_BASE_URL}/query/collection`, {
+		method: 'POST',
 		headers: {
+			Accept: 'application/json',
 			'Content-Type': 'application/json',
-			Authorization: `Bearer ${token}`
-		}
+			authorization: `Bearer ${token}`
+		},
+		body: JSON.stringify({
+			collection_names: collection_names,
+			query: query,
+			k: k
+		})
 	})
 		.then(async (res) => {
 			if (!res.ok) throw await res.json();
 			return res.json();
 		})
 		.catch((err) => {
-			console.log(err);
 			error = err.detail;
 			return null;
 		});
@@ -544,29 +535,21 @@ export const getRerankingConfig = async (token: string) => {
 	return res;
 };
 
-type RerankingModelUpdateForm = {
-	reranking_model: string;
-};
-
-export const updateRerankingConfig = async (token: string, payload: RerankingModelUpdateForm) => {
+export const resetUploadDir = async (token: string) => {
 	let error = null;
 
-	const res = await fetch(`${RAG_API_BASE_URL}/reranking/update`, {
+	const res = await fetch(`${RAG_API_BASE_URL}/reset/uploads`, {
 		method: 'POST',
 		headers: {
-			'Content-Type': 'application/json',
-			Authorization: `Bearer ${token}`
-		},
-		body: JSON.stringify({
-			...payload
-		})
+			Accept: 'application/json',
+			authorization: `Bearer ${token}`
+		}
 	})
 		.then(async (res) => {
 			if (!res.ok) throw await res.json();
 			return res.json();
 		})
 		.catch((err) => {
-			console.log(err);
 			error = err.detail;
 			return null;
 		});
@@ -578,30 +561,21 @@ export const updateRerankingConfig = async (token: string, payload: RerankingMod
 	return res;
 };
 
-export const runWebSearch = async (
-	token: string,
-	query: string,
-	collection_name?: string
-): Promise<SearchDocument | null> => {
+export const resetVectorDB = async (token: string) => {
 	let error = null;
 
-	const res = await fetch(`${RAG_API_BASE_URL}/web/search`, {
+	const res = await fetch(`${RAG_API_BASE_URL}/reset/db`, {
 		method: 'POST',
 		headers: {
-			'Content-Type': 'application/json',
-			Authorization: `Bearer ${token}`
-		},
-		body: JSON.stringify({
-			query,
-			collection_name: collection_name ?? ''
-		})
+			Accept: 'application/json',
+			authorization: `Bearer ${token}`
+		}
 	})
 		.then(async (res) => {
 			if (!res.ok) throw await res.json();
 			return res.json();
 		})
 		.catch((err) => {
-			console.log(err);
 			error = err.detail;
 			return null;
 		});
@@ -612,9 +586,3 @@ export const runWebSearch = async (
 
 	return res;
 };
-
-export interface SearchDocument {
-	status: boolean;
-	collection_name: string;
-	filenames: string[];
-}

+ 2 - 2
src/lib/components/admin/Settings/Documents.svelte

@@ -7,7 +7,7 @@
 	import { deleteAllFiles, deleteFileById } from '$lib/apis/files';
 	import {
 		getQuerySettings,
-		scanDocs,
+		processDocsDir,
 		updateQuerySettings,
 		resetVectorDB,
 		getEmbeddingConfig,
@@ -63,7 +63,7 @@
 
 	const scanHandler = async () => {
 		scanDirLoading = true;
-		const res = await scanDocs(localStorage.token);
+		const res = await processDocsDir(localStorage.token);
 		scanDirLoading = false;
 
 		if (res) {

+ 2 - 2
src/lib/components/chat/Chat.svelte

@@ -52,7 +52,7 @@
 		updateChatById
 	} from '$lib/apis/chats';
 	import { generateOpenAIChatCompletion } from '$lib/apis/openai';
-	import { runWebSearch } from '$lib/apis/retrieval';
+	import { processWebSearch } from '$lib/apis/retrieval';
 	import { createOpenAITextStream } from '$lib/apis/streaming';
 	import { queryMemory } from '$lib/apis/memories';
 	import { getAndUpdateUserLocation, getUserSettings } from '$lib/apis/users';
@@ -1737,7 +1737,7 @@
 		});
 		history.messages[responseMessageId] = responseMessage;
 
-		const results = await runWebSearch(localStorage.token, searchQuery).catch((error) => {
+		const results = await processWebSearch(localStorage.token, searchQuery).catch((error) => {
 			console.log(error);
 			toast.error(error);
 

+ 3 - 0
src/lib/components/chat/Controls/Controls.svelte

@@ -46,6 +46,9 @@
 								chatFiles.splice(fileIdx, 1);
 								chatFiles = chatFiles;
 							}}
+							on:click={() => {
+								console.log(file);
+							}}
 						/>
 					{/each}
 				</div>

+ 3 - 3
src/lib/components/chat/MessageInput/Commands.svelte

@@ -9,7 +9,7 @@
 	import Models from './Commands/Models.svelte';
 
 	import { removeLastWordFromString } from '$lib/utils';
-	import { uploadWebToVectorDB, uploadYoutubeTranscriptionToVectorDB } from '$lib/apis/retrieval';
+	import { processWeb, processYoutubeVideo } from '$lib/apis/retrieval';
 
 	export let prompt = '';
 	export let files = [];
@@ -41,7 +41,7 @@
 
 		try {
 			files = [...files, doc];
-			const res = await uploadWebToVectorDB(localStorage.token, '', url);
+			const res = await processWeb(localStorage.token, '', url);
 
 			if (res) {
 				doc.status = 'processed';
@@ -69,7 +69,7 @@
 
 		try {
 			files = [...files, doc];
-			const res = await uploadYoutubeTranscriptionToVectorDB(localStorage.token, url);
+			const res = await processYoutubeVideo(localStorage.token, url);
 
 			if (res) {
 				doc.status = 'processed';

+ 3 - 15
src/lib/components/common/FileItem.svelte

@@ -8,8 +8,6 @@
 	export let colorClassName = 'bg-white dark:bg-gray-800';
 	export let url: string | null = null;
 
-	export let clickHandler: Function | null = null;
-
 	export let dismissible = false;
 	export let status = 'processed';
 
@@ -17,7 +15,7 @@
 	export let type: string;
 	export let size: number;
 
-	function formatSize(size) {
+	const formatSize = (size) => {
 		if (size == null) return 'Unknown size';
 		if (typeof size !== 'number' || size < 0) return 'Invalid size';
 		if (size === 0) return '0 B';
@@ -29,7 +27,7 @@
 			unitIndex++;
 		}
 		return `${size.toFixed(1)} ${units[unitIndex]}`;
-	}
+	};
 </script>
 
 <div class="relative group">
@@ -37,17 +35,7 @@
 		class="h-14 {className} flex items-center space-x-3 {colorClassName} rounded-xl border border-gray-100 dark:border-gray-800 text-left"
 		type="button"
 		on:click={async () => {
-			if (clickHandler === null) {
-				if (url) {
-					if (type === 'file') {
-						window.open(`${url}/content`, '_blank').focus();
-					} else {
-						window.open(`${url}`, '_blank').focus();
-					}
-				}
-			} else {
-				clickHandler();
-			}
+			dispatch('click');
 		}}
 	>
 		<div class="p-4 py-[1.1rem] bg-red-400 text-white rounded-l-xl">

部分文件因为文件数量过多而无法显示