main.py 23 KB


  1. from fastapi import (
  2. FastAPI,
  3. Depends,
  4. HTTPException,
  5. status,
  6. UploadFile,
  7. File,
  8. Form,
  9. )
  10. from fastapi.middleware.cors import CORSMiddleware
  11. import os, shutil, logging, re
  12. from pathlib import Path
  13. from typing import List
  14. from chromadb.utils.batch_utils import create_batches
  15. from langchain_community.document_loaders import (
  16. WebBaseLoader,
  17. TextLoader,
  18. PyPDFLoader,
  19. CSVLoader,
  20. BSHTMLLoader,
  21. Docx2txtLoader,
  22. UnstructuredEPubLoader,
  23. UnstructuredWordDocumentLoader,
  24. UnstructuredMarkdownLoader,
  25. UnstructuredXMLLoader,
  26. UnstructuredRSTLoader,
  27. UnstructuredExcelLoader,
  28. )
  29. from langchain.text_splitter import RecursiveCharacterTextSplitter
  30. from pydantic import BaseModel
  31. from typing import Optional
  32. import mimetypes
  33. import uuid
  34. import json
  35. import sentence_transformers
  36. from apps.web.models.documents import (
  37. Documents,
  38. DocumentForm,
  39. DocumentResponse,
  40. )
  41. from apps.rag.utils import (
  42. get_model_path,
  43. query_embeddings_doc,
  44. query_embeddings_function,
  45. query_embeddings_collection,
  46. )
  47. from utils.misc import (
  48. calculate_sha256,
  49. calculate_sha256_string,
  50. sanitize_filename,
  51. extract_folders_after_data_docs,
  52. )
  53. from utils.utils import get_current_user, get_admin_user
  54. from config import (
  55. SRC_LOG_LEVELS,
  56. UPLOAD_DIR,
  57. DOCS_DIR,
  58. RAG_TOP_K,
  59. RAG_RELEVANCE_THRESHOLD,
  60. RAG_EMBEDDING_ENGINE,
  61. RAG_EMBEDDING_MODEL,
  62. RAG_EMBEDDING_MODEL_AUTO_UPDATE,
  63. RAG_EMBEDDING_MODEL_TRUST_REMOTE_CODE,
  64. RAG_HYBRID,
  65. RAG_RERANKING_MODEL,
  66. RAG_RERANKING_MODEL_AUTO_UPDATE,
  67. RAG_RERANKING_MODEL_TRUST_REMOTE_CODE,
  68. RAG_OPENAI_API_BASE_URL,
  69. RAG_OPENAI_API_KEY,
  70. DEVICE_TYPE,
  71. CHROMA_CLIENT,
  72. CHUNK_SIZE,
  73. CHUNK_OVERLAP,
  74. RAG_TEMPLATE,
  75. )
  76. from constants import ERROR_MESSAGES
  77. log = logging.getLogger(__name__)
  78. log.setLevel(SRC_LOG_LEVELS["RAG"])
  79. app = FastAPI()
  80. app.state.TOP_K = RAG_TOP_K
  81. app.state.RELEVANCE_THRESHOLD = RAG_RELEVANCE_THRESHOLD
  82. app.state.HYBRID = RAG_HYBRID
  83. app.state.CHUNK_SIZE = CHUNK_SIZE
  84. app.state.CHUNK_OVERLAP = CHUNK_OVERLAP
  85. app.state.RAG_EMBEDDING_ENGINE = RAG_EMBEDDING_ENGINE
  86. app.state.RAG_EMBEDDING_MODEL = RAG_EMBEDDING_MODEL
  87. app.state.RAG_RERANKING_MODEL = RAG_RERANKING_MODEL
  88. app.state.RAG_TEMPLATE = RAG_TEMPLATE
  89. app.state.OPENAI_API_BASE_URL = RAG_OPENAI_API_BASE_URL
  90. app.state.OPENAI_API_KEY = RAG_OPENAI_API_KEY
  91. app.state.PDF_EXTRACT_IMAGES = False
  92. def update_embedding_model(
  93. embedding_model: str,
  94. update_model: bool = False,
  95. ):
  96. if embedding_model and app.state.RAG_EMBEDDING_ENGINE == "":
  97. app.state.sentence_transformer_ef = sentence_transformers.SentenceTransformer(
  98. get_model_path(embedding_model, update_model),
  99. device=DEVICE_TYPE,
  100. trust_remote_code=RAG_EMBEDDING_MODEL_TRUST_REMOTE_CODE,
  101. )
  102. else:
  103. app.state.sentence_transformer_ef = None
  104. def update_reranking_model(
  105. reranking_model: str,
  106. update_model: bool = False,
  107. ):
  108. if reranking_model:
  109. app.state.sentence_transformer_rf = sentence_transformers.CrossEncoder(
  110. get_model_path(reranking_model, update_model),
  111. device=DEVICE_TYPE,
  112. trust_remote_code=RAG_RERANKING_MODEL_TRUST_REMOTE_CODE,
  113. )
  114. else:
  115. app.state.sentence_transformer_rf = None
  116. update_embedding_model(
  117. app.state.RAG_EMBEDDING_MODEL,
  118. RAG_EMBEDDING_MODEL_AUTO_UPDATE,
  119. )
  120. update_reranking_model(
  121. app.state.RAG_RERANKING_MODEL,
  122. RAG_RERANKING_MODEL_AUTO_UPDATE,
  123. )
  124. origins = ["*"]
  125. app.add_middleware(
  126. CORSMiddleware,
  127. allow_origins=origins,
  128. allow_credentials=True,
  129. allow_methods=["*"],
  130. allow_headers=["*"],
  131. )
  132. class CollectionNameForm(BaseModel):
  133. collection_name: Optional[str] = "test"
  134. class StoreWebForm(CollectionNameForm):
  135. url: str
  136. @app.get("/")
  137. async def get_status():
  138. return {
  139. "status": True,
  140. "chunk_size": app.state.CHUNK_SIZE,
  141. "chunk_overlap": app.state.CHUNK_OVERLAP,
  142. "template": app.state.RAG_TEMPLATE,
  143. "embedding_engine": app.state.RAG_EMBEDDING_ENGINE,
  144. "embedding_model": app.state.RAG_EMBEDDING_MODEL,
  145. "reranking_model": app.state.RAG_RERANKING_MODEL,
  146. }
  147. @app.get("/embedding")
  148. async def get_embedding_config(user=Depends(get_admin_user)):
  149. return {
  150. "status": True,
  151. "embedding_engine": app.state.RAG_EMBEDDING_ENGINE,
  152. "embedding_model": app.state.RAG_EMBEDDING_MODEL,
  153. "openai_config": {
  154. "url": app.state.OPENAI_API_BASE_URL,
  155. "key": app.state.OPENAI_API_KEY,
  156. },
  157. }
  158. @app.get("/reranking")
  159. async def get_reraanking_config(user=Depends(get_admin_user)):
  160. return {"status": True, "reranking_model": app.state.RAG_RERANKING_MODEL}
  161. class OpenAIConfigForm(BaseModel):
  162. url: str
  163. key: str
  164. class EmbeddingModelUpdateForm(BaseModel):
  165. openai_config: Optional[OpenAIConfigForm] = None
  166. embedding_engine: str
  167. embedding_model: str
  168. @app.post("/embedding/update")
  169. async def update_embedding_config(
  170. form_data: EmbeddingModelUpdateForm, user=Depends(get_admin_user)
  171. ):
  172. log.info(
  173. f"Updating embedding model: {app.state.RAG_EMBEDDING_MODEL} to {form_data.embedding_model}"
  174. )
  175. try:
  176. app.state.RAG_EMBEDDING_ENGINE = form_data.embedding_engine
  177. app.state.RAG_EMBEDDING_MODEL = form_data.embedding_model
  178. if app.state.RAG_EMBEDDING_ENGINE in ["ollama", "openai"]:
  179. if form_data.openai_config != None:
  180. app.state.OPENAI_API_BASE_URL = form_data.openai_config.url
  181. app.state.OPENAI_API_KEY = form_data.openai_config.key
  182. update_embedding_model(app.state.RAG_EMBEDDING_MODEL, True)
  183. return {
  184. "status": True,
  185. "embedding_engine": app.state.RAG_EMBEDDING_ENGINE,
  186. "embedding_model": app.state.RAG_EMBEDDING_MODEL,
  187. "openai_config": {
  188. "url": app.state.OPENAI_API_BASE_URL,
  189. "key": app.state.OPENAI_API_KEY,
  190. },
  191. }
  192. except Exception as e:
  193. log.exception(f"Problem updating embedding model: {e}")
  194. raise HTTPException(
  195. status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
  196. detail=ERROR_MESSAGES.DEFAULT(e),
  197. )
  198. class RerankingModelUpdateForm(BaseModel):
  199. reranking_model: str
  200. @app.post("/reranking/update")
  201. async def update_reranking_config(
  202. form_data: RerankingModelUpdateForm, user=Depends(get_admin_user)
  203. ):
  204. log.info(
  205. f"Updating reranking model: {app.state.RAG_RERANKING_MODEL} to {form_data.reranking_model}"
  206. )
  207. try:
  208. app.state.RAG_RERANKING_MODEL = form_data.reranking_model
  209. update_reranking_model(app.state.RAG_RERANKING_MODEL, True)
  210. return {
  211. "status": True,
  212. "reranking_model": app.state.RAG_RERANKING_MODEL,
  213. }
  214. except Exception as e:
  215. log.exception(f"Problem updating reranking model: {e}")
  216. raise HTTPException(
  217. status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
  218. detail=ERROR_MESSAGES.DEFAULT(e),
  219. )
  220. @app.get("/config")
  221. async def get_rag_config(user=Depends(get_admin_user)):
  222. return {
  223. "status": True,
  224. "pdf_extract_images": app.state.PDF_EXTRACT_IMAGES,
  225. "chunk": {
  226. "chunk_size": app.state.CHUNK_SIZE,
  227. "chunk_overlap": app.state.CHUNK_OVERLAP,
  228. },
  229. }
  230. class ChunkParamUpdateForm(BaseModel):
  231. chunk_size: int
  232. chunk_overlap: int
  233. class ConfigUpdateForm(BaseModel):
  234. pdf_extract_images: bool
  235. chunk: ChunkParamUpdateForm
  236. @app.post("/config/update")
  237. async def update_rag_config(form_data: ConfigUpdateForm, user=Depends(get_admin_user)):
  238. app.state.PDF_EXTRACT_IMAGES = form_data.pdf_extract_images
  239. app.state.CHUNK_SIZE = form_data.chunk.chunk_size
  240. app.state.CHUNK_OVERLAP = form_data.chunk.chunk_overlap
  241. return {
  242. "status": True,
  243. "pdf_extract_images": app.state.PDF_EXTRACT_IMAGES,
  244. "chunk": {
  245. "chunk_size": app.state.CHUNK_SIZE,
  246. "chunk_overlap": app.state.CHUNK_OVERLAP,
  247. },
  248. }
  249. @app.get("/template")
  250. async def get_rag_template(user=Depends(get_current_user)):
  251. return {
  252. "status": True,
  253. "template": app.state.RAG_TEMPLATE,
  254. }
  255. @app.get("/query/settings")
  256. async def get_query_settings(user=Depends(get_admin_user)):
  257. return {
  258. "status": True,
  259. "template": app.state.RAG_TEMPLATE,
  260. "k": app.state.TOP_K,
  261. "r": app.state.RELEVANCE_THRESHOLD,
  262. "hybrid": app.state.HYBRID,
  263. }
  264. class QuerySettingsForm(BaseModel):
  265. k: Optional[int] = None
  266. r: Optional[float] = None
  267. template: Optional[str] = None
  268. hybrid: Optional[bool] = None
  269. @app.post("/query/settings/update")
  270. async def update_query_settings(
  271. form_data: QuerySettingsForm, user=Depends(get_admin_user)
  272. ):
  273. app.state.RAG_TEMPLATE = form_data.template if form_data.template else RAG_TEMPLATE
  274. app.state.TOP_K = form_data.k if form_data.k else 4
  275. app.state.RELEVANCE_THRESHOLD = form_data.r if form_data.r else 0.0
  276. app.state.HYBRID = form_data.hybrid if form_data.hybrid else False
  277. return {
  278. "status": True,
  279. "template": app.state.RAG_TEMPLATE,
  280. "k": app.state.TOP_K,
  281. "r": app.state.RELEVANCE_THRESHOLD,
  282. "hybrid": app.state.HYBRID,
  283. }
  284. class QueryDocForm(BaseModel):
  285. collection_name: str
  286. query: str
  287. k: Optional[int] = None
  288. r: Optional[float] = None
  289. hybrid: Optional[bool] = None
  290. @app.post("/query/doc")
  291. def query_doc_handler(
  292. form_data: QueryDocForm,
  293. user=Depends(get_current_user),
  294. ):
  295. try:
  296. embeddings_function = query_embeddings_function(
  297. app.state.RAG_EMBEDDING_ENGINE,
  298. app.state.RAG_EMBEDDING_MODEL,
  299. app.state.sentence_transformer_ef,
  300. app.state.OPENAI_API_KEY,
  301. app.state.OPENAI_API_BASE_URL,
  302. )
  303. return query_embeddings_doc(
  304. collection_name=form_data.collection_name,
  305. query=form_data.query,
  306. k=form_data.k if form_data.k else app.state.TOP_K,
  307. r=form_data.r if form_data.r else app.state.RELEVANCE_THRESHOLD,
  308. embeddings_function=embeddings_function,
  309. reranking_function=app.state.sentence_transformer_rf,
  310. hybrid=form_data.hybrid if form_data.hybrid else app.state.HYBRID,
  311. )
  312. except Exception as e:
  313. log.exception(e)
  314. raise HTTPException(
  315. status_code=status.HTTP_400_BAD_REQUEST,
  316. detail=ERROR_MESSAGES.DEFAULT(e),
  317. )
  318. class QueryCollectionsForm(BaseModel):
  319. collection_names: List[str]
  320. query: str
  321. k: Optional[int] = None
  322. r: Optional[float] = None
  323. hybrid: Optional[bool] = None
  324. @app.post("/query/collection")
  325. def query_collection_handler(
  326. form_data: QueryCollectionsForm,
  327. user=Depends(get_current_user),
  328. ):
  329. try:
  330. embeddings_function = query_embeddings_function(
  331. app.state.RAG_EMBEDDING_ENGINE,
  332. app.state.RAG_EMBEDDING_MODEL,
  333. app.state.sentence_transformer_ef,
  334. app.state.OPENAI_API_KEY,
  335. app.state.OPENAI_API_BASE_URL,
  336. )
  337. return query_embeddings_collection(
  338. collection_names=form_data.collection_names,
  339. query=form_data.query,
  340. k=form_data.k if form_data.k else app.state.TOP_K,
  341. r=form_data.r if form_data.r else app.state.RELEVANCE_THRESHOLD,
  342. embeddings_function=embeddings_function,
  343. reranking_function=app.state.sentence_transformer_rf,
  344. hybrid=form_data.hybrid if form_data.hybrid else app.state.HYBRID,
  345. )
  346. except Exception as e:
  347. log.exception(e)
  348. raise HTTPException(
  349. status_code=status.HTTP_400_BAD_REQUEST,
  350. detail=ERROR_MESSAGES.DEFAULT(e),
  351. )
  352. @app.post("/web")
  353. def store_web(form_data: StoreWebForm, user=Depends(get_current_user)):
  354. # "https://www.gutenberg.org/files/1727/1727-h/1727-h.htm"
  355. try:
  356. loader = WebBaseLoader(form_data.url)
  357. data = loader.load()
  358. collection_name = form_data.collection_name
  359. if collection_name == "":
  360. collection_name = calculate_sha256_string(form_data.url)[:63]
  361. store_data_in_vector_db(data, collection_name, overwrite=True)
  362. return {
  363. "status": True,
  364. "collection_name": collection_name,
  365. "filename": form_data.url,
  366. }
  367. except Exception as e:
  368. log.exception(e)
  369. raise HTTPException(
  370. status_code=status.HTTP_400_BAD_REQUEST,
  371. detail=ERROR_MESSAGES.DEFAULT(e),
  372. )
  373. def store_data_in_vector_db(data, collection_name, overwrite: bool = False) -> bool:
  374. text_splitter = RecursiveCharacterTextSplitter(
  375. chunk_size=app.state.CHUNK_SIZE,
  376. chunk_overlap=app.state.CHUNK_OVERLAP,
  377. add_start_index=True,
  378. )
  379. docs = text_splitter.split_documents(data)
  380. if len(docs) > 0:
  381. log.info(f"store_data_in_vector_db {docs}")
  382. return store_docs_in_vector_db(docs, collection_name, overwrite), None
  383. else:
  384. raise ValueError(ERROR_MESSAGES.EMPTY_CONTENT)
  385. def store_text_in_vector_db(
  386. text, metadata, collection_name, overwrite: bool = False
  387. ) -> bool:
  388. text_splitter = RecursiveCharacterTextSplitter(
  389. chunk_size=app.state.CHUNK_SIZE,
  390. chunk_overlap=app.state.CHUNK_OVERLAP,
  391. add_start_index=True,
  392. )
  393. docs = text_splitter.create_documents([text], metadatas=[metadata])
  394. return store_docs_in_vector_db(docs, collection_name, overwrite)
  395. def store_docs_in_vector_db(docs, collection_name, overwrite: bool = False) -> bool:
  396. log.info(f"store_docs_in_vector_db {docs} {collection_name}")
  397. texts = [doc.page_content for doc in docs]
  398. metadatas = [doc.metadata for doc in docs]
  399. try:
  400. if overwrite:
  401. for collection in CHROMA_CLIENT.list_collections():
  402. if collection_name == collection.name:
  403. log.info(f"deleting existing collection {collection_name}")
  404. CHROMA_CLIENT.delete_collection(name=collection_name)
  405. collection = CHROMA_CLIENT.create_collection(name=collection_name)
  406. embedding_func = query_embeddings_function(
  407. app.state.RAG_EMBEDDING_ENGINE,
  408. app.state.RAG_EMBEDDING_MODEL,
  409. app.state.sentence_transformer_ef,
  410. app.state.OPENAI_API_KEY,
  411. app.state.OPENAI_API_BASE_URL,
  412. )
  413. embedding_texts = list(map(lambda x: x.replace("\n", " "), texts))
  414. embeddings = embedding_func(embedding_texts)
  415. for batch in create_batches(
  416. api=CHROMA_CLIENT,
  417. ids=[str(uuid.uuid1()) for _ in texts],
  418. metadatas=metadatas,
  419. embeddings=embeddings,
  420. documents=texts,
  421. ):
  422. collection.add(*batch)
  423. return True
  424. except Exception as e:
  425. log.exception(e)
  426. if e.__class__.__name__ == "UniqueConstraintError":
  427. return True
  428. return False
  429. def get_loader(filename: str, file_content_type: str, file_path: str):
  430. file_ext = filename.split(".")[-1].lower()
  431. known_type = True
  432. known_source_ext = [
  433. "go",
  434. "py",
  435. "java",
  436. "sh",
  437. "bat",
  438. "ps1",
  439. "cmd",
  440. "js",
  441. "ts",
  442. "css",
  443. "cpp",
  444. "hpp",
  445. "h",
  446. "c",
  447. "cs",
  448. "sql",
  449. "log",
  450. "ini",
  451. "pl",
  452. "pm",
  453. "r",
  454. "dart",
  455. "dockerfile",
  456. "env",
  457. "php",
  458. "hs",
  459. "hsc",
  460. "lua",
  461. "nginxconf",
  462. "conf",
  463. "m",
  464. "mm",
  465. "plsql",
  466. "perl",
  467. "rb",
  468. "rs",
  469. "db2",
  470. "scala",
  471. "bash",
  472. "swift",
  473. "vue",
  474. "svelte",
  475. ]
  476. if file_ext == "pdf":
  477. loader = PyPDFLoader(file_path, extract_images=app.state.PDF_EXTRACT_IMAGES)
  478. elif file_ext == "csv":
  479. loader = CSVLoader(file_path)
  480. elif file_ext == "rst":
  481. loader = UnstructuredRSTLoader(file_path, mode="elements")
  482. elif file_ext == "xml":
  483. loader = UnstructuredXMLLoader(file_path)
  484. elif file_ext in ["htm", "html"]:
  485. loader = BSHTMLLoader(file_path, open_encoding="unicode_escape")
  486. elif file_ext == "md":
  487. loader = UnstructuredMarkdownLoader(file_path)
  488. elif file_content_type == "application/epub+zip":
  489. loader = UnstructuredEPubLoader(file_path)
  490. elif (
  491. file_content_type
  492. == "application/vnd.openxmlformats-officedocument.wordprocessingml.document"
  493. or file_ext in ["doc", "docx"]
  494. ):
  495. loader = Docx2txtLoader(file_path)
  496. elif file_content_type in [
  497. "application/vnd.ms-excel",
  498. "application/vnd.openxmlformats-officedocument.spreadsheetml.sheet",
  499. ] or file_ext in ["xls", "xlsx"]:
  500. loader = UnstructuredExcelLoader(file_path)
  501. elif file_ext in known_source_ext or (
  502. file_content_type and file_content_type.find("text/") >= 0
  503. ):
  504. loader = TextLoader(file_path, autodetect_encoding=True)
  505. else:
  506. loader = TextLoader(file_path, autodetect_encoding=True)
  507. known_type = False
  508. return loader, known_type
  509. @app.post("/doc")
  510. def store_doc(
  511. collection_name: Optional[str] = Form(None),
  512. file: UploadFile = File(...),
  513. user=Depends(get_current_user),
  514. ):
  515. # "https://www.gutenberg.org/files/1727/1727-h/1727-h.htm"
  516. log.info(f"file.content_type: {file.content_type}")
  517. try:
  518. unsanitized_filename = file.filename
  519. filename = os.path.basename(unsanitized_filename)
  520. file_path = f"{UPLOAD_DIR}/{filename}"
  521. contents = file.file.read()
  522. with open(file_path, "wb") as f:
  523. f.write(contents)
  524. f.close()
  525. f = open(file_path, "rb")
  526. if collection_name == None:
  527. collection_name = calculate_sha256(f)[:63]
  528. f.close()
  529. loader, known_type = get_loader(filename, file.content_type, file_path)
  530. data = loader.load()
  531. try:
  532. result = store_data_in_vector_db(data, collection_name)
  533. if result:
  534. return {
  535. "status": True,
  536. "collection_name": collection_name,
  537. "filename": filename,
  538. "known_type": known_type,
  539. }
  540. except Exception as e:
  541. raise HTTPException(
  542. status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
  543. detail=e,
  544. )
  545. except Exception as e:
  546. log.exception(e)
  547. if "No pandoc was found" in str(e):
  548. raise HTTPException(
  549. status_code=status.HTTP_400_BAD_REQUEST,
  550. detail=ERROR_MESSAGES.PANDOC_NOT_INSTALLED,
  551. )
  552. else:
  553. raise HTTPException(
  554. status_code=status.HTTP_400_BAD_REQUEST,
  555. detail=ERROR_MESSAGES.DEFAULT(e),
  556. )
  557. class TextRAGForm(BaseModel):
  558. name: str
  559. content: str
  560. collection_name: Optional[str] = None
  561. @app.post("/text")
  562. def store_text(
  563. form_data: TextRAGForm,
  564. user=Depends(get_current_user),
  565. ):
  566. collection_name = form_data.collection_name
  567. if collection_name == None:
  568. collection_name = calculate_sha256_string(form_data.content)
  569. result = store_text_in_vector_db(
  570. form_data.content,
  571. metadata={"name": form_data.name, "created_by": user.id},
  572. collection_name=collection_name,
  573. )
  574. if result:
  575. return {"status": True, "collection_name": collection_name}
  576. else:
  577. raise HTTPException(
  578. status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
  579. detail=ERROR_MESSAGES.DEFAULT(),
  580. )
  581. @app.get("/scan")
  582. def scan_docs_dir(user=Depends(get_admin_user)):
  583. for path in Path(DOCS_DIR).rglob("./**/*"):
  584. try:
  585. if path.is_file() and not path.name.startswith("."):
  586. tags = extract_folders_after_data_docs(path)
  587. filename = path.name
  588. file_content_type = mimetypes.guess_type(path)
  589. f = open(path, "rb")
  590. collection_name = calculate_sha256(f)[:63]
  591. f.close()
  592. loader, known_type = get_loader(
  593. filename, file_content_type[0], str(path)
  594. )
  595. data = loader.load()
  596. try:
  597. result = store_data_in_vector_db(data, collection_name)
  598. if result:
  599. sanitized_filename = sanitize_filename(filename)
  600. doc = Documents.get_doc_by_name(sanitized_filename)
  601. if doc == None:
  602. doc = Documents.insert_new_doc(
  603. user.id,
  604. DocumentForm(
  605. **{
  606. "name": sanitized_filename,
  607. "title": filename,
  608. "collection_name": collection_name,
  609. "filename": filename,
  610. "content": (
  611. json.dumps(
  612. {
  613. "tags": list(
  614. map(
  615. lambda name: {"name": name},
  616. tags,
  617. )
  618. )
  619. }
  620. )
  621. if len(tags)
  622. else "{}"
  623. ),
  624. }
  625. ),
  626. )
  627. except Exception as e:
  628. log.exception(e)
  629. pass
  630. except Exception as e:
  631. log.exception(e)
  632. return True
  633. @app.get("/reset/db")
  634. def reset_vector_db(user=Depends(get_admin_user)):
  635. CHROMA_CLIENT.reset()
  636. @app.get("/reset")
  637. def reset(user=Depends(get_admin_user)) -> bool:
  638. folder = f"{UPLOAD_DIR}"
  639. for filename in os.listdir(folder):
  640. file_path = os.path.join(folder, filename)
  641. try:
  642. if os.path.isfile(file_path) or os.path.islink(file_path):
  643. os.unlink(file_path)
  644. elif os.path.isdir(file_path):
  645. shutil.rmtree(file_path)
  646. except Exception as e:
  647. log.error("Failed to delete %s. Reason: %s" % (file_path, e))
  648. try:
  649. CHROMA_CLIENT.reset()
  650. except Exception as e:
  651. log.exception(e)
  652. return True