|
@@ -1,9 +1,8 @@
|
|
|
import logging
|
|
|
import os
|
|
|
-import uuid
|
|
|
+import heapq
|
|
|
from typing import Optional, Union
|
|
|
|
|
|
-import asyncio
|
|
|
import requests
|
|
|
|
|
|
from huggingface_hub import snapshot_download
|
|
@@ -34,8 +33,6 @@ class VectorSearchRetriever(BaseRetriever):
|
|
|
def _get_relevant_documents(
|
|
|
self,
|
|
|
query: str,
|
|
|
- *,
|
|
|
- run_manager: CallbackManagerForRetrieverRun,
|
|
|
) -> list[Document]:
|
|
|
result = VECTOR_DB_CLIENT.search(
|
|
|
collection_name=self.collection_name,
|
|
@@ -47,15 +44,12 @@ class VectorSearchRetriever(BaseRetriever):
|
|
|
metadatas = result.metadatas[0]
|
|
|
documents = result.documents[0]
|
|
|
|
|
|
- results = []
|
|
|
- for idx in range(len(ids)):
|
|
|
- results.append(
|
|
|
- Document(
|
|
|
- metadata=metadatas[idx],
|
|
|
- page_content=documents[idx],
|
|
|
- )
|
|
|
- )
|
|
|
- return results
|
|
|
+ return [
|
|
|
+ Document(
|
|
|
+ metadata=metadatas[idx],
|
|
|
+ page_content=documents[idx],
|
|
|
+ ) for idx in range(len(ids))
|
|
|
+ ]
|
|
|
|
|
|
|
|
|
def query_doc(
|
|
@@ -64,16 +58,14 @@ def query_doc(
|
|
|
k: int,
|
|
|
):
|
|
|
try:
|
|
|
- result = VECTOR_DB_CLIENT.search(
|
|
|
+ if result := VECTOR_DB_CLIENT.search(
|
|
|
collection_name=collection_name,
|
|
|
vectors=[query_embedding],
|
|
|
limit=k,
|
|
|
- )
|
|
|
-
|
|
|
- if result:
|
|
|
+ ):
|
|
|
log.info(f"query_doc:result {result.ids} {result.metadatas}")
|
|
|
|
|
|
- return result
|
|
|
+ return result
|
|
|
except Exception as e:
|
|
|
print(e)
|
|
|
raise e
|
|
@@ -135,44 +127,38 @@ def query_doc_with_hybrid_search(
|
|
|
def merge_and_sort_query_results(
|
|
|
query_results: list[dict], k: int, reverse: bool = False
|
|
|
) -> list[dict]:
|
|
|
- # Initialize lists to store combined data
|
|
|
- combined_distances = []
|
|
|
- combined_documents = []
|
|
|
- combined_metadatas = []
|
|
|
-
|
|
|
- for data in query_results:
|
|
|
- combined_distances.extend(data["distances"][0])
|
|
|
- combined_documents.extend(data["documents"][0])
|
|
|
- combined_metadatas.extend(data["metadatas"][0])
|
|
|
-
|
|
|
- # Create a list of tuples (distance, document, metadata)
|
|
|
- combined = list(zip(combined_distances, combined_documents, combined_metadatas))
|
|
|
-
|
|
|
- # Sort the list based on distances
|
|
|
- combined.sort(key=lambda x: x[0], reverse=reverse)
|
|
|
-
|
|
|
- # We don't have anything :-(
|
|
|
- if not combined:
|
|
|
- sorted_distances = []
|
|
|
- sorted_documents = []
|
|
|
- sorted_metadatas = []
|
|
|
+ if not query_results:
|
|
|
+ return {
|
|
|
+ "distances": [[]],
|
|
|
+ "documents": [[]],
|
|
|
+ "metadatas": [[]],
|
|
|
+ }
|
|
|
+
|
|
|
+ combined = (
|
|
|
+ (data.get("distances", [float('inf')])[0],
|
|
|
+ data.get("documents", [None])[0],
|
|
|
+ data.get("metadatas", [{}])[0])
|
|
|
+ for data in query_results
|
|
|
+ )
|
|
|
+
|
|
|
+ if reverse:
|
|
|
+ top_k = heapq.nlargest(k, combined, key=lambda x: x[0])
|
|
|
else:
|
|
|
- # Unzip the sorted list
|
|
|
- sorted_distances, sorted_documents, sorted_metadatas = zip(*combined)
|
|
|
-
|
|
|
- # Slicing the lists to include only k elements
|
|
|
- sorted_distances = list(sorted_distances)[:k]
|
|
|
- sorted_documents = list(sorted_documents)[:k]
|
|
|
- sorted_metadatas = list(sorted_metadatas)[:k]
|
|
|
-
|
|
|
- # Create the output dictionary
|
|
|
- result = {
|
|
|
- "distances": [sorted_distances],
|
|
|
- "documents": [sorted_documents],
|
|
|
- "metadatas": [sorted_metadatas],
|
|
|
- }
|
|
|
-
|
|
|
- return result
|
|
|
+ top_k = heapq.nsmallest(k, combined, key=lambda x: x[0])
|
|
|
+
|
|
|
+ if not top_k:
|
|
|
+ return {
|
|
|
+ "distances": [[]],
|
|
|
+ "documents": [[]],
|
|
|
+ "metadatas": [[]],
|
|
|
+ }
|
|
|
+ else:
|
|
|
+ sorted_distances, sorted_documents, sorted_metadatas = zip(*top_k)
|
|
|
+ return {
|
|
|
+ "distances": [sorted_distances],
|
|
|
+ "documents": [sorted_documents],
|
|
|
+ "metadatas": [sorted_metadatas],
|
|
|
+ }
|
|
|
|
|
|
|
|
|
def query_collection(
|
|
@@ -185,19 +171,18 @@ def query_collection(
|
|
|
for query in queries:
|
|
|
query_embedding = embedding_function(query)
|
|
|
for collection_name in collection_names:
|
|
|
- if collection_name:
|
|
|
- try:
|
|
|
- result = query_doc(
|
|
|
- collection_name=collection_name,
|
|
|
- k=k,
|
|
|
- query_embedding=query_embedding,
|
|
|
- )
|
|
|
- if result is not None:
|
|
|
- results.append(result.model_dump())
|
|
|
- except Exception as e:
|
|
|
- log.exception(f"Error when querying the collection: {e}")
|
|
|
- else:
|
|
|
- pass
|
|
|
+ if not collection_name:
|
|
|
+ continue
|
|
|
+
|
|
|
+ try:
|
|
|
+ if result := query_doc(
|
|
|
+ collection_name=collection_name,
|
|
|
+ k=k,
|
|
|
+ query_embedding=query_embedding,
|
|
|
+ ):
|
|
|
+ results.append(result.model_dump())
|
|
|
+ except Exception as e:
|
|
|
+ log.exception(f"Error when querying the collection: {e}")
|
|
|
|
|
|
return merge_and_sort_query_results(results, k=k)
|
|
|
|
|
@@ -213,8 +198,8 @@ def query_collection_with_hybrid_search(
|
|
|
results = []
|
|
|
error = False
|
|
|
for collection_name in collection_names:
|
|
|
- try:
|
|
|
- for query in queries:
|
|
|
+ for query in queries:
|
|
|
+ try:
|
|
|
result = query_doc_with_hybrid_search(
|
|
|
collection_name=collection_name,
|
|
|
query=query,
|
|
@@ -224,11 +209,11 @@ def query_collection_with_hybrid_search(
|
|
|
r=r,
|
|
|
)
|
|
|
results.append(result)
|
|
|
- except Exception as e:
|
|
|
- log.exception(
|
|
|
- "Error when querying the collection with " f"hybrid_search: {e}"
|
|
|
- )
|
|
|
- error = True
|
|
|
+ except Exception as e:
|
|
|
+ log.exception(
|
|
|
+ "Error when querying the collection with " f"hybrid_search: {e}"
|
|
|
+ )
|
|
|
+ error = True
|
|
|
|
|
|
if error:
|
|
|
raise Exception(
|
|
@@ -259,10 +244,10 @@ def get_embedding_function(
|
|
|
|
|
|
def generate_multiple(query, func):
|
|
|
if isinstance(query, list):
|
|
|
- embeddings = []
|
|
|
- for i in range(0, len(query), embedding_batch_size):
|
|
|
- embeddings.extend(func(query[i : i + embedding_batch_size]))
|
|
|
- return embeddings
|
|
|
+ return [
|
|
|
+ func(query[i : i + embedding_batch_size])
|
|
|
+ for i in range(0, len(query), embedding_batch_size)
|
|
|
+ ]
|
|
|
else:
|
|
|
return func(query)
|
|
|
|
|
@@ -436,25 +421,26 @@ def generate_openai_batch_embeddings(
|
|
|
def generate_ollama_batch_embeddings(
|
|
|
model: str, texts: list[str], url: str, key: str = ""
|
|
|
) -> Optional[list[list[float]]]:
|
|
|
+ r = requests.post(
|
|
|
+ f"{url}/api/embed",
|
|
|
+ headers={
|
|
|
+ "Content-Type": "application/json",
|
|
|
+ "Authorization": f"Bearer {key}",
|
|
|
+ },
|
|
|
+ json={"input": texts, "model": model},
|
|
|
+ )
|
|
|
try:
|
|
|
- r = requests.post(
|
|
|
- f"{url}/api/embed",
|
|
|
- headers={
|
|
|
- "Content-Type": "application/json",
|
|
|
- "Authorization": f"Bearer {key}",
|
|
|
- },
|
|
|
- json={"input": texts, "model": model},
|
|
|
- )
|
|
|
r.raise_for_status()
|
|
|
- data = r.json()
|
|
|
-
|
|
|
- if "embeddings" in data:
|
|
|
- return data["embeddings"]
|
|
|
- else:
|
|
|
- raise "Something went wrong :/"
|
|
|
except Exception as e:
|
|
|
print(e)
|
|
|
return None
|
|
|
+
|
|
|
+ data = r.json()
|
|
|
+
|
|
|
+ if 'embeddings' not in data:
|
|
|
+ raise "Something went wrong :/"
|
|
|
+
|
|
|
+ return data['embeddings']
|
|
|
|
|
|
|
|
|
def generate_embeddings(engine: str, model: str, text: Union[str, list[str]], **kwargs):
|