main.py 23 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597598599600601602603604605606607608609610611612613614615616617618619620621622623624625626627628629630631632633634635636637638639640641642643644645646647648649650651652653654655656657658659660661662663664665666667668669670671672673674675676677678679680681682683684685686687688689690691692693694695696697698699700701702703704705706707708709710711712713714715716717718719720721722723724725726727728729730731732733734735736737738739740741742743744745746747748749750751752753754755756757758759760761762763764765766767768769770771772773774775776777778779780781782783784785786787788789790791792793794795796797798799800801802803804805806807
  1. from fastapi import (
  2. FastAPI,
  3. Depends,
  4. HTTPException,
  5. status,
  6. UploadFile,
  7. File,
  8. Form,
  9. )
  10. from fastapi.middleware.cors import CORSMiddleware
  11. import os, shutil, logging, re
  12. from pathlib import Path
  13. from typing import List
  14. from chromadb.utils.batch_utils import create_batches
  15. from langchain_community.document_loaders import (
  16. WebBaseLoader,
  17. TextLoader,
  18. PyPDFLoader,
  19. CSVLoader,
  20. BSHTMLLoader,
  21. Docx2txtLoader,
  22. UnstructuredEPubLoader,
  23. UnstructuredWordDocumentLoader,
  24. UnstructuredMarkdownLoader,
  25. UnstructuredXMLLoader,
  26. UnstructuredRSTLoader,
  27. UnstructuredExcelLoader,
  28. )
  29. from langchain.text_splitter import RecursiveCharacterTextSplitter
  30. from pydantic import BaseModel
  31. from typing import Optional
  32. import mimetypes
  33. import uuid
  34. import json
  35. import sentence_transformers
  36. from apps.web.models.documents import (
  37. Documents,
  38. DocumentForm,
  39. DocumentResponse,
  40. )
  41. from apps.rag.utils import (
  42. get_model_path,
  43. get_embedding_function,
  44. query_doc,
  45. query_doc_with_hybrid_search,
  46. query_collection,
  47. query_collection_with_hybrid_search,
  48. )
  49. from utils.misc import (
  50. calculate_sha256,
  51. calculate_sha256_string,
  52. sanitize_filename,
  53. extract_folders_after_data_docs,
  54. )
  55. from utils.utils import get_current_user, get_admin_user
  56. from config import (
  57. SRC_LOG_LEVELS,
  58. UPLOAD_DIR,
  59. DOCS_DIR,
  60. RAG_TOP_K,
  61. RAG_RELEVANCE_THRESHOLD,
  62. RAG_EMBEDDING_ENGINE,
  63. RAG_EMBEDDING_MODEL,
  64. RAG_EMBEDDING_MODEL_AUTO_UPDATE,
  65. RAG_EMBEDDING_MODEL_TRUST_REMOTE_CODE,
  66. ENABLE_RAG_HYBRID_SEARCH,
  67. RAG_RERANKING_MODEL,
  68. RAG_RERANKING_MODEL_AUTO_UPDATE,
  69. RAG_RERANKING_MODEL_TRUST_REMOTE_CODE,
  70. RAG_OPENAI_API_BASE_URL,
  71. RAG_OPENAI_API_KEY,
  72. DEVICE_TYPE,
  73. CHROMA_CLIENT,
  74. CHUNK_SIZE,
  75. CHUNK_OVERLAP,
  76. RAG_TEMPLATE,
  77. )
  78. from constants import ERROR_MESSAGES
  79. log = logging.getLogger(__name__)
  80. log.setLevel(SRC_LOG_LEVELS["RAG"])
  81. app = FastAPI()
  82. app.state.TOP_K = RAG_TOP_K
  83. app.state.RELEVANCE_THRESHOLD = RAG_RELEVANCE_THRESHOLD
  84. app.state.ENABLE_RAG_HYBRID_SEARCH = ENABLE_RAG_HYBRID_SEARCH
  85. app.state.CHUNK_SIZE = CHUNK_SIZE
  86. app.state.CHUNK_OVERLAP = CHUNK_OVERLAP
  87. app.state.RAG_EMBEDDING_ENGINE = RAG_EMBEDDING_ENGINE
  88. app.state.RAG_EMBEDDING_MODEL = RAG_EMBEDDING_MODEL
  89. app.state.RAG_RERANKING_MODEL = RAG_RERANKING_MODEL
  90. app.state.RAG_TEMPLATE = RAG_TEMPLATE
  91. app.state.OPENAI_API_BASE_URL = RAG_OPENAI_API_BASE_URL
  92. app.state.OPENAI_API_KEY = RAG_OPENAI_API_KEY
  93. app.state.PDF_EXTRACT_IMAGES = False
  94. def update_embedding_model(
  95. embedding_model: str,
  96. update_model: bool = False,
  97. ):
  98. if embedding_model and app.state.RAG_EMBEDDING_ENGINE == "":
  99. app.state.sentence_transformer_ef = sentence_transformers.SentenceTransformer(
  100. get_model_path(embedding_model, update_model),
  101. device=DEVICE_TYPE,
  102. trust_remote_code=RAG_EMBEDDING_MODEL_TRUST_REMOTE_CODE,
  103. )
  104. else:
  105. app.state.sentence_transformer_ef = None
  106. def update_reranking_model(
  107. reranking_model: str,
  108. update_model: bool = False,
  109. ):
  110. if reranking_model:
  111. app.state.sentence_transformer_rf = sentence_transformers.CrossEncoder(
  112. get_model_path(reranking_model, update_model),
  113. device=DEVICE_TYPE,
  114. trust_remote_code=RAG_RERANKING_MODEL_TRUST_REMOTE_CODE,
  115. )
  116. else:
  117. app.state.sentence_transformer_rf = None
  118. update_embedding_model(
  119. app.state.RAG_EMBEDDING_MODEL,
  120. RAG_EMBEDDING_MODEL_AUTO_UPDATE,
  121. )
  122. update_reranking_model(
  123. app.state.RAG_RERANKING_MODEL,
  124. RAG_RERANKING_MODEL_AUTO_UPDATE,
  125. )
  126. app.state.EMBEDDING_FUNCTION = get_embedding_function(
  127. app.state.RAG_EMBEDDING_ENGINE,
  128. app.state.RAG_EMBEDDING_MODEL,
  129. app.state.sentence_transformer_ef,
  130. app.state.OPENAI_API_KEY,
  131. app.state.OPENAI_API_BASE_URL,
  132. )
  133. origins = ["*"]
  134. app.add_middleware(
  135. CORSMiddleware,
  136. allow_origins=origins,
  137. allow_credentials=True,
  138. allow_methods=["*"],
  139. allow_headers=["*"],
  140. )
  141. class CollectionNameForm(BaseModel):
  142. collection_name: Optional[str] = "test"
  143. class StoreWebForm(CollectionNameForm):
  144. url: str
  145. @app.get("/")
  146. async def get_status():
  147. return {
  148. "status": True,
  149. "chunk_size": app.state.CHUNK_SIZE,
  150. "chunk_overlap": app.state.CHUNK_OVERLAP,
  151. "template": app.state.RAG_TEMPLATE,
  152. "embedding_engine": app.state.RAG_EMBEDDING_ENGINE,
  153. "embedding_model": app.state.RAG_EMBEDDING_MODEL,
  154. "reranking_model": app.state.RAG_RERANKING_MODEL,
  155. }
  156. @app.get("/embedding")
  157. async def get_embedding_config(user=Depends(get_admin_user)):
  158. return {
  159. "status": True,
  160. "embedding_engine": app.state.RAG_EMBEDDING_ENGINE,
  161. "embedding_model": app.state.RAG_EMBEDDING_MODEL,
  162. "openai_config": {
  163. "url": app.state.OPENAI_API_BASE_URL,
  164. "key": app.state.OPENAI_API_KEY,
  165. },
  166. }
  167. @app.get("/reranking")
  168. async def get_reraanking_config(user=Depends(get_admin_user)):
  169. return {"status": True, "reranking_model": app.state.RAG_RERANKING_MODEL}
  170. class OpenAIConfigForm(BaseModel):
  171. url: str
  172. key: str
  173. class EmbeddingModelUpdateForm(BaseModel):
  174. openai_config: Optional[OpenAIConfigForm] = None
  175. embedding_engine: str
  176. embedding_model: str
  177. @app.post("/embedding/update")
  178. async def update_embedding_config(
  179. form_data: EmbeddingModelUpdateForm, user=Depends(get_admin_user)
  180. ):
  181. log.info(
  182. f"Updating embedding model: {app.state.RAG_EMBEDDING_MODEL} to {form_data.embedding_model}"
  183. )
  184. try:
  185. app.state.RAG_EMBEDDING_ENGINE = form_data.embedding_engine
  186. app.state.RAG_EMBEDDING_MODEL = form_data.embedding_model
  187. if app.state.RAG_EMBEDDING_ENGINE in ["ollama", "openai"]:
  188. if form_data.openai_config != None:
  189. app.state.OPENAI_API_BASE_URL = form_data.openai_config.url
  190. app.state.OPENAI_API_KEY = form_data.openai_config.key
  191. update_embedding_model(app.state.RAG_EMBEDDING_MODEL, True)
  192. app.state.EMBEDDING_FUNCTION = get_embedding_function(
  193. app.state.RAG_EMBEDDING_ENGINE,
  194. app.state.RAG_EMBEDDING_MODEL,
  195. app.state.sentence_transformer_ef,
  196. app.state.OPENAI_API_KEY,
  197. app.state.OPENAI_API_BASE_URL,
  198. )
  199. return {
  200. "status": True,
  201. "embedding_engine": app.state.RAG_EMBEDDING_ENGINE,
  202. "embedding_model": app.state.RAG_EMBEDDING_MODEL,
  203. "openai_config": {
  204. "url": app.state.OPENAI_API_BASE_URL,
  205. "key": app.state.OPENAI_API_KEY,
  206. },
  207. }
  208. except Exception as e:
  209. log.exception(f"Problem updating embedding model: {e}")
  210. raise HTTPException(
  211. status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
  212. detail=ERROR_MESSAGES.DEFAULT(e),
  213. )
  214. class RerankingModelUpdateForm(BaseModel):
  215. reranking_model: str
  216. @app.post("/reranking/update")
  217. async def update_reranking_config(
  218. form_data: RerankingModelUpdateForm, user=Depends(get_admin_user)
  219. ):
  220. log.info(
  221. f"Updating reranking model: {app.state.RAG_RERANKING_MODEL} to {form_data.reranking_model}"
  222. )
  223. try:
  224. app.state.RAG_RERANKING_MODEL = form_data.reranking_model
  225. update_reranking_model(app.state.RAG_RERANKING_MODEL, True)
  226. return {
  227. "status": True,
  228. "reranking_model": app.state.RAG_RERANKING_MODEL,
  229. }
  230. except Exception as e:
  231. log.exception(f"Problem updating reranking model: {e}")
  232. raise HTTPException(
  233. status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
  234. detail=ERROR_MESSAGES.DEFAULT(e),
  235. )
  236. @app.get("/config")
  237. async def get_rag_config(user=Depends(get_admin_user)):
  238. return {
  239. "status": True,
  240. "pdf_extract_images": app.state.PDF_EXTRACT_IMAGES,
  241. "chunk": {
  242. "chunk_size": app.state.CHUNK_SIZE,
  243. "chunk_overlap": app.state.CHUNK_OVERLAP,
  244. },
  245. }
  246. class ChunkParamUpdateForm(BaseModel):
  247. chunk_size: int
  248. chunk_overlap: int
  249. class ConfigUpdateForm(BaseModel):
  250. pdf_extract_images: bool
  251. chunk: ChunkParamUpdateForm
  252. @app.post("/config/update")
  253. async def update_rag_config(form_data: ConfigUpdateForm, user=Depends(get_admin_user)):
  254. app.state.PDF_EXTRACT_IMAGES = form_data.pdf_extract_images
  255. app.state.CHUNK_SIZE = form_data.chunk.chunk_size
  256. app.state.CHUNK_OVERLAP = form_data.chunk.chunk_overlap
  257. return {
  258. "status": True,
  259. "pdf_extract_images": app.state.PDF_EXTRACT_IMAGES,
  260. "chunk": {
  261. "chunk_size": app.state.CHUNK_SIZE,
  262. "chunk_overlap": app.state.CHUNK_OVERLAP,
  263. },
  264. }
  265. @app.get("/template")
  266. async def get_rag_template(user=Depends(get_current_user)):
  267. return {
  268. "status": True,
  269. "template": app.state.RAG_TEMPLATE,
  270. }
  271. @app.get("/query/settings")
  272. async def get_query_settings(user=Depends(get_admin_user)):
  273. return {
  274. "status": True,
  275. "template": app.state.RAG_TEMPLATE,
  276. "k": app.state.TOP_K,
  277. "r": app.state.RELEVANCE_THRESHOLD,
  278. "hybrid": app.state.ENABLE_RAG_HYBRID_SEARCH,
  279. }
  280. class QuerySettingsForm(BaseModel):
  281. k: Optional[int] = None
  282. r: Optional[float] = None
  283. template: Optional[str] = None
  284. hybrid: Optional[bool] = None
  285. @app.post("/query/settings/update")
  286. async def update_query_settings(
  287. form_data: QuerySettingsForm, user=Depends(get_admin_user)
  288. ):
  289. app.state.RAG_TEMPLATE = form_data.template if form_data.template else RAG_TEMPLATE
  290. app.state.TOP_K = form_data.k if form_data.k else 4
  291. app.state.RELEVANCE_THRESHOLD = form_data.r if form_data.r else 0.0
  292. app.state.ENABLE_RAG_HYBRID_SEARCH = form_data.hybrid if form_data.hybrid else False
  293. return {
  294. "status": True,
  295. "template": app.state.RAG_TEMPLATE,
  296. "k": app.state.TOP_K,
  297. "r": app.state.RELEVANCE_THRESHOLD,
  298. "hybrid": app.state.ENABLE_RAG_HYBRID_SEARCH,
  299. }
  300. class QueryDocForm(BaseModel):
  301. collection_name: str
  302. query: str
  303. k: Optional[int] = None
  304. r: Optional[float] = None
  305. hybrid: Optional[bool] = None
  306. @app.post("/query/doc")
  307. def query_doc_handler(
  308. form_data: QueryDocForm,
  309. user=Depends(get_current_user),
  310. ):
  311. try:
  312. if app.state.ENABLE_RAG_HYBRID_SEARCH:
  313. return query_doc_with_hybrid_search(
  314. collection_name=form_data.collection_name,
  315. query=form_data.query,
  316. embeddings_function=app.state.EMBEDDING_FUNCTION,
  317. reranking_function=app.state.sentence_transformer_rf,
  318. k=form_data.k if form_data.k else app.state.TOP_K,
  319. r=form_data.r if form_data.r else app.state.RELEVANCE_THRESHOLD,
  320. )
  321. else:
  322. return query_doc(
  323. collection_name=form_data.collection_name,
  324. query=form_data.query,
  325. embeddings_function=app.state.EMBEDDING_FUNCTION,
  326. k=form_data.k if form_data.k else app.state.TOP_K,
  327. )
  328. except Exception as e:
  329. log.exception(e)
  330. raise HTTPException(
  331. status_code=status.HTTP_400_BAD_REQUEST,
  332. detail=ERROR_MESSAGES.DEFAULT(e),
  333. )
  334. class QueryCollectionsForm(BaseModel):
  335. collection_names: List[str]
  336. query: str
  337. k: Optional[int] = None
  338. r: Optional[float] = None
  339. hybrid: Optional[bool] = None
  340. @app.post("/query/collection")
  341. def query_collection_handler(
  342. form_data: QueryCollectionsForm,
  343. user=Depends(get_current_user),
  344. ):
  345. try:
  346. if app.state.ENABLE_RAG_HYBRID_SEARCH:
  347. return query_collection_with_hybrid_search(
  348. collection_names=form_data.collection_names,
  349. query=form_data.query,
  350. embeddings_function=app.state.EMBEDDING_FUNCTION,
  351. reranking_function=app.state.sentence_transformer_rf,
  352. k=form_data.k if form_data.k else app.state.TOP_K,
  353. r=form_data.r if form_data.r else app.state.RELEVANCE_THRESHOLD,
  354. )
  355. else:
  356. return query_collection(
  357. collection_names=form_data.collection_names,
  358. query=form_data.query,
  359. embeddings_function=app.state.EMBEDDING_FUNCTION,
  360. k=form_data.k if form_data.k else app.state.TOP_K,
  361. )
  362. except Exception as e:
  363. log.exception(e)
  364. raise HTTPException(
  365. status_code=status.HTTP_400_BAD_REQUEST,
  366. detail=ERROR_MESSAGES.DEFAULT(e),
  367. )
  368. @app.post("/web")
  369. def store_web(form_data: StoreWebForm, user=Depends(get_current_user)):
  370. # "https://www.gutenberg.org/files/1727/1727-h/1727-h.htm"
  371. try:
  372. loader = WebBaseLoader(form_data.url)
  373. data = loader.load()
  374. collection_name = form_data.collection_name
  375. if collection_name == "":
  376. collection_name = calculate_sha256_string(form_data.url)[:63]
  377. store_data_in_vector_db(data, collection_name, overwrite=True)
  378. return {
  379. "status": True,
  380. "collection_name": collection_name,
  381. "filename": form_data.url,
  382. }
  383. except Exception as e:
  384. log.exception(e)
  385. raise HTTPException(
  386. status_code=status.HTTP_400_BAD_REQUEST,
  387. detail=ERROR_MESSAGES.DEFAULT(e),
  388. )
  389. def store_data_in_vector_db(data, collection_name, overwrite: bool = False) -> bool:
  390. text_splitter = RecursiveCharacterTextSplitter(
  391. chunk_size=app.state.CHUNK_SIZE,
  392. chunk_overlap=app.state.CHUNK_OVERLAP,
  393. add_start_index=True,
  394. )
  395. docs = text_splitter.split_documents(data)
  396. if len(docs) > 0:
  397. log.info(f"store_data_in_vector_db {docs}")
  398. return store_docs_in_vector_db(docs, collection_name, overwrite), None
  399. else:
  400. raise ValueError(ERROR_MESSAGES.EMPTY_CONTENT)
  401. def store_text_in_vector_db(
  402. text, metadata, collection_name, overwrite: bool = False
  403. ) -> bool:
  404. text_splitter = RecursiveCharacterTextSplitter(
  405. chunk_size=app.state.CHUNK_SIZE,
  406. chunk_overlap=app.state.CHUNK_OVERLAP,
  407. add_start_index=True,
  408. )
  409. docs = text_splitter.create_documents([text], metadatas=[metadata])
  410. return store_docs_in_vector_db(docs, collection_name, overwrite)
  411. def store_docs_in_vector_db(docs, collection_name, overwrite: bool = False) -> bool:
  412. log.info(f"store_docs_in_vector_db {docs} {collection_name}")
  413. texts = [doc.page_content for doc in docs]
  414. metadatas = [doc.metadata for doc in docs]
  415. try:
  416. if overwrite:
  417. for collection in CHROMA_CLIENT.list_collections():
  418. if collection_name == collection.name:
  419. log.info(f"deleting existing collection {collection_name}")
  420. CHROMA_CLIENT.delete_collection(name=collection_name)
  421. collection = CHROMA_CLIENT.create_collection(name=collection_name)
  422. embedding_func = get_embedding_function(
  423. app.state.RAG_EMBEDDING_ENGINE,
  424. app.state.RAG_EMBEDDING_MODEL,
  425. app.state.sentence_transformer_ef,
  426. app.state.OPENAI_API_KEY,
  427. app.state.OPENAI_API_BASE_URL,
  428. )
  429. embedding_texts = list(map(lambda x: x.replace("\n", " "), texts))
  430. embeddings = embedding_func(embedding_texts)
  431. for batch in create_batches(
  432. api=CHROMA_CLIENT,
  433. ids=[str(uuid.uuid1()) for _ in texts],
  434. metadatas=metadatas,
  435. embeddings=embeddings,
  436. documents=texts,
  437. ):
  438. collection.add(*batch)
  439. return True
  440. except Exception as e:
  441. log.exception(e)
  442. if e.__class__.__name__ == "UniqueConstraintError":
  443. return True
  444. return False
  445. def get_loader(filename: str, file_content_type: str, file_path: str):
  446. file_ext = filename.split(".")[-1].lower()
  447. known_type = True
  448. known_source_ext = [
  449. "go",
  450. "py",
  451. "java",
  452. "sh",
  453. "bat",
  454. "ps1",
  455. "cmd",
  456. "js",
  457. "ts",
  458. "css",
  459. "cpp",
  460. "hpp",
  461. "h",
  462. "c",
  463. "cs",
  464. "sql",
  465. "log",
  466. "ini",
  467. "pl",
  468. "pm",
  469. "r",
  470. "dart",
  471. "dockerfile",
  472. "env",
  473. "php",
  474. "hs",
  475. "hsc",
  476. "lua",
  477. "nginxconf",
  478. "conf",
  479. "m",
  480. "mm",
  481. "plsql",
  482. "perl",
  483. "rb",
  484. "rs",
  485. "db2",
  486. "scala",
  487. "bash",
  488. "swift",
  489. "vue",
  490. "svelte",
  491. ]
  492. if file_ext == "pdf":
  493. loader = PyPDFLoader(file_path, extract_images=app.state.PDF_EXTRACT_IMAGES)
  494. elif file_ext == "csv":
  495. loader = CSVLoader(file_path)
  496. elif file_ext == "rst":
  497. loader = UnstructuredRSTLoader(file_path, mode="elements")
  498. elif file_ext == "xml":
  499. loader = UnstructuredXMLLoader(file_path)
  500. elif file_ext in ["htm", "html"]:
  501. loader = BSHTMLLoader(file_path, open_encoding="unicode_escape")
  502. elif file_ext == "md":
  503. loader = UnstructuredMarkdownLoader(file_path)
  504. elif file_content_type == "application/epub+zip":
  505. loader = UnstructuredEPubLoader(file_path)
  506. elif (
  507. file_content_type
  508. == "application/vnd.openxmlformats-officedocument.wordprocessingml.document"
  509. or file_ext in ["doc", "docx"]
  510. ):
  511. loader = Docx2txtLoader(file_path)
  512. elif file_content_type in [
  513. "application/vnd.ms-excel",
  514. "application/vnd.openxmlformats-officedocument.spreadsheetml.sheet",
  515. ] or file_ext in ["xls", "xlsx"]:
  516. loader = UnstructuredExcelLoader(file_path)
  517. elif file_ext in known_source_ext or (
  518. file_content_type and file_content_type.find("text/") >= 0
  519. ):
  520. loader = TextLoader(file_path, autodetect_encoding=True)
  521. else:
  522. loader = TextLoader(file_path, autodetect_encoding=True)
  523. known_type = False
  524. return loader, known_type
  525. @app.post("/doc")
  526. def store_doc(
  527. collection_name: Optional[str] = Form(None),
  528. file: UploadFile = File(...),
  529. user=Depends(get_current_user),
  530. ):
  531. # "https://www.gutenberg.org/files/1727/1727-h/1727-h.htm"
  532. log.info(f"file.content_type: {file.content_type}")
  533. try:
  534. unsanitized_filename = file.filename
  535. filename = os.path.basename(unsanitized_filename)
  536. file_path = f"{UPLOAD_DIR}/{filename}"
  537. contents = file.file.read()
  538. with open(file_path, "wb") as f:
  539. f.write(contents)
  540. f.close()
  541. f = open(file_path, "rb")
  542. if collection_name == None:
  543. collection_name = calculate_sha256(f)[:63]
  544. f.close()
  545. loader, known_type = get_loader(filename, file.content_type, file_path)
  546. data = loader.load()
  547. try:
  548. result = store_data_in_vector_db(data, collection_name)
  549. if result:
  550. return {
  551. "status": True,
  552. "collection_name": collection_name,
  553. "filename": filename,
  554. "known_type": known_type,
  555. }
  556. except Exception as e:
  557. raise HTTPException(
  558. status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
  559. detail=e,
  560. )
  561. except Exception as e:
  562. log.exception(e)
  563. if "No pandoc was found" in str(e):
  564. raise HTTPException(
  565. status_code=status.HTTP_400_BAD_REQUEST,
  566. detail=ERROR_MESSAGES.PANDOC_NOT_INSTALLED,
  567. )
  568. else:
  569. raise HTTPException(
  570. status_code=status.HTTP_400_BAD_REQUEST,
  571. detail=ERROR_MESSAGES.DEFAULT(e),
  572. )
  573. class TextRAGForm(BaseModel):
  574. name: str
  575. content: str
  576. collection_name: Optional[str] = None
  577. @app.post("/text")
  578. def store_text(
  579. form_data: TextRAGForm,
  580. user=Depends(get_current_user),
  581. ):
  582. collection_name = form_data.collection_name
  583. if collection_name == None:
  584. collection_name = calculate_sha256_string(form_data.content)
  585. result = store_text_in_vector_db(
  586. form_data.content,
  587. metadata={"name": form_data.name, "created_by": user.id},
  588. collection_name=collection_name,
  589. )
  590. if result:
  591. return {"status": True, "collection_name": collection_name}
  592. else:
  593. raise HTTPException(
  594. status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
  595. detail=ERROR_MESSAGES.DEFAULT(),
  596. )
  597. @app.get("/scan")
  598. def scan_docs_dir(user=Depends(get_admin_user)):
  599. for path in Path(DOCS_DIR).rglob("./**/*"):
  600. try:
  601. if path.is_file() and not path.name.startswith("."):
  602. tags = extract_folders_after_data_docs(path)
  603. filename = path.name
  604. file_content_type = mimetypes.guess_type(path)
  605. f = open(path, "rb")
  606. collection_name = calculate_sha256(f)[:63]
  607. f.close()
  608. loader, known_type = get_loader(
  609. filename, file_content_type[0], str(path)
  610. )
  611. data = loader.load()
  612. try:
  613. result = store_data_in_vector_db(data, collection_name)
  614. if result:
  615. sanitized_filename = sanitize_filename(filename)
  616. doc = Documents.get_doc_by_name(sanitized_filename)
  617. if doc == None:
  618. doc = Documents.insert_new_doc(
  619. user.id,
  620. DocumentForm(
  621. **{
  622. "name": sanitized_filename,
  623. "title": filename,
  624. "collection_name": collection_name,
  625. "filename": filename,
  626. "content": (
  627. json.dumps(
  628. {
  629. "tags": list(
  630. map(
  631. lambda name: {"name": name},
  632. tags,
  633. )
  634. )
  635. }
  636. )
  637. if len(tags)
  638. else "{}"
  639. ),
  640. }
  641. ),
  642. )
  643. except Exception as e:
  644. log.exception(e)
  645. pass
  646. except Exception as e:
  647. log.exception(e)
  648. return True
  649. @app.get("/reset/db")
  650. def reset_vector_db(user=Depends(get_admin_user)):
  651. CHROMA_CLIENT.reset()
  652. @app.get("/reset")
  653. def reset(user=Depends(get_admin_user)) -> bool:
  654. folder = f"{UPLOAD_DIR}"
  655. for filename in os.listdir(folder):
  656. file_path = os.path.join(folder, filename)
  657. try:
  658. if os.path.isfile(file_path) or os.path.islink(file_path):
  659. os.unlink(file_path)
  660. elif os.path.isdir(file_path):
  661. shutil.rmtree(file_path)
  662. except Exception as e:
  663. log.error("Failed to delete %s. Reason: %s" % (file_path, e))
  664. try:
  665. CHROMA_CLIENT.reset()
  666. except Exception as e:
  667. log.exception(e)
  668. return True