utils.py 5.3 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183
  1. import re
  2. from typing import List
  3. from config import CHROMA_CLIENT
  4. def query_doc(collection_name: str, query: str, k: int, embedding_function):
  5. try:
  6. # if you use docker use the model from the environment variable
  7. collection = CHROMA_CLIENT.get_collection(
  8. name=collection_name,
  9. embedding_function=embedding_function,
  10. )
  11. result = collection.query(
  12. query_texts=[query],
  13. n_results=k,
  14. )
  15. return result
  16. except Exception as e:
  17. raise e
  18. def merge_and_sort_query_results(query_results, k):
  19. # Initialize lists to store combined data
  20. combined_ids = []
  21. combined_distances = []
  22. combined_metadatas = []
  23. combined_documents = []
  24. # Combine data from each dictionary
  25. for data in query_results:
  26. combined_ids.extend(data["ids"][0])
  27. combined_distances.extend(data["distances"][0])
  28. combined_metadatas.extend(data["metadatas"][0])
  29. combined_documents.extend(data["documents"][0])
  30. # Create a list of tuples (distance, id, metadata, document)
  31. combined = list(
  32. zip(combined_distances, combined_ids, combined_metadatas, combined_documents)
  33. )
  34. # Sort the list based on distances
  35. combined.sort(key=lambda x: x[0])
  36. # Unzip the sorted list
  37. sorted_distances, sorted_ids, sorted_metadatas, sorted_documents = zip(*combined)
  38. # Slicing the lists to include only k elements
  39. sorted_distances = list(sorted_distances)[:k]
  40. sorted_ids = list(sorted_ids)[:k]
  41. sorted_metadatas = list(sorted_metadatas)[:k]
  42. sorted_documents = list(sorted_documents)[:k]
  43. # Create the output dictionary
  44. merged_query_results = {
  45. "ids": [sorted_ids],
  46. "distances": [sorted_distances],
  47. "metadatas": [sorted_metadatas],
  48. "documents": [sorted_documents],
  49. "embeddings": None,
  50. "uris": None,
  51. "data": None,
  52. }
  53. return merged_query_results
  54. def query_collection(
  55. collection_names: List[str], query: str, k: int, embedding_function
  56. ):
  57. results = []
  58. for collection_name in collection_names:
  59. try:
  60. # if you use docker use the model from the environment variable
  61. collection = CHROMA_CLIENT.get_collection(
  62. name=collection_name,
  63. embedding_function=embedding_function,
  64. )
  65. result = collection.query(
  66. query_texts=[query],
  67. n_results=k,
  68. )
  69. results.append(result)
  70. except:
  71. pass
  72. return merge_and_sort_query_results(results, k)
  73. def rag_template(template: str, context: str, query: str):
  74. template = re.sub(r"\[context\]", context, template)
  75. template = re.sub(r"\[query\]", query, template)
  76. return template
  77. def rag_messages(docs, messages, template, k, embedding_function):
  78. print(docs)
  79. last_user_message_idx = None
  80. for i in range(len(messages) - 1, -1, -1):
  81. if messages[i]["role"] == "user":
  82. last_user_message_idx = i
  83. break
  84. user_message = messages[last_user_message_idx]
  85. if isinstance(user_message["content"], list):
  86. # Handle list content input
  87. content_type = "list"
  88. query = ""
  89. for content_item in user_message["content"]:
  90. if content_item["type"] == "text":
  91. query = content_item["text"]
  92. break
  93. elif isinstance(user_message["content"], str):
  94. # Handle text content input
  95. content_type = "text"
  96. query = user_message["content"]
  97. else:
  98. # Fallback in case the input does not match expected types
  99. content_type = None
  100. query = ""
  101. relevant_contexts = []
  102. for doc in docs:
  103. context = None
  104. try:
  105. if doc["type"] == "collection":
  106. context = query_collection(
  107. collection_names=doc["collection_names"],
  108. query=query,
  109. k=k,
  110. embedding_function=embedding_function,
  111. )
  112. else:
  113. context = query_doc(
  114. collection_name=doc["collection_name"],
  115. query=query,
  116. k=k,
  117. embedding_function=embedding_function,
  118. )
  119. except Exception as e:
  120. print(e)
  121. context = None
  122. relevant_contexts.append(context)
  123. context_string = ""
  124. for context in relevant_contexts:
  125. if context:
  126. context_string += " ".join(context["documents"][0]) + "\n"
  127. ra_content = rag_template(
  128. template=template,
  129. context=context_string,
  130. query=query,
  131. )
  132. if content_type == "list":
  133. new_content = []
  134. for content_item in user_message["content"]:
  135. if content_item["type"] == "text":
  136. # Update the text item's content with ra_content
  137. new_content.append({"type": "text", "text": ra_content})
  138. else:
  139. # Keep other types of content as they are
  140. new_content.append(content_item)
  141. new_user_message = {**user_message, "content": new_content}
  142. else:
  143. new_user_message = {
  144. **user_message,
  145. "content": ra_content,
  146. }
  147. messages[last_user_message_idx] = new_user_message
  148. return messages