utils.py 14 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492
  1. import os
  2. import logging
  3. import requests
  4. from typing import List
  5. from apps.ollama.main import (
  6. generate_ollama_embeddings,
  7. GenerateEmbeddingsForm,
  8. )
  9. from huggingface_hub import snapshot_download
  10. from langchain_core.documents import Document
  11. from langchain_community.retrievers import BM25Retriever
  12. from langchain.retrievers import (
  13. ContextualCompressionRetriever,
  14. EnsembleRetriever,
  15. )
  16. from sentence_transformers import CrossEncoder
  17. from typing import Optional
  18. from config import SRC_LOG_LEVELS, CHROMA_CLIENT
  19. log = logging.getLogger(__name__)
  20. log.setLevel(SRC_LOG_LEVELS["RAG"])
  21. def query_embeddings_doc(
  22. collection_name: str,
  23. query: str,
  24. embeddings_function,
  25. k: int,
  26. reranking_function: Optional[CrossEncoder] = None,
  27. r: Optional[float] = None,
  28. ):
  29. try:
  30. if reranking_function:
  31. # if you use docker use the model from the environment variable
  32. collection = CHROMA_CLIENT.get_collection(name=collection_name)
  33. documents = collection.get() # get all documents
  34. bm25_retriever = BM25Retriever.from_texts(
  35. texts=documents.get("documents"),
  36. metadatas=documents.get("metadatas"),
  37. )
  38. bm25_retriever.k = k
  39. chroma_retriever = ChromaRetriever(
  40. collection=collection,
  41. embeddings_function=embeddings_function,
  42. top_n=k,
  43. )
  44. ensemble_retriever = EnsembleRetriever(
  45. retrievers=[bm25_retriever, chroma_retriever], weights=[0.5, 0.5]
  46. )
  47. compressor = RerankCompressor(
  48. embeddings_function=embeddings_function,
  49. reranking_function=reranking_function,
  50. r_score=r,
  51. top_n=k,
  52. )
  53. compression_retriever = ContextualCompressionRetriever(
  54. base_compressor=compressor, base_retriever=ensemble_retriever
  55. )
  56. result = compression_retriever.invoke(query)
  57. result = {
  58. "distances": [[d.metadata.get("score") for d in result]],
  59. "documents": [[d.page_content for d in result]],
  60. "metadatas": [[d.metadata for d in result]],
  61. }
  62. else:
  63. # if you use docker use the model from the environment variable
  64. query_embeddings = embeddings_function(query)
  65. log.info(f"query_embeddings_doc {query_embeddings}")
  66. collection = CHROMA_CLIENT.get_collection(name=collection_name)
  67. result = collection.query(
  68. query_embeddings=[query_embeddings],
  69. n_results=k,
  70. )
  71. log.info(f"query_embeddings_doc:result {result}")
  72. return result
  73. except Exception as e:
  74. raise e
  75. def merge_and_sort_query_results(query_results, k):
  76. # Initialize lists to store combined data
  77. combined_distances = []
  78. combined_documents = []
  79. combined_metadatas = []
  80. for data in query_results:
  81. combined_distances.extend(data["distances"][0])
  82. combined_documents.extend(data["documents"][0])
  83. combined_metadatas.extend(data["metadatas"][0])
  84. # Create a list of tuples (distance, document, metadata)
  85. combined = list(zip(combined_distances, combined_documents, combined_metadatas))
  86. # Sort the list based on distances
  87. combined.sort(key=lambda x: x[0])
  88. # We don't have anything :-(
  89. if not combined:
  90. sorted_distances = []
  91. sorted_documents = []
  92. sorted_metadatas = []
  93. else:
  94. # Unzip the sorted list
  95. sorted_distances, sorted_documents, sorted_metadatas = zip(*combined)
  96. # Slicing the lists to include only k elements
  97. sorted_distances = list(sorted_distances)[:k]
  98. sorted_documents = list(sorted_documents)[:k]
  99. sorted_metadatas = list(sorted_metadatas)[:k]
  100. # Create the output dictionary
  101. result = {
  102. "distances": [sorted_distances],
  103. "documents": [sorted_documents],
  104. "metadatas": [sorted_metadatas],
  105. }
  106. return result
  107. def query_embeddings_collection(
  108. collection_names: List[str],
  109. query: str,
  110. k: int,
  111. r: float,
  112. embeddings_function,
  113. reranking_function,
  114. ):
  115. results = []
  116. for collection_name in collection_names:
  117. try:
  118. result = query_embeddings_doc(
  119. collection_name=collection_name,
  120. query=query,
  121. k=k,
  122. r=r,
  123. embeddings_function=embeddings_function,
  124. reranking_function=reranking_function,
  125. )
  126. results.append(result)
  127. except:
  128. pass
  129. return merge_and_sort_query_results(results, k)
  130. def rag_template(template: str, context: str, query: str):
  131. template = template.replace("[context]", context)
  132. template = template.replace("[query]", query)
  133. return template
  134. def query_embeddings_function(
  135. embedding_engine,
  136. embedding_model,
  137. embedding_function,
  138. openai_key,
  139. openai_url,
  140. ):
  141. if embedding_engine == "":
  142. return lambda query: embedding_function.encode(query).tolist()
  143. elif embedding_engine in ["ollama", "openai"]:
  144. if embedding_engine == "ollama":
  145. func = lambda query: generate_ollama_embeddings(
  146. GenerateEmbeddingsForm(
  147. **{
  148. "model": embedding_model,
  149. "prompt": query,
  150. }
  151. )
  152. )
  153. elif embedding_engine == "openai":
  154. func = lambda query: generate_openai_embeddings(
  155. model=embedding_model,
  156. text=query,
  157. key=openai_key,
  158. url=openai_url,
  159. )
  160. def generate_multiple(query, f):
  161. if isinstance(query, list):
  162. return [f(q) for q in query]
  163. else:
  164. return f(query)
  165. return lambda query: generate_multiple(query, func)
  166. def rag_messages(
  167. docs,
  168. messages,
  169. template,
  170. k,
  171. r,
  172. embedding_engine,
  173. embedding_model,
  174. embedding_function,
  175. reranking_function,
  176. openai_key,
  177. openai_url,
  178. ):
  179. log.debug(
  180. f"docs: {docs} {messages} {embedding_engine} {embedding_model} {embedding_function} {reranking_function} {openai_key} {openai_url}"
  181. )
  182. last_user_message_idx = None
  183. for i in range(len(messages) - 1, -1, -1):
  184. if messages[i]["role"] == "user":
  185. last_user_message_idx = i
  186. break
  187. user_message = messages[last_user_message_idx]
  188. if isinstance(user_message["content"], list):
  189. # Handle list content input
  190. content_type = "list"
  191. query = ""
  192. for content_item in user_message["content"]:
  193. if content_item["type"] == "text":
  194. query = content_item["text"]
  195. break
  196. elif isinstance(user_message["content"], str):
  197. # Handle text content input
  198. content_type = "text"
  199. query = user_message["content"]
  200. else:
  201. # Fallback in case the input does not match expected types
  202. content_type = None
  203. query = ""
  204. embeddings_function = query_embeddings_function(
  205. embedding_engine,
  206. embedding_model,
  207. embedding_function,
  208. openai_key,
  209. openai_url,
  210. )
  211. extracted_collections = []
  212. relevant_contexts = []
  213. for doc in docs:
  214. context = None
  215. collection = doc.get("collection_name")
  216. if collection:
  217. collection = [collection]
  218. else:
  219. collection = doc.get("collection_names", [])
  220. collection = set(collection).difference(extracted_collections)
  221. if not collection:
  222. log.debug(f"skipping {doc} as it has already been extracted")
  223. continue
  224. try:
  225. if doc["type"] == "text":
  226. context = doc["content"]
  227. elif doc["type"] == "collection":
  228. context = query_embeddings_collection(
  229. collection_names=doc["collection_names"],
  230. query=query,
  231. k=k,
  232. r=r,
  233. embeddings_function=embeddings_function,
  234. reranking_function=reranking_function,
  235. )
  236. else:
  237. context = query_embeddings_doc(
  238. collection_name=doc["collection_name"],
  239. query=query,
  240. k=k,
  241. r=r,
  242. embeddings_function=embeddings_function,
  243. reranking_function=reranking_function,
  244. )
  245. except Exception as e:
  246. log.exception(e)
  247. context = None
  248. if context:
  249. relevant_contexts.append(context)
  250. extracted_collections.extend(collection)
  251. context_string = ""
  252. for context in relevant_contexts:
  253. items = context["documents"][0]
  254. context_string += "\n\n".join(items)
  255. context_string = context_string.strip()
  256. ra_content = rag_template(
  257. template=template,
  258. context=context_string,
  259. query=query,
  260. )
  261. log.debug(f"ra_content: {ra_content}")
  262. if content_type == "list":
  263. new_content = []
  264. for content_item in user_message["content"]:
  265. if content_item["type"] == "text":
  266. # Update the text item's content with ra_content
  267. new_content.append({"type": "text", "text": ra_content})
  268. else:
  269. # Keep other types of content as they are
  270. new_content.append(content_item)
  271. new_user_message = {**user_message, "content": new_content}
  272. else:
  273. new_user_message = {
  274. **user_message,
  275. "content": ra_content,
  276. }
  277. messages[last_user_message_idx] = new_user_message
  278. return messages
  279. def get_model_path(model: str, update_model: bool = False):
  280. # Construct huggingface_hub kwargs with local_files_only to return the snapshot path
  281. cache_dir = os.getenv("SENTENCE_TRANSFORMERS_HOME")
  282. local_files_only = not update_model
  283. snapshot_kwargs = {
  284. "cache_dir": cache_dir,
  285. "local_files_only": local_files_only,
  286. }
  287. log.debug(f"model: {model}")
  288. log.debug(f"snapshot_kwargs: {snapshot_kwargs}")
  289. # Inspiration from upstream sentence_transformers
  290. if (
  291. os.path.exists(model)
  292. or ("\\" in model or model.count("/") > 1)
  293. and local_files_only
  294. ):
  295. # If fully qualified path exists, return input, else set repo_id
  296. return model
  297. elif "/" not in model:
  298. # Set valid repo_id for model short-name
  299. model = "sentence-transformers" + "/" + model
  300. snapshot_kwargs["repo_id"] = model
  301. # Attempt to query the huggingface_hub library to determine the local path and/or to update
  302. try:
  303. model_repo_path = snapshot_download(**snapshot_kwargs)
  304. log.debug(f"model_repo_path: {model_repo_path}")
  305. return model_repo_path
  306. except Exception as e:
  307. log.exception(f"Cannot determine model snapshot path: {e}")
  308. return model
  309. def generate_openai_embeddings(
  310. model: str, text: str, key: str, url: str = "https://api.openai.com/v1"
  311. ):
  312. try:
  313. r = requests.post(
  314. f"{url}/embeddings",
  315. headers={
  316. "Content-Type": "application/json",
  317. "Authorization": f"Bearer {key}",
  318. },
  319. json={"input": text, "model": model},
  320. )
  321. r.raise_for_status()
  322. data = r.json()
  323. if "data" in data:
  324. return data["data"][0]["embedding"]
  325. else:
  326. raise "Something went wrong :/"
  327. except Exception as e:
  328. print(e)
  329. return None
  330. from typing import Any
  331. from langchain_core.retrievers import BaseRetriever
  332. from langchain_core.callbacks import CallbackManagerForRetrieverRun
  333. class ChromaRetriever(BaseRetriever):
  334. collection: Any
  335. embeddings_function: Any
  336. top_n: int
  337. def _get_relevant_documents(
  338. self,
  339. query: str,
  340. *,
  341. run_manager: CallbackManagerForRetrieverRun,
  342. ) -> List[Document]:
  343. query_embeddings = self.embeddings_function(query)
  344. results = self.collection.query(
  345. query_embeddings=[query_embeddings],
  346. n_results=self.top_n,
  347. )
  348. ids = results["ids"][0]
  349. metadatas = results["metadatas"][0]
  350. documents = results["documents"][0]
  351. return [
  352. Document(
  353. metadata=metadatas[idx],
  354. page_content=documents[idx],
  355. )
  356. for idx in range(len(ids))
  357. ]
  358. import operator
  359. from typing import Optional, Sequence
  360. from langchain_core.documents import BaseDocumentCompressor, Document
  361. from langchain_core.callbacks import Callbacks
  362. from langchain_core.pydantic_v1 import Extra
  363. from sentence_transformers import util
  364. class RerankCompressor(BaseDocumentCompressor):
  365. embeddings_function: Any
  366. reranking_function: Any
  367. r_score: float
  368. top_n: int
  369. class Config:
  370. extra = Extra.forbid
  371. arbitrary_types_allowed = True
  372. def compress_documents(
  373. self,
  374. documents: Sequence[Document],
  375. query: str,
  376. callbacks: Optional[Callbacks] = None,
  377. ) -> Sequence[Document]:
  378. if self.reranking_function:
  379. scores = self.reranking_function.predict(
  380. [(query, doc.page_content) for doc in documents]
  381. )
  382. else:
  383. query_embedding = self.embeddings_function(query)
  384. document_embedding = self.embeddings_function(
  385. [doc.page_content for doc in documents]
  386. )
  387. scores = util.cos_sim(query_embedding, document_embedding)[0]
  388. docs_with_scores = list(zip(documents, scores.tolist()))
  389. if self.r_score:
  390. docs_with_scores = [
  391. (d, s) for d, s in docs_with_scores if s >= self.r_score
  392. ]
  393. result = sorted(docs_with_scores, key=operator.itemgetter(1), reverse=True)
  394. final_results = []
  395. for doc, doc_score in result[: self.top_n]:
  396. metadata = doc.metadata
  397. metadata["score"] = doc_score
  398. doc = Document(
  399. page_content=doc.page_content,
  400. metadata=metadata,
  401. )
  402. final_results.append(doc)
  403. return final_results