main.py 29 KB


  1. from fastapi import (
  2. FastAPI,
  3. Depends,
  4. HTTPException,
  5. status,
  6. UploadFile,
  7. File,
  8. Form,
  9. )
  10. from fastapi.middleware.cors import CORSMiddleware
  11. import os, shutil, logging, re
  12. from pathlib import Path
  13. from typing import List
  14. from chromadb.utils.batch_utils import create_batches
  15. from langchain_community.document_loaders import (
  16. WebBaseLoader,
  17. TextLoader,
  18. PyPDFLoader,
  19. CSVLoader,
  20. BSHTMLLoader,
  21. Docx2txtLoader,
  22. UnstructuredEPubLoader,
  23. UnstructuredWordDocumentLoader,
  24. UnstructuredMarkdownLoader,
  25. UnstructuredXMLLoader,
  26. UnstructuredRSTLoader,
  27. UnstructuredExcelLoader,
  28. UnstructuredPowerPointLoader,
  29. YoutubeLoader,
  30. )
  31. from langchain.text_splitter import RecursiveCharacterTextSplitter
  32. import validators
  33. import urllib.parse
  34. import socket
  35. from pydantic import BaseModel
  36. from typing import Optional
  37. import mimetypes
  38. import uuid
  39. import json
  40. import sentence_transformers
  41. from apps.web.models.documents import (
  42. Documents,
  43. DocumentForm,
  44. DocumentResponse,
  45. )
  46. from apps.rag.utils import (
  47. get_model_path,
  48. get_embedding_function,
  49. query_doc,
  50. query_doc_with_hybrid_search,
  51. query_collection,
  52. query_collection_with_hybrid_search,
  53. )
  54. from utils.misc import (
  55. calculate_sha256,
  56. calculate_sha256_string,
  57. sanitize_filename,
  58. extract_folders_after_data_docs,
  59. )
  60. from utils.utils import get_current_user, get_admin_user
  61. from config import (
  62. ENV,
  63. SRC_LOG_LEVELS,
  64. UPLOAD_DIR,
  65. DOCS_DIR,
  66. RAG_TOP_K,
  67. RAG_RELEVANCE_THRESHOLD,
  68. RAG_EMBEDDING_ENGINE,
  69. RAG_EMBEDDING_MODEL,
  70. RAG_EMBEDDING_MODEL_AUTO_UPDATE,
  71. RAG_EMBEDDING_MODEL_TRUST_REMOTE_CODE,
  72. ENABLE_RAG_HYBRID_SEARCH,
  73. ENABLE_RAG_WEB_LOADER_SSL_VERIFICATION,
  74. RAG_RERANKING_MODEL,
  75. PDF_EXTRACT_IMAGES,
  76. RAG_RERANKING_MODEL_AUTO_UPDATE,
  77. RAG_RERANKING_MODEL_TRUST_REMOTE_CODE,
  78. RAG_OPENAI_API_BASE_URL,
  79. RAG_OPENAI_API_KEY,
  80. DEVICE_TYPE,
  81. CHROMA_CLIENT,
  82. CHUNK_SIZE,
  83. CHUNK_OVERLAP,
  84. RAG_TEMPLATE,
  85. ENABLE_RAG_LOCAL_WEB_FETCH,
  86. YOUTUBE_LOADER_LANGUAGE,
  87. AppConfig,
  88. )
  89. from constants import ERROR_MESSAGES
  90. log = logging.getLogger(__name__)
  91. log.setLevel(SRC_LOG_LEVELS["RAG"])
  92. app = FastAPI()
  93. app.state.config = AppConfig()
  94. app.state.config.TOP_K = RAG_TOP_K
  95. app.state.config.RELEVANCE_THRESHOLD = RAG_RELEVANCE_THRESHOLD
  96. app.state.config.ENABLE_RAG_HYBRID_SEARCH = ENABLE_RAG_HYBRID_SEARCH
  97. app.state.config.ENABLE_RAG_WEB_LOADER_SSL_VERIFICATION = (
  98. ENABLE_RAG_WEB_LOADER_SSL_VERIFICATION
  99. )
  100. app.state.config.CHUNK_SIZE = CHUNK_SIZE
  101. app.state.config.CHUNK_OVERLAP = CHUNK_OVERLAP
  102. app.state.config.RAG_EMBEDDING_ENGINE = RAG_EMBEDDING_ENGINE
  103. app.state.config.RAG_EMBEDDING_MODEL = RAG_EMBEDDING_MODEL
  104. app.state.config.RAG_RERANKING_MODEL = RAG_RERANKING_MODEL
  105. app.state.config.RAG_TEMPLATE = RAG_TEMPLATE
  106. app.state.config.OPENAI_API_BASE_URL = RAG_OPENAI_API_BASE_URL
  107. app.state.config.OPENAI_API_KEY = RAG_OPENAI_API_KEY
  108. app.state.config.PDF_EXTRACT_IMAGES = PDF_EXTRACT_IMAGES
  109. app.state.config.YOUTUBE_LOADER_LANGUAGE = YOUTUBE_LOADER_LANGUAGE
  110. app.state.YOUTUBE_LOADER_TRANSLATION = None
  111. def update_embedding_model(
  112. embedding_model: str,
  113. update_model: bool = False,
  114. ):
  115. if embedding_model and app.state.config.RAG_EMBEDDING_ENGINE == "":
  116. app.state.sentence_transformer_ef = sentence_transformers.SentenceTransformer(
  117. get_model_path(embedding_model, update_model),
  118. device=DEVICE_TYPE,
  119. trust_remote_code=RAG_EMBEDDING_MODEL_TRUST_REMOTE_CODE,
  120. )
  121. else:
  122. app.state.sentence_transformer_ef = None
  123. def update_reranking_model(
  124. reranking_model: str,
  125. update_model: bool = False,
  126. ):
  127. if reranking_model:
  128. app.state.sentence_transformer_rf = sentence_transformers.CrossEncoder(
  129. get_model_path(reranking_model, update_model),
  130. device=DEVICE_TYPE,
  131. trust_remote_code=RAG_RERANKING_MODEL_TRUST_REMOTE_CODE,
  132. )
  133. else:
  134. app.state.sentence_transformer_rf = None
  135. update_embedding_model(
  136. app.state.config.RAG_EMBEDDING_MODEL,
  137. RAG_EMBEDDING_MODEL_AUTO_UPDATE,
  138. )
  139. update_reranking_model(
  140. app.state.config.RAG_RERANKING_MODEL,
  141. RAG_RERANKING_MODEL_AUTO_UPDATE,
  142. )
  143. app.state.EMBEDDING_FUNCTION = get_embedding_function(
  144. app.state.config.RAG_EMBEDDING_ENGINE,
  145. app.state.config.RAG_EMBEDDING_MODEL,
  146. app.state.sentence_transformer_ef,
  147. app.state.config.OPENAI_API_KEY,
  148. app.state.config.OPENAI_API_BASE_URL,
  149. )
  150. origins = ["*"]
  151. app.add_middleware(
  152. CORSMiddleware,
  153. allow_origins=origins,
  154. allow_credentials=True,
  155. allow_methods=["*"],
  156. allow_headers=["*"],
  157. )
  158. class CollectionNameForm(BaseModel):
  159. collection_name: Optional[str] = "test"
  160. class UrlForm(CollectionNameForm):
  161. url: str
  162. @app.get("/")
  163. async def get_status():
  164. return {
  165. "status": True,
  166. "chunk_size": app.state.config.CHUNK_SIZE,
  167. "chunk_overlap": app.state.config.CHUNK_OVERLAP,
  168. "template": app.state.config.RAG_TEMPLATE,
  169. "embedding_engine": app.state.config.RAG_EMBEDDING_ENGINE,
  170. "embedding_model": app.state.config.RAG_EMBEDDING_MODEL,
  171. "reranking_model": app.state.config.RAG_RERANKING_MODEL,
  172. }
  173. @app.get("/embedding")
  174. async def get_embedding_config(user=Depends(get_admin_user)):
  175. return {
  176. "status": True,
  177. "embedding_engine": app.state.config.RAG_EMBEDDING_ENGINE,
  178. "embedding_model": app.state.config.RAG_EMBEDDING_MODEL,
  179. "openai_config": {
  180. "url": app.state.config.OPENAI_API_BASE_URL,
  181. "key": app.state.config.OPENAI_API_KEY,
  182. },
  183. }
  184. @app.get("/reranking")
  185. async def get_reraanking_config(user=Depends(get_admin_user)):
  186. return {
  187. "status": True,
  188. "reranking_model": app.state.config.RAG_RERANKING_MODEL,
  189. }
  190. class OpenAIConfigForm(BaseModel):
  191. url: str
  192. key: str
  193. class EmbeddingModelUpdateForm(BaseModel):
  194. openai_config: Optional[OpenAIConfigForm] = None
  195. embedding_engine: str
  196. embedding_model: str
  197. @app.post("/embedding/update")
  198. async def update_embedding_config(
  199. form_data: EmbeddingModelUpdateForm, user=Depends(get_admin_user)
  200. ):
  201. log.info(
  202. f"Updating embedding model: {app.state.config.RAG_EMBEDDING_MODEL} to {form_data.embedding_model}"
  203. )
  204. try:
  205. app.state.config.RAG_EMBEDDING_ENGINE = form_data.embedding_engine
  206. app.state.config.RAG_EMBEDDING_MODEL = form_data.embedding_model
  207. if app.state.config.RAG_EMBEDDING_ENGINE in ["ollama", "openai"]:
  208. if form_data.openai_config != None:
  209. app.state.config.OPENAI_API_BASE_URL = form_data.openai_config.url
  210. app.state.config.OPENAI_API_KEY = form_data.openai_config.key
  211. update_embedding_model(app.state.config.RAG_EMBEDDING_MODEL)
  212. app.state.EMBEDDING_FUNCTION = get_embedding_function(
  213. app.state.config.RAG_EMBEDDING_ENGINE,
  214. app.state.config.RAG_EMBEDDING_MODEL,
  215. app.state.sentence_transformer_ef,
  216. app.state.config.OPENAI_API_KEY,
  217. app.state.config.OPENAI_API_BASE_URL,
  218. )
  219. return {
  220. "status": True,
  221. "embedding_engine": app.state.config.RAG_EMBEDDING_ENGINE,
  222. "embedding_model": app.state.config.RAG_EMBEDDING_MODEL,
  223. "openai_config": {
  224. "url": app.state.config.OPENAI_API_BASE_URL,
  225. "key": app.state.config.OPENAI_API_KEY,
  226. },
  227. }
  228. except Exception as e:
  229. log.exception(f"Problem updating embedding model: {e}")
  230. raise HTTPException(
  231. status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
  232. detail=ERROR_MESSAGES.DEFAULT(e),
  233. )
  234. class RerankingModelUpdateForm(BaseModel):
  235. reranking_model: str
  236. @app.post("/reranking/update")
  237. async def update_reranking_config(
  238. form_data: RerankingModelUpdateForm, user=Depends(get_admin_user)
  239. ):
  240. log.info(
  241. f"Updating reranking model: {app.state.config.RAG_RERANKING_MODEL} to {form_data.reranking_model}"
  242. )
  243. try:
  244. app.state.config.RAG_RERANKING_MODEL = form_data.reranking_model
  245. update_reranking_model(app.state.config.RAG_RERANKING_MODEL), True
  246. return {
  247. "status": True,
  248. "reranking_model": app.state.config.RAG_RERANKING_MODEL,
  249. }
  250. except Exception as e:
  251. log.exception(f"Problem updating reranking model: {e}")
  252. raise HTTPException(
  253. status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
  254. detail=ERROR_MESSAGES.DEFAULT(e),
  255. )
  256. @app.get("/config")
  257. async def get_rag_config(user=Depends(get_admin_user)):
  258. return {
  259. "status": True,
  260. "pdf_extract_images": app.state.config.PDF_EXTRACT_IMAGES,
  261. "chunk": {
  262. "chunk_size": app.state.config.CHUNK_SIZE,
  263. "chunk_overlap": app.state.config.CHUNK_OVERLAP,
  264. },
  265. "web_loader_ssl_verification": app.state.config.ENABLE_RAG_WEB_LOADER_SSL_VERIFICATION,
  266. "youtube": {
  267. "language": app.state.config.YOUTUBE_LOADER_LANGUAGE,
  268. "translation": app.state.YOUTUBE_LOADER_TRANSLATION,
  269. },
  270. }
  271. class ChunkParamUpdateForm(BaseModel):
  272. chunk_size: int
  273. chunk_overlap: int
  274. class YoutubeLoaderConfig(BaseModel):
  275. language: List[str]
  276. translation: Optional[str] = None
  277. class ConfigUpdateForm(BaseModel):
  278. pdf_extract_images: Optional[bool] = None
  279. chunk: Optional[ChunkParamUpdateForm] = None
  280. web_loader_ssl_verification: Optional[bool] = None
  281. youtube: Optional[YoutubeLoaderConfig] = None
  282. @app.post("/config/update")
  283. async def update_rag_config(form_data: ConfigUpdateForm, user=Depends(get_admin_user)):
  284. app.state.config.PDF_EXTRACT_IMAGES = (
  285. form_data.pdf_extract_images
  286. if form_data.pdf_extract_images is not None
  287. else app.state.config.PDF_EXTRACT_IMAGES
  288. )
  289. app.state.config.CHUNK_SIZE = (
  290. form_data.chunk.chunk_size
  291. if form_data.chunk is not None
  292. else app.state.config.CHUNK_SIZE
  293. )
  294. app.state.config.CHUNK_OVERLAP = (
  295. form_data.chunk.chunk_overlap
  296. if form_data.chunk is not None
  297. else app.state.config.CHUNK_OVERLAP
  298. )
  299. app.state.config.ENABLE_RAG_WEB_LOADER_SSL_VERIFICATION = (
  300. form_data.web_loader_ssl_verification
  301. if form_data.web_loader_ssl_verification != None
  302. else app.state.config.ENABLE_RAG_WEB_LOADER_SSL_VERIFICATION
  303. )
  304. app.state.config.YOUTUBE_LOADER_LANGUAGE = (
  305. form_data.youtube.language
  306. if form_data.youtube is not None
  307. else app.state.config.YOUTUBE_LOADER_LANGUAGE
  308. )
  309. app.state.YOUTUBE_LOADER_TRANSLATION = (
  310. form_data.youtube.translation
  311. if form_data.youtube is not None
  312. else app.state.YOUTUBE_LOADER_TRANSLATION
  313. )
  314. return {
  315. "status": True,
  316. "pdf_extract_images": app.state.config.PDF_EXTRACT_IMAGES,
  317. "chunk": {
  318. "chunk_size": app.state.config.CHUNK_SIZE,
  319. "chunk_overlap": app.state.config.CHUNK_OVERLAP,
  320. },
  321. "web_loader_ssl_verification": app.state.config.ENABLE_RAG_WEB_LOADER_SSL_VERIFICATION,
  322. "youtube": {
  323. "language": app.state.config.YOUTUBE_LOADER_LANGUAGE,
  324. "translation": app.state.YOUTUBE_LOADER_TRANSLATION,
  325. },
  326. }
  327. @app.get("/template")
  328. async def get_rag_template(user=Depends(get_current_user)):
  329. return {
  330. "status": True,
  331. "template": app.state.config.RAG_TEMPLATE,
  332. }
  333. @app.get("/query/settings")
  334. async def get_query_settings(user=Depends(get_admin_user)):
  335. return {
  336. "status": True,
  337. "template": app.state.config.RAG_TEMPLATE,
  338. "k": app.state.config.TOP_K,
  339. "r": app.state.config.RELEVANCE_THRESHOLD,
  340. "hybrid": app.state.config.ENABLE_RAG_HYBRID_SEARCH,
  341. }
  342. class QuerySettingsForm(BaseModel):
  343. k: Optional[int] = None
  344. r: Optional[float] = None
  345. template: Optional[str] = None
  346. hybrid: Optional[bool] = None
  347. @app.post("/query/settings/update")
  348. async def update_query_settings(
  349. form_data: QuerySettingsForm, user=Depends(get_admin_user)
  350. ):
  351. app.state.config.RAG_TEMPLATE = (
  352. form_data.template if form_data.template else RAG_TEMPLATE
  353. )
  354. app.state.config.TOP_K = form_data.k if form_data.k else 4
  355. app.state.config.RELEVANCE_THRESHOLD = form_data.r if form_data.r else 0.0
  356. app.state.config.ENABLE_RAG_HYBRID_SEARCH = (
  357. form_data.hybrid if form_data.hybrid else False
  358. )
  359. return {
  360. "status": True,
  361. "template": app.state.config.RAG_TEMPLATE,
  362. "k": app.state.config.TOP_K,
  363. "r": app.state.config.RELEVANCE_THRESHOLD,
  364. "hybrid": app.state.config.ENABLE_RAG_HYBRID_SEARCH,
  365. }
  366. class QueryDocForm(BaseModel):
  367. collection_name: str
  368. query: str
  369. k: Optional[int] = None
  370. r: Optional[float] = None
  371. hybrid: Optional[bool] = None
  372. @app.post("/query/doc")
  373. def query_doc_handler(
  374. form_data: QueryDocForm,
  375. user=Depends(get_current_user),
  376. ):
  377. try:
  378. if app.state.config.ENABLE_RAG_HYBRID_SEARCH:
  379. return query_doc_with_hybrid_search(
  380. collection_name=form_data.collection_name,
  381. query=form_data.query,
  382. embedding_function=app.state.EMBEDDING_FUNCTION,
  383. k=form_data.k if form_data.k else app.state.config.TOP_K,
  384. reranking_function=app.state.sentence_transformer_rf,
  385. r=(
  386. form_data.r if form_data.r else app.state.config.RELEVANCE_THRESHOLD
  387. ),
  388. )
  389. else:
  390. return query_doc(
  391. collection_name=form_data.collection_name,
  392. query=form_data.query,
  393. embedding_function=app.state.EMBEDDING_FUNCTION,
  394. k=form_data.k if form_data.k else app.state.config.TOP_K,
  395. )
  396. except Exception as e:
  397. log.exception(e)
  398. raise HTTPException(
  399. status_code=status.HTTP_400_BAD_REQUEST,
  400. detail=ERROR_MESSAGES.DEFAULT(e),
  401. )
  402. class QueryCollectionsForm(BaseModel):
  403. collection_names: List[str]
  404. query: str
  405. k: Optional[int] = None
  406. r: Optional[float] = None
  407. hybrid: Optional[bool] = None
  408. @app.post("/query/collection")
  409. def query_collection_handler(
  410. form_data: QueryCollectionsForm,
  411. user=Depends(get_current_user),
  412. ):
  413. try:
  414. if app.state.config.ENABLE_RAG_HYBRID_SEARCH:
  415. return query_collection_with_hybrid_search(
  416. collection_names=form_data.collection_names,
  417. query=form_data.query,
  418. embedding_function=app.state.EMBEDDING_FUNCTION,
  419. k=form_data.k if form_data.k else app.state.config.TOP_K,
  420. reranking_function=app.state.sentence_transformer_rf,
  421. r=(
  422. form_data.r if form_data.r else app.state.config.RELEVANCE_THRESHOLD
  423. ),
  424. )
  425. else:
  426. return query_collection(
  427. collection_names=form_data.collection_names,
  428. query=form_data.query,
  429. embedding_function=app.state.EMBEDDING_FUNCTION,
  430. k=form_data.k if form_data.k else app.state.config.TOP_K,
  431. )
  432. except Exception as e:
  433. log.exception(e)
  434. raise HTTPException(
  435. status_code=status.HTTP_400_BAD_REQUEST,
  436. detail=ERROR_MESSAGES.DEFAULT(e),
  437. )
  438. @app.post("/youtube")
  439. def store_youtube_video(form_data: UrlForm, user=Depends(get_current_user)):
  440. try:
  441. loader = YoutubeLoader.from_youtube_url(
  442. form_data.url,
  443. add_video_info=True,
  444. language=app.state.config.YOUTUBE_LOADER_LANGUAGE,
  445. translation=app.state.YOUTUBE_LOADER_TRANSLATION,
  446. )
  447. data = loader.load()
  448. collection_name = form_data.collection_name
  449. if collection_name == "":
  450. collection_name = calculate_sha256_string(form_data.url)[:63]
  451. store_data_in_vector_db(data, collection_name, overwrite=True)
  452. return {
  453. "status": True,
  454. "collection_name": collection_name,
  455. "filename": form_data.url,
  456. }
  457. except Exception as e:
  458. log.exception(e)
  459. raise HTTPException(
  460. status_code=status.HTTP_400_BAD_REQUEST,
  461. detail=ERROR_MESSAGES.DEFAULT(e),
  462. )
  463. @app.post("/web")
  464. def store_web(form_data: UrlForm, user=Depends(get_current_user)):
  465. # "https://www.gutenberg.org/files/1727/1727-h/1727-h.htm"
  466. try:
  467. loader = get_web_loader(
  468. form_data.url,
  469. verify_ssl=app.state.config.ENABLE_RAG_WEB_LOADER_SSL_VERIFICATION,
  470. )
  471. data = loader.load()
  472. collection_name = form_data.collection_name
  473. if collection_name == "":
  474. collection_name = calculate_sha256_string(form_data.url)[:63]
  475. store_data_in_vector_db(data, collection_name, overwrite=True)
  476. return {
  477. "status": True,
  478. "collection_name": collection_name,
  479. "filename": form_data.url,
  480. }
  481. except Exception as e:
  482. log.exception(e)
  483. raise HTTPException(
  484. status_code=status.HTTP_400_BAD_REQUEST,
  485. detail=ERROR_MESSAGES.DEFAULT(e),
  486. )
  487. def get_web_loader(url: str, verify_ssl: bool = True):
  488. # Check if the URL is valid
  489. if isinstance(validators.url(url), validators.ValidationError):
  490. raise ValueError(ERROR_MESSAGES.INVALID_URL)
  491. if not ENABLE_RAG_LOCAL_WEB_FETCH:
  492. # Local web fetch is disabled, filter out any URLs that resolve to private IP addresses
  493. parsed_url = urllib.parse.urlparse(url)
  494. # Get IPv4 and IPv6 addresses
  495. ipv4_addresses, ipv6_addresses = resolve_hostname(parsed_url.hostname)
  496. # Check if any of the resolved addresses are private
  497. # This is technically still vulnerable to DNS rebinding attacks, as we don't control WebBaseLoader
  498. for ip in ipv4_addresses:
  499. if validators.ipv4(ip, private=True):
  500. raise ValueError(ERROR_MESSAGES.INVALID_URL)
  501. for ip in ipv6_addresses:
  502. if validators.ipv6(ip, private=True):
  503. raise ValueError(ERROR_MESSAGES.INVALID_URL)
  504. return WebBaseLoader(url, verify_ssl=verify_ssl)
  505. def resolve_hostname(hostname):
  506. # Get address information
  507. addr_info = socket.getaddrinfo(hostname, None)
  508. # Extract IP addresses from address information
  509. ipv4_addresses = [info[4][0] for info in addr_info if info[0] == socket.AF_INET]
  510. ipv6_addresses = [info[4][0] for info in addr_info if info[0] == socket.AF_INET6]
  511. return ipv4_addresses, ipv6_addresses
  512. def store_data_in_vector_db(data, collection_name, overwrite: bool = False) -> bool:
  513. text_splitter = RecursiveCharacterTextSplitter(
  514. chunk_size=app.state.config.CHUNK_SIZE,
  515. chunk_overlap=app.state.config.CHUNK_OVERLAP,
  516. add_start_index=True,
  517. )
  518. docs = text_splitter.split_documents(data)
  519. if len(docs) > 0:
  520. log.info(f"store_data_in_vector_db {docs}")
  521. return store_docs_in_vector_db(docs, collection_name, overwrite), None
  522. else:
  523. raise ValueError(ERROR_MESSAGES.EMPTY_CONTENT)
  524. def store_text_in_vector_db(
  525. text, metadata, collection_name, overwrite: bool = False
  526. ) -> bool:
  527. text_splitter = RecursiveCharacterTextSplitter(
  528. chunk_size=app.state.config.CHUNK_SIZE,
  529. chunk_overlap=app.state.config.CHUNK_OVERLAP,
  530. add_start_index=True,
  531. )
  532. docs = text_splitter.create_documents([text], metadatas=[metadata])
  533. return store_docs_in_vector_db(docs, collection_name, overwrite)
  534. def store_docs_in_vector_db(docs, collection_name, overwrite: bool = False) -> bool:
  535. log.info(f"store_docs_in_vector_db {docs} {collection_name}")
  536. texts = [doc.page_content for doc in docs]
  537. metadatas = [doc.metadata for doc in docs]
  538. try:
  539. if overwrite:
  540. for collection in CHROMA_CLIENT.list_collections():
  541. if collection_name == collection.name:
  542. log.info(f"deleting existing collection {collection_name}")
  543. CHROMA_CLIENT.delete_collection(name=collection_name)
  544. collection = CHROMA_CLIENT.create_collection(name=collection_name)
  545. embedding_func = get_embedding_function(
  546. app.state.config.RAG_EMBEDDING_ENGINE,
  547. app.state.config.RAG_EMBEDDING_MODEL,
  548. app.state.sentence_transformer_ef,
  549. app.state.config.OPENAI_API_KEY,
  550. app.state.config.OPENAI_API_BASE_URL,
  551. )
  552. embedding_texts = list(map(lambda x: x.replace("\n", " "), texts))
  553. embeddings = embedding_func(embedding_texts)
  554. for batch in create_batches(
  555. api=CHROMA_CLIENT,
  556. ids=[str(uuid.uuid4()) for _ in texts],
  557. metadatas=metadatas,
  558. embeddings=embeddings,
  559. documents=texts,
  560. ):
  561. collection.add(*batch)
  562. return True
  563. except Exception as e:
  564. log.exception(e)
  565. if e.__class__.__name__ == "UniqueConstraintError":
  566. return True
  567. return False
  568. def get_loader(filename: str, file_content_type: str, file_path: str):
  569. file_ext = filename.split(".")[-1].lower()
  570. known_type = True
  571. known_source_ext = [
  572. "go",
  573. "py",
  574. "java",
  575. "sh",
  576. "bat",
  577. "ps1",
  578. "cmd",
  579. "js",
  580. "ts",
  581. "css",
  582. "cpp",
  583. "hpp",
  584. "h",
  585. "c",
  586. "cs",
  587. "sql",
  588. "log",
  589. "ini",
  590. "pl",
  591. "pm",
  592. "r",
  593. "dart",
  594. "dockerfile",
  595. "env",
  596. "php",
  597. "hs",
  598. "hsc",
  599. "lua",
  600. "nginxconf",
  601. "conf",
  602. "m",
  603. "mm",
  604. "plsql",
  605. "perl",
  606. "rb",
  607. "rs",
  608. "db2",
  609. "scala",
  610. "bash",
  611. "swift",
  612. "vue",
  613. "svelte",
  614. ]
  615. if file_ext == "pdf":
  616. loader = PyPDFLoader(
  617. file_path, extract_images=app.state.config.PDF_EXTRACT_IMAGES
  618. )
  619. elif file_ext == "csv":
  620. loader = CSVLoader(file_path)
  621. elif file_ext == "rst":
  622. loader = UnstructuredRSTLoader(file_path, mode="elements")
  623. elif file_ext == "xml":
  624. loader = UnstructuredXMLLoader(file_path)
  625. elif file_ext in ["htm", "html"]:
  626. loader = BSHTMLLoader(file_path, open_encoding="unicode_escape")
  627. elif file_ext == "md":
  628. loader = UnstructuredMarkdownLoader(file_path)
  629. elif file_content_type == "application/epub+zip":
  630. loader = UnstructuredEPubLoader(file_path)
  631. elif (
  632. file_content_type
  633. == "application/vnd.openxmlformats-officedocument.wordprocessingml.document"
  634. or file_ext in ["doc", "docx"]
  635. ):
  636. loader = Docx2txtLoader(file_path)
  637. elif file_content_type in [
  638. "application/vnd.ms-excel",
  639. "application/vnd.openxmlformats-officedocument.spreadsheetml.sheet",
  640. ] or file_ext in ["xls", "xlsx"]:
  641. loader = UnstructuredExcelLoader(file_path)
  642. elif file_content_type in [
  643. "application/vnd.ms-powerpoint",
  644. "application/vnd.openxmlformats-officedocument.presentationml.presentation",
  645. ] or file_ext in ["ppt", "pptx"]:
  646. loader = UnstructuredPowerPointLoader(file_path)
  647. elif file_ext in known_source_ext or (
  648. file_content_type and file_content_type.find("text/") >= 0
  649. ):
  650. loader = TextLoader(file_path, autodetect_encoding=True)
  651. else:
  652. loader = TextLoader(file_path, autodetect_encoding=True)
  653. known_type = False
  654. return loader, known_type
  655. @app.post("/doc")
  656. def store_doc(
  657. collection_name: Optional[str] = Form(None),
  658. file: UploadFile = File(...),
  659. user=Depends(get_current_user),
  660. ):
  661. # "https://www.gutenberg.org/files/1727/1727-h/1727-h.htm"
  662. log.info(f"file.content_type: {file.content_type}")
  663. try:
  664. unsanitized_filename = file.filename
  665. filename = os.path.basename(unsanitized_filename)
  666. file_path = f"{UPLOAD_DIR}/{filename}"
  667. contents = file.file.read()
  668. with open(file_path, "wb") as f:
  669. f.write(contents)
  670. f.close()
  671. f = open(file_path, "rb")
  672. if collection_name == None:
  673. collection_name = calculate_sha256(f)[:63]
  674. f.close()
  675. loader, known_type = get_loader(filename, file.content_type, file_path)
  676. data = loader.load()
  677. try:
  678. result = store_data_in_vector_db(data, collection_name)
  679. if result:
  680. return {
  681. "status": True,
  682. "collection_name": collection_name,
  683. "filename": filename,
  684. "known_type": known_type,
  685. }
  686. except Exception as e:
  687. raise HTTPException(
  688. status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
  689. detail=e,
  690. )
  691. except Exception as e:
  692. log.exception(e)
  693. if "No pandoc was found" in str(e):
  694. raise HTTPException(
  695. status_code=status.HTTP_400_BAD_REQUEST,
  696. detail=ERROR_MESSAGES.PANDOC_NOT_INSTALLED,
  697. )
  698. else:
  699. raise HTTPException(
  700. status_code=status.HTTP_400_BAD_REQUEST,
  701. detail=ERROR_MESSAGES.DEFAULT(e),
  702. )
  703. class TextRAGForm(BaseModel):
  704. name: str
  705. content: str
  706. collection_name: Optional[str] = None
  707. @app.post("/text")
  708. def store_text(
  709. form_data: TextRAGForm,
  710. user=Depends(get_current_user),
  711. ):
  712. collection_name = form_data.collection_name
  713. if collection_name == None:
  714. collection_name = calculate_sha256_string(form_data.content)
  715. result = store_text_in_vector_db(
  716. form_data.content,
  717. metadata={"name": form_data.name, "created_by": user.id},
  718. collection_name=collection_name,
  719. )
  720. if result:
  721. return {"status": True, "collection_name": collection_name}
  722. else:
  723. raise HTTPException(
  724. status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
  725. detail=ERROR_MESSAGES.DEFAULT(),
  726. )
  727. @app.get("/scan")
  728. def scan_docs_dir(user=Depends(get_admin_user)):
  729. for path in Path(DOCS_DIR).rglob("./**/*"):
  730. try:
  731. if path.is_file() and not path.name.startswith("."):
  732. tags = extract_folders_after_data_docs(path)
  733. filename = path.name
  734. file_content_type = mimetypes.guess_type(path)
  735. f = open(path, "rb")
  736. collection_name = calculate_sha256(f)[:63]
  737. f.close()
  738. loader, known_type = get_loader(
  739. filename, file_content_type[0], str(path)
  740. )
  741. data = loader.load()
  742. try:
  743. result = store_data_in_vector_db(data, collection_name)
  744. if result:
  745. sanitized_filename = sanitize_filename(filename)
  746. doc = Documents.get_doc_by_name(sanitized_filename)
  747. if doc == None:
  748. doc = Documents.insert_new_doc(
  749. user.id,
  750. DocumentForm(
  751. **{
  752. "name": sanitized_filename,
  753. "title": filename,
  754. "collection_name": collection_name,
  755. "filename": filename,
  756. "content": (
  757. json.dumps(
  758. {
  759. "tags": list(
  760. map(
  761. lambda name: {"name": name},
  762. tags,
  763. )
  764. )
  765. }
  766. )
  767. if len(tags)
  768. else "{}"
  769. ),
  770. }
  771. ),
  772. )
  773. except Exception as e:
  774. log.exception(e)
  775. pass
  776. except Exception as e:
  777. log.exception(e)
  778. return True
  779. @app.get("/reset/db")
  780. def reset_vector_db(user=Depends(get_admin_user)):
  781. CHROMA_CLIENT.reset()
  782. @app.get("/reset")
  783. def reset(user=Depends(get_admin_user)) -> bool:
  784. folder = f"{UPLOAD_DIR}"
  785. for filename in os.listdir(folder):
  786. file_path = os.path.join(folder, filename)
  787. try:
  788. if os.path.isfile(file_path) or os.path.islink(file_path):
  789. os.unlink(file_path)
  790. elif os.path.isdir(file_path):
  791. shutil.rmtree(file_path)
  792. except Exception as e:
  793. log.error("Failed to delete %s. Reason: %s" % (file_path, e))
  794. try:
  795. CHROMA_CLIENT.reset()
  796. except Exception as e:
  797. log.exception(e)
  798. return True
  799. if ENV == "dev":
  800. @app.get("/ef")
  801. async def get_embeddings():
  802. return {"result": app.state.EMBEDDING_FUNCTION("hello world")}
  803. @app.get("/ef/{text}")
  804. async def get_embeddings_text(text: str):
  805. return {"result": app.state.EMBEDDING_FUNCTION(text)}