utils.py 2.8 KB

12345678910111213141516171819202122232425262728293031323334353637383940414243444546474849505152535455565758596061626364656667686970717273747576777879808182838485868788899091929394959697
  1. import re
  2. from typing import List
  3. from config import CHROMA_CLIENT
  4. def query_doc(collection_name: str, query: str, k: int, embedding_function):
  5. try:
  6. # if you use docker use the model from the environment variable
  7. collection = CHROMA_CLIENT.get_collection(
  8. name=collection_name,
  9. embedding_function=embedding_function,
  10. )
  11. result = collection.query(
  12. query_texts=[query],
  13. n_results=k,
  14. )
  15. return result
  16. except Exception as e:
  17. raise e
  18. def merge_and_sort_query_results(query_results, k):
  19. # Initialize lists to store combined data
  20. combined_ids = []
  21. combined_distances = []
  22. combined_metadatas = []
  23. combined_documents = []
  24. # Combine data from each dictionary
  25. for data in query_results:
  26. combined_ids.extend(data["ids"][0])
  27. combined_distances.extend(data["distances"][0])
  28. combined_metadatas.extend(data["metadatas"][0])
  29. combined_documents.extend(data["documents"][0])
  30. # Create a list of tuples (distance, id, metadata, document)
  31. combined = list(
  32. zip(combined_distances, combined_ids, combined_metadatas, combined_documents)
  33. )
  34. # Sort the list based on distances
  35. combined.sort(key=lambda x: x[0])
  36. # Unzip the sorted list
  37. sorted_distances, sorted_ids, sorted_metadatas, sorted_documents = zip(*combined)
  38. # Slicing the lists to include only k elements
  39. sorted_distances = list(sorted_distances)[:k]
  40. sorted_ids = list(sorted_ids)[:k]
  41. sorted_metadatas = list(sorted_metadatas)[:k]
  42. sorted_documents = list(sorted_documents)[:k]
  43. # Create the output dictionary
  44. merged_query_results = {
  45. "ids": [sorted_ids],
  46. "distances": [sorted_distances],
  47. "metadatas": [sorted_metadatas],
  48. "documents": [sorted_documents],
  49. "embeddings": None,
  50. "uris": None,
  51. "data": None,
  52. }
  53. return merged_query_results
  54. def query_collection(
  55. collection_names: List[str], query: str, k: int, embedding_function
  56. ):
  57. results = []
  58. for collection_name in collection_names:
  59. try:
  60. # if you use docker use the model from the environment variable
  61. collection = CHROMA_CLIENT.get_collection(
  62. name=collection_name,
  63. embedding_function=embedding_function,
  64. )
  65. result = collection.query(
  66. query_texts=[query],
  67. n_results=k,
  68. )
  69. results.append(result)
  70. except:
  71. pass
  72. return merge_and_sort_query_results(results, k)
  73. def rag_template(template: str, context: str, query: str):
  74. template = re.sub(r"\[context\]", context, template)
  75. template = re.sub(r"\[query\]", query, template)
  76. return template