|
@@ -4,7 +4,7 @@ import json
|
|
|
|
|
|
from typing import Optional
|
|
from typing import Optional
|
|
|
|
|
|
-from open_webui.apps.rag.vector.main import VectorItem, QueryResult
|
|
|
|
|
|
+from open_webui.apps.rag.vector.main import VectorItem, SearchResult, GetResult
|
|
from open_webui.config import (
|
|
from open_webui.config import (
|
|
MILVUS_URI,
|
|
MILVUS_URI,
|
|
)
|
|
)
|
|
@@ -15,7 +15,7 @@ class MilvusClient:
|
|
self.collection_prefix = "open_webui"
|
|
self.collection_prefix = "open_webui"
|
|
self.client = Client(uri=MILVUS_URI)
|
|
self.client = Client(uri=MILVUS_URI)
|
|
|
|
|
|
- def _result_to_query_result(self, result) -> QueryResult:
|
|
|
|
|
|
+ def _result_to_query_result(self, result) -> SearchResult:
|
|
print(result)
|
|
print(result)
|
|
|
|
|
|
ids = []
|
|
ids = []
|
|
@@ -40,12 +40,14 @@ class MilvusClient:
|
|
documents.append(_documents)
|
|
documents.append(_documents)
|
|
metadatas.append(_metadatas)
|
|
metadatas.append(_metadatas)
|
|
|
|
|
|
- return {
|
|
|
|
- "ids": ids,
|
|
|
|
- "distances": distances,
|
|
|
|
- "documents": documents,
|
|
|
|
- "metadatas": metadatas,
|
|
|
|
- }
|
|
|
|
|
|
+ return SearchResult(
|
|
|
|
+ **{
|
|
|
|
+ "ids": ids,
|
|
|
|
+ "distances": distances,
|
|
|
|
+ "documents": documents,
|
|
|
|
+ "metadatas": metadatas,
|
|
|
|
+ }
|
|
|
|
+ )
|
|
|
|
|
|
def _create_collection(self, collection_name: str, dimension: int):
|
|
def _create_collection(self, collection_name: str, dimension: int):
|
|
schema = self.client.create_schema(
|
|
schema = self.client.create_schema(
|
|
@@ -94,7 +96,7 @@ class MilvusClient:
|
|
|
|
|
|
def search(
|
|
def search(
|
|
self, collection_name: str, vectors: list[list[float | int]], limit: int
|
|
self, collection_name: str, vectors: list[list[float | int]], limit: int
|
|
- ) -> Optional[QueryResult]:
|
|
|
|
|
|
+ ) -> Optional[SearchResult]:
|
|
# Search for the nearest neighbor items based on the vectors and return 'limit' number of results.
|
|
# Search for the nearest neighbor items based on the vectors and return 'limit' number of results.
|
|
result = self.client.search(
|
|
result = self.client.search(
|
|
collection_name=f"{self.collection_prefix}_{collection_name}",
|
|
collection_name=f"{self.collection_prefix}_{collection_name}",
|
|
@@ -105,10 +107,11 @@ class MilvusClient:
|
|
|
|
|
|
return self._result_to_query_result(result)
|
|
return self._result_to_query_result(result)
|
|
|
|
|
|
- def get(self, collection_name: str) -> Optional[QueryResult]:
|
|
|
|
|
|
+ def get(self, collection_name: str) -> Optional[GetResult]:
|
|
# Get all the items in the collection.
|
|
# Get all the items in the collection.
|
|
result = self.client.query(
|
|
result = self.client.query(
|
|
collection_name=f"{self.collection_prefix}_{collection_name}",
|
|
collection_name=f"{self.collection_prefix}_{collection_name}",
|
|
|
|
+ filter='id != ""',
|
|
)
|
|
)
|
|
return self._result_to_query_result(result)
|
|
return self._result_to_query_result(result)
|
|
|
|
|