utils.py 14 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492
  1. import os
  2. import logging
  3. import requests
  4. from typing import List
  5. from apps.ollama.main import (
  6. generate_ollama_embeddings,
  7. GenerateEmbeddingsForm,
  8. )
  9. from huggingface_hub import snapshot_download
  10. from langchain_core.documents import Document
  11. from langchain_community.retrievers import BM25Retriever
  12. from langchain.retrievers import (
  13. ContextualCompressionRetriever,
  14. EnsembleRetriever,
  15. )
  16. from typing import Optional
  17. from config import SRC_LOG_LEVELS, CHROMA_CLIENT
  18. log = logging.getLogger(__name__)
  19. log.setLevel(SRC_LOG_LEVELS["RAG"])
  20. def query_embeddings_doc(
  21. collection_name: str,
  22. query: str,
  23. embeddings_function,
  24. reranking_function,
  25. k: int,
  26. r: int,
  27. hybrid: bool,
  28. ):
  29. try:
  30. collection = CHROMA_CLIENT.get_collection(name=collection_name)
  31. if hybrid:
  32. documents = collection.get() # get all documents
  33. bm25_retriever = BM25Retriever.from_texts(
  34. texts=documents.get("documents"),
  35. metadatas=documents.get("metadatas"),
  36. )
  37. bm25_retriever.k = k
  38. chroma_retriever = ChromaRetriever(
  39. collection=collection,
  40. embeddings_function=embeddings_function,
  41. top_n=k,
  42. )
  43. ensemble_retriever = EnsembleRetriever(
  44. retrievers=[bm25_retriever, chroma_retriever], weights=[0.5, 0.5]
  45. )
  46. compressor = RerankCompressor(
  47. embeddings_function=embeddings_function,
  48. reranking_function=reranking_function,
  49. r_score=r,
  50. top_n=k,
  51. )
  52. compression_retriever = ContextualCompressionRetriever(
  53. base_compressor=compressor, base_retriever=ensemble_retriever
  54. )
  55. result = compression_retriever.invoke(query)
  56. result = {
  57. "distances": [[d.metadata.get("score") for d in result]],
  58. "documents": [[d.page_content for d in result]],
  59. "metadatas": [[d.metadata for d in result]],
  60. }
  61. else:
  62. query_embeddings = embeddings_function(query)
  63. result = collection.query(
  64. query_embeddings=[query_embeddings],
  65. n_results=k,
  66. )
  67. log.info(f"query_embeddings_doc:result {result}")
  68. return result
  69. except Exception as e:
  70. raise e
  71. def merge_and_sort_query_results(query_results, k, reverse=False):
  72. # Initialize lists to store combined data
  73. combined_distances = []
  74. combined_documents = []
  75. combined_metadatas = []
  76. for data in query_results:
  77. combined_distances.extend(data["distances"][0])
  78. combined_documents.extend(data["documents"][0])
  79. combined_metadatas.extend(data["metadatas"][0])
  80. # Create a list of tuples (distance, document, metadata)
  81. combined = list(zip(combined_distances, combined_documents, combined_metadatas))
  82. # Sort the list based on distances
  83. combined.sort(key=lambda x: x[0], reverse=reverse)
  84. # We don't have anything :-(
  85. if not combined:
  86. sorted_distances = []
  87. sorted_documents = []
  88. sorted_metadatas = []
  89. else:
  90. # Unzip the sorted list
  91. sorted_distances, sorted_documents, sorted_metadatas = zip(*combined)
  92. # Slicing the lists to include only k elements
  93. sorted_distances = list(sorted_distances)[:k]
  94. sorted_documents = list(sorted_documents)[:k]
  95. sorted_metadatas = list(sorted_metadatas)[:k]
  96. # Create the output dictionary
  97. result = {
  98. "distances": [sorted_distances],
  99. "documents": [sorted_documents],
  100. "metadatas": [sorted_metadatas],
  101. }
  102. return result
  103. def query_embeddings_collection(
  104. collection_names: List[str],
  105. query: str,
  106. k: int,
  107. r: float,
  108. embeddings_function,
  109. reranking_function,
  110. hybrid: bool,
  111. ):
  112. results = []
  113. for collection_name in collection_names:
  114. try:
  115. result = query_embeddings_doc(
  116. collection_name=collection_name,
  117. query=query,
  118. k=k,
  119. r=r,
  120. embeddings_function=embeddings_function,
  121. reranking_function=reranking_function,
  122. hybrid=hybrid,
  123. )
  124. results.append(result)
  125. except:
  126. pass
  127. reverse = hybrid and reranking_function is not None
  128. return merge_and_sort_query_results(results, k=k, reverse=reverse)
  129. def rag_template(template: str, context: str, query: str):
  130. template = template.replace("[context]", context)
  131. template = template.replace("[query]", query)
  132. return template
  133. def query_embeddings_function(
  134. embedding_engine,
  135. embedding_model,
  136. embedding_function,
  137. openai_key,
  138. openai_url,
  139. ):
  140. if embedding_engine == "":
  141. return lambda query: embedding_function.encode(query).tolist()
  142. elif embedding_engine in ["ollama", "openai"]:
  143. if embedding_engine == "ollama":
  144. func = lambda query: generate_ollama_embeddings(
  145. GenerateEmbeddingsForm(
  146. **{
  147. "model": embedding_model,
  148. "prompt": query,
  149. }
  150. )
  151. )
  152. elif embedding_engine == "openai":
  153. func = lambda query: generate_openai_embeddings(
  154. model=embedding_model,
  155. text=query,
  156. key=openai_key,
  157. url=openai_url,
  158. )
  159. def generate_multiple(query, f):
  160. if isinstance(query, list):
  161. return [f(q) for q in query]
  162. else:
  163. return f(query)
  164. return lambda query: generate_multiple(query, func)
  165. def rag_messages(
  166. docs,
  167. messages,
  168. template,
  169. k,
  170. r,
  171. hybrid,
  172. embedding_engine,
  173. embedding_model,
  174. embedding_function,
  175. reranking_function,
  176. openai_key,
  177. openai_url,
  178. ):
  179. log.debug(
  180. f"docs: {docs} {messages} {embedding_engine} {embedding_model} {embedding_function} {reranking_function} {openai_key} {openai_url}"
  181. )
  182. last_user_message_idx = None
  183. for i in range(len(messages) - 1, -1, -1):
  184. if messages[i]["role"] == "user":
  185. last_user_message_idx = i
  186. break
  187. user_message = messages[last_user_message_idx]
  188. if isinstance(user_message["content"], list):
  189. # Handle list content input
  190. content_type = "list"
  191. query = ""
  192. for content_item in user_message["content"]:
  193. if content_item["type"] == "text":
  194. query = content_item["text"]
  195. break
  196. elif isinstance(user_message["content"], str):
  197. # Handle text content input
  198. content_type = "text"
  199. query = user_message["content"]
  200. else:
  201. # Fallback in case the input does not match expected types
  202. content_type = None
  203. query = ""
  204. embeddings_function = query_embeddings_function(
  205. embedding_engine,
  206. embedding_model,
  207. embedding_function,
  208. openai_key,
  209. openai_url,
  210. )
  211. extracted_collections = []
  212. relevant_contexts = []
  213. for doc in docs:
  214. context = None
  215. collection = doc.get("collection_name")
  216. if collection:
  217. collection = [collection]
  218. else:
  219. collection = doc.get("collection_names", [])
  220. collection = set(collection).difference(extracted_collections)
  221. if not collection:
  222. log.debug(f"skipping {doc} as it has already been extracted")
  223. continue
  224. try:
  225. if doc["type"] == "text":
  226. context = doc["content"]
  227. elif doc["type"] == "collection":
  228. context = query_embeddings_collection(
  229. collection_names=doc["collection_names"],
  230. query=query,
  231. k=k,
  232. r=r,
  233. embeddings_function=embeddings_function,
  234. reranking_function=reranking_function,
  235. hybrid=hybrid,
  236. )
  237. else:
  238. context = query_embeddings_doc(
  239. collection_name=doc["collection_name"],
  240. query=query,
  241. k=k,
  242. r=r,
  243. embeddings_function=embeddings_function,
  244. reranking_function=reranking_function,
  245. hybrid=hybrid,
  246. )
  247. except Exception as e:
  248. log.exception(e)
  249. context = None
  250. if context:
  251. relevant_contexts.append(context)
  252. extracted_collections.extend(collection)
  253. context_string = ""
  254. for context in relevant_contexts:
  255. items = context["documents"][0]
  256. context_string += "\n\n".join(items)
  257. context_string = context_string.strip()
  258. ra_content = rag_template(
  259. template=template,
  260. context=context_string,
  261. query=query,
  262. )
  263. log.debug(f"ra_content: {ra_content}")
  264. if content_type == "list":
  265. new_content = []
  266. for content_item in user_message["content"]:
  267. if content_item["type"] == "text":
  268. # Update the text item's content with ra_content
  269. new_content.append({"type": "text", "text": ra_content})
  270. else:
  271. # Keep other types of content as they are
  272. new_content.append(content_item)
  273. new_user_message = {**user_message, "content": new_content}
  274. else:
  275. new_user_message = {
  276. **user_message,
  277. "content": ra_content,
  278. }
  279. messages[last_user_message_idx] = new_user_message
  280. return messages
  281. def get_model_path(model: str, update_model: bool = False):
  282. # Construct huggingface_hub kwargs with local_files_only to return the snapshot path
  283. cache_dir = os.getenv("SENTENCE_TRANSFORMERS_HOME")
  284. local_files_only = not update_model
  285. snapshot_kwargs = {
  286. "cache_dir": cache_dir,
  287. "local_files_only": local_files_only,
  288. }
  289. log.debug(f"model: {model}")
  290. log.debug(f"snapshot_kwargs: {snapshot_kwargs}")
  291. # Inspiration from upstream sentence_transformers
  292. if (
  293. os.path.exists(model)
  294. or ("\\" in model or model.count("/") > 1)
  295. and local_files_only
  296. ):
  297. # If fully qualified path exists, return input, else set repo_id
  298. return model
  299. elif "/" not in model:
  300. # Set valid repo_id for model short-name
  301. model = "sentence-transformers" + "/" + model
  302. snapshot_kwargs["repo_id"] = model
  303. # Attempt to query the huggingface_hub library to determine the local path and/or to update
  304. try:
  305. model_repo_path = snapshot_download(**snapshot_kwargs)
  306. log.debug(f"model_repo_path: {model_repo_path}")
  307. return model_repo_path
  308. except Exception as e:
  309. log.exception(f"Cannot determine model snapshot path: {e}")
  310. return model
  311. def generate_openai_embeddings(
  312. model: str, text: str, key: str, url: str = "https://api.openai.com/v1"
  313. ):
  314. try:
  315. r = requests.post(
  316. f"{url}/embeddings",
  317. headers={
  318. "Content-Type": "application/json",
  319. "Authorization": f"Bearer {key}",
  320. },
  321. json={"input": text, "model": model},
  322. )
  323. r.raise_for_status()
  324. data = r.json()
  325. if "data" in data:
  326. return data["data"][0]["embedding"]
  327. else:
  328. raise "Something went wrong :/"
  329. except Exception as e:
  330. print(e)
  331. return None
  332. from typing import Any
  333. from langchain_core.retrievers import BaseRetriever
  334. from langchain_core.callbacks import CallbackManagerForRetrieverRun
  335. class ChromaRetriever(BaseRetriever):
  336. collection: Any
  337. embeddings_function: Any
  338. top_n: int
  339. def _get_relevant_documents(
  340. self,
  341. query: str,
  342. *,
  343. run_manager: CallbackManagerForRetrieverRun,
  344. ) -> List[Document]:
  345. query_embeddings = self.embeddings_function(query)
  346. results = self.collection.query(
  347. query_embeddings=[query_embeddings],
  348. n_results=self.top_n,
  349. )
  350. ids = results["ids"][0]
  351. metadatas = results["metadatas"][0]
  352. documents = results["documents"][0]
  353. return [
  354. Document(
  355. metadata=metadatas[idx],
  356. page_content=documents[idx],
  357. )
  358. for idx in range(len(ids))
  359. ]
  360. import operator
  361. from typing import Optional, Sequence
  362. from langchain_core.documents import BaseDocumentCompressor, Document
  363. from langchain_core.callbacks import Callbacks
  364. from langchain_core.pydantic_v1 import Extra
  365. from sentence_transformers import util
  366. class RerankCompressor(BaseDocumentCompressor):
  367. embeddings_function: Any
  368. reranking_function: Any
  369. r_score: float
  370. top_n: int
  371. class Config:
  372. extra = Extra.forbid
  373. arbitrary_types_allowed = True
  374. def compress_documents(
  375. self,
  376. documents: Sequence[Document],
  377. query: str,
  378. callbacks: Optional[Callbacks] = None,
  379. ) -> Sequence[Document]:
  380. if self.reranking_function:
  381. scores = self.reranking_function.predict(
  382. [(query, doc.page_content) for doc in documents]
  383. )
  384. else:
  385. query_embedding = self.embeddings_function(query)
  386. document_embedding = self.embeddings_function(
  387. [doc.page_content for doc in documents]
  388. )
  389. scores = util.cos_sim(query_embedding, document_embedding)[0]
  390. docs_with_scores = list(zip(documents, scores.tolist()))
  391. if self.r_score:
  392. docs_with_scores = [
  393. (d, s) for d, s in docs_with_scores if s >= self.r_score
  394. ]
  395. reverse = self.reranking_function is not None
  396. result = sorted(docs_with_scores, key=operator.itemgetter(1), reverse=reverse)
  397. final_results = []
  398. for doc, doc_score in result[: self.top_n]:
  399. metadata = doc.metadata
  400. metadata["score"] = doc_score
  401. doc = Document(
  402. page_content=doc.page_content,
  403. metadata=metadata,
  404. )
  405. final_results.append(doc)
  406. return final_results