main.py 23 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597598599600601602603604605606607608609610611612613614615616617618619620621622623624625626627628629630631632633634635636637638639640641642643644645646647648649650651652653654655656657658659660661662663664665666667668669670671672673674675676677678679680681682683684685686687688689690691692693694695696697698699700701702703704705706707708709710711712713714715716717718719720721722723724725726727728729730731732733734735736737738739740741742743744745746747748749750751752753754755756757758759760761762763764765766767768769770771772773774775776777778779780781782783784785786787788789790791792793794795796797
  1. from fastapi import (
  2. FastAPI,
  3. Depends,
  4. HTTPException,
  5. status,
  6. UploadFile,
  7. File,
  8. Form,
  9. )
  10. from fastapi.middleware.cors import CORSMiddleware
  11. import os, shutil, logging, re
  12. from pathlib import Path
  13. from typing import List
  14. from chromadb.utils.batch_utils import create_batches
  15. from langchain_community.document_loaders import (
  16. WebBaseLoader,
  17. TextLoader,
  18. PyPDFLoader,
  19. CSVLoader,
  20. BSHTMLLoader,
  21. Docx2txtLoader,
  22. UnstructuredEPubLoader,
  23. UnstructuredWordDocumentLoader,
  24. UnstructuredMarkdownLoader,
  25. UnstructuredXMLLoader,
  26. UnstructuredRSTLoader,
  27. UnstructuredExcelLoader,
  28. )
  29. from langchain.text_splitter import RecursiveCharacterTextSplitter
  30. from pydantic import BaseModel
  31. from typing import Optional
  32. import mimetypes
  33. import uuid
  34. import json
  35. import sentence_transformers
  36. from apps.web.models.documents import (
  37. Documents,
  38. DocumentForm,
  39. DocumentResponse,
  40. )
  41. from apps.rag.utils import (
  42. get_model_path,
  43. query_embeddings_doc,
  44. get_embeddings_function,
  45. query_embeddings_collection,
  46. )
  47. from utils.misc import (
  48. calculate_sha256,
  49. calculate_sha256_string,
  50. sanitize_filename,
  51. extract_folders_after_data_docs,
  52. )
  53. from utils.utils import get_current_user, get_admin_user
  54. from config import (
  55. SRC_LOG_LEVELS,
  56. UPLOAD_DIR,
  57. DOCS_DIR,
  58. RAG_TOP_K,
  59. RAG_RELEVANCE_THRESHOLD,
  60. RAG_EMBEDDING_ENGINE,
  61. RAG_EMBEDDING_MODEL,
  62. RAG_EMBEDDING_MODEL_AUTO_UPDATE,
  63. RAG_EMBEDDING_MODEL_TRUST_REMOTE_CODE,
  64. ENABLE_RAG_HYBRID_SEARCH,
  65. RAG_RERANKING_MODEL,
  66. RAG_RERANKING_MODEL_AUTO_UPDATE,
  67. RAG_RERANKING_MODEL_TRUST_REMOTE_CODE,
  68. RAG_OPENAI_API_BASE_URL,
  69. RAG_OPENAI_API_KEY,
  70. DEVICE_TYPE,
  71. CHROMA_CLIENT,
  72. CHUNK_SIZE,
  73. CHUNK_OVERLAP,
  74. RAG_TEMPLATE,
  75. )
  76. from constants import ERROR_MESSAGES
  77. log = logging.getLogger(__name__)
  78. log.setLevel(SRC_LOG_LEVELS["RAG"])
  79. app = FastAPI()
  80. app.state.TOP_K = RAG_TOP_K
  81. app.state.RELEVANCE_THRESHOLD = RAG_RELEVANCE_THRESHOLD
  82. app.state.ENABLE_RAG_HYBRID_SEARCH = ENABLE_RAG_HYBRID_SEARCH
  83. app.state.CHUNK_SIZE = CHUNK_SIZE
  84. app.state.CHUNK_OVERLAP = CHUNK_OVERLAP
  85. app.state.RAG_EMBEDDING_ENGINE = RAG_EMBEDDING_ENGINE
  86. app.state.RAG_EMBEDDING_MODEL = RAG_EMBEDDING_MODEL
  87. app.state.RAG_RERANKING_MODEL = RAG_RERANKING_MODEL
  88. app.state.RAG_TEMPLATE = RAG_TEMPLATE
  89. app.state.OPENAI_API_BASE_URL = RAG_OPENAI_API_BASE_URL
  90. app.state.OPENAI_API_KEY = RAG_OPENAI_API_KEY
  91. app.state.PDF_EXTRACT_IMAGES = False
  92. def update_embedding_model(
  93. embedding_model: str,
  94. update_model: bool = False,
  95. ):
  96. if embedding_model and app.state.RAG_EMBEDDING_ENGINE == "":
  97. app.state.sentence_transformer_ef = sentence_transformers.SentenceTransformer(
  98. get_model_path(embedding_model, update_model),
  99. device=DEVICE_TYPE,
  100. trust_remote_code=RAG_EMBEDDING_MODEL_TRUST_REMOTE_CODE,
  101. )
  102. else:
  103. app.state.sentence_transformer_ef = None
  104. def update_reranking_model(
  105. reranking_model: str,
  106. update_model: bool = False,
  107. ):
  108. if reranking_model:
  109. app.state.sentence_transformer_rf = sentence_transformers.CrossEncoder(
  110. get_model_path(reranking_model, update_model),
  111. device=DEVICE_TYPE,
  112. trust_remote_code=RAG_RERANKING_MODEL_TRUST_REMOTE_CODE,
  113. )
  114. else:
  115. app.state.sentence_transformer_rf = None
  116. update_embedding_model(
  117. app.state.RAG_EMBEDDING_MODEL,
  118. RAG_EMBEDDING_MODEL_AUTO_UPDATE,
  119. )
  120. update_reranking_model(
  121. app.state.RAG_RERANKING_MODEL,
  122. RAG_RERANKING_MODEL_AUTO_UPDATE,
  123. )
  124. origins = ["*"]
  125. app.add_middleware(
  126. CORSMiddleware,
  127. allow_origins=origins,
  128. allow_credentials=True,
  129. allow_methods=["*"],
  130. allow_headers=["*"],
  131. )
  132. class CollectionNameForm(BaseModel):
  133. collection_name: Optional[str] = "test"
  134. class StoreWebForm(CollectionNameForm):
  135. url: str
  136. @app.get("/")
  137. async def get_status():
  138. return {
  139. "status": True,
  140. "chunk_size": app.state.CHUNK_SIZE,
  141. "chunk_overlap": app.state.CHUNK_OVERLAP,
  142. "template": app.state.RAG_TEMPLATE,
  143. "embedding_engine": app.state.RAG_EMBEDDING_ENGINE,
  144. "embedding_model": app.state.RAG_EMBEDDING_MODEL,
  145. "reranking_model": app.state.RAG_RERANKING_MODEL,
  146. }
  147. @app.get("/embedding")
  148. async def get_embedding_config(user=Depends(get_admin_user)):
  149. return {
  150. "status": True,
  151. "embedding_engine": app.state.RAG_EMBEDDING_ENGINE,
  152. "embedding_model": app.state.RAG_EMBEDDING_MODEL,
  153. "openai_config": {
  154. "url": app.state.OPENAI_API_BASE_URL,
  155. "key": app.state.OPENAI_API_KEY,
  156. },
  157. }
  158. @app.get("/reranking")
  159. async def get_reraanking_config(user=Depends(get_admin_user)):
  160. return {"status": True, "reranking_model": app.state.RAG_RERANKING_MODEL}
  161. class OpenAIConfigForm(BaseModel):
  162. url: str
  163. key: str
  164. class EmbeddingModelUpdateForm(BaseModel):
  165. openai_config: Optional[OpenAIConfigForm] = None
  166. embedding_engine: str
  167. embedding_model: str
  168. @app.post("/embedding/update")
  169. async def update_embedding_config(
  170. form_data: EmbeddingModelUpdateForm, user=Depends(get_admin_user)
  171. ):
  172. log.info(
  173. f"Updating embedding model: {app.state.RAG_EMBEDDING_MODEL} to {form_data.embedding_model}"
  174. )
  175. try:
  176. app.state.RAG_EMBEDDING_ENGINE = form_data.embedding_engine
  177. app.state.RAG_EMBEDDING_MODEL = form_data.embedding_model
  178. if app.state.RAG_EMBEDDING_ENGINE in ["ollama", "openai"]:
  179. if form_data.openai_config != None:
  180. app.state.OPENAI_API_BASE_URL = form_data.openai_config.url
  181. app.state.OPENAI_API_KEY = form_data.openai_config.key
  182. update_embedding_model(app.state.RAG_EMBEDDING_MODEL, True)
  183. return {
  184. "status": True,
  185. "embedding_engine": app.state.RAG_EMBEDDING_ENGINE,
  186. "embedding_model": app.state.RAG_EMBEDDING_MODEL,
  187. "openai_config": {
  188. "url": app.state.OPENAI_API_BASE_URL,
  189. "key": app.state.OPENAI_API_KEY,
  190. },
  191. }
  192. except Exception as e:
  193. log.exception(f"Problem updating embedding model: {e}")
  194. raise HTTPException(
  195. status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
  196. detail=ERROR_MESSAGES.DEFAULT(e),
  197. )
  198. class RerankingModelUpdateForm(BaseModel):
  199. reranking_model: str
  200. @app.post("/reranking/update")
  201. async def update_reranking_config(
  202. form_data: RerankingModelUpdateForm, user=Depends(get_admin_user)
  203. ):
  204. log.info(
  205. f"Updating reranking model: {app.state.RAG_RERANKING_MODEL} to {form_data.reranking_model}"
  206. )
  207. try:
  208. app.state.RAG_RERANKING_MODEL = form_data.reranking_model
  209. update_reranking_model(app.state.RAG_RERANKING_MODEL, True)
  210. return {
  211. "status": True,
  212. "reranking_model": app.state.RAG_RERANKING_MODEL,
  213. }
  214. except Exception as e:
  215. log.exception(f"Problem updating reranking model: {e}")
  216. raise HTTPException(
  217. status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
  218. detail=ERROR_MESSAGES.DEFAULT(e),
  219. )
  220. @app.get("/config")
  221. async def get_rag_config(user=Depends(get_admin_user)):
  222. return {
  223. "status": True,
  224. "pdf_extract_images": app.state.PDF_EXTRACT_IMAGES,
  225. "chunk": {
  226. "chunk_size": app.state.CHUNK_SIZE,
  227. "chunk_overlap": app.state.CHUNK_OVERLAP,
  228. },
  229. }
  230. class ChunkParamUpdateForm(BaseModel):
  231. chunk_size: int
  232. chunk_overlap: int
  233. class ConfigUpdateForm(BaseModel):
  234. pdf_extract_images: bool
  235. chunk: ChunkParamUpdateForm
  236. @app.post("/config/update")
  237. async def update_rag_config(form_data: ConfigUpdateForm, user=Depends(get_admin_user)):
  238. app.state.PDF_EXTRACT_IMAGES = form_data.pdf_extract_images
  239. app.state.CHUNK_SIZE = form_data.chunk.chunk_size
  240. app.state.CHUNK_OVERLAP = form_data.chunk.chunk_overlap
  241. return {
  242. "status": True,
  243. "pdf_extract_images": app.state.PDF_EXTRACT_IMAGES,
  244. "chunk": {
  245. "chunk_size": app.state.CHUNK_SIZE,
  246. "chunk_overlap": app.state.CHUNK_OVERLAP,
  247. },
  248. }
  249. @app.get("/template")
  250. async def get_rag_template(user=Depends(get_current_user)):
  251. return {
  252. "status": True,
  253. "template": app.state.RAG_TEMPLATE,
  254. }
  255. @app.get("/query/settings")
  256. async def get_query_settings(user=Depends(get_admin_user)):
  257. return {
  258. "status": True,
  259. "template": app.state.RAG_TEMPLATE,
  260. "k": app.state.TOP_K,
  261. "r": app.state.RELEVANCE_THRESHOLD,
  262. "hybrid": app.state.ENABLE_RAG_HYBRID_SEARCH,
  263. }
  264. class QuerySettingsForm(BaseModel):
  265. k: Optional[int] = None
  266. r: Optional[float] = None
  267. template: Optional[str] = None
  268. hybrid: Optional[bool] = None
  269. @app.post("/query/settings/update")
  270. async def update_query_settings(
  271. form_data: QuerySettingsForm, user=Depends(get_admin_user)
  272. ):
  273. app.state.RAG_TEMPLATE = form_data.template if form_data.template else RAG_TEMPLATE
  274. app.state.TOP_K = form_data.k if form_data.k else 4
  275. app.state.RELEVANCE_THRESHOLD = form_data.r if form_data.r else 0.0
  276. app.state.ENABLE_RAG_HYBRID_SEARCH = form_data.hybrid if form_data.hybrid else False
  277. return {
  278. "status": True,
  279. "template": app.state.RAG_TEMPLATE,
  280. "k": app.state.TOP_K,
  281. "r": app.state.RELEVANCE_THRESHOLD,
  282. "hybrid": app.state.ENABLE_RAG_HYBRID_SEARCH,
  283. }
  284. class QueryDocForm(BaseModel):
  285. collection_name: str
  286. query: str
  287. k: Optional[int] = None
  288. r: Optional[float] = None
  289. hybrid: Optional[bool] = None
  290. @app.post("/query/doc")
  291. def query_doc_handler(
  292. form_data: QueryDocForm,
  293. user=Depends(get_current_user),
  294. ):
  295. try:
  296. embeddings_function = get_embeddings_function(
  297. app.state.RAG_EMBEDDING_ENGINE,
  298. app.state.RAG_EMBEDDING_MODEL,
  299. app.state.sentence_transformer_ef,
  300. app.state.OPENAI_API_KEY,
  301. app.state.OPENAI_API_BASE_URL,
  302. )
  303. return query_embeddings_doc(
  304. collection_name=form_data.collection_name,
  305. query=form_data.query,
  306. k=form_data.k if form_data.k else app.state.TOP_K,
  307. r=form_data.r if form_data.r else app.state.RELEVANCE_THRESHOLD,
  308. embeddings_function=embeddings_function,
  309. reranking_function=app.state.sentence_transformer_rf,
  310. hybrid_search=(
  311. form_data.hybrid
  312. if form_data.hybrid
  313. else app.state.ENABLE_RAG_HYBRID_SEARCH
  314. ),
  315. )
  316. except Exception as e:
  317. log.exception(e)
  318. raise HTTPException(
  319. status_code=status.HTTP_400_BAD_REQUEST,
  320. detail=ERROR_MESSAGES.DEFAULT(e),
  321. )
  322. class QueryCollectionsForm(BaseModel):
  323. collection_names: List[str]
  324. query: str
  325. k: Optional[int] = None
  326. r: Optional[float] = None
  327. hybrid: Optional[bool] = None
  328. @app.post("/query/collection")
  329. def query_collection_handler(
  330. form_data: QueryCollectionsForm,
  331. user=Depends(get_current_user),
  332. ):
  333. try:
  334. embeddings_function = get_embeddings_function(
  335. app.state.RAG_EMBEDDING_ENGINE,
  336. app.state.RAG_EMBEDDING_MODEL,
  337. app.state.sentence_transformer_ef,
  338. app.state.OPENAI_API_KEY,
  339. app.state.OPENAI_API_BASE_URL,
  340. )
  341. return query_embeddings_collection(
  342. collection_names=form_data.collection_names,
  343. query=form_data.query,
  344. k=form_data.k if form_data.k else app.state.TOP_K,
  345. r=form_data.r if form_data.r else app.state.RELEVANCE_THRESHOLD,
  346. embeddings_function=embeddings_function,
  347. reranking_function=app.state.sentence_transformer_rf,
  348. hybrid_search=(
  349. form_data.hybrid
  350. if form_data.hybrid
  351. else app.state.ENABLE_RAG_HYBRID_SEARCH
  352. ),
  353. )
  354. except Exception as e:
  355. log.exception(e)
  356. raise HTTPException(
  357. status_code=status.HTTP_400_BAD_REQUEST,
  358. detail=ERROR_MESSAGES.DEFAULT(e),
  359. )
  360. @app.post("/web")
  361. def store_web(form_data: StoreWebForm, user=Depends(get_current_user)):
  362. # "https://www.gutenberg.org/files/1727/1727-h/1727-h.htm"
  363. try:
  364. loader = WebBaseLoader(form_data.url)
  365. data = loader.load()
  366. collection_name = form_data.collection_name
  367. if collection_name == "":
  368. collection_name = calculate_sha256_string(form_data.url)[:63]
  369. store_data_in_vector_db(data, collection_name, overwrite=True)
  370. return {
  371. "status": True,
  372. "collection_name": collection_name,
  373. "filename": form_data.url,
  374. }
  375. except Exception as e:
  376. log.exception(e)
  377. raise HTTPException(
  378. status_code=status.HTTP_400_BAD_REQUEST,
  379. detail=ERROR_MESSAGES.DEFAULT(e),
  380. )
  381. def store_data_in_vector_db(data, collection_name, overwrite: bool = False) -> bool:
  382. text_splitter = RecursiveCharacterTextSplitter(
  383. chunk_size=app.state.CHUNK_SIZE,
  384. chunk_overlap=app.state.CHUNK_OVERLAP,
  385. add_start_index=True,
  386. )
  387. docs = text_splitter.split_documents(data)
  388. if len(docs) > 0:
  389. log.info(f"store_data_in_vector_db {docs}")
  390. return store_docs_in_vector_db(docs, collection_name, overwrite), None
  391. else:
  392. raise ValueError(ERROR_MESSAGES.EMPTY_CONTENT)
  393. def store_text_in_vector_db(
  394. text, metadata, collection_name, overwrite: bool = False
  395. ) -> bool:
  396. text_splitter = RecursiveCharacterTextSplitter(
  397. chunk_size=app.state.CHUNK_SIZE,
  398. chunk_overlap=app.state.CHUNK_OVERLAP,
  399. add_start_index=True,
  400. )
  401. docs = text_splitter.create_documents([text], metadatas=[metadata])
  402. return store_docs_in_vector_db(docs, collection_name, overwrite)
  403. def store_docs_in_vector_db(docs, collection_name, overwrite: bool = False) -> bool:
  404. log.info(f"store_docs_in_vector_db {docs} {collection_name}")
  405. texts = [doc.page_content for doc in docs]
  406. metadatas = [doc.metadata for doc in docs]
  407. try:
  408. if overwrite:
  409. for collection in CHROMA_CLIENT.list_collections():
  410. if collection_name == collection.name:
  411. log.info(f"deleting existing collection {collection_name}")
  412. CHROMA_CLIENT.delete_collection(name=collection_name)
  413. collection = CHROMA_CLIENT.create_collection(name=collection_name)
  414. embedding_func = get_embeddings_function(
  415. app.state.RAG_EMBEDDING_ENGINE,
  416. app.state.RAG_EMBEDDING_MODEL,
  417. app.state.sentence_transformer_ef,
  418. app.state.OPENAI_API_KEY,
  419. app.state.OPENAI_API_BASE_URL,
  420. )
  421. embedding_texts = list(map(lambda x: x.replace("\n", " "), texts))
  422. embeddings = embedding_func(embedding_texts)
  423. for batch in create_batches(
  424. api=CHROMA_CLIENT,
  425. ids=[str(uuid.uuid1()) for _ in texts],
  426. metadatas=metadatas,
  427. embeddings=embeddings,
  428. documents=texts,
  429. ):
  430. collection.add(*batch)
  431. return True
  432. except Exception as e:
  433. log.exception(e)
  434. if e.__class__.__name__ == "UniqueConstraintError":
  435. return True
  436. return False
  437. def get_loader(filename: str, file_content_type: str, file_path: str):
  438. file_ext = filename.split(".")[-1].lower()
  439. known_type = True
  440. known_source_ext = [
  441. "go",
  442. "py",
  443. "java",
  444. "sh",
  445. "bat",
  446. "ps1",
  447. "cmd",
  448. "js",
  449. "ts",
  450. "css",
  451. "cpp",
  452. "hpp",
  453. "h",
  454. "c",
  455. "cs",
  456. "sql",
  457. "log",
  458. "ini",
  459. "pl",
  460. "pm",
  461. "r",
  462. "dart",
  463. "dockerfile",
  464. "env",
  465. "php",
  466. "hs",
  467. "hsc",
  468. "lua",
  469. "nginxconf",
  470. "conf",
  471. "m",
  472. "mm",
  473. "plsql",
  474. "perl",
  475. "rb",
  476. "rs",
  477. "db2",
  478. "scala",
  479. "bash",
  480. "swift",
  481. "vue",
  482. "svelte",
  483. ]
  484. if file_ext == "pdf":
  485. loader = PyPDFLoader(file_path, extract_images=app.state.PDF_EXTRACT_IMAGES)
  486. elif file_ext == "csv":
  487. loader = CSVLoader(file_path)
  488. elif file_ext == "rst":
  489. loader = UnstructuredRSTLoader(file_path, mode="elements")
  490. elif file_ext == "xml":
  491. loader = UnstructuredXMLLoader(file_path)
  492. elif file_ext in ["htm", "html"]:
  493. loader = BSHTMLLoader(file_path, open_encoding="unicode_escape")
  494. elif file_ext == "md":
  495. loader = UnstructuredMarkdownLoader(file_path)
  496. elif file_content_type == "application/epub+zip":
  497. loader = UnstructuredEPubLoader(file_path)
  498. elif (
  499. file_content_type
  500. == "application/vnd.openxmlformats-officedocument.wordprocessingml.document"
  501. or file_ext in ["doc", "docx"]
  502. ):
  503. loader = Docx2txtLoader(file_path)
  504. elif file_content_type in [
  505. "application/vnd.ms-excel",
  506. "application/vnd.openxmlformats-officedocument.spreadsheetml.sheet",
  507. ] or file_ext in ["xls", "xlsx"]:
  508. loader = UnstructuredExcelLoader(file_path)
  509. elif file_ext in known_source_ext or (
  510. file_content_type and file_content_type.find("text/") >= 0
  511. ):
  512. loader = TextLoader(file_path, autodetect_encoding=True)
  513. else:
  514. loader = TextLoader(file_path, autodetect_encoding=True)
  515. known_type = False
  516. return loader, known_type
  517. @app.post("/doc")
  518. def store_doc(
  519. collection_name: Optional[str] = Form(None),
  520. file: UploadFile = File(...),
  521. user=Depends(get_current_user),
  522. ):
  523. # "https://www.gutenberg.org/files/1727/1727-h/1727-h.htm"
  524. log.info(f"file.content_type: {file.content_type}")
  525. try:
  526. unsanitized_filename = file.filename
  527. filename = os.path.basename(unsanitized_filename)
  528. file_path = f"{UPLOAD_DIR}/{filename}"
  529. contents = file.file.read()
  530. with open(file_path, "wb") as f:
  531. f.write(contents)
  532. f.close()
  533. f = open(file_path, "rb")
  534. if collection_name == None:
  535. collection_name = calculate_sha256(f)[:63]
  536. f.close()
  537. loader, known_type = get_loader(filename, file.content_type, file_path)
  538. data = loader.load()
  539. try:
  540. result = store_data_in_vector_db(data, collection_name)
  541. if result:
  542. return {
  543. "status": True,
  544. "collection_name": collection_name,
  545. "filename": filename,
  546. "known_type": known_type,
  547. }
  548. except Exception as e:
  549. raise HTTPException(
  550. status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
  551. detail=e,
  552. )
  553. except Exception as e:
  554. log.exception(e)
  555. if "No pandoc was found" in str(e):
  556. raise HTTPException(
  557. status_code=status.HTTP_400_BAD_REQUEST,
  558. detail=ERROR_MESSAGES.PANDOC_NOT_INSTALLED,
  559. )
  560. else:
  561. raise HTTPException(
  562. status_code=status.HTTP_400_BAD_REQUEST,
  563. detail=ERROR_MESSAGES.DEFAULT(e),
  564. )
  565. class TextRAGForm(BaseModel):
  566. name: str
  567. content: str
  568. collection_name: Optional[str] = None
  569. @app.post("/text")
  570. def store_text(
  571. form_data: TextRAGForm,
  572. user=Depends(get_current_user),
  573. ):
  574. collection_name = form_data.collection_name
  575. if collection_name == None:
  576. collection_name = calculate_sha256_string(form_data.content)
  577. result = store_text_in_vector_db(
  578. form_data.content,
  579. metadata={"name": form_data.name, "created_by": user.id},
  580. collection_name=collection_name,
  581. )
  582. if result:
  583. return {"status": True, "collection_name": collection_name}
  584. else:
  585. raise HTTPException(
  586. status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
  587. detail=ERROR_MESSAGES.DEFAULT(),
  588. )
  589. @app.get("/scan")
  590. def scan_docs_dir(user=Depends(get_admin_user)):
  591. for path in Path(DOCS_DIR).rglob("./**/*"):
  592. try:
  593. if path.is_file() and not path.name.startswith("."):
  594. tags = extract_folders_after_data_docs(path)
  595. filename = path.name
  596. file_content_type = mimetypes.guess_type(path)
  597. f = open(path, "rb")
  598. collection_name = calculate_sha256(f)[:63]
  599. f.close()
  600. loader, known_type = get_loader(
  601. filename, file_content_type[0], str(path)
  602. )
  603. data = loader.load()
  604. try:
  605. result = store_data_in_vector_db(data, collection_name)
  606. if result:
  607. sanitized_filename = sanitize_filename(filename)
  608. doc = Documents.get_doc_by_name(sanitized_filename)
  609. if doc == None:
  610. doc = Documents.insert_new_doc(
  611. user.id,
  612. DocumentForm(
  613. **{
  614. "name": sanitized_filename,
  615. "title": filename,
  616. "collection_name": collection_name,
  617. "filename": filename,
  618. "content": (
  619. json.dumps(
  620. {
  621. "tags": list(
  622. map(
  623. lambda name: {"name": name},
  624. tags,
  625. )
  626. )
  627. }
  628. )
  629. if len(tags)
  630. else "{}"
  631. ),
  632. }
  633. ),
  634. )
  635. except Exception as e:
  636. log.exception(e)
  637. pass
  638. except Exception as e:
  639. log.exception(e)
  640. return True
  641. @app.get("/reset/db")
  642. def reset_vector_db(user=Depends(get_admin_user)):
  643. CHROMA_CLIENT.reset()
  644. @app.get("/reset")
  645. def reset(user=Depends(get_admin_user)) -> bool:
  646. folder = f"{UPLOAD_DIR}"
  647. for filename in os.listdir(folder):
  648. file_path = os.path.join(folder, filename)
  649. try:
  650. if os.path.isfile(file_path) or os.path.islink(file_path):
  651. os.unlink(file_path)
  652. elif os.path.isdir(file_path):
  653. shutil.rmtree(file_path)
  654. except Exception as e:
  655. log.error("Failed to delete %s. Reason: %s" % (file_path, e))
  656. try:
  657. CHROMA_CLIENT.reset()
  658. except Exception as e:
  659. log.exception(e)
  660. return True