retrieval.py 50 KB

12345678910111213141516171819202122232425262728293031323334353637383940414243444546474849505152535455565758596061626364656667686970717273747576777879808182838485868788899091929394959697989910010110210310410510610710810911011111211311411511611711811912012112212312412512612712812913013113213313413513613713813914014114214314414514614714814915015115215315415515615715815916016116216316416516616716816917017117217317417517617717817918018118218318418518618718818919019119219319419519619719819920020120220320420520620720820921021121221321421521621721821922022122222322422522622722822923023123223323423523623723823924024124224324424524624724824925025125225325425525625725825926026126226326426526626726826927027127227327427527627727827928028128228328428528628728828929029129229329429529629729829930030130230330430530630730830931031131231331431531631731831932032132232332432532632732832933033133233333433533633733833934034134234334434534634734834935035135235335435535635735835936036136236336436536636736836937037137237337437537637737837938038138238338438538638738838939039139239339439539639739839940040140240340440540640740840941041141241341441541641741841942042142242342442542642742842943043143243343443543643743843944044144244344444544644744844945045145245345445545645745845946046146246346446546646746846947047147247347447547647747847948048148248348448548648748848949049149249349449549649749849950050150250350450550650750850951051151251351451551651751851952052152252352452552652752852953053153253353453553653753853954054154254354454554654754854955055155255355455555655755855956056156256356456556656756856957057157257357457557657757857958058158258358458558658758858959059159259359459559659759859960060160260360460560660760860961061161261361461561661761861962062162262362462562662762862963063163263363463563663763863964064164264364464564664764864965065165265365465565665765865966066166266366466566666766866967067167267367467567667767867968068168268368468568668768868969069169269369469569669769869970070170270370470570670770870971071171271371471571671771871972072172272372472572672772872973073173273373473573673773873974074174274374474574674774874975075175275375475575675775875976076176276376476576676776876977077177277377477577677777877978078178278378478578678778878979079179279379479579679779879980080180280380480580680780880981081181281381481581681781881982082182282382482582682782882983083183283383483583683783883984084184284384484584684784884985085185285385485585685785885986086186286386486586686786886987087187287387487587687787887988088188288388488588688788888989089189289389489589689789889990090190290390490590690790890991091191291391491591691791891992092192292392492592692792892993093193293393493593693793893994094194294394494594694794894995095195295395495595695795895996096196296396496596696796896997097197297397497597697797897998098198298398498598698798898999099199299399499599699799899910001001100210031004100510061007100810091010101110121013101410151016101710181019102010211022102310241025102610271028102910301031103210331034103510361037103810391040104110421043104410451046104710481049105010511052105310541055105610571058105910601061106210631064106510661067106810691070107110721073107410751076107710781079108010811082108310841085108610871088108910901091109210931094109510961097109810991100110111021103110411051106110711081109111011111112111311141115111611171118111911201121112211231124112511261127112811291130113111321133113411351136113711381139114011411142114311441145114611471148114911501151115211531154115511561157115811591160116111621163116411651166116711681169117011711172117311741175117611771178117911801181118211831184118511861187118811891190119111921193119411951196119711981199120012011202120312041205120612071208120912101211121212131214121512161217121812191220122112221223122412251226122712281229123012311232123312341235123612371238123912401241124212431244124512461247124812491250125112521253125412551256125712581259126012611262126312641265126612671268126912701271127212731274127512761277127812791280128112821283128412851286128712881289129012911292129312941295129612971298129913001301130213031304130513061307130813091310131113121313131413151316131713181319132013211322132313241325132613271328132913301331133213331334133513361337133813391340134113421343134413451346134713481349135013511352135313541355135613571358135913601361136213631364136513661367136813691370137113721373137413751376137713781379138013811382138313841385138613871388138913901391139213931394139513961397139813991400140114021403140414051406140714081409141014111412141314141415141614171418
  1. import json
  2. import logging
  3. import mimetypes
  4. import os
  5. import shutil
  6. import uuid
  7. from datetime import datetime
  8. from pathlib import Path
  9. from typing import Iterator, Optional, Sequence, Union
  10. from fastapi import (
  11. Depends,
  12. FastAPI,
  13. File,
  14. Form,
  15. HTTPException,
  16. UploadFile,
  17. Request,
  18. status,
  19. APIRouter,
  20. )
  21. from fastapi.middleware.cors import CORSMiddleware
  22. from pydantic import BaseModel
  23. import tiktoken
  24. from langchain.text_splitter import RecursiveCharacterTextSplitter, TokenTextSplitter
  25. from langchain_core.documents import Document
  26. from open_webui.models.files import Files
  27. from open_webui.models.knowledge import Knowledges
  28. from open_webui.storage.provider import Storage
  29. from open_webui.retrieval.vector.connector import VECTOR_DB_CLIENT
  30. # Document loaders
  31. from open_webui.retrieval.loaders.main import Loader
  32. from open_webui.retrieval.loaders.youtube import YoutubeLoader
  33. # Web search engines
  34. from open_webui.retrieval.web.main import SearchResult
  35. from open_webui.retrieval.web.utils import get_web_loader
  36. from open_webui.retrieval.web.brave import search_brave
  37. from open_webui.retrieval.web.kagi import search_kagi
  38. from open_webui.retrieval.web.mojeek import search_mojeek
  39. from open_webui.retrieval.web.duckduckgo import search_duckduckgo
  40. from open_webui.retrieval.web.google_pse import search_google_pse
  41. from open_webui.retrieval.web.jina_search import search_jina
  42. from open_webui.retrieval.web.searchapi import search_searchapi
  43. from open_webui.retrieval.web.searxng import search_searxng
  44. from open_webui.retrieval.web.serper import search_serper
  45. from open_webui.retrieval.web.serply import search_serply
  46. from open_webui.retrieval.web.serpstack import search_serpstack
  47. from open_webui.retrieval.web.tavily import search_tavily
  48. from open_webui.retrieval.web.bing import search_bing
  49. from open_webui.retrieval.utils import (
  50. get_embedding_function,
  51. get_model_path,
  52. query_collection,
  53. query_collection_with_hybrid_search,
  54. query_doc,
  55. query_doc_with_hybrid_search,
  56. )
  57. from open_webui.utils.misc import (
  58. calculate_sha256_string,
  59. )
  60. from open_webui.utils.auth import get_admin_user, get_verified_user
  61. from open_webui.config import (
  62. ENV,
  63. RAG_EMBEDDING_MODEL_AUTO_UPDATE,
  64. RAG_EMBEDDING_MODEL_TRUST_REMOTE_CODE,
  65. RAG_RERANKING_MODEL_AUTO_UPDATE,
  66. RAG_RERANKING_MODEL_TRUST_REMOTE_CODE,
  67. UPLOAD_DIR,
  68. DEFAULT_LOCALE,
  69. )
  70. from open_webui.env import (
  71. SRC_LOG_LEVELS,
  72. DEVICE_TYPE,
  73. DOCKER,
  74. )
  75. from open_webui.constants import ERROR_MESSAGES
  76. log = logging.getLogger(__name__)
  77. log.setLevel(SRC_LOG_LEVELS["RAG"])
  78. ##########################################
  79. #
  80. # Utility functions
  81. #
  82. ##########################################
  83. def update_embedding_model(
  84. request: Request,
  85. embedding_model: str,
  86. auto_update: bool = False,
  87. ):
  88. if embedding_model and request.app.state.config.RAG_EMBEDDING_ENGINE == "":
  89. from sentence_transformers import SentenceTransformer
  90. try:
  91. request.app.state.sentence_transformer_ef = SentenceTransformer(
  92. get_model_path(embedding_model, auto_update),
  93. device=DEVICE_TYPE,
  94. trust_remote_code=RAG_EMBEDDING_MODEL_TRUST_REMOTE_CODE,
  95. )
  96. except Exception as e:
  97. log.debug(f"Error loading SentenceTransformer: {e}")
  98. request.app.state.sentence_transformer_ef = None
  99. else:
  100. request.app.state.sentence_transformer_ef = None
  101. def update_reranking_model(
  102. request: Request,
  103. reranking_model: str,
  104. auto_update: bool = False,
  105. ):
  106. if reranking_model:
  107. if any(model in reranking_model for model in ["jinaai/jina-colbert-v2"]):
  108. try:
  109. from open_webui.retrieval.models.colbert import ColBERT
  110. request.app.state.sentence_transformer_rf = ColBERT(
  111. get_model_path(reranking_model, auto_update),
  112. env="docker" if DOCKER else None,
  113. )
  114. except Exception as e:
  115. log.error(f"ColBERT: {e}")
  116. request.app.state.sentence_transformer_rf = None
  117. request.app.state.config.ENABLE_RAG_HYBRID_SEARCH = False
  118. else:
  119. import sentence_transformers
  120. try:
  121. request.app.state.sentence_transformer_rf = (
  122. sentence_transformers.CrossEncoder(
  123. get_model_path(reranking_model, auto_update),
  124. device=DEVICE_TYPE,
  125. trust_remote_code=RAG_RERANKING_MODEL_TRUST_REMOTE_CODE,
  126. )
  127. )
  128. except:
  129. log.error("CrossEncoder error")
  130. request.app.state.sentence_transformer_rf = None
  131. request.app.state.config.ENABLE_RAG_HYBRID_SEARCH = False
  132. else:
  133. request.app.state.sentence_transformer_rf = None
  134. ##########################################
  135. #
  136. # API routes
  137. #
  138. ##########################################
  139. router = APIRouter()
  140. class CollectionNameForm(BaseModel):
  141. collection_name: Optional[str] = None
  142. class ProcessUrlForm(CollectionNameForm):
  143. url: str
  144. class SearchForm(CollectionNameForm):
  145. query: str
  146. @router.get("/")
  147. async def get_status(request: Request):
  148. return {
  149. "status": True,
  150. "chunk_size": request.app.state.config.CHUNK_SIZE,
  151. "chunk_overlap": request.app.state.config.CHUNK_OVERLAP,
  152. "template": request.app.state.config.RAG_TEMPLATE,
  153. "embedding_engine": request.app.state.config.RAG_EMBEDDING_ENGINE,
  154. "embedding_model": request.app.state.config.RAG_EMBEDDING_MODEL,
  155. "reranking_model": request.app.state.config.RAG_RERANKING_MODEL,
  156. "embedding_batch_size": request.app.state.config.RAG_EMBEDDING_BATCH_SIZE,
  157. }
  158. @router.get("/embedding")
  159. async def get_embedding_config(request: Request, user=Depends(get_admin_user)):
  160. return {
  161. "status": True,
  162. "embedding_engine": request.app.state.config.RAG_EMBEDDING_ENGINE,
  163. "embedding_model": request.app.state.config.RAG_EMBEDDING_MODEL,
  164. "embedding_batch_size": request.app.state.config.RAG_EMBEDDING_BATCH_SIZE,
  165. "openai_config": {
  166. "url": request.app.state.config.OPENAI_API_BASE_URL,
  167. "key": request.app.state.config.OPENAI_API_KEY,
  168. },
  169. "ollama_config": {
  170. "url": request.app.state.config.OLLAMA_BASE_URL,
  171. "key": request.app.state.config.OLLAMA_API_KEY,
  172. },
  173. }
  174. @router.get("/reranking")
  175. async def get_reraanking_config(request: Request, user=Depends(get_admin_user)):
  176. return {
  177. "status": True,
  178. "reranking_model": request.app.state.config.RAG_RERANKING_MODEL,
  179. }
  180. class OpenAIConfigForm(BaseModel):
  181. url: str
  182. key: str
  183. class OllamaConfigForm(BaseModel):
  184. url: str
  185. key: str
  186. class EmbeddingModelUpdateForm(BaseModel):
  187. openai_config: Optional[OpenAIConfigForm] = None
  188. ollama_config: Optional[OllamaConfigForm] = None
  189. embedding_engine: str
  190. embedding_model: str
  191. embedding_batch_size: Optional[int] = 1
  192. @router.post("/embedding/update")
  193. async def update_embedding_config(
  194. request: Request, form_data: EmbeddingModelUpdateForm, user=Depends(get_admin_user)
  195. ):
  196. log.info(
  197. f"Updating embedding model: {request.app.state.config.RAG_EMBEDDING_MODEL} to {form_data.embedding_model}"
  198. )
  199. try:
  200. request.app.state.config.RAG_EMBEDDING_ENGINE = form_data.embedding_engine
  201. request.app.state.config.RAG_EMBEDDING_MODEL = form_data.embedding_model
  202. if request.app.state.config.RAG_EMBEDDING_ENGINE in ["ollama", "openai"]:
  203. if form_data.openai_config is not None:
  204. request.app.state.config.OPENAI_API_BASE_URL = (
  205. form_data.openai_config.url
  206. )
  207. request.app.state.config.OPENAI_API_KEY = form_data.openai_config.key
  208. if form_data.ollama_config is not None:
  209. request.app.state.config.OLLAMA_BASE_URL = form_data.ollama_config.url
  210. request.app.state.config.OLLAMA_API_KEY = form_data.ollama_config.key
  211. request.app.state.config.RAG_EMBEDDING_BATCH_SIZE = (
  212. form_data.embedding_batch_size
  213. )
  214. update_embedding_model(request.app.state.config.RAG_EMBEDDING_MODEL)
  215. request.app.state.EMBEDDING_FUNCTION = get_embedding_function(
  216. request.app.state.config.RAG_EMBEDDING_ENGINE,
  217. request.app.state.config.RAG_EMBEDDING_MODEL,
  218. request.app.state.sentence_transformer_ef,
  219. (
  220. request.app.state.config.OPENAI_API_BASE_URL
  221. if request.app.state.config.RAG_EMBEDDING_ENGINE == "openai"
  222. else request.app.state.config.OLLAMA_BASE_URL
  223. ),
  224. (
  225. request.app.state.config.OPENAI_API_KEY
  226. if request.app.state.config.RAG_EMBEDDING_ENGINE == "openai"
  227. else request.app.state.config.OLLAMA_API_KEY
  228. ),
  229. request.app.state.config.RAG_EMBEDDING_BATCH_SIZE,
  230. )
  231. return {
  232. "status": True,
  233. "embedding_engine": request.app.state.config.RAG_EMBEDDING_ENGINE,
  234. "embedding_model": request.app.state.config.RAG_EMBEDDING_MODEL,
  235. "embedding_batch_size": request.app.state.config.RAG_EMBEDDING_BATCH_SIZE,
  236. "openai_config": {
  237. "url": request.app.state.config.OPENAI_API_BASE_URL,
  238. "key": request.app.state.config.OPENAI_API_KEY,
  239. },
  240. "ollama_config": {
  241. "url": request.app.state.config.OLLAMA_BASE_URL,
  242. "key": request.app.state.config.OLLAMA_API_KEY,
  243. },
  244. }
  245. except Exception as e:
  246. log.exception(f"Problem updating embedding model: {e}")
  247. raise HTTPException(
  248. status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
  249. detail=ERROR_MESSAGES.DEFAULT(e),
  250. )
  251. class RerankingModelUpdateForm(BaseModel):
  252. reranking_model: str
  253. @router.post("/reranking/update")
  254. async def update_reranking_config(
  255. request: Request, form_data: RerankingModelUpdateForm, user=Depends(get_admin_user)
  256. ):
  257. log.info(
  258. f"Updating reranking model: {request.app.state.config.RAG_RERANKING_MODEL} to {form_data.reranking_model}"
  259. )
  260. try:
  261. request.app.state.config.RAG_RERANKING_MODEL = form_data.reranking_model
  262. update_reranking_model(request.app.state.config.RAG_RERANKING_MODEL, True)
  263. return {
  264. "status": True,
  265. "reranking_model": request.app.state.config.RAG_RERANKING_MODEL,
  266. }
  267. except Exception as e:
  268. log.exception(f"Problem updating reranking model: {e}")
  269. raise HTTPException(
  270. status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
  271. detail=ERROR_MESSAGES.DEFAULT(e),
  272. )
  273. @router.get("/config")
  274. async def get_rag_config(request: Request, user=Depends(get_admin_user)):
  275. return {
  276. "status": True,
  277. "pdf_extract_images": request.app.state.config.PDF_EXTRACT_IMAGES,
  278. "content_extraction": {
  279. "engine": request.app.state.config.CONTENT_EXTRACTION_ENGINE,
  280. "tika_server_url": request.app.state.config.TIKA_SERVER_URL,
  281. },
  282. "chunk": {
  283. "text_splitter": request.app.state.config.TEXT_SPLITTER,
  284. "chunk_size": request.app.state.config.CHUNK_SIZE,
  285. "chunk_overlap": request.app.state.config.CHUNK_OVERLAP,
  286. },
  287. "file": {
  288. "max_size": request.app.state.config.FILE_MAX_SIZE,
  289. "max_count": request.app.state.config.FILE_MAX_COUNT,
  290. },
  291. "youtube": {
  292. "language": request.app.state.config.YOUTUBE_LOADER_LANGUAGE,
  293. "translation": request.app.state.YOUTUBE_LOADER_TRANSLATION,
  294. "proxy_url": request.app.state.config.YOUTUBE_LOADER_PROXY_URL,
  295. },
  296. "web": {
  297. "web_loader_ssl_verification": request.app.state.config.ENABLE_RAG_WEB_LOADER_SSL_VERIFICATION,
  298. "search": {
  299. "enabled": request.app.state.config.ENABLE_RAG_WEB_SEARCH,
  300. "engine": request.app.state.config.RAG_WEB_SEARCH_ENGINE,
  301. "searxng_query_url": request.app.state.config.SEARXNG_QUERY_URL,
  302. "google_pse_api_key": request.app.state.config.GOOGLE_PSE_API_KEY,
  303. "google_pse_engine_id": request.app.state.config.GOOGLE_PSE_ENGINE_ID,
  304. "brave_search_api_key": request.app.state.config.BRAVE_SEARCH_API_KEY,
  305. "kagi_search_api_key": request.app.state.config.KAGI_SEARCH_API_KEY,
  306. "mojeek_search_api_key": request.app.state.config.MOJEEK_SEARCH_API_KEY,
  307. "serpstack_api_key": request.app.state.config.SERPSTACK_API_KEY,
  308. "serpstack_https": request.app.state.config.SERPSTACK_HTTPS,
  309. "serper_api_key": request.app.state.config.SERPER_API_KEY,
  310. "serply_api_key": request.app.state.config.SERPLY_API_KEY,
  311. "tavily_api_key": request.app.state.config.TAVILY_API_KEY,
  312. "searchapi_api_key": request.app.state.config.SEARCHAPI_API_KEY,
  313. "seaarchapi_engine": request.app.state.config.SEARCHAPI_ENGINE,
  314. "jina_api_key": request.app.state.config.JINA_API_KEY,
  315. "bing_search_v7_endpoint": request.app.state.config.BING_SEARCH_V7_ENDPOINT,
  316. "bing_search_v7_subscription_key": request.app.state.config.BING_SEARCH_V7_SUBSCRIPTION_KEY,
  317. "result_count": request.app.state.config.RAG_WEB_SEARCH_RESULT_COUNT,
  318. "concurrent_requests": request.app.state.config.RAG_WEB_SEARCH_CONCURRENT_REQUESTS,
  319. },
  320. },
  321. }
  322. class FileConfig(BaseModel):
  323. max_size: Optional[int] = None
  324. max_count: Optional[int] = None
  325. class ContentExtractionConfig(BaseModel):
  326. engine: str = ""
  327. tika_server_url: Optional[str] = None
  328. class ChunkParamUpdateForm(BaseModel):
  329. text_splitter: Optional[str] = None
  330. chunk_size: int
  331. chunk_overlap: int
  332. class YoutubeLoaderConfig(BaseModel):
  333. language: list[str]
  334. translation: Optional[str] = None
  335. proxy_url: str = ""
  336. class WebSearchConfig(BaseModel):
  337. enabled: bool
  338. engine: Optional[str] = None
  339. searxng_query_url: Optional[str] = None
  340. google_pse_api_key: Optional[str] = None
  341. google_pse_engine_id: Optional[str] = None
  342. brave_search_api_key: Optional[str] = None
  343. kagi_search_api_key: Optional[str] = None
  344. mojeek_search_api_key: Optional[str] = None
  345. serpstack_api_key: Optional[str] = None
  346. serpstack_https: Optional[bool] = None
  347. serper_api_key: Optional[str] = None
  348. serply_api_key: Optional[str] = None
  349. tavily_api_key: Optional[str] = None
  350. searchapi_api_key: Optional[str] = None
  351. searchapi_engine: Optional[str] = None
  352. jina_api_key: Optional[str] = None
  353. bing_search_v7_endpoint: Optional[str] = None
  354. bing_search_v7_subscription_key: Optional[str] = None
  355. result_count: Optional[int] = None
  356. concurrent_requests: Optional[int] = None
  357. class WebConfig(BaseModel):
  358. search: WebSearchConfig
  359. web_loader_ssl_verification: Optional[bool] = None
  360. class ConfigUpdateForm(BaseModel):
  361. pdf_extract_images: Optional[bool] = None
  362. file: Optional[FileConfig] = None
  363. content_extraction: Optional[ContentExtractionConfig] = None
  364. chunk: Optional[ChunkParamUpdateForm] = None
  365. youtube: Optional[YoutubeLoaderConfig] = None
  366. web: Optional[WebConfig] = None
  367. @router.post("/config/update")
  368. async def update_rag_config(
  369. request: Request, form_data: ConfigUpdateForm, user=Depends(get_admin_user)
  370. ):
  371. request.app.state.config.PDF_EXTRACT_IMAGES = (
  372. form_data.pdf_extract_images
  373. if form_data.pdf_extract_images is not None
  374. else request.app.state.config.PDF_EXTRACT_IMAGES
  375. )
  376. if form_data.file is not None:
  377. request.app.state.config.FILE_MAX_SIZE = form_data.file.max_size
  378. request.app.state.config.FILE_MAX_COUNT = form_data.file.max_count
  379. if form_data.content_extraction is not None:
  380. log.info(f"Updating text settings: {form_data.content_extraction}")
  381. request.app.state.config.CONTENT_EXTRACTION_ENGINE = (
  382. form_data.content_extraction.engine
  383. )
  384. request.app.state.config.TIKA_SERVER_URL = (
  385. form_data.content_extraction.tika_server_url
  386. )
  387. if form_data.chunk is not None:
  388. request.app.state.config.TEXT_SPLITTER = form_data.chunk.text_splitter
  389. request.app.state.config.CHUNK_SIZE = form_data.chunk.chunk_size
  390. request.app.state.config.CHUNK_OVERLAP = form_data.chunk.chunk_overlap
  391. if form_data.youtube is not None:
  392. request.app.state.config.YOUTUBE_LOADER_LANGUAGE = form_data.youtube.language
  393. request.app.state.config.YOUTUBE_LOADER_PROXY_URL = form_data.youtube.proxy_url
  394. request.app.state.YOUTUBE_LOADER_TRANSLATION = form_data.youtube.translation
  395. if form_data.web is not None:
  396. request.app.state.config.ENABLE_RAG_WEB_LOADER_SSL_VERIFICATION = (
  397. # Note: When UI "Bypass SSL verification for Websites"=True then ENABLE_RAG_WEB_LOADER_SSL_VERIFICATION=False
  398. form_data.web.web_loader_ssl_verification
  399. )
  400. request.app.state.config.ENABLE_RAG_WEB_SEARCH = form_data.web.search.enabled
  401. request.app.state.config.RAG_WEB_SEARCH_ENGINE = form_data.web.search.engine
  402. request.app.state.config.SEARXNG_QUERY_URL = (
  403. form_data.web.search.searxng_query_url
  404. )
  405. request.app.state.config.GOOGLE_PSE_API_KEY = (
  406. form_data.web.search.google_pse_api_key
  407. )
  408. request.app.state.config.GOOGLE_PSE_ENGINE_ID = (
  409. form_data.web.search.google_pse_engine_id
  410. )
  411. request.app.state.config.BRAVE_SEARCH_API_KEY = (
  412. form_data.web.search.brave_search_api_key
  413. )
  414. request.app.state.config.KAGI_SEARCH_API_KEY = (
  415. form_data.web.search.kagi_search_api_key
  416. )
  417. request.app.state.config.MOJEEK_SEARCH_API_KEY = (
  418. form_data.web.search.mojeek_search_api_key
  419. )
  420. request.app.state.config.SERPSTACK_API_KEY = (
  421. form_data.web.search.serpstack_api_key
  422. )
  423. request.app.state.config.SERPSTACK_HTTPS = form_data.web.search.serpstack_https
  424. request.app.state.config.SERPER_API_KEY = form_data.web.search.serper_api_key
  425. request.app.state.config.SERPLY_API_KEY = form_data.web.search.serply_api_key
  426. request.app.state.config.TAVILY_API_KEY = form_data.web.search.tavily_api_key
  427. request.app.state.config.SEARCHAPI_API_KEY = (
  428. form_data.web.search.searchapi_api_key
  429. )
  430. request.app.state.config.SEARCHAPI_ENGINE = (
  431. form_data.web.search.searchapi_engine
  432. )
  433. request.app.state.config.JINA_API_KEY = form_data.web.search.jina_api_key
  434. request.app.state.config.BING_SEARCH_V7_ENDPOINT = (
  435. form_data.web.search.bing_search_v7_endpoint
  436. )
  437. request.app.state.config.BING_SEARCH_V7_SUBSCRIPTION_KEY = (
  438. form_data.web.search.bing_search_v7_subscription_key
  439. )
  440. request.app.state.config.RAG_WEB_SEARCH_RESULT_COUNT = (
  441. form_data.web.search.result_count
  442. )
  443. request.app.state.config.RAG_WEB_SEARCH_CONCURRENT_REQUESTS = (
  444. form_data.web.search.concurrent_requests
  445. )
  446. return {
  447. "status": True,
  448. "pdf_extract_images": request.app.state.config.PDF_EXTRACT_IMAGES,
  449. "file": {
  450. "max_size": request.app.state.config.FILE_MAX_SIZE,
  451. "max_count": request.app.state.config.FILE_MAX_COUNT,
  452. },
  453. "content_extraction": {
  454. "engine": request.app.state.config.CONTENT_EXTRACTION_ENGINE,
  455. "tika_server_url": request.app.state.config.TIKA_SERVER_URL,
  456. },
  457. "chunk": {
  458. "text_splitter": request.app.state.config.TEXT_SPLITTER,
  459. "chunk_size": request.app.state.config.CHUNK_SIZE,
  460. "chunk_overlap": request.app.state.config.CHUNK_OVERLAP,
  461. },
  462. "youtube": {
  463. "language": request.app.state.config.YOUTUBE_LOADER_LANGUAGE,
  464. "proxy_url": request.app.state.config.YOUTUBE_LOADER_PROXY_URL,
  465. "translation": request.app.state.YOUTUBE_LOADER_TRANSLATION,
  466. },
  467. "web": {
  468. "web_loader_ssl_verification": request.app.state.config.ENABLE_RAG_WEB_LOADER_SSL_VERIFICATION,
  469. "search": {
  470. "enabled": request.app.state.config.ENABLE_RAG_WEB_SEARCH,
  471. "engine": request.app.state.config.RAG_WEB_SEARCH_ENGINE,
  472. "searxng_query_url": request.app.state.config.SEARXNG_QUERY_URL,
  473. "google_pse_api_key": request.app.state.config.GOOGLE_PSE_API_KEY,
  474. "google_pse_engine_id": request.app.state.config.GOOGLE_PSE_ENGINE_ID,
  475. "brave_search_api_key": request.app.state.config.BRAVE_SEARCH_API_KEY,
  476. "kagi_search_api_key": request.app.state.config.KAGI_SEARCH_API_KEY,
  477. "mojeek_search_api_key": request.app.state.config.MOJEEK_SEARCH_API_KEY,
  478. "serpstack_api_key": request.app.state.config.SERPSTACK_API_KEY,
  479. "serpstack_https": request.app.state.config.SERPSTACK_HTTPS,
  480. "serper_api_key": request.app.state.config.SERPER_API_KEY,
  481. "serply_api_key": request.app.state.config.SERPLY_API_KEY,
  482. "serachapi_api_key": request.app.state.config.SEARCHAPI_API_KEY,
  483. "searchapi_engine": request.app.state.config.SEARCHAPI_ENGINE,
  484. "tavily_api_key": request.app.state.config.TAVILY_API_KEY,
  485. "jina_api_key": request.app.state.config.JINA_API_KEY,
  486. "bing_search_v7_endpoint": request.app.state.config.BING_SEARCH_V7_ENDPOINT,
  487. "bing_search_v7_subscription_key": request.app.state.config.BING_SEARCH_V7_SUBSCRIPTION_KEY,
  488. "result_count": request.app.state.config.RAG_WEB_SEARCH_RESULT_COUNT,
  489. "concurrent_requests": request.app.state.config.RAG_WEB_SEARCH_CONCURRENT_REQUESTS,
  490. },
  491. },
  492. }
  493. @router.get("/template")
  494. async def get_rag_template(request: Request, user=Depends(get_verified_user)):
  495. return {
  496. "status": True,
  497. "template": request.app.state.config.RAG_TEMPLATE,
  498. }
  499. @router.get("/query/settings")
  500. async def get_query_settings(request: Request, user=Depends(get_admin_user)):
  501. return {
  502. "status": True,
  503. "template": request.app.state.config.RAG_TEMPLATE,
  504. "k": request.app.state.config.TOP_K,
  505. "r": request.app.state.config.RELEVANCE_THRESHOLD,
  506. "hybrid": request.app.state.config.ENABLE_RAG_HYBRID_SEARCH,
  507. }
  508. class QuerySettingsForm(BaseModel):
  509. k: Optional[int] = None
  510. r: Optional[float] = None
  511. template: Optional[str] = None
  512. hybrid: Optional[bool] = None
  513. @router.post("/query/settings/update")
  514. async def update_query_settings(
  515. request: Request, form_data: QuerySettingsForm, user=Depends(get_admin_user)
  516. ):
  517. request.app.state.config.RAG_TEMPLATE = form_data.template
  518. request.app.state.config.TOP_K = form_data.k if form_data.k else 4
  519. request.app.state.config.RELEVANCE_THRESHOLD = form_data.r if form_data.r else 0.0
  520. request.app.state.config.ENABLE_RAG_HYBRID_SEARCH = (
  521. form_data.hybrid if form_data.hybrid else False
  522. )
  523. return {
  524. "status": True,
  525. "template": request.app.state.config.RAG_TEMPLATE,
  526. "k": request.app.state.config.TOP_K,
  527. "r": request.app.state.config.RELEVANCE_THRESHOLD,
  528. "hybrid": request.app.state.config.ENABLE_RAG_HYBRID_SEARCH,
  529. }
  530. ####################################
  531. #
  532. # Document process and retrieval
  533. #
  534. ####################################
  535. def save_docs_to_vector_db(
  536. request: Request,
  537. docs,
  538. collection_name,
  539. metadata: Optional[dict] = None,
  540. overwrite: bool = False,
  541. split: bool = True,
  542. add: bool = False,
  543. ) -> bool:
  544. def _get_docs_info(docs: list[Document]) -> str:
  545. docs_info = set()
  546. # Trying to select relevant metadata identifying the document.
  547. for doc in docs:
  548. metadata = getattr(doc, "metadata", {})
  549. doc_name = metadata.get("name", "")
  550. if not doc_name:
  551. doc_name = metadata.get("title", "")
  552. if not doc_name:
  553. doc_name = metadata.get("source", "")
  554. if doc_name:
  555. docs_info.add(doc_name)
  556. return ", ".join(docs_info)
  557. log.info(
  558. f"save_docs_to_vector_db: document {_get_docs_info(docs)} {collection_name}"
  559. )
  560. # Check if entries with the same hash (metadata.hash) already exist
  561. if metadata and "hash" in metadata:
  562. result = VECTOR_DB_CLIENT.query(
  563. collection_name=collection_name,
  564. filter={"hash": metadata["hash"]},
  565. )
  566. if result is not None:
  567. existing_doc_ids = result.ids[0]
  568. if existing_doc_ids:
  569. log.info(f"Document with hash {metadata['hash']} already exists")
  570. raise ValueError(ERROR_MESSAGES.DUPLICATE_CONTENT)
  571. if split:
  572. if request.app.state.config.TEXT_SPLITTER in ["", "character"]:
  573. text_splitter = RecursiveCharacterTextSplitter(
  574. chunk_size=request.app.state.config.CHUNK_SIZE,
  575. chunk_overlap=request.app.state.config.CHUNK_OVERLAP,
  576. add_start_index=True,
  577. )
  578. elif request.app.state.config.TEXT_SPLITTER == "token":
  579. log.info(
  580. f"Using token text splitter: {request.app.state.config.TIKTOKEN_ENCODING_NAME}"
  581. )
  582. tiktoken.get_encoding(str(request.app.state.config.TIKTOKEN_ENCODING_NAME))
  583. text_splitter = TokenTextSplitter(
  584. encoding_name=str(request.app.state.config.TIKTOKEN_ENCODING_NAME),
  585. chunk_size=request.app.state.config.CHUNK_SIZE,
  586. chunk_overlap=request.app.state.config.CHUNK_OVERLAP,
  587. add_start_index=True,
  588. )
  589. else:
  590. raise ValueError(ERROR_MESSAGES.DEFAULT("Invalid text splitter"))
  591. docs = text_splitter.split_documents(docs)
  592. if len(docs) == 0:
  593. raise ValueError(ERROR_MESSAGES.EMPTY_CONTENT)
  594. texts = [doc.page_content for doc in docs]
  595. metadatas = [
  596. {
  597. **doc.metadata,
  598. **(metadata if metadata else {}),
  599. "embedding_config": json.dumps(
  600. {
  601. "engine": request.app.state.config.RAG_EMBEDDING_ENGINE,
  602. "model": request.app.state.config.RAG_EMBEDDING_MODEL,
  603. }
  604. ),
  605. }
  606. for doc in docs
  607. ]
  608. # ChromaDB does not like datetime formats
  609. # for meta-data so convert them to string.
  610. for metadata in metadatas:
  611. for key, value in metadata.items():
  612. if isinstance(value, datetime):
  613. metadata[key] = str(value)
  614. try:
  615. if VECTOR_DB_CLIENT.has_collection(collection_name=collection_name):
  616. log.info(f"collection {collection_name} already exists")
  617. if overwrite:
  618. VECTOR_DB_CLIENT.delete_collection(collection_name=collection_name)
  619. log.info(f"deleting existing collection {collection_name}")
  620. elif add is False:
  621. log.info(
  622. f"collection {collection_name} already exists, overwrite is False and add is False"
  623. )
  624. return True
  625. log.info(f"adding to collection {collection_name}")
  626. embedding_function = get_embedding_function(
  627. request.app.state.config.RAG_EMBEDDING_ENGINE,
  628. request.app.state.config.RAG_EMBEDDING_MODEL,
  629. request.app.state.sentence_transformer_ef,
  630. (
  631. request.app.state.config.OPENAI_API_BASE_URL
  632. if request.app.state.config.RAG_EMBEDDING_ENGINE == "openai"
  633. else request.app.state.config.OLLAMA_BASE_URL
  634. ),
  635. (
  636. request.app.state.config.OPENAI_API_KEY
  637. if request.app.state.config.RAG_EMBEDDING_ENGINE == "openai"
  638. else request.app.state.config.OLLAMA_API_KEY
  639. ),
  640. request.app.state.config.RAG_EMBEDDING_BATCH_SIZE,
  641. )
  642. embeddings = embedding_function(
  643. list(map(lambda x: x.replace("\n", " "), texts))
  644. )
  645. items = [
  646. {
  647. "id": str(uuid.uuid4()),
  648. "text": text,
  649. "vector": embeddings[idx],
  650. "metadata": metadatas[idx],
  651. }
  652. for idx, text in enumerate(texts)
  653. ]
  654. VECTOR_DB_CLIENT.insert(
  655. collection_name=collection_name,
  656. items=items,
  657. )
  658. return True
  659. except Exception as e:
  660. log.exception(e)
  661. raise e
  662. class ProcessFileForm(BaseModel):
  663. file_id: str
  664. content: Optional[str] = None
  665. collection_name: Optional[str] = None
  666. @router.post("/process/file")
  667. def process_file(
  668. request: Request,
  669. form_data: ProcessFileForm,
  670. user=Depends(get_verified_user),
  671. ):
  672. try:
  673. file = Files.get_file_by_id(form_data.file_id)
  674. collection_name = form_data.collection_name
  675. if collection_name is None:
  676. collection_name = f"file-{file.id}"
  677. if form_data.content:
  678. # Update the content in the file
  679. # Usage: /files/{file_id}/data/content/update
  680. VECTOR_DB_CLIENT.delete_collection(collection_name=f"file-{file.id}")
  681. docs = [
  682. Document(
  683. page_content=form_data.content.replace("<br/>", "\n"),
  684. metadata={
  685. **file.meta,
  686. "name": file.filename,
  687. "created_by": file.user_id,
  688. "file_id": file.id,
  689. "source": file.filename,
  690. },
  691. )
  692. ]
  693. text_content = form_data.content
  694. elif form_data.collection_name:
  695. # Check if the file has already been processed and save the content
  696. # Usage: /knowledge/{id}/file/add, /knowledge/{id}/file/update
  697. result = VECTOR_DB_CLIENT.query(
  698. collection_name=f"file-{file.id}", filter={"file_id": file.id}
  699. )
  700. if result is not None and len(result.ids[0]) > 0:
  701. docs = [
  702. Document(
  703. page_content=result.documents[0][idx],
  704. metadata=result.metadatas[0][idx],
  705. )
  706. for idx, id in enumerate(result.ids[0])
  707. ]
  708. else:
  709. docs = [
  710. Document(
  711. page_content=file.data.get("content", ""),
  712. metadata={
  713. **file.meta,
  714. "name": file.filename,
  715. "created_by": file.user_id,
  716. "file_id": file.id,
  717. "source": file.filename,
  718. },
  719. )
  720. ]
  721. text_content = file.data.get("content", "")
  722. else:
  723. # Process the file and save the content
  724. # Usage: /files/
  725. file_path = file.path
  726. if file_path:
  727. file_path = Storage.get_file(file_path)
  728. loader = Loader(
  729. engine=request.app.state.config.CONTENT_EXTRACTION_ENGINE,
  730. TIKA_SERVER_URL=request.app.state.config.TIKA_SERVER_URL,
  731. PDF_EXTRACT_IMAGES=request.app.state.config.PDF_EXTRACT_IMAGES,
  732. )
  733. docs = loader.load(
  734. file.filename, file.meta.get("content_type"), file_path
  735. )
  736. docs = [
  737. Document(
  738. page_content=doc.page_content,
  739. metadata={
  740. **doc.metadata,
  741. "name": file.filename,
  742. "created_by": file.user_id,
  743. "file_id": file.id,
  744. "source": file.filename,
  745. },
  746. )
  747. for doc in docs
  748. ]
  749. else:
  750. docs = [
  751. Document(
  752. page_content=file.data.get("content", ""),
  753. metadata={
  754. **file.meta,
  755. "name": file.filename,
  756. "created_by": file.user_id,
  757. "file_id": file.id,
  758. "source": file.filename,
  759. },
  760. )
  761. ]
  762. text_content = " ".join([doc.page_content for doc in docs])
  763. log.debug(f"text_content: {text_content}")
  764. Files.update_file_data_by_id(
  765. file.id,
  766. {"content": text_content},
  767. )
  768. hash = calculate_sha256_string(text_content)
  769. Files.update_file_hash_by_id(file.id, hash)
  770. try:
  771. result = save_docs_to_vector_db(
  772. request,
  773. docs=docs,
  774. collection_name=collection_name,
  775. metadata={
  776. "file_id": file.id,
  777. "name": file.filename,
  778. "hash": hash,
  779. },
  780. add=(True if form_data.collection_name else False),
  781. )
  782. if result:
  783. Files.update_file_metadata_by_id(
  784. file.id,
  785. {
  786. "collection_name": collection_name,
  787. },
  788. )
  789. return {
  790. "status": True,
  791. "collection_name": collection_name,
  792. "filename": file.filename,
  793. "content": text_content,
  794. }
  795. except Exception as e:
  796. raise e
  797. except Exception as e:
  798. log.exception(e)
  799. if "No pandoc was found" in str(e):
  800. raise HTTPException(
  801. status_code=status.HTTP_400_BAD_REQUEST,
  802. detail=ERROR_MESSAGES.PANDOC_NOT_INSTALLED,
  803. )
  804. else:
  805. raise HTTPException(
  806. status_code=status.HTTP_400_BAD_REQUEST,
  807. detail=str(e),
  808. )
  809. class ProcessTextForm(BaseModel):
  810. name: str
  811. content: str
  812. collection_name: Optional[str] = None
  813. @router.post("/process/text")
  814. def process_text(
  815. request: Request,
  816. form_data: ProcessTextForm,
  817. user=Depends(get_verified_user),
  818. ):
  819. collection_name = form_data.collection_name
  820. if collection_name is None:
  821. collection_name = calculate_sha256_string(form_data.content)
  822. docs = [
  823. Document(
  824. page_content=form_data.content,
  825. metadata={"name": form_data.name, "created_by": user.id},
  826. )
  827. ]
  828. text_content = form_data.content
  829. log.debug(f"text_content: {text_content}")
  830. result = save_docs_to_vector_db(request, docs, collection_name)
  831. if result:
  832. return {
  833. "status": True,
  834. "collection_name": collection_name,
  835. "content": text_content,
  836. }
  837. else:
  838. raise HTTPException(
  839. status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
  840. detail=ERROR_MESSAGES.DEFAULT(),
  841. )
  842. @router.post("/process/youtube")
  843. def process_youtube_video(
  844. request: Request, form_data: ProcessUrlForm, user=Depends(get_verified_user)
  845. ):
  846. try:
  847. collection_name = form_data.collection_name
  848. if not collection_name:
  849. collection_name = calculate_sha256_string(form_data.url)[:63]
  850. loader = YoutubeLoader(
  851. form_data.url,
  852. language=request.app.state.config.YOUTUBE_LOADER_LANGUAGE,
  853. proxy_url=request.app.state.config.YOUTUBE_LOADER_PROXY_URL,
  854. )
  855. docs = loader.load()
  856. content = " ".join([doc.page_content for doc in docs])
  857. log.debug(f"text_content: {content}")
  858. save_docs_to_vector_db(request, docs, collection_name, overwrite=True)
  859. return {
  860. "status": True,
  861. "collection_name": collection_name,
  862. "filename": form_data.url,
  863. "file": {
  864. "data": {
  865. "content": content,
  866. },
  867. "meta": {
  868. "name": form_data.url,
  869. },
  870. },
  871. }
  872. except Exception as e:
  873. log.exception(e)
  874. raise HTTPException(
  875. status_code=status.HTTP_400_BAD_REQUEST,
  876. detail=ERROR_MESSAGES.DEFAULT(e),
  877. )
  878. @router.post("/process/web")
  879. def process_web(
  880. request: Request, form_data: ProcessUrlForm, user=Depends(get_verified_user)
  881. ):
  882. try:
  883. collection_name = form_data.collection_name
  884. if not collection_name:
  885. collection_name = calculate_sha256_string(form_data.url)[:63]
  886. loader = get_web_loader(
  887. form_data.url,
  888. verify_ssl=request.app.state.config.ENABLE_RAG_WEB_LOADER_SSL_VERIFICATION,
  889. requests_per_second=request.app.state.config.RAG_WEB_SEARCH_CONCURRENT_REQUESTS,
  890. )
  891. docs = loader.load()
  892. content = " ".join([doc.page_content for doc in docs])
  893. log.debug(f"text_content: {content}")
  894. save_docs_to_vector_db(request, docs, collection_name, overwrite=True)
  895. return {
  896. "status": True,
  897. "collection_name": collection_name,
  898. "filename": form_data.url,
  899. "file": {
  900. "data": {
  901. "content": content,
  902. },
  903. "meta": {
  904. "name": form_data.url,
  905. },
  906. },
  907. }
  908. except Exception as e:
  909. log.exception(e)
  910. raise HTTPException(
  911. status_code=status.HTTP_400_BAD_REQUEST,
  912. detail=ERROR_MESSAGES.DEFAULT(e),
  913. )
  914. def search_web(request: Request, engine: str, query: str) -> list[SearchResult]:
  915. """Search the web using a search engine and return the results as a list of SearchResult objects.
  916. Will look for a search engine API key in environment variables in the following order:
  917. - SEARXNG_QUERY_URL
  918. - GOOGLE_PSE_API_KEY + GOOGLE_PSE_ENGINE_ID
  919. - BRAVE_SEARCH_API_KEY
  920. - KAGI_SEARCH_API_KEY
  921. - MOJEEK_SEARCH_API_KEY
  922. - SERPSTACK_API_KEY
  923. - SERPER_API_KEY
  924. - SERPLY_API_KEY
  925. - TAVILY_API_KEY
  926. - SEARCHAPI_API_KEY + SEARCHAPI_ENGINE (by default `google`)
  927. Args:
  928. query (str): The query to search for
  929. """
  930. # TODO: add playwright to search the web
  931. if engine == "searxng":
  932. if request.app.state.config.SEARXNG_QUERY_URL:
  933. return search_searxng(
  934. request.app.state.config.SEARXNG_QUERY_URL,
  935. query,
  936. request.app.state.config.RAG_WEB_SEARCH_RESULT_COUNT,
  937. request.app.state.config.RAG_WEB_SEARCH_DOMAIN_FILTER_LIST,
  938. )
  939. else:
  940. raise Exception("No SEARXNG_QUERY_URL found in environment variables")
  941. elif engine == "google_pse":
  942. if (
  943. request.app.state.config.GOOGLE_PSE_API_KEY
  944. and request.app.state.config.GOOGLE_PSE_ENGINE_ID
  945. ):
  946. return search_google_pse(
  947. request.app.state.config.GOOGLE_PSE_API_KEY,
  948. request.app.state.config.GOOGLE_PSE_ENGINE_ID,
  949. query,
  950. request.app.state.config.RAG_WEB_SEARCH_RESULT_COUNT,
  951. request.app.state.config.RAG_WEB_SEARCH_DOMAIN_FILTER_LIST,
  952. )
  953. else:
  954. raise Exception(
  955. "No GOOGLE_PSE_API_KEY or GOOGLE_PSE_ENGINE_ID found in environment variables"
  956. )
  957. elif engine == "brave":
  958. if request.app.state.config.BRAVE_SEARCH_API_KEY:
  959. return search_brave(
  960. request.app.state.config.BRAVE_SEARCH_API_KEY,
  961. query,
  962. request.app.state.config.RAG_WEB_SEARCH_RESULT_COUNT,
  963. request.app.state.config.RAG_WEB_SEARCH_DOMAIN_FILTER_LIST,
  964. )
  965. else:
  966. raise Exception("No BRAVE_SEARCH_API_KEY found in environment variables")
  967. elif engine == "kagi":
  968. if request.app.state.config.KAGI_SEARCH_API_KEY:
  969. return search_kagi(
  970. request.app.state.config.KAGI_SEARCH_API_KEY,
  971. query,
  972. request.app.state.config.RAG_WEB_SEARCH_RESULT_COUNT,
  973. request.app.state.config.RAG_WEB_SEARCH_DOMAIN_FILTER_LIST,
  974. )
  975. else:
  976. raise Exception("No KAGI_SEARCH_API_KEY found in environment variables")
  977. elif engine == "mojeek":
  978. if request.app.state.config.MOJEEK_SEARCH_API_KEY:
  979. return search_mojeek(
  980. request.app.state.config.MOJEEK_SEARCH_API_KEY,
  981. query,
  982. request.app.state.config.RAG_WEB_SEARCH_RESULT_COUNT,
  983. request.app.state.config.RAG_WEB_SEARCH_DOMAIN_FILTER_LIST,
  984. )
  985. else:
  986. raise Exception("No MOJEEK_SEARCH_API_KEY found in environment variables")
  987. elif engine == "serpstack":
  988. if request.app.state.config.SERPSTACK_API_KEY:
  989. return search_serpstack(
  990. request.app.state.config.SERPSTACK_API_KEY,
  991. query,
  992. request.app.state.config.RAG_WEB_SEARCH_RESULT_COUNT,
  993. request.app.state.config.RAG_WEB_SEARCH_DOMAIN_FILTER_LIST,
  994. https_enabled=request.app.state.config.SERPSTACK_HTTPS,
  995. )
  996. else:
  997. raise Exception("No SERPSTACK_API_KEY found in environment variables")
  998. elif engine == "serper":
  999. if request.app.state.config.SERPER_API_KEY:
  1000. return search_serper(
  1001. request.app.state.config.SERPER_API_KEY,
  1002. query,
  1003. request.app.state.config.RAG_WEB_SEARCH_RESULT_COUNT,
  1004. request.app.state.config.RAG_WEB_SEARCH_DOMAIN_FILTER_LIST,
  1005. )
  1006. else:
  1007. raise Exception("No SERPER_API_KEY found in environment variables")
  1008. elif engine == "serply":
  1009. if request.app.state.config.SERPLY_API_KEY:
  1010. return search_serply(
  1011. request.app.state.config.SERPLY_API_KEY,
  1012. query,
  1013. request.app.state.config.RAG_WEB_SEARCH_RESULT_COUNT,
  1014. request.app.state.config.RAG_WEB_SEARCH_DOMAIN_FILTER_LIST,
  1015. )
  1016. else:
  1017. raise Exception("No SERPLY_API_KEY found in environment variables")
  1018. elif engine == "duckduckgo":
  1019. return search_duckduckgo(
  1020. query,
  1021. request.app.state.config.RAG_WEB_SEARCH_RESULT_COUNT,
  1022. request.app.state.config.RAG_WEB_SEARCH_DOMAIN_FILTER_LIST,
  1023. )
  1024. elif engine == "tavily":
  1025. if request.app.state.config.TAVILY_API_KEY:
  1026. return search_tavily(
  1027. request.app.state.config.TAVILY_API_KEY,
  1028. query,
  1029. request.app.state.config.RAG_WEB_SEARCH_RESULT_COUNT,
  1030. )
  1031. else:
  1032. raise Exception("No TAVILY_API_KEY found in environment variables")
  1033. elif engine == "searchapi":
  1034. if request.app.state.config.SEARCHAPI_API_KEY:
  1035. return search_searchapi(
  1036. request.app.state.config.SEARCHAPI_API_KEY,
  1037. request.app.state.config.SEARCHAPI_ENGINE,
  1038. query,
  1039. request.app.state.config.RAG_WEB_SEARCH_RESULT_COUNT,
  1040. request.app.state.config.RAG_WEB_SEARCH_DOMAIN_FILTER_LIST,
  1041. )
  1042. else:
  1043. raise Exception("No SEARCHAPI_API_KEY found in environment variables")
  1044. elif engine == "jina":
  1045. return search_jina(
  1046. request.app.state.config.JINA_API_KEY,
  1047. query,
  1048. request.app.state.config.RAG_WEB_SEARCH_RESULT_COUNT,
  1049. )
  1050. elif engine == "bing":
  1051. return search_bing(
  1052. request.app.state.config.BING_SEARCH_V7_SUBSCRIPTION_KEY,
  1053. request.app.state.config.BING_SEARCH_V7_ENDPOINT,
  1054. str(DEFAULT_LOCALE),
  1055. query,
  1056. request.app.state.config.RAG_WEB_SEARCH_RESULT_COUNT,
  1057. request.app.state.config.RAG_WEB_SEARCH_DOMAIN_FILTER_LIST,
  1058. )
  1059. else:
  1060. raise Exception("No search engine API key found in environment variables")
  1061. @router.post("/process/web/search")
  1062. def process_web_search(
  1063. request: Request, form_data: SearchForm, user=Depends(get_verified_user)
  1064. ):
  1065. try:
  1066. logging.info(
  1067. f"trying to web search with {request.app.state.config.RAG_WEB_SEARCH_ENGINE, form_data.query}"
  1068. )
  1069. web_results = search_web(
  1070. request, request.app.state.config.RAG_WEB_SEARCH_ENGINE, form_data.query
  1071. )
  1072. except Exception as e:
  1073. log.exception(e)
  1074. raise HTTPException(
  1075. status_code=status.HTTP_400_BAD_REQUEST,
  1076. detail=ERROR_MESSAGES.WEB_SEARCH_ERROR(e),
  1077. )
  1078. try:
  1079. collection_name = form_data.collection_name
  1080. if collection_name == "":
  1081. collection_name = f"web-search-{calculate_sha256_string(form_data.query)}"[
  1082. :63
  1083. ]
  1084. urls = [result.link for result in web_results]
  1085. loader = get_web_loader(
  1086. urls=urls,
  1087. verify_ssl=request.app.state.config.ENABLE_RAG_WEB_LOADER_SSL_VERIFICATION,
  1088. requests_per_second=request.app.state.config.RAG_WEB_SEARCH_CONCURRENT_REQUESTS,
  1089. )
  1090. docs = loader.aload()
  1091. save_docs_to_vector_db(request, docs, collection_name, overwrite=True)
  1092. return {
  1093. "status": True,
  1094. "collection_name": collection_name,
  1095. "filenames": urls,
  1096. }
  1097. except Exception as e:
  1098. log.exception(e)
  1099. raise HTTPException(
  1100. status_code=status.HTTP_400_BAD_REQUEST,
  1101. detail=ERROR_MESSAGES.DEFAULT(e),
  1102. )
  1103. class QueryDocForm(BaseModel):
  1104. collection_name: str
  1105. query: str
  1106. k: Optional[int] = None
  1107. r: Optional[float] = None
  1108. hybrid: Optional[bool] = None
  1109. @router.post("/query/doc")
  1110. def query_doc_handler(
  1111. request: Request,
  1112. form_data: QueryDocForm,
  1113. user=Depends(get_verified_user),
  1114. ):
  1115. try:
  1116. if request.app.state.config.ENABLE_RAG_HYBRID_SEARCH:
  1117. return query_doc_with_hybrid_search(
  1118. collection_name=form_data.collection_name,
  1119. query=form_data.query,
  1120. embedding_function=request.app.state.EMBEDDING_FUNCTION,
  1121. k=form_data.k if form_data.k else request.app.state.config.TOP_K,
  1122. reranking_function=request.app.state.sentence_transformer_rf,
  1123. r=(
  1124. form_data.r
  1125. if form_data.r
  1126. else request.app.state.config.RELEVANCE_THRESHOLD
  1127. ),
  1128. )
  1129. else:
  1130. return query_doc(
  1131. collection_name=form_data.collection_name,
  1132. query_embedding=request.app.state.EMBEDDING_FUNCTION(form_data.query),
  1133. k=form_data.k if form_data.k else request.app.state.config.TOP_K,
  1134. )
  1135. except Exception as e:
  1136. log.exception(e)
  1137. raise HTTPException(
  1138. status_code=status.HTTP_400_BAD_REQUEST,
  1139. detail=ERROR_MESSAGES.DEFAULT(e),
  1140. )
  1141. class QueryCollectionsForm(BaseModel):
  1142. collection_names: list[str]
  1143. query: str
  1144. k: Optional[int] = None
  1145. r: Optional[float] = None
  1146. hybrid: Optional[bool] = None
  1147. @router.post("/query/collection")
  1148. def query_collection_handler(
  1149. request: Request,
  1150. form_data: QueryCollectionsForm,
  1151. user=Depends(get_verified_user),
  1152. ):
  1153. try:
  1154. if request.app.state.config.ENABLE_RAG_HYBRID_SEARCH:
  1155. return query_collection_with_hybrid_search(
  1156. collection_names=form_data.collection_names,
  1157. queries=[form_data.query],
  1158. embedding_function=request.app.state.EMBEDDING_FUNCTION,
  1159. k=form_data.k if form_data.k else request.app.state.config.TOP_K,
  1160. reranking_function=request.app.state.sentence_transformer_rf,
  1161. r=(
  1162. form_data.r
  1163. if form_data.r
  1164. else request.app.state.config.RELEVANCE_THRESHOLD
  1165. ),
  1166. )
  1167. else:
  1168. return query_collection(
  1169. collection_names=form_data.collection_names,
  1170. queries=[form_data.query],
  1171. embedding_function=request.app.state.EMBEDDING_FUNCTION,
  1172. k=form_data.k if form_data.k else request.app.state.config.TOP_K,
  1173. )
  1174. except Exception as e:
  1175. log.exception(e)
  1176. raise HTTPException(
  1177. status_code=status.HTTP_400_BAD_REQUEST,
  1178. detail=ERROR_MESSAGES.DEFAULT(e),
  1179. )
  1180. ####################################
  1181. #
  1182. # Vector DB operations
  1183. #
  1184. ####################################
  1185. class DeleteForm(BaseModel):
  1186. collection_name: str
  1187. file_id: str
  1188. @router.post("/delete")
  1189. def delete_entries_from_collection(form_data: DeleteForm, user=Depends(get_admin_user)):
  1190. try:
  1191. if VECTOR_DB_CLIENT.has_collection(collection_name=form_data.collection_name):
  1192. file = Files.get_file_by_id(form_data.file_id)
  1193. hash = file.hash
  1194. VECTOR_DB_CLIENT.delete(
  1195. collection_name=form_data.collection_name,
  1196. metadata={"hash": hash},
  1197. )
  1198. return {"status": True}
  1199. else:
  1200. return {"status": False}
  1201. except Exception as e:
  1202. log.exception(e)
  1203. return {"status": False}
  1204. @router.post("/reset/db")
  1205. def reset_vector_db(user=Depends(get_admin_user)):
  1206. VECTOR_DB_CLIENT.reset()
  1207. Knowledges.delete_all_knowledge()
  1208. @router.post("/reset/uploads")
  1209. def reset_upload_dir(user=Depends(get_admin_user)) -> bool:
  1210. folder = f"{UPLOAD_DIR}"
  1211. try:
  1212. # Check if the directory exists
  1213. if os.path.exists(folder):
  1214. # Iterate over all the files and directories in the specified directory
  1215. for filename in os.listdir(folder):
  1216. file_path = os.path.join(folder, filename)
  1217. try:
  1218. if os.path.isfile(file_path) or os.path.islink(file_path):
  1219. os.unlink(file_path) # Remove the file or link
  1220. elif os.path.isdir(file_path):
  1221. shutil.rmtree(file_path) # Remove the directory
  1222. except Exception as e:
  1223. print(f"Failed to delete {file_path}. Reason: {e}")
  1224. else:
  1225. print(f"The directory {folder} does not exist")
  1226. except Exception as e:
  1227. print(f"Failed to process the directory {folder}. Reason: {e}")
  1228. return True
  1229. if ENV == "dev":
  1230. @router.get("/ef/{text}")
  1231. async def get_embeddings(request: Request, text: Optional[str] = "Hello World!"):
  1232. return {"result": request.app.state.EMBEDDING_FUNCTION(text)}