Timothy J. Baek 7 달 전
부모
커밋
4354f270ce

+ 23 - 22
backend/open_webui/apps/rag/main.py

@@ -12,11 +12,16 @@ from typing import Iterator, Optional, Sequence, Union
 
 import requests
 import validators
+
+from fastapi import Depends, FastAPI, File, Form, HTTPException, UploadFile, status
+from fastapi.middleware.cors import CORSMiddleware
+from pydantic import BaseModel
+
+from open_webui.apps.rag.search.main import SearchResult
 from open_webui.apps.rag.search.brave import search_brave
 from open_webui.apps.rag.search.duckduckgo import search_duckduckgo
 from open_webui.apps.rag.search.google_pse import search_google_pse
 from open_webui.apps.rag.search.jina_search import search_jina
-from open_webui.apps.rag.search.main import SearchResult
 from open_webui.apps.rag.search.searchapi import search_searchapi
 from open_webui.apps.rag.search.searxng import search_searxng
 from open_webui.apps.rag.search.serper import search_serper
@@ -33,10 +38,8 @@ from open_webui.apps.rag.utils import (
 )
 from open_webui.apps.webui.models.documents import DocumentForm, Documents
 from open_webui.apps.webui.models.files import Files
-from chromadb.utils.batch_utils import create_batches
 from open_webui.config import (
     BRAVE_SEARCH_API_KEY,
-    CHROMA_CLIENT,
     CHUNK_OVERLAP,
     CHUNK_SIZE,
     CONTENT_EXTRACTION_ENGINE,
@@ -84,9 +87,17 @@ from open_webui.config import (
     AppConfig,
 )
 from open_webui.constants import ERROR_MESSAGES
-from open_webui.env import SRC_LOG_LEVELS
-from fastapi import Depends, FastAPI, File, Form, HTTPException, UploadFile, status
-from fastapi.middleware.cors import CORSMiddleware
+from open_webui.env import SRC_LOG_LEVELS, DEVICE_TYPE
+from open_webui.utils.misc import (
+    calculate_sha256,
+    calculate_sha256_string,
+    extract_folders_after_data_docs,
+    sanitize_filename,
+)
+from open_webui.utils.utils import get_admin_user, get_verified_user
+from open_webui.apps.rag.vector.connector import VECTOR_DB_CLIENT
+
+from chromadb.utils.batch_utils import create_batches
 from langchain.text_splitter import RecursiveCharacterTextSplitter
 from langchain_community.document_loaders import (
     BSHTMLLoader,
@@ -105,14 +116,6 @@ from langchain_community.document_loaders import (
     YoutubeLoader,
 )
 from langchain_core.documents import Document
-from pydantic import BaseModel
-from open_webui.utils.misc import (
-    calculate_sha256,
-    calculate_sha256_string,
-    extract_folders_after_data_docs,
-    sanitize_filename,
-)
-from open_webui.utils.utils import get_admin_user, get_verified_user
 
 log = logging.getLogger(__name__)
 log.setLevel(SRC_LOG_LEVELS["RAG"])
@@ -143,13 +146,11 @@ app.state.config.RAG_EMBEDDING_OPENAI_BATCH_SIZE = RAG_EMBEDDING_OPENAI_BATCH_SI
 app.state.config.RAG_RERANKING_MODEL = RAG_RERANKING_MODEL
 app.state.config.RAG_TEMPLATE = RAG_TEMPLATE
 
-
 app.state.config.OPENAI_API_BASE_URL = RAG_OPENAI_API_BASE_URL
 app.state.config.OPENAI_API_KEY = RAG_OPENAI_API_KEY
 
 app.state.config.PDF_EXTRACT_IMAGES = PDF_EXTRACT_IMAGES
 
-
 app.state.config.YOUTUBE_LOADER_LANGUAGE = YOUTUBE_LOADER_LANGUAGE
 app.state.YOUTUBE_LOADER_TRANSLATION = None
 
@@ -998,12 +999,12 @@ def store_docs_in_vector_db(
 
     try:
         if overwrite:
-            for collection in CHROMA_CLIENT.list_collections():
+            for collection in VECTOR_DB_CLIENT.list_collections():
                 if collection_name == collection.name:
                     log.info(f"deleting existing collection {collection_name}")
-                    CHROMA_CLIENT.delete_collection(name=collection_name)
+                    VECTOR_DB_CLIENT.delete_collection(name=collection_name)
 
-        collection = CHROMA_CLIENT.create_collection(name=collection_name)
+        collection = VECTOR_DB_CLIENT.create_collection(name=collection_name)
 
         embedding_func = get_embedding_function(
             app.state.config.RAG_EMBEDDING_ENGINE,
@@ -1018,7 +1019,7 @@ def store_docs_in_vector_db(
         embeddings = embedding_func(embedding_texts)
 
         for batch in create_batches(
-            api=CHROMA_CLIENT,
+            api=VECTOR_DB_CLIENT,
             ids=[str(uuid.uuid4()) for _ in texts],
             metadatas=metadatas,
             embeddings=embeddings,
@@ -1396,7 +1397,7 @@ def scan_docs_dir(user=Depends(get_admin_user)):
 
 @app.post("/reset/db")
 def reset_vector_db(user=Depends(get_admin_user)):
-    CHROMA_CLIENT.reset()
+    VECTOR_DB_CLIENT.reset()
 
 
 @app.post("/reset/uploads")
@@ -1437,7 +1438,7 @@ def reset(user=Depends(get_admin_user)) -> bool:
             log.error("Failed to delete %s. Reason: %s" % (file_path, e))
 
     try:
-        CHROMA_CLIENT.reset()
+        VECTOR_DB_CLIENT.reset()
     except Exception as e:
         log.exception(e)
 

+ 16 - 13
backend/open_webui/apps/rag/utils.py

@@ -3,18 +3,23 @@ import os
 from typing import Optional, Union
 
 import requests
-from open_webui.apps.ollama.main import (
-    GenerateEmbeddingsForm,
-    generate_ollama_embeddings,
-)
-from open_webui.config import CHROMA_CLIENT
-from open_webui.env import SRC_LOG_LEVELS
+
 from huggingface_hub import snapshot_download
 from langchain.retrievers import ContextualCompressionRetriever, EnsembleRetriever
 from langchain_community.retrievers import BM25Retriever
 from langchain_core.documents import Document
+
+
+from open_webui.apps.ollama.main import (
+    GenerateEmbeddingsForm,
+    generate_ollama_embeddings,
+)
+from open_webui.apps.rag.vector.connector import VECTOR_DB_CLIENT
 from open_webui.utils.misc import get_last_user_message
 
+from open_webui.env import SRC_LOG_LEVELS
+
+
 log = logging.getLogger(__name__)
 log.setLevel(SRC_LOG_LEVELS["RAG"])
 
@@ -26,12 +31,10 @@ def query_doc(
     k: int,
 ):
     try:
-        collection = CHROMA_CLIENT.get_collection(name=collection_name)
-        query_embeddings = embedding_function(query)
-
-        result = collection.query(
-            query_embeddings=[query_embeddings],
-            n_results=k,
+        result = VECTOR_DB_CLIENT.query_collection(
+            name=collection_name,
+            query_embeddings=embedding_function(query),
+            k=k,
         )
 
         log.info(f"query_doc:result {result}")
@@ -49,7 +52,7 @@ def query_doc_with_hybrid_search(
     r: float,
 ):
     try:
-        collection = CHROMA_CLIENT.get_collection(name=collection_name)
+        collection = VECTOR_DB_CLIENT.get_collection(name=collection_name)
         documents = collection.get()  # get all documents
 
         bm25_retriever = BM25Retriever.from_texts(

+ 4 - 0
backend/open_webui/apps/rag/vector/connector.py

@@ -0,0 +1,4 @@
+from open_webui.apps.rag.vector.dbs.chroma import Chroma
+from open_webui.config import VECTOR_DB
+
+VECTOR_DB_CLIENT = Chroma()

+ 58 - 0
backend/open_webui/apps/rag/vector/dbs/chroma.py

@@ -0,0 +1,58 @@
+import chromadb
+from chromadb import Settings
+
+from open_webui.config import (
+    CHROMA_DATA_PATH,
+    CHROMA_HTTP_HOST,
+    CHROMA_HTTP_PORT,
+    CHROMA_HTTP_HEADERS,
+    CHROMA_HTTP_SSL,
+    CHROMA_TENANT,
+    CHROMA_DATABASE,
+)
+
+
+class Chroma:
+    def __init__(self):
+        if CHROMA_HTTP_HOST != "":
+            self.client = chromadb.HttpClient(
+                host=CHROMA_HTTP_HOST,
+                port=CHROMA_HTTP_PORT,
+                headers=CHROMA_HTTP_HEADERS,
+                ssl=CHROMA_HTTP_SSL,
+                tenant=CHROMA_TENANT,
+                database=CHROMA_DATABASE,
+                settings=Settings(allow_reset=True, anonymized_telemetry=False),
+            )
+        else:
+            self.client = chromadb.PersistentClient(
+                path=CHROMA_DATA_PATH,
+                settings=Settings(allow_reset=True, anonymized_telemetry=False),
+                tenant=CHROMA_TENANT,
+                database=CHROMA_DATABASE,
+            )
+
+    def query_collection(self, name, query_embeddings, k):
+        collection = self.client.get_collection(name=name)
+        if collection:
+            result = collection.query(
+                query_embeddings=[query_embeddings],
+                n_results=k,
+            )
+            return result
+        return None
+
+    def list_collections(self):
+        return self.client.list_collections()
+
+    def create_collection(self, name):
+        return self.client.create_collection(name=name)
+
+    def get_or_create_collection(self, name):
+        return self.client.get_or_create_collection(name=name)
+
+    def delete_collection(self, name):
+        return self.client.delete_collection(name=name)
+
+    def reset(self):
+        return self.client.reset()

+ 0 - 0
backend/open_webui/apps/rag/vector/dbs/milvus.py


+ 18 - 11
backend/open_webui/apps/webui/routers/memories.py

@@ -1,12 +1,13 @@
+from fastapi import APIRouter, Depends, HTTPException, Request
+from pydantic import BaseModel
 import logging
 from typing import Optional
 
 from open_webui.apps.webui.models.memories import Memories, MemoryModel
-from open_webui.config import CHROMA_CLIENT
-from open_webui.env import SRC_LOG_LEVELS
-from fastapi import APIRouter, Depends, HTTPException, Request
-from pydantic import BaseModel
+from open_webui.apps.rag.vector.connector import VECTOR_DB_CLIENT
 from open_webui.utils.utils import get_verified_user
+from open_webui.env import SRC_LOG_LEVELS
+
 
 log = logging.getLogger(__name__)
 log.setLevel(SRC_LOG_LEVELS["MODELS"])
@@ -51,7 +52,9 @@ async def add_memory(
     memory = Memories.insert_new_memory(user.id, form_data.content)
     memory_embedding = request.app.state.EMBEDDING_FUNCTION(memory.content)
 
-    collection = CHROMA_CLIENT.get_or_create_collection(name=f"user-memory-{user.id}")
+    collection = VECTOR_DB_CLIENT.get_or_create_collection(
+        name=f"user-memory-{user.id}"
+    )
     collection.upsert(
         documents=[memory.content],
         ids=[memory.id],
@@ -77,7 +80,9 @@ async def query_memory(
     request: Request, form_data: QueryMemoryForm, user=Depends(get_verified_user)
 ):
     query_embedding = request.app.state.EMBEDDING_FUNCTION(form_data.content)
-    collection = CHROMA_CLIENT.get_or_create_collection(name=f"user-memory-{user.id}")
+    collection = VECTOR_DB_CLIENT.get_or_create_collection(
+        name=f"user-memory-{user.id}"
+    )
 
     results = collection.query(
         query_embeddings=[query_embedding],
@@ -94,8 +99,10 @@ async def query_memory(
 async def reset_memory_from_vector_db(
     request: Request, user=Depends(get_verified_user)
 ):
-    CHROMA_CLIENT.delete_collection(f"user-memory-{user.id}")
-    collection = CHROMA_CLIENT.get_or_create_collection(name=f"user-memory-{user.id}")
+    VECTOR_DB_CLIENT.delete_collection(f"user-memory-{user.id}")
+    collection = VECTOR_DB_CLIENT.get_or_create_collection(
+        name=f"user-memory-{user.id}"
+    )
 
     memories = Memories.get_memories_by_user_id(user.id)
     for memory in memories:
@@ -119,7 +126,7 @@ async def delete_memory_by_user_id(user=Depends(get_verified_user)):
 
     if result:
         try:
-            CHROMA_CLIENT.delete_collection(f"user-memory-{user.id}")
+            VECTOR_DB_CLIENT.delete_collection(f"user-memory-{user.id}")
         except Exception as e:
             log.error(e)
         return True
@@ -145,7 +152,7 @@ async def update_memory_by_id(
 
     if form_data.content is not None:
         memory_embedding = request.app.state.EMBEDDING_FUNCTION(form_data.content)
-        collection = CHROMA_CLIENT.get_or_create_collection(
+        collection = VECTOR_DB_CLIENT.get_or_create_collection(
             name=f"user-memory-{user.id}"
         )
         collection.upsert(
@@ -170,7 +177,7 @@ async def delete_memory_by_id(memory_id: str, user=Depends(get_verified_user)):
     result = Memories.delete_memory_by_id_and_user_id(memory_id, user.id)
 
     if result:
-        collection = CHROMA_CLIENT.get_or_create_collection(
+        collection = VECTOR_DB_CLIENT.get_or_create_collection(
             name=f"user-memory-{user.id}"
         )
         collection.delete(ids=[memory_id])

+ 19 - 16
backend/open_webui/config.py

@@ -11,7 +11,6 @@ import chromadb
 import requests
 import yaml
 from open_webui.apps.webui.internal.db import Base, get_db
-from chromadb import Settings
 from open_webui.env import (
     OPEN_WEBUI_DIR,
     DATA_DIR,
@@ -926,22 +925,9 @@ TOOLS_FUNCTION_CALLING_PROMPT_TEMPLATE = PersistentConfig(
 # RAG document content extraction
 ####################################
 
-CONTENT_EXTRACTION_ENGINE = PersistentConfig(
-    "CONTENT_EXTRACTION_ENGINE",
-    "rag.CONTENT_EXTRACTION_ENGINE",
-    os.environ.get("CONTENT_EXTRACTION_ENGINE", "").lower(),
-)
-
-TIKA_SERVER_URL = PersistentConfig(
-    "TIKA_SERVER_URL",
-    "rag.tika_server_url",
-    os.getenv("TIKA_SERVER_URL", "http://tika:9998"),  # Default for sidecar deployment
-)
-
-####################################
-# RAG
-####################################
+VECTOR_DB = os.environ.get("VECTOR_DB", "chroma")
 
+# Chroma
 CHROMA_DATA_PATH = f"{DATA_DIR}/vector_db"
 CHROMA_TENANT = os.environ.get("CHROMA_TENANT", chromadb.DEFAULT_TENANT)
 CHROMA_DATABASE = os.environ.get("CHROMA_DATABASE", chromadb.DEFAULT_DATABASE)
@@ -958,6 +944,23 @@ else:
 CHROMA_HTTP_SSL = os.environ.get("CHROMA_HTTP_SSL", "false").lower() == "true"
 # this uses the model defined in the Dockerfile ENV variable. If you dont use docker or docker based deployments such as k8s, the default embedding model will be used (sentence-transformers/all-MiniLM-L6-v2)
 
+####################################
+# RAG
+####################################
+
+# RAG Content Extraction
+CONTENT_EXTRACTION_ENGINE = PersistentConfig(
+    "CONTENT_EXTRACTION_ENGINE",
+    "rag.CONTENT_EXTRACTION_ENGINE",
+    os.environ.get("CONTENT_EXTRACTION_ENGINE", "").lower(),
+)
+
+TIKA_SERVER_URL = PersistentConfig(
+    "TIKA_SERVER_URL",
+    "rag.tika_server_url",
+    os.getenv("TIKA_SERVER_URL", "http://tika:9998"),  # Default for sidecar deployment
+)
+
 RAG_TOP_K = PersistentConfig(
     "RAG_TOP_K", "rag.top_k", int(os.environ.get("RAG_TOP_K", "5"))
 )