Timothy Jaeryang Baek 4 hónapja
szülő
commit
fd0170c179
1 módosított fájl, 93 hozzáadás és 79 törlés
  1. 93 79
      backend/open_webui/retrieval/utils.py

+ 93 - 79
backend/open_webui/retrieval/utils.py

@@ -1,8 +1,9 @@
 import logging
 import os
-import heapq
+import uuid
 from typing import Optional, Union
 
+import asyncio
 import requests
 
 from huggingface_hub import snapshot_download
@@ -33,6 +34,8 @@ class VectorSearchRetriever(BaseRetriever):
     def _get_relevant_documents(
         self,
         query: str,
+        *,
+        run_manager: CallbackManagerForRetrieverRun,
     ) -> list[Document]:
         result = VECTOR_DB_CLIENT.search(
             collection_name=self.collection_name,
@@ -44,12 +47,15 @@ class VectorSearchRetriever(BaseRetriever):
         metadatas = result.metadatas[0]
         documents = result.documents[0]
 
-        return [
-            Document(
-                metadata=metadatas[idx],
-                page_content=documents[idx],
-            ) for idx in range(len(ids))
-        ]
+        results = []
+        for idx in range(len(ids)):
+            results.append(
+                Document(
+                    metadata=metadatas[idx],
+                    page_content=documents[idx],
+                )
+            )
+        return results
 
 
 def query_doc(
@@ -58,14 +64,16 @@ def query_doc(
     k: int,
 ):
     try:
-        if result := VECTOR_DB_CLIENT.search(
+        result = VECTOR_DB_CLIENT.search(
             collection_name=collection_name,
             vectors=[query_embedding],
             limit=k,
-        ):
+        )
+
+        if result:
             log.info(f"query_doc:result {result.ids} {result.metadatas}")
 
-            return result
+        return result
     except Exception as e:
         print(e)
         raise e
@@ -127,38 +135,44 @@ def query_doc_with_hybrid_search(
 def merge_and_sort_query_results(
     query_results: list[dict], k: int, reverse: bool = False
 ) -> list[dict]:
-    if not query_results:
-        return {
-            "distances": [[]],
-            "documents": [[]],
-            "metadatas": [[]],
-        }
-        
-    combined = (
-        (data.get("distances", [float('inf')])[0],
-         data.get("documents", [None])[0],
-         data.get("metadatas", [{}])[0])
-        for data in query_results
-    )
-    
-    if reverse:
-        top_k = heapq.nlargest(k, combined, key=lambda x: x[0])
-    else:
-        top_k = heapq.nsmallest(k, combined, key=lambda x: x[0])
-    
-    if not top_k:
-        return {
-            "distances": [[]],
-            "documents": [[]],
-            "metadatas": [[]],
-        }
+    # Initialize lists to store combined data
+    combined_distances = []
+    combined_documents = []
+    combined_metadatas = []
+
+    for data in query_results:
+        combined_distances.extend(data["distances"][0])
+        combined_documents.extend(data["documents"][0])
+        combined_metadatas.extend(data["metadatas"][0])
+
+    # Create a list of tuples (distance, document, metadata)
+    combined = list(zip(combined_distances, combined_documents, combined_metadatas))
+
+    # Sort the list based on distances
+    combined.sort(key=lambda x: x[0], reverse=reverse)
+
+    # We don't have anything :-(
+    if not combined:
+        sorted_distances = []
+        sorted_documents = []
+        sorted_metadatas = []
     else:
-        sorted_distances, sorted_documents, sorted_metadatas = zip(*top_k)
-        return {
-            "distances": [sorted_distances],
-            "documents": [sorted_documents],
-            "metadatas": [sorted_metadatas],
-        }
+        # Unzip the sorted list
+        sorted_distances, sorted_documents, sorted_metadatas = zip(*combined)
+
+        # Slicing the lists to include only k elements
+        sorted_distances = list(sorted_distances)[:k]
+        sorted_documents = list(sorted_documents)[:k]
+        sorted_metadatas = list(sorted_metadatas)[:k]
+
+    # Create the output dictionary
+    result = {
+        "distances": [sorted_distances],
+        "documents": [sorted_documents],
+        "metadatas": [sorted_metadatas],
+    }
+
+    return result
 
 
 def query_collection(
@@ -171,18 +185,19 @@ def query_collection(
     for query in queries:
         query_embedding = embedding_function(query)
         for collection_name in collection_names:
-            if not collection_name:
-                continue
-            
-            try:
-                if result := query_doc(
-                    collection_name=collection_name,
-                    k=k,
-                    query_embedding=query_embedding,
-                ):
-                    results.append(result.model_dump())
-            except Exception as e:
-                log.exception(f"Error when querying the collection: {e}")
+            if collection_name:
+                try:
+                    result = query_doc(
+                        collection_name=collection_name,
+                        k=k,
+                        query_embedding=query_embedding,
+                    )
+                    if result is not None:
+                        results.append(result.model_dump())
+                except Exception as e:
+                    log.exception(f"Error when querying the collection: {e}")
+            else:
+                pass
 
     return merge_and_sort_query_results(results, k=k)
 
@@ -198,8 +213,8 @@ def query_collection_with_hybrid_search(
     results = []
     error = False
     for collection_name in collection_names:
-        for query in queries:
-            try:
+        try:
+            for query in queries:
                 result = query_doc_with_hybrid_search(
                     collection_name=collection_name,
                     query=query,
@@ -209,11 +224,11 @@ def query_collection_with_hybrid_search(
                     r=r,
                 )
                 results.append(result)
-            except Exception as e:
-                log.exception(
-                    "Error when querying the collection with " f"hybrid_search: {e}"
-                )
-                error = True
+        except Exception as e:
+            log.exception(
+                "Error when querying the collection with " f"hybrid_search: {e}"
+            )
+            error = True
 
     if error:
         raise Exception(
@@ -244,10 +259,10 @@ def get_embedding_function(
 
         def generate_multiple(query, func):
             if isinstance(query, list):
-                return [
-                    func(query[i : i + embedding_batch_size])
-                    for i in range(0, len(query), embedding_batch_size)
-                ]
+                embeddings = []
+                for i in range(0, len(query), embedding_batch_size):
+                    embeddings.extend(func(query[i : i + embedding_batch_size]))
+                return embeddings
             else:
                 return func(query)
 
@@ -421,26 +436,25 @@ def generate_openai_batch_embeddings(
 def generate_ollama_batch_embeddings(
     model: str, texts: list[str], url: str, key: str = ""
 ) -> Optional[list[list[float]]]:
-    r = requests.post(
-        f"{url}/api/embed",
-        headers={
-            "Content-Type": "application/json",
-            "Authorization": f"Bearer {key}",
-        },
-        json={"input": texts, "model": model},
-    )
     try:
+        r = requests.post(
+            f"{url}/api/embed",
+            headers={
+                "Content-Type": "application/json",
+                "Authorization": f"Bearer {key}",
+            },
+            json={"input": texts, "model": model},
+        )
         r.raise_for_status()
+        data = r.json()
+
+        if "embeddings" in data:
+            return data["embeddings"]
+        else:
+            raise "Something went wrong :/"
     except Exception as e:
         print(e)
         return None
-    
-    data = r.json()
-    
-    if 'embeddings' not in data:
-        raise "Something went wrong :/"
-    
-    return data['embeddings']
 
 
 def generate_embeddings(engine: str, model: str, text: Union[str, list[str]], **kwargs):