utils.py 7.3 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252
  1. import logging
  2. import requests
  3. from typing import List
  4. from apps.ollama.main import (
  5. generate_ollama_embeddings,
  6. GenerateEmbeddingsForm,
  7. )
  8. from config import SRC_LOG_LEVELS, CHROMA_CLIENT
  9. log = logging.getLogger(__name__)
  10. log.setLevel(SRC_LOG_LEVELS["RAG"])
  11. def query_embeddings_doc(collection_name: str, query: str, query_embeddings, k: int):
  12. try:
  13. # if you use docker use the model from the environment variable
  14. log.info(f"query_embeddings_doc {query_embeddings}")
  15. collection = CHROMA_CLIENT.get_collection(name=collection_name)
  16. result = collection.query(
  17. query_embeddings=[query_embeddings],
  18. n_results=k,
  19. )
  20. log.info(f"query_embeddings_doc:result {result}")
  21. return result
  22. except Exception as e:
  23. raise e
  24. def merge_and_sort_query_results(query_results, k):
  25. # Initialize lists to store combined data
  26. combined_ids = []
  27. combined_distances = []
  28. combined_metadatas = []
  29. combined_documents = []
  30. # Combine data from each dictionary
  31. for data in query_results:
  32. combined_ids.extend(data["ids"][0])
  33. combined_distances.extend(data["distances"][0])
  34. combined_metadatas.extend(data["metadatas"][0])
  35. combined_documents.extend(data["documents"][0])
  36. # Create a list of tuples (distance, id, metadata, document)
  37. combined = list(
  38. zip(combined_distances, combined_ids, combined_metadatas, combined_documents)
  39. )
  40. # Sort the list based on distances
  41. combined.sort(key=lambda x: x[0])
  42. # Unzip the sorted list
  43. sorted_distances, sorted_ids, sorted_metadatas, sorted_documents = zip(*combined)
  44. # Slicing the lists to include only k elements
  45. sorted_distances = list(sorted_distances)[:k]
  46. sorted_ids = list(sorted_ids)[:k]
  47. sorted_metadatas = list(sorted_metadatas)[:k]
  48. sorted_documents = list(sorted_documents)[:k]
  49. # Create the output dictionary
  50. merged_query_results = {
  51. "ids": [sorted_ids],
  52. "distances": [sorted_distances],
  53. "metadatas": [sorted_metadatas],
  54. "documents": [sorted_documents],
  55. "embeddings": None,
  56. "uris": None,
  57. "data": None,
  58. }
  59. return merged_query_results
  60. def query_embeddings_collection(
  61. collection_names: List[str], query: str, query_embeddings, k: int
  62. ):
  63. results = []
  64. log.info(f"query_embeddings_collection {query_embeddings}")
  65. for collection_name in collection_names:
  66. try:
  67. result = query_embeddings_doc(
  68. collection_name=collection_name,
  69. query=query,
  70. query_embeddings=query_embeddings,
  71. k=k,
  72. )
  73. results.append(result)
  74. except:
  75. pass
  76. return merge_and_sort_query_results(results, k)
  77. def rag_template(template: str, context: str, query: str):
  78. template = template.replace("[context]", context)
  79. template = template.replace("[query]", query)
  80. return template
  81. def rag_messages(
  82. docs,
  83. messages,
  84. template,
  85. k,
  86. embedding_engine,
  87. embedding_model,
  88. embedding_function,
  89. openai_key,
  90. openai_url,
  91. ):
  92. log.debug(
  93. f"docs: {docs} {messages} {embedding_engine} {embedding_model} {embedding_function} {openai_key} {openai_url}"
  94. )
  95. last_user_message_idx = None
  96. for i in range(len(messages) - 1, -1, -1):
  97. if messages[i]["role"] == "user":
  98. last_user_message_idx = i
  99. break
  100. user_message = messages[last_user_message_idx]
  101. if isinstance(user_message["content"], list):
  102. # Handle list content input
  103. content_type = "list"
  104. query = ""
  105. for content_item in user_message["content"]:
  106. if content_item["type"] == "text":
  107. query = content_item["text"]
  108. break
  109. elif isinstance(user_message["content"], str):
  110. # Handle text content input
  111. content_type = "text"
  112. query = user_message["content"]
  113. else:
  114. # Fallback in case the input does not match expected types
  115. content_type = None
  116. query = ""
  117. relevant_contexts = []
  118. for doc in docs:
  119. context = None
  120. try:
  121. if doc["type"] == "text":
  122. context = doc["content"]
  123. else:
  124. if embedding_engine == "":
  125. query_embeddings = embedding_function.encode(query).tolist()
  126. elif embedding_engine == "ollama":
  127. query_embeddings = generate_ollama_embeddings(
  128. GenerateEmbeddingsForm(
  129. **{
  130. "model": embedding_model,
  131. "prompt": query,
  132. }
  133. )
  134. )
  135. elif embedding_engine == "openai":
  136. query_embeddings = generate_openai_embeddings(
  137. model=embedding_model,
  138. text=query,
  139. key=openai_key,
  140. url=openai_url,
  141. )
  142. if doc["type"] == "collection":
  143. context = query_embeddings_collection(
  144. collection_names=doc["collection_names"],
  145. query=query,
  146. query_embeddings=query_embeddings,
  147. k=k,
  148. )
  149. else:
  150. context = query_embeddings_doc(
  151. collection_name=doc["collection_name"],
  152. query=query,
  153. query_embeddings=query_embeddings,
  154. k=k,
  155. )
  156. except Exception as e:
  157. log.exception(e)
  158. context = None
  159. relevant_contexts.append(context)
  160. log.debug(f"relevant_contexts: {relevant_contexts}")
  161. context_string = ""
  162. for context in relevant_contexts:
  163. if context:
  164. context_string += " ".join(context["documents"][0]) + "\n"
  165. ra_content = rag_template(
  166. template=template,
  167. context=context_string,
  168. query=query,
  169. )
  170. if content_type == "list":
  171. new_content = []
  172. for content_item in user_message["content"]:
  173. if content_item["type"] == "text":
  174. # Update the text item's content with ra_content
  175. new_content.append({"type": "text", "text": ra_content})
  176. else:
  177. # Keep other types of content as they are
  178. new_content.append(content_item)
  179. new_user_message = {**user_message, "content": new_content}
  180. else:
  181. new_user_message = {
  182. **user_message,
  183. "content": ra_content,
  184. }
  185. messages[last_user_message_idx] = new_user_message
  186. return messages
  187. def generate_openai_embeddings(
  188. model: str, text: str, key: str, url: str = "https://api.openai.com/v1"
  189. ):
  190. try:
  191. r = requests.post(
  192. f"{url}/embeddings",
  193. headers={
  194. "Content-Type": "application/json",
  195. "Authorization": f"Bearer {key}",
  196. },
  197. json={"input": text, "model": model},
  198. )
  199. r.raise_for_status()
  200. data = r.json()
  201. if "data" in data:
  202. return data["data"][0]["embedding"]
  203. else:
  204. raise "Something went wrong :/"
  205. except Exception as e:
  206. print(e)
  207. return None