utils.py 2.6 KB

1234567891011121314151617181920212223242526272829303132333435363738394041424344454647484950515253545556575859606162636465666768697071727374757677787980818283848586878889
  1. from typing import List
  2. from config import CHROMA_CLIENT
  3. def query_doc(collection_name: str, query: str, k: int, embedding_function):
  4. try:
  5. # if you use docker use the model from the environment variable
  6. collection = CHROMA_CLIENT.get_collection(
  7. name=collection_name,
  8. embedding_function=embedding_function,
  9. )
  10. result = collection.query(
  11. query_texts=[query],
  12. n_results=k,
  13. )
  14. return result
  15. except Exception as e:
  16. raise e
  17. def merge_and_sort_query_results(query_results, k):
  18. # Initialize lists to store combined data
  19. combined_ids = []
  20. combined_distances = []
  21. combined_metadatas = []
  22. combined_documents = []
  23. # Combine data from each dictionary
  24. for data in query_results:
  25. combined_ids.extend(data["ids"][0])
  26. combined_distances.extend(data["distances"][0])
  27. combined_metadatas.extend(data["metadatas"][0])
  28. combined_documents.extend(data["documents"][0])
  29. # Create a list of tuples (distance, id, metadata, document)
  30. combined = list(
  31. zip(combined_distances, combined_ids, combined_metadatas, combined_documents)
  32. )
  33. # Sort the list based on distances
  34. combined.sort(key=lambda x: x[0])
  35. # Unzip the sorted list
  36. sorted_distances, sorted_ids, sorted_metadatas, sorted_documents = zip(*combined)
  37. # Slicing the lists to include only k elements
  38. sorted_distances = list(sorted_distances)[:k]
  39. sorted_ids = list(sorted_ids)[:k]
  40. sorted_metadatas = list(sorted_metadatas)[:k]
  41. sorted_documents = list(sorted_documents)[:k]
  42. # Create the output dictionary
  43. merged_query_results = {
  44. "ids": [sorted_ids],
  45. "distances": [sorted_distances],
  46. "metadatas": [sorted_metadatas],
  47. "documents": [sorted_documents],
  48. "embeddings": None,
  49. "uris": None,
  50. "data": None,
  51. }
  52. return merged_query_results
  53. def query_collection(
  54. collection_names: List[str], query: str, k: int, embedding_function
  55. ):
  56. results = []
  57. for collection_name in collection_names:
  58. try:
  59. # if you use docker use the model from the environment variable
  60. collection = CHROMA_CLIENT.get_collection(
  61. name=collection_name,
  62. embedding_function=embedding_function,
  63. )
  64. result = collection.query(
  65. query_texts=[query],
  66. n_results=k,
  67. )
  68. results.append(result)
  69. except:
  70. pass
  71. return merge_and_sort_query_results(results, k)