main.py 35 KB

1234567891011121314151617181920212223242526272829303132333435363738394041424344454647484950515253545556575859606162636465666768697071727374757677787980818283848586878889909192939495969798991001011021031041051061071081091101111121131141151161171181191201211221231241251261271281291301311321331341351361371381391401411421431441451461471481491501511521531541551561571581591601611621631641651661671681691701711721731741751761771781791801811821831841851861871881891901911921931941951961971981992002012022032042052062072082092102112122132142152162172182192202212222232242252262272282292302312322332342352362372382392402412422432442452462472482492502512522532542552562572582592602612622632642652662672682692702712722732742752762772782792802812822832842852862872882892902912922932942952962972982993003013023033043053063073083093103113123133143153163173183193203213223233243253263273283293303313323333343353363373383393403413423433443453463473483493503513523533543553563573583593603613623633643653663673683693703713723733743753763773783793803813823833843853863873883893903913923933943953963973983994004014024034044054064074084094104114124134144154164174184194204214224234244254264274284294304314324334344354364374384394404414424434444454464474484494504514524534544554564574584594604614624634644654664674684694704714724734744754764774784794804814824834844854864874884894904914924934944954964974984995005015025035045055065075085095105115125135145155165175185195205215225235245255265275285295305315325335345355365375385395405415425435445455465475485495505515525535545555565575585595605615625635645655665675685695705715725735745755765775785795805815825835845855865875885895905915925935945955965975985996006016026036046056066076086096106116126136146156166176186196206216226236246256266276286296306316326336346356366376386396406416426436446456466476486496506516526536546556566576586596606616626636646656666676686696706716726736746756766776786796806816826836846856866876886896906916926936946956966976986997007017027037047057067077087097107117127137147157167177187197207217227237247257267277287297307317327337347357367377387397407417427437447457467477487497507517527537547557567577587597607617627637647657667677687697707717727737747757767777787797807817827837847857867877887897907917927937947957967977987998008018028038048058068078088098108118128138148158168178188198208218228238248258268278288298308318328338348358368378388398408418428438448458468478488498508518528538548558568578588598608618628638648658668678688698708718728738748758768778788798808818828838848858868878888898908918928938948958968978988999009019029039049059069079089099109119129139149159169179189199209219229239249259269279289299309319329339349359369379389399409419429439449459469479489499509519529539549559569579589599609619629639649659669679689699709719729739749759769779789799809819829839849859869879889899909919929939949959969979989991000100110021003100410051006100710081009101010111012101310141015101610171018101910201021102210231024102510261027102810291030103110321033103410351036103710381039104010411042104310441045104610471048104910501051105210531054105510561057105810591060106110621063106410651066106710681069107010711072107310741075107610771078107910801081108210831084108510861087108810891090109110921093109410951096109710981099110011011102
  1. from fastapi import (
  2. FastAPI,
  3. Depends,
  4. HTTPException,
  5. status,
  6. UploadFile,
  7. File,
  8. Form,
  9. )
  10. from fastapi.middleware.cors import CORSMiddleware
  11. import os, shutil, logging, re
  12. from pathlib import Path
  13. from typing import List, Union, Sequence
  14. from chromadb.utils.batch_utils import create_batches
  15. from langchain_community.document_loaders import (
  16. WebBaseLoader,
  17. TextLoader,
  18. PyPDFLoader,
  19. CSVLoader,
  20. BSHTMLLoader,
  21. Docx2txtLoader,
  22. UnstructuredEPubLoader,
  23. UnstructuredWordDocumentLoader,
  24. UnstructuredMarkdownLoader,
  25. UnstructuredXMLLoader,
  26. UnstructuredRSTLoader,
  27. UnstructuredExcelLoader,
  28. UnstructuredPowerPointLoader,
  29. YoutubeLoader,
  30. )
  31. from langchain.text_splitter import RecursiveCharacterTextSplitter
  32. import validators
  33. import urllib.parse
  34. import socket
  35. from pydantic import BaseModel
  36. from typing import Optional
  37. import mimetypes
  38. import uuid
  39. import json
  40. import sentence_transformers
  41. from apps.webui.models.documents import (
  42. Documents,
  43. DocumentForm,
  44. DocumentResponse,
  45. )
  46. from apps.rag.utils import (
  47. get_model_path,
  48. get_embedding_function,
  49. query_doc,
  50. query_doc_with_hybrid_search,
  51. query_collection,
  52. query_collection_with_hybrid_search,
  53. search_web,
  54. )
  55. from utils.misc import (
  56. calculate_sha256,
  57. calculate_sha256_string,
  58. sanitize_filename,
  59. extract_folders_after_data_docs,
  60. )
  61. from utils.utils import get_current_user, get_admin_user
  62. from config import (
  63. ENV,
  64. SRC_LOG_LEVELS,
  65. UPLOAD_DIR,
  66. DOCS_DIR,
  67. RAG_TOP_K,
  68. RAG_RELEVANCE_THRESHOLD,
  69. RAG_EMBEDDING_ENGINE,
  70. RAG_EMBEDDING_MODEL,
  71. RAG_EMBEDDING_MODEL_AUTO_UPDATE,
  72. RAG_EMBEDDING_MODEL_TRUST_REMOTE_CODE,
  73. ENABLE_RAG_HYBRID_SEARCH,
  74. ENABLE_RAG_WEB_LOADER_SSL_VERIFICATION,
  75. RAG_RERANKING_MODEL,
  76. PDF_EXTRACT_IMAGES,
  77. RAG_RERANKING_MODEL_AUTO_UPDATE,
  78. RAG_RERANKING_MODEL_TRUST_REMOTE_CODE,
  79. RAG_OPENAI_API_BASE_URL,
  80. RAG_OPENAI_API_KEY,
  81. DEVICE_TYPE,
  82. CHROMA_CLIENT,
  83. CHUNK_SIZE,
  84. CHUNK_OVERLAP,
  85. RAG_TEMPLATE,
  86. ENABLE_RAG_LOCAL_WEB_FETCH,
  87. YOUTUBE_LOADER_LANGUAGE,
  88. ENABLE_RAG_WEB_SEARCH,
  89. RAG_WEB_SEARCH_ENGINE,
  90. SEARXNG_QUERY_URL,
  91. GOOGLE_PSE_API_KEY,
  92. GOOGLE_PSE_ENGINE_ID,
  93. BRAVE_SEARCH_API_KEY,
  94. SERPSTACK_API_KEY,
  95. SERPSTACK_HTTPS,
  96. SERPER_API_KEY,
  97. RAG_WEB_SEARCH_RESULT_COUNT,
  98. RAG_WEB_SEARCH_CONCURRENT_REQUESTS,
  99. AppConfig,
  100. )
  101. from constants import ERROR_MESSAGES
  102. log = logging.getLogger(__name__)
  103. log.setLevel(SRC_LOG_LEVELS["RAG"])
  104. app = FastAPI()
  105. app.state.config = AppConfig()
  106. app.state.config.TOP_K = RAG_TOP_K
  107. app.state.config.RELEVANCE_THRESHOLD = RAG_RELEVANCE_THRESHOLD
  108. app.state.config.ENABLE_RAG_HYBRID_SEARCH = ENABLE_RAG_HYBRID_SEARCH
  109. app.state.config.ENABLE_RAG_WEB_LOADER_SSL_VERIFICATION = (
  110. ENABLE_RAG_WEB_LOADER_SSL_VERIFICATION
  111. )
  112. app.state.config.CHUNK_SIZE = CHUNK_SIZE
  113. app.state.config.CHUNK_OVERLAP = CHUNK_OVERLAP
  114. app.state.config.RAG_EMBEDDING_ENGINE = RAG_EMBEDDING_ENGINE
  115. app.state.config.RAG_EMBEDDING_MODEL = RAG_EMBEDDING_MODEL
  116. app.state.config.RAG_RERANKING_MODEL = RAG_RERANKING_MODEL
  117. app.state.config.RAG_TEMPLATE = RAG_TEMPLATE
  118. app.state.config.OPENAI_API_BASE_URL = RAG_OPENAI_API_BASE_URL
  119. app.state.config.OPENAI_API_KEY = RAG_OPENAI_API_KEY
  120. app.state.config.PDF_EXTRACT_IMAGES = PDF_EXTRACT_IMAGES
  121. app.state.config.YOUTUBE_LOADER_LANGUAGE = YOUTUBE_LOADER_LANGUAGE
  122. app.state.YOUTUBE_LOADER_TRANSLATION = None
  123. app.state.config.ENABLE_RAG_WEB_SEARCH = ENABLE_RAG_WEB_SEARCH
  124. app.state.config.RAG_WEB_SEARCH_ENGINE = RAG_WEB_SEARCH_ENGINE
  125. app.state.config.SEARXNG_QUERY_URL = SEARXNG_QUERY_URL
  126. app.state.config.GOOGLE_PSE_API_KEY = GOOGLE_PSE_API_KEY
  127. app.state.config.GOOGLE_PSE_ENGINE_ID = GOOGLE_PSE_ENGINE_ID
  128. app.state.config.BRAVE_SEARCH_API_KEY = BRAVE_SEARCH_API_KEY
  129. app.state.config.SERPSTACK_API_KEY = SERPSTACK_API_KEY
  130. app.state.config.SERPSTACK_HTTPS = SERPSTACK_HTTPS
  131. app.state.config.SERPER_API_KEY = SERPER_API_KEY
  132. app.state.config.RAG_WEB_SEARCH_RESULT_COUNT = RAG_WEB_SEARCH_RESULT_COUNT
  133. app.state.config.RAG_WEB_SEARCH_CONCURRENT_REQUESTS = RAG_WEB_SEARCH_CONCURRENT_REQUESTS
  134. def update_embedding_model(
  135. embedding_model: str,
  136. update_model: bool = False,
  137. ):
  138. if embedding_model and app.state.config.RAG_EMBEDDING_ENGINE == "":
  139. app.state.sentence_transformer_ef = sentence_transformers.SentenceTransformer(
  140. get_model_path(embedding_model, update_model),
  141. device=DEVICE_TYPE,
  142. trust_remote_code=RAG_EMBEDDING_MODEL_TRUST_REMOTE_CODE,
  143. )
  144. else:
  145. app.state.sentence_transformer_ef = None
  146. def update_reranking_model(
  147. reranking_model: str,
  148. update_model: bool = False,
  149. ):
  150. if reranking_model:
  151. app.state.sentence_transformer_rf = sentence_transformers.CrossEncoder(
  152. get_model_path(reranking_model, update_model),
  153. device=DEVICE_TYPE,
  154. trust_remote_code=RAG_RERANKING_MODEL_TRUST_REMOTE_CODE,
  155. )
  156. else:
  157. app.state.sentence_transformer_rf = None
  158. update_embedding_model(
  159. app.state.config.RAG_EMBEDDING_MODEL,
  160. RAG_EMBEDDING_MODEL_AUTO_UPDATE,
  161. )
  162. update_reranking_model(
  163. app.state.config.RAG_RERANKING_MODEL,
  164. RAG_RERANKING_MODEL_AUTO_UPDATE,
  165. )
  166. app.state.EMBEDDING_FUNCTION = get_embedding_function(
  167. app.state.config.RAG_EMBEDDING_ENGINE,
  168. app.state.config.RAG_EMBEDDING_MODEL,
  169. app.state.sentence_transformer_ef,
  170. app.state.config.OPENAI_API_KEY,
  171. app.state.config.OPENAI_API_BASE_URL,
  172. )
  173. origins = ["*"]
  174. app.add_middleware(
  175. CORSMiddleware,
  176. allow_origins=origins,
  177. allow_credentials=True,
  178. allow_methods=["*"],
  179. allow_headers=["*"],
  180. )
  181. class CollectionNameForm(BaseModel):
  182. collection_name: Optional[str] = "test"
  183. class UrlForm(CollectionNameForm):
  184. url: str
  185. class SearchForm(CollectionNameForm):
  186. query: str
  187. @app.get("/")
  188. async def get_status():
  189. return {
  190. "status": True,
  191. "chunk_size": app.state.config.CHUNK_SIZE,
  192. "chunk_overlap": app.state.config.CHUNK_OVERLAP,
  193. "template": app.state.config.RAG_TEMPLATE,
  194. "embedding_engine": app.state.config.RAG_EMBEDDING_ENGINE,
  195. "embedding_model": app.state.config.RAG_EMBEDDING_MODEL,
  196. "reranking_model": app.state.config.RAG_RERANKING_MODEL,
  197. }
  198. @app.get("/embedding")
  199. async def get_embedding_config(user=Depends(get_admin_user)):
  200. return {
  201. "status": True,
  202. "embedding_engine": app.state.config.RAG_EMBEDDING_ENGINE,
  203. "embedding_model": app.state.config.RAG_EMBEDDING_MODEL,
  204. "openai_config": {
  205. "url": app.state.config.OPENAI_API_BASE_URL,
  206. "key": app.state.config.OPENAI_API_KEY,
  207. },
  208. }
  209. @app.get("/reranking")
  210. async def get_reraanking_config(user=Depends(get_admin_user)):
  211. return {
  212. "status": True,
  213. "reranking_model": app.state.config.RAG_RERANKING_MODEL,
  214. }
  215. class OpenAIConfigForm(BaseModel):
  216. url: str
  217. key: str
  218. class EmbeddingModelUpdateForm(BaseModel):
  219. openai_config: Optional[OpenAIConfigForm] = None
  220. embedding_engine: str
  221. embedding_model: str
  222. @app.post("/embedding/update")
  223. async def update_embedding_config(
  224. form_data: EmbeddingModelUpdateForm, user=Depends(get_admin_user)
  225. ):
  226. log.info(
  227. f"Updating embedding model: {app.state.config.RAG_EMBEDDING_MODEL} to {form_data.embedding_model}"
  228. )
  229. try:
  230. app.state.config.RAG_EMBEDDING_ENGINE = form_data.embedding_engine
  231. app.state.config.RAG_EMBEDDING_MODEL = form_data.embedding_model
  232. if app.state.config.RAG_EMBEDDING_ENGINE in ["ollama", "openai"]:
  233. if form_data.openai_config != None:
  234. app.state.config.OPENAI_API_BASE_URL = form_data.openai_config.url
  235. app.state.config.OPENAI_API_KEY = form_data.openai_config.key
  236. update_embedding_model(app.state.config.RAG_EMBEDDING_MODEL)
  237. app.state.EMBEDDING_FUNCTION = get_embedding_function(
  238. app.state.config.RAG_EMBEDDING_ENGINE,
  239. app.state.config.RAG_EMBEDDING_MODEL,
  240. app.state.sentence_transformer_ef,
  241. app.state.config.OPENAI_API_KEY,
  242. app.state.config.OPENAI_API_BASE_URL,
  243. )
  244. return {
  245. "status": True,
  246. "embedding_engine": app.state.config.RAG_EMBEDDING_ENGINE,
  247. "embedding_model": app.state.config.RAG_EMBEDDING_MODEL,
  248. "openai_config": {
  249. "url": app.state.config.OPENAI_API_BASE_URL,
  250. "key": app.state.config.OPENAI_API_KEY,
  251. },
  252. }
  253. except Exception as e:
  254. log.exception(f"Problem updating embedding model: {e}")
  255. raise HTTPException(
  256. status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
  257. detail=ERROR_MESSAGES.DEFAULT(e),
  258. )
  259. class RerankingModelUpdateForm(BaseModel):
  260. reranking_model: str
  261. @app.post("/reranking/update")
  262. async def update_reranking_config(
  263. form_data: RerankingModelUpdateForm, user=Depends(get_admin_user)
  264. ):
  265. log.info(
  266. f"Updating reranking model: {app.state.config.RAG_RERANKING_MODEL} to {form_data.reranking_model}"
  267. )
  268. try:
  269. app.state.config.RAG_RERANKING_MODEL = form_data.reranking_model
  270. update_reranking_model(app.state.config.RAG_RERANKING_MODEL), True
  271. return {
  272. "status": True,
  273. "reranking_model": app.state.config.RAG_RERANKING_MODEL,
  274. }
  275. except Exception as e:
  276. log.exception(f"Problem updating reranking model: {e}")
  277. raise HTTPException(
  278. status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
  279. detail=ERROR_MESSAGES.DEFAULT(e),
  280. )
  281. @app.get("/config")
  282. async def get_rag_config(user=Depends(get_admin_user)):
  283. return {
  284. "status": True,
  285. "pdf_extract_images": app.state.config.PDF_EXTRACT_IMAGES,
  286. "chunk": {
  287. "chunk_size": app.state.config.CHUNK_SIZE,
  288. "chunk_overlap": app.state.config.CHUNK_OVERLAP,
  289. },
  290. "youtube": {
  291. "language": app.state.config.YOUTUBE_LOADER_LANGUAGE,
  292. "translation": app.state.YOUTUBE_LOADER_TRANSLATION,
  293. },
  294. "web": {
  295. "ssl_verification": app.state.config.ENABLE_RAG_WEB_LOADER_SSL_VERIFICATION,
  296. "search": {
  297. "enable": app.state.config.ENABLE_RAG_WEB_SEARCH,
  298. "engine": app.state.config.RAG_WEB_SEARCH_ENGINE,
  299. "searxng_query_url": app.state.config.SEARXNG_QUERY_URL,
  300. "google_pse_api_key": app.state.config.GOOGLE_PSE_API_KEY,
  301. "google_pse_engine_id": app.state.config.GOOGLE_PSE_ENGINE_ID,
  302. "brave_search_api_key": app.state.config.BRAVE_SEARCH_API_KEY,
  303. "serpstack_api_key": app.state.config.SERPSTACK_API_KEY,
  304. "serpstack_https": app.state.config.SERPSTACK_HTTPS,
  305. "serper_api_key": app.state.config.SERPER_API_KEY,
  306. "result_count": app.state.config.RAG_WEB_SEARCH_RESULT_COUNT,
  307. "concurrent_requests": app.state.config.RAG_WEB_SEARCH_CONCURRENT_REQUESTS,
  308. },
  309. },
  310. }
  311. class ChunkParamUpdateForm(BaseModel):
  312. chunk_size: int
  313. chunk_overlap: int
  314. class YoutubeLoaderConfig(BaseModel):
  315. language: List[str]
  316. translation: Optional[str] = None
  317. class WebSearchConfig(BaseModel):
  318. enable: bool
  319. engine: Optional[str] = None
  320. searxng_query_url: Optional[str] = None
  321. google_pse_api_key: Optional[str] = None
  322. google_pse_engine_id: Optional[str] = None
  323. brave_search_api_key: Optional[str] = None
  324. serpstack_api_key: Optional[str] = None
  325. serpstack_https: Optional[bool] = None
  326. serper_api_key: Optional[str] = None
  327. result_count: Optional[int] = None
  328. concurrent_requests: Optional[int] = None
  329. class WebConfig(BaseModel):
  330. search: WebSearchConfig
  331. web_loader_ssl_verification: Optional[bool] = None
  332. class ConfigUpdateForm(BaseModel):
  333. pdf_extract_images: Optional[bool] = None
  334. chunk: Optional[ChunkParamUpdateForm] = None
  335. youtube: Optional[YoutubeLoaderConfig] = None
  336. web: Optional[WebConfig] = None
  337. @app.post("/config/update")
  338. async def update_rag_config(form_data: ConfigUpdateForm, user=Depends(get_admin_user)):
  339. app.state.config.PDF_EXTRACT_IMAGES = (
  340. form_data.pdf_extract_images
  341. if form_data.pdf_extract_images is not None
  342. else app.state.config.PDF_EXTRACT_IMAGES
  343. )
  344. if form_data.chunk is not None:
  345. app.state.config.CHUNK_SIZE = form_data.chunk.chunk_size
  346. app.state.config.CHUNK_OVERLAP = form_data.chunk.chunk_overlap
  347. if form_data.youtube is not None:
  348. app.state.config.YOUTUBE_LOADER_LANGUAGE = form_data.youtube.language
  349. app.state.YOUTUBE_LOADER_TRANSLATION = form_data.youtube.translation
  350. if form_data.web is not None:
  351. app.state.config.ENABLE_RAG_WEB_LOADER_SSL_VERIFICATION = (
  352. form_data.web.web_loader_ssl_verification
  353. )
  354. app.state.config.ENABLE_RAG_WEB_SEARCH = form_data.web.search.enable
  355. app.state.config.RAG_WEB_SEARCH_ENGINE = form_data.web.search.engine
  356. app.state.config.SEARXNG_QUERY_URL = form_data.web.search.searxng_query_url
  357. app.state.config.GOOGLE_PSE_API_KEY = form_data.web.search.google_pse_api_key
  358. app.state.config.GOOGLE_PSE_ENGINE_ID = (
  359. form_data.web.search.google_pse_engine_id
  360. )
  361. app.state.config.BRAVE_SEARCH_API_KEY = (
  362. form_data.web.search.brave_search_api_key
  363. )
  364. app.state.config.SERPSTACK_API_KEY = form_data.web.search.serpstack_api_key
  365. app.state.config.SERPSTACK_HTTPS = form_data.web.search.serpstack_https
  366. app.state.config.SERPER_API_KEY = form_data.web.search.serper_api_key
  367. app.state.config.RAG_WEB_SEARCH_RESULT_COUNT = form_data.web.search.result_count
  368. app.state.config.RAG_WEB_SEARCH_CONCURRENT_REQUESTS = (
  369. form_data.web.search.concurrent_requests
  370. )
  371. return {
  372. "status": True,
  373. "pdf_extract_images": app.state.config.PDF_EXTRACT_IMAGES,
  374. "chunk": {
  375. "chunk_size": app.state.config.CHUNK_SIZE,
  376. "chunk_overlap": app.state.config.CHUNK_OVERLAP,
  377. },
  378. "youtube": {
  379. "language": app.state.config.YOUTUBE_LOADER_LANGUAGE,
  380. "translation": app.state.YOUTUBE_LOADER_TRANSLATION,
  381. },
  382. "web": {
  383. "ssl_verification": app.state.config.ENABLE_RAG_WEB_LOADER_SSL_VERIFICATION,
  384. "search": {
  385. "enable": app.state.config.ENABLE_RAG_WEB_SEARCH,
  386. "engine": app.state.config.RAG_WEB_SEARCH_ENGINE,
  387. "searxng_query_url": app.state.config.SEARXNG_QUERY_URL,
  388. "google_pse_api_key": app.state.config.GOOGLE_PSE_API_KEY,
  389. "google_pse_engine_id": app.state.config.GOOGLE_PSE_ENGINE_ID,
  390. "brave_search_api_key": app.state.config.BRAVE_SEARCH_API_KEY,
  391. "serpstack_api_key": app.state.config.SERPSTACK_API_KEY,
  392. "serpstack_https": app.state.config.SERPSTACK_HTTPS,
  393. "serper_api_key": app.state.config.SERPER_API_KEY,
  394. "result_count": app.state.config.RAG_WEB_SEARCH_RESULT_COUNT,
  395. "concurrent_requests": app.state.config.RAG_WEB_SEARCH_CONCURRENT_REQUESTS,
  396. },
  397. },
  398. }
  399. @app.get("/template")
  400. async def get_rag_template(user=Depends(get_current_user)):
  401. return {
  402. "status": True,
  403. "template": app.state.config.RAG_TEMPLATE,
  404. }
  405. @app.get("/query/settings")
  406. async def get_query_settings(user=Depends(get_admin_user)):
  407. return {
  408. "status": True,
  409. "template": app.state.config.RAG_TEMPLATE,
  410. "k": app.state.config.TOP_K,
  411. "r": app.state.config.RELEVANCE_THRESHOLD,
  412. "hybrid": app.state.config.ENABLE_RAG_HYBRID_SEARCH,
  413. }
  414. class QuerySettingsForm(BaseModel):
  415. k: Optional[int] = None
  416. r: Optional[float] = None
  417. template: Optional[str] = None
  418. hybrid: Optional[bool] = None
  419. @app.post("/query/settings/update")
  420. async def update_query_settings(
  421. form_data: QuerySettingsForm, user=Depends(get_admin_user)
  422. ):
  423. app.state.config.RAG_TEMPLATE = (
  424. form_data.template if form_data.template else RAG_TEMPLATE
  425. )
  426. app.state.config.TOP_K = form_data.k if form_data.k else 4
  427. app.state.config.RELEVANCE_THRESHOLD = form_data.r if form_data.r else 0.0
  428. app.state.config.ENABLE_RAG_HYBRID_SEARCH = (
  429. form_data.hybrid if form_data.hybrid else False
  430. )
  431. return {
  432. "status": True,
  433. "template": app.state.config.RAG_TEMPLATE,
  434. "k": app.state.config.TOP_K,
  435. "r": app.state.config.RELEVANCE_THRESHOLD,
  436. "hybrid": app.state.config.ENABLE_RAG_HYBRID_SEARCH,
  437. }
  438. class QueryDocForm(BaseModel):
  439. collection_name: str
  440. query: str
  441. k: Optional[int] = None
  442. r: Optional[float] = None
  443. hybrid: Optional[bool] = None
  444. @app.post("/query/doc")
  445. def query_doc_handler(
  446. form_data: QueryDocForm,
  447. user=Depends(get_current_user),
  448. ):
  449. try:
  450. if app.state.config.ENABLE_RAG_HYBRID_SEARCH:
  451. return query_doc_with_hybrid_search(
  452. collection_name=form_data.collection_name,
  453. query=form_data.query,
  454. embedding_function=app.state.EMBEDDING_FUNCTION,
  455. k=form_data.k if form_data.k else app.state.config.TOP_K,
  456. reranking_function=app.state.sentence_transformer_rf,
  457. r=(
  458. form_data.r if form_data.r else app.state.config.RELEVANCE_THRESHOLD
  459. ),
  460. )
  461. else:
  462. return query_doc(
  463. collection_name=form_data.collection_name,
  464. query=form_data.query,
  465. embedding_function=app.state.EMBEDDING_FUNCTION,
  466. k=form_data.k if form_data.k else app.state.config.TOP_K,
  467. )
  468. except Exception as e:
  469. log.exception(e)
  470. raise HTTPException(
  471. status_code=status.HTTP_400_BAD_REQUEST,
  472. detail=ERROR_MESSAGES.DEFAULT(e),
  473. )
  474. class QueryCollectionsForm(BaseModel):
  475. collection_names: List[str]
  476. query: str
  477. k: Optional[int] = None
  478. r: Optional[float] = None
  479. hybrid: Optional[bool] = None
  480. @app.post("/query/collection")
  481. def query_collection_handler(
  482. form_data: QueryCollectionsForm,
  483. user=Depends(get_current_user),
  484. ):
  485. try:
  486. if app.state.config.ENABLE_RAG_HYBRID_SEARCH:
  487. return query_collection_with_hybrid_search(
  488. collection_names=form_data.collection_names,
  489. query=form_data.query,
  490. embedding_function=app.state.EMBEDDING_FUNCTION,
  491. k=form_data.k if form_data.k else app.state.config.TOP_K,
  492. reranking_function=app.state.sentence_transformer_rf,
  493. r=(
  494. form_data.r if form_data.r else app.state.config.RELEVANCE_THRESHOLD
  495. ),
  496. )
  497. else:
  498. return query_collection(
  499. collection_names=form_data.collection_names,
  500. query=form_data.query,
  501. embedding_function=app.state.EMBEDDING_FUNCTION,
  502. k=form_data.k if form_data.k else app.state.config.TOP_K,
  503. )
  504. except Exception as e:
  505. log.exception(e)
  506. raise HTTPException(
  507. status_code=status.HTTP_400_BAD_REQUEST,
  508. detail=ERROR_MESSAGES.DEFAULT(e),
  509. )
  510. @app.post("/youtube")
  511. def store_youtube_video(form_data: UrlForm, user=Depends(get_current_user)):
  512. try:
  513. loader = YoutubeLoader.from_youtube_url(
  514. form_data.url,
  515. add_video_info=True,
  516. language=app.state.config.YOUTUBE_LOADER_LANGUAGE,
  517. translation=app.state.YOUTUBE_LOADER_TRANSLATION,
  518. )
  519. data = loader.load()
  520. collection_name = form_data.collection_name
  521. if collection_name == "":
  522. collection_name = calculate_sha256_string(form_data.url)[:63]
  523. store_data_in_vector_db(data, collection_name, overwrite=True)
  524. return {
  525. "status": True,
  526. "collection_name": collection_name,
  527. "filename": form_data.url,
  528. }
  529. except Exception as e:
  530. log.exception(e)
  531. raise HTTPException(
  532. status_code=status.HTTP_400_BAD_REQUEST,
  533. detail=ERROR_MESSAGES.DEFAULT(e),
  534. )
  535. @app.post("/web")
  536. def store_web(form_data: UrlForm, user=Depends(get_current_user)):
  537. # "https://www.gutenberg.org/files/1727/1727-h/1727-h.htm"
  538. try:
  539. loader = get_web_loader(
  540. form_data.url,
  541. verify_ssl=app.state.config.ENABLE_RAG_WEB_LOADER_SSL_VERIFICATION,
  542. )
  543. data = loader.load()
  544. collection_name = form_data.collection_name
  545. if collection_name == "":
  546. collection_name = calculate_sha256_string(form_data.url)[:63]
  547. store_data_in_vector_db(data, collection_name, overwrite=True)
  548. return {
  549. "status": True,
  550. "collection_name": collection_name,
  551. "filename": form_data.url,
  552. }
  553. except Exception as e:
  554. log.exception(e)
  555. raise HTTPException(
  556. status_code=status.HTTP_400_BAD_REQUEST,
  557. detail=ERROR_MESSAGES.DEFAULT(e),
  558. )
  559. def get_web_loader(url: Union[str, Sequence[str]], verify_ssl: bool = True):
  560. # Check if the URL is valid
  561. if not validate_url(url):
  562. raise ValueError(ERROR_MESSAGES.INVALID_URL)
  563. return WebBaseLoader(
  564. url,
  565. verify_ssl=verify_ssl,
  566. requests_per_second=RAG_WEB_SEARCH_CONCURRENT_REQUESTS,
  567. continue_on_failure=True,
  568. )
  569. def validate_url(url: Union[str, Sequence[str]]):
  570. if isinstance(url, str):
  571. if isinstance(validators.url(url), validators.ValidationError):
  572. raise ValueError(ERROR_MESSAGES.INVALID_URL)
  573. if not ENABLE_RAG_LOCAL_WEB_FETCH:
  574. # Local web fetch is disabled, filter out any URLs that resolve to private IP addresses
  575. parsed_url = urllib.parse.urlparse(url)
  576. # Get IPv4 and IPv6 addresses
  577. ipv4_addresses, ipv6_addresses = resolve_hostname(parsed_url.hostname)
  578. # Check if any of the resolved addresses are private
  579. # This is technically still vulnerable to DNS rebinding attacks, as we don't control WebBaseLoader
  580. for ip in ipv4_addresses:
  581. if validators.ipv4(ip, private=True):
  582. raise ValueError(ERROR_MESSAGES.INVALID_URL)
  583. for ip in ipv6_addresses:
  584. if validators.ipv6(ip, private=True):
  585. raise ValueError(ERROR_MESSAGES.INVALID_URL)
  586. return True
  587. elif isinstance(url, Sequence):
  588. return all(validate_url(u) for u in url)
  589. else:
  590. return False
  591. def resolve_hostname(hostname):
  592. # Get address information
  593. addr_info = socket.getaddrinfo(hostname, None)
  594. # Extract IP addresses from address information
  595. ipv4_addresses = [info[4][0] for info in addr_info if info[0] == socket.AF_INET]
  596. ipv6_addresses = [info[4][0] for info in addr_info if info[0] == socket.AF_INET6]
  597. return ipv4_addresses, ipv6_addresses
  598. @app.post("/web/search")
  599. def store_web_search(form_data: SearchForm, user=Depends(get_current_user)):
  600. try:
  601. try:
  602. web_results = search_web(
  603. app.state.config.RAG_WEB_SEARCH_ENGINE, form_data.query
  604. )
  605. except Exception as e:
  606. log.exception(e)
  607. raise HTTPException(
  608. status_code=status.HTTP_400_BAD_REQUEST,
  609. detail=ERROR_MESSAGES.WEB_SEARCH_ERROR,
  610. )
  611. urls = [result.link for result in web_results]
  612. loader = get_web_loader(urls)
  613. data = loader.load()
  614. collection_name = form_data.collection_name
  615. if collection_name == "":
  616. collection_name = calculate_sha256_string(form_data.query)[:63]
  617. store_data_in_vector_db(data, collection_name, overwrite=True)
  618. return {
  619. "status": True,
  620. "collection_name": collection_name,
  621. "filenames": urls,
  622. }
  623. except Exception as e:
  624. log.exception(e)
  625. raise HTTPException(
  626. status_code=status.HTTP_400_BAD_REQUEST,
  627. detail=ERROR_MESSAGES.DEFAULT(e),
  628. )
  629. def store_data_in_vector_db(data, collection_name, overwrite: bool = False) -> bool:
  630. text_splitter = RecursiveCharacterTextSplitter(
  631. chunk_size=app.state.config.CHUNK_SIZE,
  632. chunk_overlap=app.state.config.CHUNK_OVERLAP,
  633. add_start_index=True,
  634. )
  635. docs = text_splitter.split_documents(data)
  636. if len(docs) > 0:
  637. log.info(f"store_data_in_vector_db {docs}")
  638. return store_docs_in_vector_db(docs, collection_name, overwrite), None
  639. else:
  640. raise ValueError(ERROR_MESSAGES.EMPTY_CONTENT)
  641. def store_text_in_vector_db(
  642. text, metadata, collection_name, overwrite: bool = False
  643. ) -> bool:
  644. text_splitter = RecursiveCharacterTextSplitter(
  645. chunk_size=app.state.config.CHUNK_SIZE,
  646. chunk_overlap=app.state.config.CHUNK_OVERLAP,
  647. add_start_index=True,
  648. )
  649. docs = text_splitter.create_documents([text], metadatas=[metadata])
  650. return store_docs_in_vector_db(docs, collection_name, overwrite)
  651. def store_docs_in_vector_db(docs, collection_name, overwrite: bool = False) -> bool:
  652. log.info(f"store_docs_in_vector_db {docs} {collection_name}")
  653. texts = [doc.page_content for doc in docs]
  654. metadatas = [doc.metadata for doc in docs]
  655. try:
  656. if overwrite:
  657. for collection in CHROMA_CLIENT.list_collections():
  658. if collection_name == collection.name:
  659. log.info(f"deleting existing collection {collection_name}")
  660. CHROMA_CLIENT.delete_collection(name=collection_name)
  661. collection = CHROMA_CLIENT.create_collection(name=collection_name)
  662. embedding_func = get_embedding_function(
  663. app.state.config.RAG_EMBEDDING_ENGINE,
  664. app.state.config.RAG_EMBEDDING_MODEL,
  665. app.state.sentence_transformer_ef,
  666. app.state.config.OPENAI_API_KEY,
  667. app.state.config.OPENAI_API_BASE_URL,
  668. )
  669. embedding_texts = list(map(lambda x: x.replace("\n", " "), texts))
  670. embeddings = embedding_func(embedding_texts)
  671. for batch in create_batches(
  672. api=CHROMA_CLIENT,
  673. ids=[str(uuid.uuid4()) for _ in texts],
  674. metadatas=metadatas,
  675. embeddings=embeddings,
  676. documents=texts,
  677. ):
  678. collection.add(*batch)
  679. return True
  680. except Exception as e:
  681. log.exception(e)
  682. if e.__class__.__name__ == "UniqueConstraintError":
  683. return True
  684. return False
  685. def get_loader(filename: str, file_content_type: str, file_path: str):
  686. file_ext = filename.split(".")[-1].lower()
  687. known_type = True
  688. known_source_ext = [
  689. "go",
  690. "py",
  691. "java",
  692. "sh",
  693. "bat",
  694. "ps1",
  695. "cmd",
  696. "js",
  697. "ts",
  698. "css",
  699. "cpp",
  700. "hpp",
  701. "h",
  702. "c",
  703. "cs",
  704. "sql",
  705. "log",
  706. "ini",
  707. "pl",
  708. "pm",
  709. "r",
  710. "dart",
  711. "dockerfile",
  712. "env",
  713. "php",
  714. "hs",
  715. "hsc",
  716. "lua",
  717. "nginxconf",
  718. "conf",
  719. "m",
  720. "mm",
  721. "plsql",
  722. "perl",
  723. "rb",
  724. "rs",
  725. "db2",
  726. "scala",
  727. "bash",
  728. "swift",
  729. "vue",
  730. "svelte",
  731. ]
  732. if file_ext == "pdf":
  733. loader = PyPDFLoader(
  734. file_path, extract_images=app.state.config.PDF_EXTRACT_IMAGES
  735. )
  736. elif file_ext == "csv":
  737. loader = CSVLoader(file_path)
  738. elif file_ext == "rst":
  739. loader = UnstructuredRSTLoader(file_path, mode="elements")
  740. elif file_ext == "xml":
  741. loader = UnstructuredXMLLoader(file_path)
  742. elif file_ext in ["htm", "html"]:
  743. loader = BSHTMLLoader(file_path, open_encoding="unicode_escape")
  744. elif file_ext == "md":
  745. loader = UnstructuredMarkdownLoader(file_path)
  746. elif file_content_type == "application/epub+zip":
  747. loader = UnstructuredEPubLoader(file_path)
  748. elif (
  749. file_content_type
  750. == "application/vnd.openxmlformats-officedocument.wordprocessingml.document"
  751. or file_ext in ["doc", "docx"]
  752. ):
  753. loader = Docx2txtLoader(file_path)
  754. elif file_content_type in [
  755. "application/vnd.ms-excel",
  756. "application/vnd.openxmlformats-officedocument.spreadsheetml.sheet",
  757. ] or file_ext in ["xls", "xlsx"]:
  758. loader = UnstructuredExcelLoader(file_path)
  759. elif file_content_type in [
  760. "application/vnd.ms-powerpoint",
  761. "application/vnd.openxmlformats-officedocument.presentationml.presentation",
  762. ] or file_ext in ["ppt", "pptx"]:
  763. loader = UnstructuredPowerPointLoader(file_path)
  764. elif file_ext in known_source_ext or (
  765. file_content_type and file_content_type.find("text/") >= 0
  766. ):
  767. loader = TextLoader(file_path, autodetect_encoding=True)
  768. else:
  769. loader = TextLoader(file_path, autodetect_encoding=True)
  770. known_type = False
  771. return loader, known_type
  772. @app.post("/doc")
  773. def store_doc(
  774. collection_name: Optional[str] = Form(None),
  775. file: UploadFile = File(...),
  776. user=Depends(get_current_user),
  777. ):
  778. # "https://www.gutenberg.org/files/1727/1727-h/1727-h.htm"
  779. log.info(f"file.content_type: {file.content_type}")
  780. try:
  781. unsanitized_filename = file.filename
  782. filename = os.path.basename(unsanitized_filename)
  783. file_path = f"{UPLOAD_DIR}/{filename}"
  784. contents = file.file.read()
  785. with open(file_path, "wb") as f:
  786. f.write(contents)
  787. f.close()
  788. f = open(file_path, "rb")
  789. if collection_name == None:
  790. collection_name = calculate_sha256(f)[:63]
  791. f.close()
  792. loader, known_type = get_loader(filename, file.content_type, file_path)
  793. data = loader.load()
  794. try:
  795. result = store_data_in_vector_db(data, collection_name)
  796. if result:
  797. return {
  798. "status": True,
  799. "collection_name": collection_name,
  800. "filename": filename,
  801. "known_type": known_type,
  802. }
  803. except Exception as e:
  804. raise HTTPException(
  805. status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
  806. detail=e,
  807. )
  808. except Exception as e:
  809. log.exception(e)
  810. if "No pandoc was found" in str(e):
  811. raise HTTPException(
  812. status_code=status.HTTP_400_BAD_REQUEST,
  813. detail=ERROR_MESSAGES.PANDOC_NOT_INSTALLED,
  814. )
  815. else:
  816. raise HTTPException(
  817. status_code=status.HTTP_400_BAD_REQUEST,
  818. detail=ERROR_MESSAGES.DEFAULT(e),
  819. )
  820. class TextRAGForm(BaseModel):
  821. name: str
  822. content: str
  823. collection_name: Optional[str] = None
  824. @app.post("/text")
  825. def store_text(
  826. form_data: TextRAGForm,
  827. user=Depends(get_current_user),
  828. ):
  829. collection_name = form_data.collection_name
  830. if collection_name == None:
  831. collection_name = calculate_sha256_string(form_data.content)
  832. result = store_text_in_vector_db(
  833. form_data.content,
  834. metadata={"name": form_data.name, "created_by": user.id},
  835. collection_name=collection_name,
  836. )
  837. if result:
  838. return {"status": True, "collection_name": collection_name}
  839. else:
  840. raise HTTPException(
  841. status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
  842. detail=ERROR_MESSAGES.DEFAULT(),
  843. )
  844. @app.get("/scan")
  845. def scan_docs_dir(user=Depends(get_admin_user)):
  846. for path in Path(DOCS_DIR).rglob("./**/*"):
  847. try:
  848. if path.is_file() and not path.name.startswith("."):
  849. tags = extract_folders_after_data_docs(path)
  850. filename = path.name
  851. file_content_type = mimetypes.guess_type(path)
  852. f = open(path, "rb")
  853. collection_name = calculate_sha256(f)[:63]
  854. f.close()
  855. loader, known_type = get_loader(
  856. filename, file_content_type[0], str(path)
  857. )
  858. data = loader.load()
  859. try:
  860. result = store_data_in_vector_db(data, collection_name)
  861. if result:
  862. sanitized_filename = sanitize_filename(filename)
  863. doc = Documents.get_doc_by_name(sanitized_filename)
  864. if doc == None:
  865. doc = Documents.insert_new_doc(
  866. user.id,
  867. DocumentForm(
  868. **{
  869. "name": sanitized_filename,
  870. "title": filename,
  871. "collection_name": collection_name,
  872. "filename": filename,
  873. "content": (
  874. json.dumps(
  875. {
  876. "tags": list(
  877. map(
  878. lambda name: {"name": name},
  879. tags,
  880. )
  881. )
  882. }
  883. )
  884. if len(tags)
  885. else "{}"
  886. ),
  887. }
  888. ),
  889. )
  890. except Exception as e:
  891. log.exception(e)
  892. pass
  893. except Exception as e:
  894. log.exception(e)
  895. return True
  896. @app.get("/reset/db")
  897. def reset_vector_db(user=Depends(get_admin_user)):
  898. CHROMA_CLIENT.reset()
  899. @app.get("/reset")
  900. def reset(user=Depends(get_admin_user)) -> bool:
  901. folder = f"{UPLOAD_DIR}"
  902. for filename in os.listdir(folder):
  903. file_path = os.path.join(folder, filename)
  904. try:
  905. if os.path.isfile(file_path) or os.path.islink(file_path):
  906. os.unlink(file_path)
  907. elif os.path.isdir(file_path):
  908. shutil.rmtree(file_path)
  909. except Exception as e:
  910. log.error("Failed to delete %s. Reason: %s" % (file_path, e))
  911. try:
  912. CHROMA_CLIENT.reset()
  913. except Exception as e:
  914. log.exception(e)
  915. return True
  916. if ENV == "dev":
  917. @app.get("/ef")
  918. async def get_embeddings():
  919. return {"result": app.state.EMBEDDING_FUNCTION("hello world")}
  920. @app.get("/ef/{text}")
  921. async def get_embeddings_text(text: str):
  922. return {"result": app.state.EMBEDDING_FUNCTION(text)}