Forráskód Böngészése

Merge pull request #8212 from ashm-dev/main

feat: Small optimization
Timothy Jaeryang Baek 4 hónapja
szülő
commit
9b56b64cfa
1 módosított fájl, 79 hozzáadás és 93 törlés
  1. 79 93
      backend/open_webui/retrieval/utils.py

+ 79 - 93
backend/open_webui/retrieval/utils.py

@@ -1,9 +1,8 @@
 import logging
 import os
-import uuid
+import heapq
 from typing import Optional, Union
 
-import asyncio
 import requests
 
 from huggingface_hub import snapshot_download
@@ -34,8 +33,6 @@ class VectorSearchRetriever(BaseRetriever):
     def _get_relevant_documents(
         self,
         query: str,
-        *,
-        run_manager: CallbackManagerForRetrieverRun,
     ) -> list[Document]:
         result = VECTOR_DB_CLIENT.search(
             collection_name=self.collection_name,
@@ -47,15 +44,12 @@ class VectorSearchRetriever(BaseRetriever):
         metadatas = result.metadatas[0]
         documents = result.documents[0]
 
-        results = []
-        for idx in range(len(ids)):
-            results.append(
-                Document(
-                    metadata=metadatas[idx],
-                    page_content=documents[idx],
-                )
-            )
-        return results
+        return [
+            Document(
+                metadata=metadatas[idx],
+                page_content=documents[idx],
+            ) for idx in range(len(ids))
+        ]
 
 
 def query_doc(
@@ -64,16 +58,14 @@ def query_doc(
     k: int,
 ):
     try:
-        result = VECTOR_DB_CLIENT.search(
+        if result := VECTOR_DB_CLIENT.search(
             collection_name=collection_name,
             vectors=[query_embedding],
             limit=k,
-        )
-
-        if result:
+        ):
             log.info(f"query_doc:result {result.ids} {result.metadatas}")
 
-        return result
+            return result
     except Exception as e:
         print(e)
         raise e
@@ -135,44 +127,38 @@ def query_doc_with_hybrid_search(
 def merge_and_sort_query_results(
     query_results: list[dict], k: int, reverse: bool = False
 ) -> list[dict]:
-    # Initialize lists to store combined data
-    combined_distances = []
-    combined_documents = []
-    combined_metadatas = []
-
-    for data in query_results:
-        combined_distances.extend(data["distances"][0])
-        combined_documents.extend(data["documents"][0])
-        combined_metadatas.extend(data["metadatas"][0])
-
-    # Create a list of tuples (distance, document, metadata)
-    combined = list(zip(combined_distances, combined_documents, combined_metadatas))
-
-    # Sort the list based on distances
-    combined.sort(key=lambda x: x[0], reverse=reverse)
-
-    # We don't have anything :-(
-    if not combined:
-        sorted_distances = []
-        sorted_documents = []
-        sorted_metadatas = []
+    if not query_results:
+        return {
+            "distances": [[]],
+            "documents": [[]],
+            "metadatas": [[]],
+        }
+        
+    combined = (
+        (data.get("distances", [float('inf')])[0],
+         data.get("documents", [None])[0],
+         data.get("metadatas", [{}])[0])
+        for data in query_results
+    )
+    
+    if reverse:
+        top_k = heapq.nlargest(k, combined, key=lambda x: x[0])
     else:
-        # Unzip the sorted list
-        sorted_distances, sorted_documents, sorted_metadatas = zip(*combined)
-
-        # Slicing the lists to include only k elements
-        sorted_distances = list(sorted_distances)[:k]
-        sorted_documents = list(sorted_documents)[:k]
-        sorted_metadatas = list(sorted_metadatas)[:k]
-
-    # Create the output dictionary
-    result = {
-        "distances": [sorted_distances],
-        "documents": [sorted_documents],
-        "metadatas": [sorted_metadatas],
-    }
-
-    return result
+        top_k = heapq.nsmallest(k, combined, key=lambda x: x[0])
+    
+    if not top_k:
+        return {
+            "distances": [[]],
+            "documents": [[]],
+            "metadatas": [[]],
+        }
+    else:
+        sorted_distances, sorted_documents, sorted_metadatas = zip(*top_k)
+        return {
+            "distances": [sorted_distances],
+            "documents": [sorted_documents],
+            "metadatas": [sorted_metadatas],
+        }
 
 
 def query_collection(
@@ -185,19 +171,18 @@ def query_collection(
     for query in queries:
         query_embedding = embedding_function(query)
         for collection_name in collection_names:
-            if collection_name:
-                try:
-                    result = query_doc(
-                        collection_name=collection_name,
-                        k=k,
-                        query_embedding=query_embedding,
-                    )
-                    if result is not None:
-                        results.append(result.model_dump())
-                except Exception as e:
-                    log.exception(f"Error when querying the collection: {e}")
-            else:
-                pass
+            if not collection_name:
+                continue
+            
+            try:
+                if result := query_doc(
+                    collection_name=collection_name,
+                    k=k,
+                    query_embedding=query_embedding,
+                ):
+                    results.append(result.model_dump())
+            except Exception as e:
+                log.exception(f"Error when querying the collection: {e}")
 
     return merge_and_sort_query_results(results, k=k)
 
@@ -213,8 +198,8 @@ def query_collection_with_hybrid_search(
     results = []
     error = False
     for collection_name in collection_names:
-        try:
-            for query in queries:
+        for query in queries:
+            try:
                 result = query_doc_with_hybrid_search(
                     collection_name=collection_name,
                     query=query,
@@ -224,11 +209,11 @@ def query_collection_with_hybrid_search(
                     r=r,
                 )
                 results.append(result)
-        except Exception as e:
-            log.exception(
-                "Error when querying the collection with " f"hybrid_search: {e}"
-            )
-            error = True
+            except Exception as e:
+                log.exception(
+                    "Error when querying the collection with " f"hybrid_search: {e}"
+                )
+                error = True
 
     if error:
         raise Exception(
@@ -259,10 +244,10 @@ def get_embedding_function(
 
         def generate_multiple(query, func):
             if isinstance(query, list):
-                embeddings = []
-                for i in range(0, len(query), embedding_batch_size):
-                    embeddings.extend(func(query[i : i + embedding_batch_size]))
-                return embeddings
+                return [
+                    func(query[i : i + embedding_batch_size])
+                    for i in range(0, len(query), embedding_batch_size)
+                ]
             else:
                 return func(query)
 
@@ -436,25 +421,26 @@ def generate_openai_batch_embeddings(
 def generate_ollama_batch_embeddings(
     model: str, texts: list[str], url: str, key: str = ""
 ) -> Optional[list[list[float]]]:
+    r = requests.post(
+        f"{url}/api/embed",
+        headers={
+            "Content-Type": "application/json",
+            "Authorization": f"Bearer {key}",
+        },
+        json={"input": texts, "model": model},
+    )
     try:
-        r = requests.post(
-            f"{url}/api/embed",
-            headers={
-                "Content-Type": "application/json",
-                "Authorization": f"Bearer {key}",
-            },
-            json={"input": texts, "model": model},
-        )
         r.raise_for_status()
-        data = r.json()
-
-        if "embeddings" in data:
-            return data["embeddings"]
-        else:
-            raise "Something went wrong :/"
     except Exception as e:
         print(e)
         return None
+    
+    data = r.json()
+    
+    if 'embeddings' not in data:
+        raise "Something went wrong :/"
+    
+    return data['embeddings']
 
 
 def generate_embeddings(engine: str, model: str, text: Union[str, list[str]], **kwargs):