Timothy J. Baek 7 月之前
父节点
当前提交
939bfd153e

+ 9 - 7
backend/open_webui/apps/rag/utils.py

@@ -48,9 +48,9 @@ class VectorSearchRetriever(BaseRetriever):
             limit=self.top_k,
         )
 
-        ids = result["ids"][0]
-        metadatas = result["metadatas"][0]
-        documents = result["documents"][0]
+        ids = result.ids[0]
+        metadatas = result.metadatas[0]
+        documents = result.documents[0]
 
         results = []
         for idx in range(len(ids)):
@@ -194,7 +194,7 @@ def query_collection(
                     k=k,
                     embedding_function=embedding_function,
                 )
-                results.append(result)
+                results.append(result.model_dump())
             except Exception as e:
                 log.exception(f"Error when querying the collection: {e}")
         else:
@@ -212,7 +212,7 @@ def query_collection_with_hybrid_search(
     r: float,
 ) -> dict:
     results = []
-    failed = 0
+    error = False
     for collection_name in collection_names:
         try:
             result = query_doc_with_hybrid_search(
@@ -228,12 +228,14 @@ def query_collection_with_hybrid_search(
             log.exception(
                 "Error when querying the collection with " f"hybrid_search: {e}"
             )
-            failed += 1
-    if failed == len(collection_names):
+            error = True
+
+    if error:
         raise Exception(
             "Hybrid search failed for all collections. Using "
             "Non hybrid search as fallback."
         )
+
     return merge_and_sort_query_results(results, k=k, reverse=True)
 
 

+ 22 - 10
backend/open_webui/apps/rag/vector/dbs/chroma.py

@@ -4,7 +4,7 @@ from chromadb.utils.batch_utils import create_batches
 
 from typing import Optional
 
-from open_webui.apps.rag.vector.main import VectorItem, QueryResult
+from open_webui.apps.rag.vector.main import VectorItem, SearchResult, GetResult
 from open_webui.config import (
     CHROMA_DATA_PATH,
     CHROMA_HTTP_HOST,
@@ -47,7 +47,7 @@ class ChromaClient:
 
     def search(
         self, collection_name: str, vectors: list[list[float | int]], limit: int
-    ) -> Optional[QueryResult]:
+    ) -> Optional[SearchResult]:
         # Search for the nearest neighbor items based on the vectors and return 'limit' number of results.
         collection = self.client.get_collection(name=collection_name)
         if collection:
@@ -56,19 +56,31 @@ class ChromaClient:
                 n_results=limit,
             )
 
-            return {
-                "ids": result["ids"],
-                "distances": result["distances"],
-                "documents": result["documents"],
-                "metadatas": result["metadatas"],
-            }
+            return SearchResult(
+                **{
+                    "ids": result["ids"],
+                    "distances": result["distances"],
+                    "documents": result["documents"],
+                    "metadatas": result["metadatas"],
+                }
+            )
         return None
 
-    def get(self, collection_name: str) -> Optional[QueryResult]:
+    def get(self, collection_name: str) -> Optional[GetResult]:
         # Get all the items in the collection.
         collection = self.client.get_collection(name=collection_name)
         if collection:
-            return collection.get()
+
+            result = collection.get()
+
+            return GetResult(
+                **{
+                    "ids": [result["ids"]],
+                    "distances": [result["distances"]],
+                    "documents": [result["documents"]],
+                    "metadatas": [result["metadatas"]],
+                }
+            )
         return None
 
     def insert(self, collection_name: str, items: list[VectorItem]):

+ 13 - 10
backend/open_webui/apps/rag/vector/dbs/milvus.py

@@ -4,7 +4,7 @@ import json
 
 from typing import Optional
 
-from open_webui.apps.rag.vector.main import VectorItem, QueryResult
+from open_webui.apps.rag.vector.main import VectorItem, SearchResult, GetResult
 from open_webui.config import (
     MILVUS_URI,
 )
@@ -15,7 +15,7 @@ class MilvusClient:
         self.collection_prefix = "open_webui"
         self.client = Client(uri=MILVUS_URI)
 
-    def _result_to_query_result(self, result) -> QueryResult:
+    def _result_to_query_result(self, result) -> SearchResult:
         print(result)
 
         ids = []
@@ -40,12 +40,14 @@ class MilvusClient:
             documents.append(_documents)
             metadatas.append(_metadatas)
 
-        return {
-            "ids": ids,
-            "distances": distances,
-            "documents": documents,
-            "metadatas": metadatas,
-        }
+        return SearchResult(
+            **{
+                "ids": ids,
+                "distances": distances,
+                "documents": documents,
+                "metadatas": metadatas,
+            }
+        )
 
     def _create_collection(self, collection_name: str, dimension: int):
         schema = self.client.create_schema(
@@ -94,7 +96,7 @@ class MilvusClient:
 
     def search(
         self, collection_name: str, vectors: list[list[float | int]], limit: int
-    ) -> Optional[QueryResult]:
+    ) -> Optional[SearchResult]:
         # Search for the nearest neighbor items based on the vectors and return 'limit' number of results.
         result = self.client.search(
             collection_name=f"{self.collection_prefix}_{collection_name}",
@@ -105,10 +107,11 @@ class MilvusClient:
 
         return self._result_to_query_result(result)
 
-    def get(self, collection_name: str) -> Optional[QueryResult]:
+    def get(self, collection_name: str) -> Optional[GetResult]:
         # Get all the items in the collection.
         result = self.client.query(
             collection_name=f"{self.collection_prefix}_{collection_name}",
+            filter='id != ""',
         )
         return self._result_to_query_result(result)
 

+ 5 - 2
backend/open_webui/apps/rag/vector/main.py

@@ -9,8 +9,11 @@ class VectorItem(BaseModel):
     metadata: Any
 
 
-class QueryResult(BaseModel):
+class GetResult(BaseModel):
     ids: Optional[List[List[str]]]
-    distances: Optional[List[List[float | int]]]
     documents: Optional[List[List[str]]]
     metadatas: Optional[List[List[Any]]]
+
+
+class SearchResult(GetResult):
+    distances: Optional[List[List[float | int]]]