12345678910111213141516171819202122232425262728293031323334353637383940414243444546474849505152535455565758596061626364656667686970717273747576777879808182838485868788899091929394959697989910010110210310410510610710810911011111211311411511611711811912012112212312412512612712812913013113213313413513613713813914014114214314414514614714814915015115215315415515615715815916016116216316416516616716816917017117217317417517617717817918018118218318418518618718818919019119219319419519619719819920020120220320420520620720820921021121221321421521621721821922022122222322422522622722822923023123223323423523623723823924024124224324424524624724824925025125225325425525625725825926026126226326426526626726826927027127227327427527627727827928028128228328428528628728828929029129229329429529629729829930030130230330430530630730830931031131231331431531631731831932032132232332432532632732832933033133233333433533633733833934034134234334434534634734834935035135235335435535635735835936036136236336436536636736836937037137237337437537637737837938038138238338438538638738838939039139239339439539639739839940040140240340440540640740840941041141241341441541641741841942042142242342442542642742842943043143243343443543643743843944044144244344444544644744844945045145245345445545645745845946046146246346446546646746846947047147247347447547647747847948048148248348448548648748848949049149249349449549649749849950050150250350450550650750850951051151251351451551651751851952052152252352452552652752852953053153253353453553653753853954054154254354454554654754854955055155255355455555655755855956056156256356456556656756856957057157257357457557657757857958058158258358458558658758858959059159259359459559659759859960060160260360460560660760860961061161261361461561661761861962062162262362462562662762862963063163263363463563663763863964064164264364464564664764864965065165265365465565665765865966066166266366466566666766866967067167267367467567667767867968068168268368468568668768868969069169269369469569669769869970070170270370470570670770870971071171271371471571671771871972072172272372472572672772872973073173273373473573673773873974074174274374474574674774874975075175275375475575675775875976076176276376476576676776876977077177277377477577677777877978078178278378478578678778878979079179279379479579679779879980080180280380480580680780880981081181281381481581681781881982082182282382482582682782882983083183283383483583683783883984084184284384484584684784884985085185285385485585685785885986086186286386486586686786886987087187287387487587687787887988088188288388488588688788888989089189289389489589689789889990090190290390490590690790890991091191291391491591691791891992092192292392492592692792892993093193293393493593693793893994094194294394494594694794894995095195295395495595695795895996096196296396496596696796896997097197297397497597697797897998098198298398498598698798898999099199299399499599699799899910001001100210031004100510061007100810091010101110121013101410151016101710181019102010211022102310241025102610271028102910301031103210331034103510361037103810391040104110421043104410451046104710481049105010511052105310541055105610571058105910601061106210631064106510661067106810691070107110721073107410751076107710781079108010811082108310841085108610871088108910901091109210931094109510961097109810991100110111021103110411051106110711081109111011111112111311141115111611171118111911201121112211231124112511261127112811291130113111321133113411351136113711381139114011411142114311441145114611471148114911501151115211531154115511561157115811591160116111621163116411651166116711681169117011711172117311741175117611771178117911801181118211831184118511861187118811891190119111921193119411951196119711981199120012011202120312041205120612071208120912101211121212131214121512161217121812191220122112221223122412251226122712281229123012311232123312341235123612371238123912401241124212431244124512461247124812491250125112521253125412551256125712581259126012611262126312641265126612671268126912701271127212731274127512761277127812791280128112821283128412851286128712881289129012911292129312941295129612971298129913001301130213031304130513061307130813091310131113121313131413151316131713181319132013211322132313241325132613271328132913301331133213331334133513361337133813391340134113421343134413451346134713481349135013511352135313541355135613571358135913601361136213631364136513661367136813691370137113721373137413751376137713781379138013811382138313841385138613871388138913901391139213931394139513961397139813991400140114021403140414051406140714081409141014111412141314141415141614171418141914201421142214231424142514261427142814291430143114321433143414351436143714381439144014411442144314441445144614471448144914501451145214531454145514561457145814591460146114621463146414651466146714681469147014711472147314741475147614771478147914801481148214831484148514861487148814891490149114921493149414951496149714981499150015011502150315041505150615071508150915101511151215131514151515161517151815191520152115221523152415251526152715281529153015311532153315341535153615371538153915401541154215431544154515461547154815491550155115521553155415551556155715581559156015611562156315641565156615671568156915701571157215731574157515761577157815791580158115821583158415851586158715881589159015911592159315941595159615971598159916001601160216031604160516061607160816091610161116121613161416151616161716181619162016211622162316241625162616271628162916301631163216331634163516361637163816391640164116421643164416451646164716481649165016511652165316541655165616571658165916601661166216631664166516661667166816691670167116721673167416751676167716781679168016811682168316841685168616871688168916901691169216931694169516961697169816991700170117021703170417051706170717081709171017111712171317141715171617171718171917201721172217231724172517261727172817291730173117321733173417351736173717381739 |
- import json
- import logging
- import mimetypes
- import os
- import shutil
- import uuid
- from datetime import datetime
- from pathlib import Path
- from typing import Iterator, List, Optional, Sequence, Union
- from fastapi import (
- Depends,
- FastAPI,
- File,
- Form,
- HTTPException,
- UploadFile,
- Request,
- status,
- APIRouter,
- )
- from fastapi.middleware.cors import CORSMiddleware
- from fastapi.concurrency import run_in_threadpool
- from pydantic import BaseModel
- import tiktoken
- from langchain.text_splitter import RecursiveCharacterTextSplitter, TokenTextSplitter
- from langchain_core.documents import Document
- from open_webui.models.files import FileModel, Files
- from open_webui.models.knowledge import Knowledges
- from open_webui.storage.provider import Storage
- from open_webui.retrieval.vector.connector import VECTOR_DB_CLIENT
- # Document loaders
- from open_webui.retrieval.loaders.main import Loader
- from open_webui.retrieval.loaders.youtube import YoutubeLoader
- # Web search engines
- from open_webui.retrieval.web.main import SearchResult
- from open_webui.retrieval.web.utils import get_web_loader
- from open_webui.retrieval.web.brave import search_brave
- from open_webui.retrieval.web.kagi import search_kagi
- from open_webui.retrieval.web.mojeek import search_mojeek
- from open_webui.retrieval.web.bocha import search_bocha
- from open_webui.retrieval.web.duckduckgo import search_duckduckgo
- from open_webui.retrieval.web.google_pse import search_google_pse
- from open_webui.retrieval.web.jina_search import search_jina
- from open_webui.retrieval.web.searchapi import search_searchapi
- from open_webui.retrieval.web.serpapi import search_serpapi
- from open_webui.retrieval.web.searxng import search_searxng
- from open_webui.retrieval.web.serper import search_serper
- from open_webui.retrieval.web.serply import search_serply
- from open_webui.retrieval.web.serpstack import search_serpstack
- from open_webui.retrieval.web.tavily import search_tavily
- from open_webui.retrieval.web.bing import search_bing
- from open_webui.retrieval.web.exa import search_exa
- from open_webui.retrieval.web.perplexity import search_perplexity
- from open_webui.retrieval.utils import (
- get_embedding_function,
- get_model_path,
- query_collection,
- query_collection_with_hybrid_search,
- query_doc,
- query_doc_with_hybrid_search,
- )
- from open_webui.utils.misc import (
- calculate_sha256_string,
- )
- from open_webui.utils.auth import get_admin_user, get_verified_user
- from open_webui.config import (
- ENV,
- RAG_EMBEDDING_MODEL_AUTO_UPDATE,
- RAG_EMBEDDING_MODEL_TRUST_REMOTE_CODE,
- RAG_RERANKING_MODEL_AUTO_UPDATE,
- RAG_RERANKING_MODEL_TRUST_REMOTE_CODE,
- UPLOAD_DIR,
- DEFAULT_LOCALE,
- )
- from open_webui.env import (
- SRC_LOG_LEVELS,
- DEVICE_TYPE,
- DOCKER,
- )
- from open_webui.constants import ERROR_MESSAGES
- log = logging.getLogger(__name__)
- log.setLevel(SRC_LOG_LEVELS["RAG"])
- ##########################################
- #
- # Utility functions
- #
- ##########################################
- def get_ef(
- engine: str,
- embedding_model: str,
- auto_update: bool = False,
- ):
- ef = None
- if embedding_model and engine == "":
- from sentence_transformers import SentenceTransformer
- try:
- ef = SentenceTransformer(
- get_model_path(embedding_model, auto_update),
- device=DEVICE_TYPE,
- trust_remote_code=RAG_EMBEDDING_MODEL_TRUST_REMOTE_CODE,
- )
- except Exception as e:
- log.debug(f"Error loading SentenceTransformer: {e}")
- return ef
- def get_rf(
- reranking_model: str,
- auto_update: bool = False,
- ):
- rf = None
- if reranking_model:
- if any(model in reranking_model for model in ["jinaai/jina-colbert-v2"]):
- try:
- from open_webui.retrieval.models.colbert import ColBERT
- rf = ColBERT(
- get_model_path(reranking_model, auto_update),
- env="docker" if DOCKER else None,
- )
- except Exception as e:
- log.error(f"ColBERT: {e}")
- raise Exception(ERROR_MESSAGES.DEFAULT(e))
- else:
- import sentence_transformers
- try:
- rf = sentence_transformers.CrossEncoder(
- get_model_path(reranking_model, auto_update),
- device=DEVICE_TYPE,
- trust_remote_code=RAG_RERANKING_MODEL_TRUST_REMOTE_CODE,
- )
- except:
- log.error("CrossEncoder error")
- raise Exception(ERROR_MESSAGES.DEFAULT("CrossEncoder error"))
- return rf
- ##########################################
- #
- # API routes
- #
- ##########################################
- router = APIRouter()
- class CollectionNameForm(BaseModel):
- collection_name: Optional[str] = None
- class ProcessUrlForm(CollectionNameForm):
- url: str
- class SearchForm(CollectionNameForm):
- query: str
- @router.get("/")
- async def get_status(request: Request):
- return {
- "status": True,
- "chunk_size": request.app.state.config.CHUNK_SIZE,
- "chunk_overlap": request.app.state.config.CHUNK_OVERLAP,
- "template": request.app.state.config.RAG_TEMPLATE,
- "embedding_engine": request.app.state.config.RAG_EMBEDDING_ENGINE,
- "embedding_model": request.app.state.config.RAG_EMBEDDING_MODEL,
- "reranking_model": request.app.state.config.RAG_RERANKING_MODEL,
- "embedding_batch_size": request.app.state.config.RAG_EMBEDDING_BATCH_SIZE,
- }
- @router.get("/embedding")
- async def get_embedding_config(request: Request, user=Depends(get_admin_user)):
- return {
- "status": True,
- "embedding_engine": request.app.state.config.RAG_EMBEDDING_ENGINE,
- "embedding_model": request.app.state.config.RAG_EMBEDDING_MODEL,
- "embedding_batch_size": request.app.state.config.RAG_EMBEDDING_BATCH_SIZE,
- "openai_config": {
- "url": request.app.state.config.RAG_OPENAI_API_BASE_URL,
- "key": request.app.state.config.RAG_OPENAI_API_KEY,
- },
- "ollama_config": {
- "url": request.app.state.config.RAG_OLLAMA_BASE_URL,
- "key": request.app.state.config.RAG_OLLAMA_API_KEY,
- },
- }
- @router.get("/reranking")
- async def get_reraanking_config(request: Request, user=Depends(get_admin_user)):
- return {
- "status": True,
- "reranking_model": request.app.state.config.RAG_RERANKING_MODEL,
- }
- class OpenAIConfigForm(BaseModel):
- url: str
- key: str
- class OllamaConfigForm(BaseModel):
- url: str
- key: str
- class EmbeddingModelUpdateForm(BaseModel):
- openai_config: Optional[OpenAIConfigForm] = None
- ollama_config: Optional[OllamaConfigForm] = None
- embedding_engine: str
- embedding_model: str
- embedding_batch_size: Optional[int] = 1
- @router.post("/embedding/update")
- async def update_embedding_config(
- request: Request, form_data: EmbeddingModelUpdateForm, user=Depends(get_admin_user)
- ):
- log.info(
- f"Updating embedding model: {request.app.state.config.RAG_EMBEDDING_MODEL} to {form_data.embedding_model}"
- )
- try:
- request.app.state.config.RAG_EMBEDDING_ENGINE = form_data.embedding_engine
- request.app.state.config.RAG_EMBEDDING_MODEL = form_data.embedding_model
- if request.app.state.config.RAG_EMBEDDING_ENGINE in ["ollama", "openai"]:
- if form_data.openai_config is not None:
- request.app.state.config.RAG_OPENAI_API_BASE_URL = (
- form_data.openai_config.url
- )
- request.app.state.config.RAG_OPENAI_API_KEY = (
- form_data.openai_config.key
- )
- if form_data.ollama_config is not None:
- request.app.state.config.RAG_OLLAMA_BASE_URL = (
- form_data.ollama_config.url
- )
- request.app.state.config.RAG_OLLAMA_API_KEY = (
- form_data.ollama_config.key
- )
- request.app.state.config.RAG_EMBEDDING_BATCH_SIZE = (
- form_data.embedding_batch_size
- )
- request.app.state.ef = get_ef(
- request.app.state.config.RAG_EMBEDDING_ENGINE,
- request.app.state.config.RAG_EMBEDDING_MODEL,
- )
- request.app.state.EMBEDDING_FUNCTION = get_embedding_function(
- request.app.state.config.RAG_EMBEDDING_ENGINE,
- request.app.state.config.RAG_EMBEDDING_MODEL,
- request.app.state.ef,
- (
- request.app.state.config.RAG_OPENAI_API_BASE_URL
- if request.app.state.config.RAG_EMBEDDING_ENGINE == "openai"
- else request.app.state.config.RAG_OLLAMA_BASE_URL
- ),
- (
- request.app.state.config.RAG_OPENAI_API_KEY
- if request.app.state.config.RAG_EMBEDDING_ENGINE == "openai"
- else request.app.state.config.RAG_OLLAMA_API_KEY
- ),
- request.app.state.config.RAG_EMBEDDING_BATCH_SIZE,
- )
- return {
- "status": True,
- "embedding_engine": request.app.state.config.RAG_EMBEDDING_ENGINE,
- "embedding_model": request.app.state.config.RAG_EMBEDDING_MODEL,
- "embedding_batch_size": request.app.state.config.RAG_EMBEDDING_BATCH_SIZE,
- "openai_config": {
- "url": request.app.state.config.RAG_OPENAI_API_BASE_URL,
- "key": request.app.state.config.RAG_OPENAI_API_KEY,
- },
- "ollama_config": {
- "url": request.app.state.config.RAG_OLLAMA_BASE_URL,
- "key": request.app.state.config.RAG_OLLAMA_API_KEY,
- },
- }
- except Exception as e:
- log.exception(f"Problem updating embedding model: {e}")
- raise HTTPException(
- status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
- detail=ERROR_MESSAGES.DEFAULT(e),
- )
- class RerankingModelUpdateForm(BaseModel):
- reranking_model: str
- @router.post("/reranking/update")
- async def update_reranking_config(
- request: Request, form_data: RerankingModelUpdateForm, user=Depends(get_admin_user)
- ):
- log.info(
- f"Updating reranking model: {request.app.state.config.RAG_RERANKING_MODEL} to {form_data.reranking_model}"
- )
- try:
- request.app.state.config.RAG_RERANKING_MODEL = form_data.reranking_model
- try:
- request.app.state.rf = get_rf(
- request.app.state.config.RAG_RERANKING_MODEL,
- True,
- )
- except Exception as e:
- log.error(f"Error loading reranking model: {e}")
- request.app.state.config.ENABLE_RAG_HYBRID_SEARCH = False
- return {
- "status": True,
- "reranking_model": request.app.state.config.RAG_RERANKING_MODEL,
- }
- except Exception as e:
- log.exception(f"Problem updating reranking model: {e}")
- raise HTTPException(
- status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
- detail=ERROR_MESSAGES.DEFAULT(e),
- )
- @router.get("/config")
- async def get_rag_config(request: Request, user=Depends(get_admin_user)):
- return {
- "status": True,
- "pdf_extract_images": request.app.state.config.PDF_EXTRACT_IMAGES,
- "RAG_FULL_CONTEXT": request.app.state.config.RAG_FULL_CONTEXT,
- "BYPASS_EMBEDDING_AND_RETRIEVAL": request.app.state.config.BYPASS_EMBEDDING_AND_RETRIEVAL,
- "enable_google_drive_integration": request.app.state.config.ENABLE_GOOGLE_DRIVE_INTEGRATION,
- "enable_onedrive_integration": request.app.state.config.ENABLE_ONEDRIVE_INTEGRATION,
- "content_extraction": {
- "engine": request.app.state.config.CONTENT_EXTRACTION_ENGINE,
- "tika_server_url": request.app.state.config.TIKA_SERVER_URL,
- "document_intelligence_config": {
- "endpoint": request.app.state.config.DOCUMENT_INTELLIGENCE_ENDPOINT,
- "key": request.app.state.config.DOCUMENT_INTELLIGENCE_KEY,
- },
- },
- "chunk": {
- "text_splitter": request.app.state.config.TEXT_SPLITTER,
- "chunk_size": request.app.state.config.CHUNK_SIZE,
- "chunk_overlap": request.app.state.config.CHUNK_OVERLAP,
- },
- "file": {
- "max_size": request.app.state.config.FILE_MAX_SIZE,
- "max_count": request.app.state.config.FILE_MAX_COUNT,
- },
- "youtube": {
- "language": request.app.state.config.YOUTUBE_LOADER_LANGUAGE,
- "translation": request.app.state.YOUTUBE_LOADER_TRANSLATION,
- "proxy_url": request.app.state.config.YOUTUBE_LOADER_PROXY_URL,
- },
- "web": {
- "ENABLE_RAG_WEB_LOADER_SSL_VERIFICATION": request.app.state.config.ENABLE_RAG_WEB_LOADER_SSL_VERIFICATION,
- "BYPASS_WEB_SEARCH_EMBEDDING_AND_RETRIEVAL": request.app.state.config.BYPASS_WEB_SEARCH_EMBEDDING_AND_RETRIEVAL,
- "search": {
- "enabled": request.app.state.config.ENABLE_RAG_WEB_SEARCH,
- "drive": request.app.state.config.ENABLE_GOOGLE_DRIVE_INTEGRATION,
- "onedrive": request.app.state.config.ENABLE_ONEDRIVE_INTEGRATION,
- "engine": request.app.state.config.RAG_WEB_SEARCH_ENGINE,
- "searxng_query_url": request.app.state.config.SEARXNG_QUERY_URL,
- "google_pse_api_key": request.app.state.config.GOOGLE_PSE_API_KEY,
- "google_pse_engine_id": request.app.state.config.GOOGLE_PSE_ENGINE_ID,
- "brave_search_api_key": request.app.state.config.BRAVE_SEARCH_API_KEY,
- "kagi_search_api_key": request.app.state.config.KAGI_SEARCH_API_KEY,
- "mojeek_search_api_key": request.app.state.config.MOJEEK_SEARCH_API_KEY,
- "bocha_search_api_key": request.app.state.config.BOCHA_SEARCH_API_KEY,
- "serpstack_api_key": request.app.state.config.SERPSTACK_API_KEY,
- "serpstack_https": request.app.state.config.SERPSTACK_HTTPS,
- "serper_api_key": request.app.state.config.SERPER_API_KEY,
- "serply_api_key": request.app.state.config.SERPLY_API_KEY,
- "tavily_api_key": request.app.state.config.TAVILY_API_KEY,
- "searchapi_api_key": request.app.state.config.SEARCHAPI_API_KEY,
- "searchapi_engine": request.app.state.config.SEARCHAPI_ENGINE,
- "serpapi_api_key": request.app.state.config.SERPAPI_API_KEY,
- "serpapi_engine": request.app.state.config.SERPAPI_ENGINE,
- "jina_api_key": request.app.state.config.JINA_API_KEY,
- "bing_search_v7_endpoint": request.app.state.config.BING_SEARCH_V7_ENDPOINT,
- "bing_search_v7_subscription_key": request.app.state.config.BING_SEARCH_V7_SUBSCRIPTION_KEY,
- "exa_api_key": request.app.state.config.EXA_API_KEY,
- "perplexity_api_key": request.app.state.config.PERPLEXITY_API_KEY,
- "result_count": request.app.state.config.RAG_WEB_SEARCH_RESULT_COUNT,
- "trust_env": request.app.state.config.RAG_WEB_SEARCH_TRUST_ENV,
- "concurrent_requests": request.app.state.config.RAG_WEB_SEARCH_CONCURRENT_REQUESTS,
- "domain_filter_list": request.app.state.config.RAG_WEB_SEARCH_DOMAIN_FILTER_LIST,
- },
- },
- }
- class FileConfig(BaseModel):
- max_size: Optional[int] = None
- max_count: Optional[int] = None
- class DocumentIntelligenceConfigForm(BaseModel):
- endpoint: str
- key: str
- class ContentExtractionConfig(BaseModel):
- engine: str = ""
- tika_server_url: Optional[str] = None
- document_intelligence_config: Optional[DocumentIntelligenceConfigForm] = None
- class ChunkParamUpdateForm(BaseModel):
- text_splitter: Optional[str] = None
- chunk_size: int
- chunk_overlap: int
- class YoutubeLoaderConfig(BaseModel):
- language: list[str]
- translation: Optional[str] = None
- proxy_url: str = ""
- class WebSearchConfig(BaseModel):
- enabled: bool
- engine: Optional[str] = None
- searxng_query_url: Optional[str] = None
- google_pse_api_key: Optional[str] = None
- google_pse_engine_id: Optional[str] = None
- brave_search_api_key: Optional[str] = None
- kagi_search_api_key: Optional[str] = None
- mojeek_search_api_key: Optional[str] = None
- bocha_search_api_key: Optional[str] = None
- serpstack_api_key: Optional[str] = None
- serpstack_https: Optional[bool] = None
- serper_api_key: Optional[str] = None
- serply_api_key: Optional[str] = None
- tavily_api_key: Optional[str] = None
- searchapi_api_key: Optional[str] = None
- searchapi_engine: Optional[str] = None
- serpapi_api_key: Optional[str] = None
- serpapi_engine: Optional[str] = None
- jina_api_key: Optional[str] = None
- bing_search_v7_endpoint: Optional[str] = None
- bing_search_v7_subscription_key: Optional[str] = None
- exa_api_key: Optional[str] = None
- perplexity_api_key: Optional[str] = None
- result_count: Optional[int] = None
- concurrent_requests: Optional[int] = None
- trust_env: Optional[bool] = None
- domain_filter_list: Optional[List[str]] = []
- class WebConfig(BaseModel):
- search: WebSearchConfig
- ENABLE_RAG_WEB_LOADER_SSL_VERIFICATION: Optional[bool] = None
- BYPASS_WEB_SEARCH_EMBEDDING_AND_RETRIEVAL: Optional[bool] = None
- class ConfigUpdateForm(BaseModel):
- RAG_FULL_CONTEXT: Optional[bool] = None
- BYPASS_EMBEDDING_AND_RETRIEVAL: Optional[bool] = None
- pdf_extract_images: Optional[bool] = None
- enable_google_drive_integration: Optional[bool] = None
- enable_onedrive_integration: Optional[bool] = None
- file: Optional[FileConfig] = None
- content_extraction: Optional[ContentExtractionConfig] = None
- chunk: Optional[ChunkParamUpdateForm] = None
- youtube: Optional[YoutubeLoaderConfig] = None
- web: Optional[WebConfig] = None
- @router.post("/config/update")
- async def update_rag_config(
- request: Request, form_data: ConfigUpdateForm, user=Depends(get_admin_user)
- ):
- request.app.state.config.PDF_EXTRACT_IMAGES = (
- form_data.pdf_extract_images
- if form_data.pdf_extract_images is not None
- else request.app.state.config.PDF_EXTRACT_IMAGES
- )
- request.app.state.config.RAG_FULL_CONTEXT = (
- form_data.RAG_FULL_CONTEXT
- if form_data.RAG_FULL_CONTEXT is not None
- else request.app.state.config.RAG_FULL_CONTEXT
- )
- request.app.state.config.BYPASS_EMBEDDING_AND_RETRIEVAL = (
- form_data.BYPASS_EMBEDDING_AND_RETRIEVAL
- if form_data.BYPASS_EMBEDDING_AND_RETRIEVAL is not None
- else request.app.state.config.BYPASS_EMBEDDING_AND_RETRIEVAL
- )
- request.app.state.config.ENABLE_GOOGLE_DRIVE_INTEGRATION = (
- form_data.enable_google_drive_integration
- if form_data.enable_google_drive_integration is not None
- else request.app.state.config.ENABLE_GOOGLE_DRIVE_INTEGRATION
- )
- request.app.state.config.ENABLE_ONEDRIVE_INTEGRATION = (
- form_data.enable_onedrive_integration
- if form_data.enable_onedrive_integration is not None
- else request.app.state.config.ENABLE_ONEDRIVE_INTEGRATION
- )
- if form_data.file is not None:
- request.app.state.config.FILE_MAX_SIZE = form_data.file.max_size
- request.app.state.config.FILE_MAX_COUNT = form_data.file.max_count
- if form_data.content_extraction is not None:
- log.info(
- f"Updating content extraction: {request.app.state.config.CONTENT_EXTRACTION_ENGINE} to {form_data.content_extraction.engine}"
- )
- request.app.state.config.CONTENT_EXTRACTION_ENGINE = (
- form_data.content_extraction.engine
- )
- request.app.state.config.TIKA_SERVER_URL = (
- form_data.content_extraction.tika_server_url
- )
- if form_data.content_extraction.document_intelligence_config is not None:
- request.app.state.config.DOCUMENT_INTELLIGENCE_ENDPOINT = (
- form_data.content_extraction.document_intelligence_config.endpoint
- )
- request.app.state.config.DOCUMENT_INTELLIGENCE_KEY = (
- form_data.content_extraction.document_intelligence_config.key
- )
- if form_data.chunk is not None:
- request.app.state.config.TEXT_SPLITTER = form_data.chunk.text_splitter
- request.app.state.config.CHUNK_SIZE = form_data.chunk.chunk_size
- request.app.state.config.CHUNK_OVERLAP = form_data.chunk.chunk_overlap
- if form_data.youtube is not None:
- request.app.state.config.YOUTUBE_LOADER_LANGUAGE = form_data.youtube.language
- request.app.state.config.YOUTUBE_LOADER_PROXY_URL = form_data.youtube.proxy_url
- request.app.state.YOUTUBE_LOADER_TRANSLATION = form_data.youtube.translation
- if form_data.web is not None:
- request.app.state.config.ENABLE_RAG_WEB_LOADER_SSL_VERIFICATION = (
- # Note: When UI "Bypass SSL verification for Websites"=True then ENABLE_RAG_WEB_LOADER_SSL_VERIFICATION=False
- form_data.web.ENABLE_RAG_WEB_LOADER_SSL_VERIFICATION
- )
- request.app.state.config.ENABLE_RAG_WEB_SEARCH = form_data.web.search.enabled
- request.app.state.config.RAG_WEB_SEARCH_ENGINE = form_data.web.search.engine
- request.app.state.config.BYPASS_WEB_SEARCH_EMBEDDING_AND_RETRIEVAL = (
- form_data.web.BYPASS_WEB_SEARCH_EMBEDDING_AND_RETRIEVAL
- )
- request.app.state.config.SEARXNG_QUERY_URL = (
- form_data.web.search.searxng_query_url
- )
- request.app.state.config.GOOGLE_PSE_API_KEY = (
- form_data.web.search.google_pse_api_key
- )
- request.app.state.config.GOOGLE_PSE_ENGINE_ID = (
- form_data.web.search.google_pse_engine_id
- )
- request.app.state.config.BRAVE_SEARCH_API_KEY = (
- form_data.web.search.brave_search_api_key
- )
- request.app.state.config.KAGI_SEARCH_API_KEY = (
- form_data.web.search.kagi_search_api_key
- )
- request.app.state.config.MOJEEK_SEARCH_API_KEY = (
- form_data.web.search.mojeek_search_api_key
- )
- request.app.state.config.BOCHA_SEARCH_API_KEY = (
- form_data.web.search.bocha_search_api_key
- )
- request.app.state.config.SERPSTACK_API_KEY = (
- form_data.web.search.serpstack_api_key
- )
- request.app.state.config.SERPSTACK_HTTPS = form_data.web.search.serpstack_https
- request.app.state.config.SERPER_API_KEY = form_data.web.search.serper_api_key
- request.app.state.config.SERPLY_API_KEY = form_data.web.search.serply_api_key
- request.app.state.config.TAVILY_API_KEY = form_data.web.search.tavily_api_key
- request.app.state.config.SEARCHAPI_API_KEY = (
- form_data.web.search.searchapi_api_key
- )
- request.app.state.config.SEARCHAPI_ENGINE = (
- form_data.web.search.searchapi_engine
- )
- request.app.state.config.SERPAPI_API_KEY = form_data.web.search.serpapi_api_key
- request.app.state.config.SERPAPI_ENGINE = form_data.web.search.serpapi_engine
- request.app.state.config.JINA_API_KEY = form_data.web.search.jina_api_key
- request.app.state.config.BING_SEARCH_V7_ENDPOINT = (
- form_data.web.search.bing_search_v7_endpoint
- )
- request.app.state.config.BING_SEARCH_V7_SUBSCRIPTION_KEY = (
- form_data.web.search.bing_search_v7_subscription_key
- )
- request.app.state.config.EXA_API_KEY = form_data.web.search.exa_api_key
- request.app.state.config.PERPLEXITY_API_KEY = (
- form_data.web.search.perplexity_api_key
- )
- request.app.state.config.RAG_WEB_SEARCH_RESULT_COUNT = (
- form_data.web.search.result_count
- )
- request.app.state.config.RAG_WEB_SEARCH_CONCURRENT_REQUESTS = (
- form_data.web.search.concurrent_requests
- )
- request.app.state.config.RAG_WEB_SEARCH_TRUST_ENV = (
- form_data.web.search.trust_env
- )
- request.app.state.config.RAG_WEB_SEARCH_DOMAIN_FILTER_LIST = (
- form_data.web.search.domain_filter_list
- )
- return {
- "status": True,
- "pdf_extract_images": request.app.state.config.PDF_EXTRACT_IMAGES,
- "RAG_FULL_CONTEXT": request.app.state.config.RAG_FULL_CONTEXT,
- "BYPASS_EMBEDDING_AND_RETRIEVAL": request.app.state.config.BYPASS_EMBEDDING_AND_RETRIEVAL,
- "file": {
- "max_size": request.app.state.config.FILE_MAX_SIZE,
- "max_count": request.app.state.config.FILE_MAX_COUNT,
- },
- "content_extraction": {
- "engine": request.app.state.config.CONTENT_EXTRACTION_ENGINE,
- "tika_server_url": request.app.state.config.TIKA_SERVER_URL,
- "document_intelligence_config": {
- "endpoint": request.app.state.config.DOCUMENT_INTELLIGENCE_ENDPOINT,
- "key": request.app.state.config.DOCUMENT_INTELLIGENCE_KEY,
- },
- },
- "chunk": {
- "text_splitter": request.app.state.config.TEXT_SPLITTER,
- "chunk_size": request.app.state.config.CHUNK_SIZE,
- "chunk_overlap": request.app.state.config.CHUNK_OVERLAP,
- },
- "youtube": {
- "language": request.app.state.config.YOUTUBE_LOADER_LANGUAGE,
- "proxy_url": request.app.state.config.YOUTUBE_LOADER_PROXY_URL,
- "translation": request.app.state.YOUTUBE_LOADER_TRANSLATION,
- },
- "web": {
- "ENABLE_RAG_WEB_LOADER_SSL_VERIFICATION": request.app.state.config.ENABLE_RAG_WEB_LOADER_SSL_VERIFICATION,
- "BYPASS_WEB_SEARCH_EMBEDDING_AND_RETRIEVAL": request.app.state.config.BYPASS_WEB_SEARCH_EMBEDDING_AND_RETRIEVAL,
- "search": {
- "enabled": request.app.state.config.ENABLE_RAG_WEB_SEARCH,
- "engine": request.app.state.config.RAG_WEB_SEARCH_ENGINE,
- "searxng_query_url": request.app.state.config.SEARXNG_QUERY_URL,
- "google_pse_api_key": request.app.state.config.GOOGLE_PSE_API_KEY,
- "google_pse_engine_id": request.app.state.config.GOOGLE_PSE_ENGINE_ID,
- "brave_search_api_key": request.app.state.config.BRAVE_SEARCH_API_KEY,
- "kagi_search_api_key": request.app.state.config.KAGI_SEARCH_API_KEY,
- "mojeek_search_api_key": request.app.state.config.MOJEEK_SEARCH_API_KEY,
- "bocha_search_api_key": request.app.state.config.BOCHA_SEARCH_API_KEY,
- "serpstack_api_key": request.app.state.config.SERPSTACK_API_KEY,
- "serpstack_https": request.app.state.config.SERPSTACK_HTTPS,
- "serper_api_key": request.app.state.config.SERPER_API_KEY,
- "serply_api_key": request.app.state.config.SERPLY_API_KEY,
- "serachapi_api_key": request.app.state.config.SEARCHAPI_API_KEY,
- "searchapi_engine": request.app.state.config.SEARCHAPI_ENGINE,
- "serpapi_api_key": request.app.state.config.SERPAPI_API_KEY,
- "serpapi_engine": request.app.state.config.SERPAPI_ENGINE,
- "tavily_api_key": request.app.state.config.TAVILY_API_KEY,
- "jina_api_key": request.app.state.config.JINA_API_KEY,
- "bing_search_v7_endpoint": request.app.state.config.BING_SEARCH_V7_ENDPOINT,
- "bing_search_v7_subscription_key": request.app.state.config.BING_SEARCH_V7_SUBSCRIPTION_KEY,
- "exa_api_key": request.app.state.config.EXA_API_KEY,
- "perplexity_api_key": request.app.state.config.PERPLEXITY_API_KEY,
- "result_count": request.app.state.config.RAG_WEB_SEARCH_RESULT_COUNT,
- "concurrent_requests": request.app.state.config.RAG_WEB_SEARCH_CONCURRENT_REQUESTS,
- "trust_env": request.app.state.config.RAG_WEB_SEARCH_TRUST_ENV,
- "domain_filter_list": request.app.state.config.RAG_WEB_SEARCH_DOMAIN_FILTER_LIST,
- },
- },
- }
- @router.get("/template")
- async def get_rag_template(request: Request, user=Depends(get_verified_user)):
- return {
- "status": True,
- "template": request.app.state.config.RAG_TEMPLATE,
- }
- @router.get("/query/settings")
- async def get_query_settings(request: Request, user=Depends(get_admin_user)):
- return {
- "status": True,
- "template": request.app.state.config.RAG_TEMPLATE,
- "k": request.app.state.config.TOP_K,
- "r": request.app.state.config.RELEVANCE_THRESHOLD,
- "hybrid": request.app.state.config.ENABLE_RAG_HYBRID_SEARCH,
- }
- class QuerySettingsForm(BaseModel):
- k: Optional[int] = None
- r: Optional[float] = None
- template: Optional[str] = None
- hybrid: Optional[bool] = None
- @router.post("/query/settings/update")
- async def update_query_settings(
- request: Request, form_data: QuerySettingsForm, user=Depends(get_admin_user)
- ):
- request.app.state.config.RAG_TEMPLATE = form_data.template
- request.app.state.config.TOP_K = form_data.k if form_data.k else 4
- request.app.state.config.RELEVANCE_THRESHOLD = form_data.r if form_data.r else 0.0
- request.app.state.config.ENABLE_RAG_HYBRID_SEARCH = (
- form_data.hybrid if form_data.hybrid else False
- )
- return {
- "status": True,
- "template": request.app.state.config.RAG_TEMPLATE,
- "k": request.app.state.config.TOP_K,
- "r": request.app.state.config.RELEVANCE_THRESHOLD,
- "hybrid": request.app.state.config.ENABLE_RAG_HYBRID_SEARCH,
- }
- ####################################
- #
- # Document process and retrieval
- #
- ####################################
- def save_docs_to_vector_db(
- request: Request,
- docs,
- collection_name,
- metadata: Optional[dict] = None,
- overwrite: bool = False,
- split: bool = True,
- add: bool = False,
- user=None,
- ) -> bool:
- def _get_docs_info(docs: list[Document]) -> str:
- docs_info = set()
- # Trying to select relevant metadata identifying the document.
- for doc in docs:
- metadata = getattr(doc, "metadata", {})
- doc_name = metadata.get("name", "")
- if not doc_name:
- doc_name = metadata.get("title", "")
- if not doc_name:
- doc_name = metadata.get("source", "")
- if doc_name:
- docs_info.add(doc_name)
- return ", ".join(docs_info)
- log.info(
- f"save_docs_to_vector_db: document {_get_docs_info(docs)} {collection_name}"
- )
- # Check if entries with the same hash (metadata.hash) already exist
- if metadata and "hash" in metadata:
- result = VECTOR_DB_CLIENT.query(
- collection_name=collection_name,
- filter={"hash": metadata["hash"]},
- )
- if result is not None:
- existing_doc_ids = result.ids[0]
- if existing_doc_ids:
- log.info(f"Document with hash {metadata['hash']} already exists")
- raise ValueError(ERROR_MESSAGES.DUPLICATE_CONTENT)
- if split:
- if request.app.state.config.TEXT_SPLITTER in ["", "character"]:
- text_splitter = RecursiveCharacterTextSplitter(
- chunk_size=request.app.state.config.CHUNK_SIZE,
- chunk_overlap=request.app.state.config.CHUNK_OVERLAP,
- add_start_index=True,
- )
- elif request.app.state.config.TEXT_SPLITTER == "token":
- log.info(
- f"Using token text splitter: {request.app.state.config.TIKTOKEN_ENCODING_NAME}"
- )
- tiktoken.get_encoding(str(request.app.state.config.TIKTOKEN_ENCODING_NAME))
- text_splitter = TokenTextSplitter(
- encoding_name=str(request.app.state.config.TIKTOKEN_ENCODING_NAME),
- chunk_size=request.app.state.config.CHUNK_SIZE,
- chunk_overlap=request.app.state.config.CHUNK_OVERLAP,
- add_start_index=True,
- )
- else:
- raise ValueError(ERROR_MESSAGES.DEFAULT("Invalid text splitter"))
- docs = text_splitter.split_documents(docs)
- if len(docs) == 0:
- raise ValueError(ERROR_MESSAGES.EMPTY_CONTENT)
- texts = [doc.page_content for doc in docs]
- metadatas = [
- {
- **doc.metadata,
- **(metadata if metadata else {}),
- "embedding_config": json.dumps(
- {
- "engine": request.app.state.config.RAG_EMBEDDING_ENGINE,
- "model": request.app.state.config.RAG_EMBEDDING_MODEL,
- }
- ),
- }
- for doc in docs
- ]
- # ChromaDB does not like datetime formats
- # for meta-data so convert them to string.
- for metadata in metadatas:
- for key, value in metadata.items():
- if (
- isinstance(value, datetime)
- or isinstance(value, list)
- or isinstance(value, dict)
- ):
- metadata[key] = str(value)
- try:
- if VECTOR_DB_CLIENT.has_collection(collection_name=collection_name):
- log.info(f"collection {collection_name} already exists")
- if overwrite:
- VECTOR_DB_CLIENT.delete_collection(collection_name=collection_name)
- log.info(f"deleting existing collection {collection_name}")
- elif add is False:
- log.info(
- f"collection {collection_name} already exists, overwrite is False and add is False"
- )
- return True
- log.info(f"adding to collection {collection_name}")
- embedding_function = get_embedding_function(
- request.app.state.config.RAG_EMBEDDING_ENGINE,
- request.app.state.config.RAG_EMBEDDING_MODEL,
- request.app.state.ef,
- (
- request.app.state.config.RAG_OPENAI_API_BASE_URL
- if request.app.state.config.RAG_EMBEDDING_ENGINE == "openai"
- else request.app.state.config.RAG_OLLAMA_BASE_URL
- ),
- (
- request.app.state.config.RAG_OPENAI_API_KEY
- if request.app.state.config.RAG_EMBEDDING_ENGINE == "openai"
- else request.app.state.config.RAG_OLLAMA_API_KEY
- ),
- request.app.state.config.RAG_EMBEDDING_BATCH_SIZE,
- )
- embeddings = embedding_function(
- list(map(lambda x: x.replace("\n", " "), texts)), user=user
- )
- items = [
- {
- "id": str(uuid.uuid4()),
- "text": text,
- "vector": embeddings[idx],
- "metadata": metadatas[idx],
- }
- for idx, text in enumerate(texts)
- ]
- VECTOR_DB_CLIENT.insert(
- collection_name=collection_name,
- items=items,
- )
- return True
- except Exception as e:
- log.exception(e)
- raise e
- class ProcessFileForm(BaseModel):
- file_id: str
- content: Optional[str] = None
- collection_name: Optional[str] = None
- @router.post("/process/file")
- def process_file(
- request: Request,
- form_data: ProcessFileForm,
- user=Depends(get_verified_user),
- ):
- try:
- file = Files.get_file_by_id(form_data.file_id)
- collection_name = form_data.collection_name
- if collection_name is None:
- collection_name = f"file-{file.id}"
- if form_data.content:
- # Update the content in the file
- # Usage: /files/{file_id}/data/content/update
- try:
- # /files/{file_id}/data/content/update
- VECTOR_DB_CLIENT.delete_collection(collection_name=f"file-{file.id}")
- except:
- # Audio file upload pipeline
- pass
- docs = [
- Document(
- page_content=form_data.content.replace("<br/>", "\n"),
- metadata={
- **file.meta,
- "name": file.filename,
- "created_by": file.user_id,
- "file_id": file.id,
- "source": file.filename,
- },
- )
- ]
- text_content = form_data.content
- elif form_data.collection_name:
- # Check if the file has already been processed and save the content
- # Usage: /knowledge/{id}/file/add, /knowledge/{id}/file/update
- result = VECTOR_DB_CLIENT.query(
- collection_name=f"file-{file.id}", filter={"file_id": file.id}
- )
- if result is not None and len(result.ids[0]) > 0:
- docs = [
- Document(
- page_content=result.documents[0][idx],
- metadata=result.metadatas[0][idx],
- )
- for idx, id in enumerate(result.ids[0])
- ]
- else:
- docs = [
- Document(
- page_content=file.data.get("content", ""),
- metadata={
- **file.meta,
- "name": file.filename,
- "created_by": file.user_id,
- "file_id": file.id,
- "source": file.filename,
- },
- )
- ]
- text_content = file.data.get("content", "")
- else:
- # Process the file and save the content
- # Usage: /files/
- file_path = file.path
- if file_path:
- file_path = Storage.get_file(file_path)
- loader = Loader(
- engine=request.app.state.config.CONTENT_EXTRACTION_ENGINE,
- TIKA_SERVER_URL=request.app.state.config.TIKA_SERVER_URL,
- PDF_EXTRACT_IMAGES=request.app.state.config.PDF_EXTRACT_IMAGES,
- DOCUMENT_INTELLIGENCE_ENDPOINT=request.app.state.config.DOCUMENT_INTELLIGENCE_ENDPOINT,
- DOCUMENT_INTELLIGENCE_KEY=request.app.state.config.DOCUMENT_INTELLIGENCE_KEY,
- )
- docs = loader.load(
- file.filename, file.meta.get("content_type"), file_path
- )
- docs = [
- Document(
- page_content=doc.page_content,
- metadata={
- **doc.metadata,
- "name": file.filename,
- "created_by": file.user_id,
- "file_id": file.id,
- "source": file.filename,
- },
- )
- for doc in docs
- ]
- else:
- docs = [
- Document(
- page_content=file.data.get("content", ""),
- metadata={
- **file.meta,
- "name": file.filename,
- "created_by": file.user_id,
- "file_id": file.id,
- "source": file.filename,
- },
- )
- ]
- text_content = " ".join([doc.page_content for doc in docs])
- log.debug(f"text_content: {text_content}")
- Files.update_file_data_by_id(
- file.id,
- {"content": text_content},
- )
- hash = calculate_sha256_string(text_content)
- Files.update_file_hash_by_id(file.id, hash)
- if not request.app.state.config.BYPASS_EMBEDDING_AND_RETRIEVAL:
- try:
- result = save_docs_to_vector_db(
- request,
- docs=docs,
- collection_name=collection_name,
- metadata={
- "file_id": file.id,
- "name": file.filename,
- "hash": hash,
- },
- add=(True if form_data.collection_name else False),
- user=user,
- )
- if result:
- Files.update_file_metadata_by_id(
- file.id,
- {
- "collection_name": collection_name,
- },
- )
- return {
- "status": True,
- "collection_name": collection_name,
- "filename": file.filename,
- "content": text_content,
- }
- except Exception as e:
- raise e
- else:
- return {
- "status": True,
- "collection_name": None,
- "filename": file.filename,
- "content": text_content,
- }
- except Exception as e:
- log.exception(e)
- if "No pandoc was found" in str(e):
- raise HTTPException(
- status_code=status.HTTP_400_BAD_REQUEST,
- detail=ERROR_MESSAGES.PANDOC_NOT_INSTALLED,
- )
- else:
- raise HTTPException(
- status_code=status.HTTP_400_BAD_REQUEST,
- detail=str(e),
- )
- class ProcessTextForm(BaseModel):
- name: str
- content: str
- collection_name: Optional[str] = None
- @router.post("/process/text")
- def process_text(
- request: Request,
- form_data: ProcessTextForm,
- user=Depends(get_verified_user),
- ):
- collection_name = form_data.collection_name
- if collection_name is None:
- collection_name = calculate_sha256_string(form_data.content)
- docs = [
- Document(
- page_content=form_data.content,
- metadata={"name": form_data.name, "created_by": user.id},
- )
- ]
- text_content = form_data.content
- log.debug(f"text_content: {text_content}")
- result = save_docs_to_vector_db(request, docs, collection_name, user=user)
- if result:
- return {
- "status": True,
- "collection_name": collection_name,
- "content": text_content,
- }
- else:
- raise HTTPException(
- status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
- detail=ERROR_MESSAGES.DEFAULT(),
- )
- @router.post("/process/youtube")
- def process_youtube_video(
- request: Request, form_data: ProcessUrlForm, user=Depends(get_verified_user)
- ):
- try:
- collection_name = form_data.collection_name
- if not collection_name:
- collection_name = calculate_sha256_string(form_data.url)[:63]
- loader = YoutubeLoader(
- form_data.url,
- language=request.app.state.config.YOUTUBE_LOADER_LANGUAGE,
- proxy_url=request.app.state.config.YOUTUBE_LOADER_PROXY_URL,
- )
- docs = loader.load()
- content = " ".join([doc.page_content for doc in docs])
- log.debug(f"text_content: {content}")
- save_docs_to_vector_db(
- request, docs, collection_name, overwrite=True, user=user
- )
- return {
- "status": True,
- "collection_name": collection_name,
- "filename": form_data.url,
- "file": {
- "data": {
- "content": content,
- },
- "meta": {
- "name": form_data.url,
- },
- },
- }
- except Exception as e:
- log.exception(e)
- raise HTTPException(
- status_code=status.HTTP_400_BAD_REQUEST,
- detail=ERROR_MESSAGES.DEFAULT(e),
- )
- @router.post("/process/web")
- def process_web(
- request: Request, form_data: ProcessUrlForm, user=Depends(get_verified_user)
- ):
- try:
- collection_name = form_data.collection_name
- if not collection_name:
- collection_name = calculate_sha256_string(form_data.url)[:63]
- loader = get_web_loader(
- form_data.url,
- verify_ssl=request.app.state.config.ENABLE_RAG_WEB_LOADER_SSL_VERIFICATION,
- requests_per_second=request.app.state.config.RAG_WEB_SEARCH_CONCURRENT_REQUESTS,
- )
- docs = loader.load()
- content = " ".join([doc.page_content for doc in docs])
- log.debug(f"text_content: {content}")
- if not request.app.state.config.BYPASS_WEB_SEARCH_EMBEDDING_AND_RETRIEVAL:
- save_docs_to_vector_db(
- request, docs, collection_name, overwrite=True, user=user
- )
- else:
- collection_name = None
- return {
- "status": True,
- "collection_name": collection_name,
- "filename": form_data.url,
- "file": {
- "data": {
- "content": content,
- },
- "meta": {
- "name": form_data.url,
- "source": form_data.url,
- },
- },
- }
- except Exception as e:
- log.exception(e)
- raise HTTPException(
- status_code=status.HTTP_400_BAD_REQUEST,
- detail=ERROR_MESSAGES.DEFAULT(e),
- )
- def search_web(request: Request, engine: str, query: str) -> list[SearchResult]:
- """Search the web using a search engine and return the results as a list of SearchResult objects.
- Will look for a search engine API key in environment variables in the following order:
- - SEARXNG_QUERY_URL
- - GOOGLE_PSE_API_KEY + GOOGLE_PSE_ENGINE_ID
- - BRAVE_SEARCH_API_KEY
- - KAGI_SEARCH_API_KEY
- - MOJEEK_SEARCH_API_KEY
- - BOCHA_SEARCH_API_KEY
- - SERPSTACK_API_KEY
- - SERPER_API_KEY
- - SERPLY_API_KEY
- - TAVILY_API_KEY
- - EXA_API_KEY
- - PERPLEXITY_API_KEY
- - SEARCHAPI_API_KEY + SEARCHAPI_ENGINE (by default `google`)
- - SERPAPI_API_KEY + SERPAPI_ENGINE (by default `google`)
- Args:
- query (str): The query to search for
- """
- # TODO: add playwright to search the web
- if engine == "searxng":
- if request.app.state.config.SEARXNG_QUERY_URL:
- return search_searxng(
- request.app.state.config.SEARXNG_QUERY_URL,
- query,
- request.app.state.config.RAG_WEB_SEARCH_RESULT_COUNT,
- request.app.state.config.RAG_WEB_SEARCH_DOMAIN_FILTER_LIST,
- )
- else:
- raise Exception("No SEARXNG_QUERY_URL found in environment variables")
- elif engine == "google_pse":
- if (
- request.app.state.config.GOOGLE_PSE_API_KEY
- and request.app.state.config.GOOGLE_PSE_ENGINE_ID
- ):
- return search_google_pse(
- request.app.state.config.GOOGLE_PSE_API_KEY,
- request.app.state.config.GOOGLE_PSE_ENGINE_ID,
- query,
- request.app.state.config.RAG_WEB_SEARCH_RESULT_COUNT,
- request.app.state.config.RAG_WEB_SEARCH_DOMAIN_FILTER_LIST,
- )
- else:
- raise Exception(
- "No GOOGLE_PSE_API_KEY or GOOGLE_PSE_ENGINE_ID found in environment variables"
- )
- elif engine == "brave":
- if request.app.state.config.BRAVE_SEARCH_API_KEY:
- return search_brave(
- request.app.state.config.BRAVE_SEARCH_API_KEY,
- query,
- request.app.state.config.RAG_WEB_SEARCH_RESULT_COUNT,
- request.app.state.config.RAG_WEB_SEARCH_DOMAIN_FILTER_LIST,
- )
- else:
- raise Exception("No BRAVE_SEARCH_API_KEY found in environment variables")
- elif engine == "kagi":
- if request.app.state.config.KAGI_SEARCH_API_KEY:
- return search_kagi(
- request.app.state.config.KAGI_SEARCH_API_KEY,
- query,
- request.app.state.config.RAG_WEB_SEARCH_RESULT_COUNT,
- request.app.state.config.RAG_WEB_SEARCH_DOMAIN_FILTER_LIST,
- )
- else:
- raise Exception("No KAGI_SEARCH_API_KEY found in environment variables")
- elif engine == "mojeek":
- if request.app.state.config.MOJEEK_SEARCH_API_KEY:
- return search_mojeek(
- request.app.state.config.MOJEEK_SEARCH_API_KEY,
- query,
- request.app.state.config.RAG_WEB_SEARCH_RESULT_COUNT,
- request.app.state.config.RAG_WEB_SEARCH_DOMAIN_FILTER_LIST,
- )
- else:
- raise Exception("No MOJEEK_SEARCH_API_KEY found in environment variables")
- elif engine == "bocha":
- if request.app.state.config.BOCHA_SEARCH_API_KEY:
- return search_bocha(
- request.app.state.config.BOCHA_SEARCH_API_KEY,
- query,
- request.app.state.config.RAG_WEB_SEARCH_RESULT_COUNT,
- request.app.state.config.RAG_WEB_SEARCH_DOMAIN_FILTER_LIST,
- )
- else:
- raise Exception("No BOCHA_SEARCH_API_KEY found in environment variables")
- elif engine == "serpstack":
- if request.app.state.config.SERPSTACK_API_KEY:
- return search_serpstack(
- request.app.state.config.SERPSTACK_API_KEY,
- query,
- request.app.state.config.RAG_WEB_SEARCH_RESULT_COUNT,
- request.app.state.config.RAG_WEB_SEARCH_DOMAIN_FILTER_LIST,
- https_enabled=request.app.state.config.SERPSTACK_HTTPS,
- )
- else:
- raise Exception("No SERPSTACK_API_KEY found in environment variables")
- elif engine == "serper":
- if request.app.state.config.SERPER_API_KEY:
- return search_serper(
- request.app.state.config.SERPER_API_KEY,
- query,
- request.app.state.config.RAG_WEB_SEARCH_RESULT_COUNT,
- request.app.state.config.RAG_WEB_SEARCH_DOMAIN_FILTER_LIST,
- )
- else:
- raise Exception("No SERPER_API_KEY found in environment variables")
- elif engine == "serply":
- if request.app.state.config.SERPLY_API_KEY:
- return search_serply(
- request.app.state.config.SERPLY_API_KEY,
- query,
- request.app.state.config.RAG_WEB_SEARCH_RESULT_COUNT,
- request.app.state.config.RAG_WEB_SEARCH_DOMAIN_FILTER_LIST,
- )
- else:
- raise Exception("No SERPLY_API_KEY found in environment variables")
- elif engine == "duckduckgo":
- return search_duckduckgo(
- query,
- request.app.state.config.RAG_WEB_SEARCH_RESULT_COUNT,
- request.app.state.config.RAG_WEB_SEARCH_DOMAIN_FILTER_LIST,
- )
- elif engine == "tavily":
- if request.app.state.config.TAVILY_API_KEY:
- return search_tavily(
- request.app.state.config.TAVILY_API_KEY,
- query,
- request.app.state.config.RAG_WEB_SEARCH_RESULT_COUNT,
- request.app.state.config.RAG_WEB_SEARCH_DOMAIN_FILTER_LIST,
- )
- else:
- raise Exception("No TAVILY_API_KEY found in environment variables")
- elif engine == "searchapi":
- if request.app.state.config.SEARCHAPI_API_KEY:
- return search_searchapi(
- request.app.state.config.SEARCHAPI_API_KEY,
- request.app.state.config.SEARCHAPI_ENGINE,
- query,
- request.app.state.config.RAG_WEB_SEARCH_RESULT_COUNT,
- request.app.state.config.RAG_WEB_SEARCH_DOMAIN_FILTER_LIST,
- )
- else:
- raise Exception("No SEARCHAPI_API_KEY found in environment variables")
- elif engine == "serpapi":
- if request.app.state.config.SERPAPI_API_KEY:
- return search_serpapi(
- request.app.state.config.SERPAPI_API_KEY,
- request.app.state.config.SERPAPI_ENGINE,
- query,
- request.app.state.config.RAG_WEB_SEARCH_RESULT_COUNT,
- request.app.state.config.RAG_WEB_SEARCH_DOMAIN_FILTER_LIST,
- )
- else:
- raise Exception("No SERPAPI_API_KEY found in environment variables")
- elif engine == "jina":
- return search_jina(
- request.app.state.config.JINA_API_KEY,
- query,
- request.app.state.config.RAG_WEB_SEARCH_RESULT_COUNT,
- )
- elif engine == "bing":
- return search_bing(
- request.app.state.config.BING_SEARCH_V7_SUBSCRIPTION_KEY,
- request.app.state.config.BING_SEARCH_V7_ENDPOINT,
- str(DEFAULT_LOCALE),
- query,
- request.app.state.config.RAG_WEB_SEARCH_RESULT_COUNT,
- request.app.state.config.RAG_WEB_SEARCH_DOMAIN_FILTER_LIST,
- )
- elif engine == "exa":
- return search_exa(
- request.app.state.config.EXA_API_KEY,
- query,
- request.app.state.config.RAG_WEB_SEARCH_RESULT_COUNT,
- request.app.state.config.RAG_WEB_SEARCH_DOMAIN_FILTER_LIST,
- )
- elif engine == "perplexity":
- return search_perplexity(
- request.app.state.config.PERPLEXITY_API_KEY,
- query,
- request.app.state.config.RAG_WEB_SEARCH_RESULT_COUNT,
- request.app.state.config.RAG_WEB_SEARCH_DOMAIN_FILTER_LIST,
- )
- else:
- raise Exception("No search engine API key found in environment variables")
- @router.post("/process/web/search")
- async def process_web_search(
- request: Request, form_data: SearchForm, user=Depends(get_verified_user)
- ):
- try:
- logging.info(
- f"trying to web search with {request.app.state.config.RAG_WEB_SEARCH_ENGINE, form_data.query}"
- )
- web_results = search_web(
- request, request.app.state.config.RAG_WEB_SEARCH_ENGINE, form_data.query
- )
- except Exception as e:
- log.exception(e)
- raise HTTPException(
- status_code=status.HTTP_400_BAD_REQUEST,
- detail=ERROR_MESSAGES.WEB_SEARCH_ERROR(e),
- )
- log.debug(f"web_results: {web_results}")
- try:
- collection_name = form_data.collection_name
- if collection_name == "" or collection_name is None:
- collection_name = f"web-search-{calculate_sha256_string(form_data.query)}"[
- :63
- ]
- urls = [result.link for result in web_results]
- loader = get_web_loader(
- urls,
- verify_ssl=request.app.state.config.ENABLE_RAG_WEB_LOADER_SSL_VERIFICATION,
- requests_per_second=request.app.state.config.RAG_WEB_SEARCH_CONCURRENT_REQUESTS,
- trust_env=request.app.state.config.RAG_WEB_SEARCH_TRUST_ENV,
- )
- docs = await loader.aload()
- if request.app.state.config.BYPASS_WEB_SEARCH_EMBEDDING_AND_RETRIEVAL:
- return {
- "status": True,
- "collection_name": None,
- "filenames": urls,
- "docs": [
- {
- "content": doc.page_content,
- "metadata": doc.metadata,
- }
- for doc in docs
- ],
- "loaded_count": len(docs),
- }
- else:
- await run_in_threadpool(
- save_docs_to_vector_db,
- request,
- docs,
- collection_name,
- overwrite=True,
- user=user,
- )
- return {
- "status": True,
- "collection_name": collection_name,
- "filenames": urls,
- "loaded_count": len(docs),
- }
- except Exception as e:
- log.exception(e)
- raise HTTPException(
- status_code=status.HTTP_400_BAD_REQUEST,
- detail=ERROR_MESSAGES.DEFAULT(e),
- )
- class QueryDocForm(BaseModel):
- collection_name: str
- query: str
- k: Optional[int] = None
- r: Optional[float] = None
- hybrid: Optional[bool] = None
- @router.post("/query/doc")
- def query_doc_handler(
- request: Request,
- form_data: QueryDocForm,
- user=Depends(get_verified_user),
- ):
- try:
- if request.app.state.config.ENABLE_RAG_HYBRID_SEARCH:
- return query_doc_with_hybrid_search(
- collection_name=form_data.collection_name,
- query=form_data.query,
- embedding_function=lambda query: request.app.state.EMBEDDING_FUNCTION(
- query, user=user
- ),
- k=form_data.k if form_data.k else request.app.state.config.TOP_K,
- reranking_function=request.app.state.rf,
- r=(
- form_data.r
- if form_data.r
- else request.app.state.config.RELEVANCE_THRESHOLD
- ),
- user=user,
- )
- else:
- return query_doc(
- collection_name=form_data.collection_name,
- query_embedding=request.app.state.EMBEDDING_FUNCTION(
- form_data.query, user=user
- ),
- k=form_data.k if form_data.k else request.app.state.config.TOP_K,
- user=user,
- )
- except Exception as e:
- log.exception(e)
- raise HTTPException(
- status_code=status.HTTP_400_BAD_REQUEST,
- detail=ERROR_MESSAGES.DEFAULT(e),
- )
- class QueryCollectionsForm(BaseModel):
- collection_names: list[str]
- query: str
- k: Optional[int] = None
- r: Optional[float] = None
- hybrid: Optional[bool] = None
- @router.post("/query/collection")
- def query_collection_handler(
- request: Request,
- form_data: QueryCollectionsForm,
- user=Depends(get_verified_user),
- ):
- try:
- if request.app.state.config.ENABLE_RAG_HYBRID_SEARCH:
- return query_collection_with_hybrid_search(
- collection_names=form_data.collection_names,
- queries=[form_data.query],
- embedding_function=lambda query: request.app.state.EMBEDDING_FUNCTION(
- query, user=user
- ),
- k=form_data.k if form_data.k else request.app.state.config.TOP_K,
- reranking_function=request.app.state.rf,
- r=(
- form_data.r
- if form_data.r
- else request.app.state.config.RELEVANCE_THRESHOLD
- ),
- )
- else:
- return query_collection(
- collection_names=form_data.collection_names,
- queries=[form_data.query],
- embedding_function=lambda query: request.app.state.EMBEDDING_FUNCTION(
- query, user=user
- ),
- k=form_data.k if form_data.k else request.app.state.config.TOP_K,
- )
- except Exception as e:
- log.exception(e)
- raise HTTPException(
- status_code=status.HTTP_400_BAD_REQUEST,
- detail=ERROR_MESSAGES.DEFAULT(e),
- )
- ####################################
- #
- # Vector DB operations
- #
- ####################################
- class DeleteForm(BaseModel):
- collection_name: str
- file_id: str
- @router.post("/delete")
- def delete_entries_from_collection(form_data: DeleteForm, user=Depends(get_admin_user)):
- try:
- if VECTOR_DB_CLIENT.has_collection(collection_name=form_data.collection_name):
- file = Files.get_file_by_id(form_data.file_id)
- hash = file.hash
- VECTOR_DB_CLIENT.delete(
- collection_name=form_data.collection_name,
- metadata={"hash": hash},
- )
- return {"status": True}
- else:
- return {"status": False}
- except Exception as e:
- log.exception(e)
- return {"status": False}
- @router.post("/reset/db")
- def reset_vector_db(user=Depends(get_admin_user)):
- VECTOR_DB_CLIENT.reset()
- Knowledges.delete_all_knowledge()
- @router.post("/reset/uploads")
- def reset_upload_dir(user=Depends(get_admin_user)) -> bool:
- folder = f"{UPLOAD_DIR}"
- try:
- # Check if the directory exists
- if os.path.exists(folder):
- # Iterate over all the files and directories in the specified directory
- for filename in os.listdir(folder):
- file_path = os.path.join(folder, filename)
- try:
- if os.path.isfile(file_path) or os.path.islink(file_path):
- os.unlink(file_path) # Remove the file or link
- elif os.path.isdir(file_path):
- shutil.rmtree(file_path) # Remove the directory
- except Exception as e:
- log.exception(f"Failed to delete {file_path}. Reason: {e}")
- else:
- log.warning(f"The directory {folder} does not exist")
- except Exception as e:
- log.exception(f"Failed to process the directory {folder}. Reason: {e}")
- return True
- if ENV == "dev":
- @router.get("/ef/{text}")
- async def get_embeddings(request: Request, text: Optional[str] = "Hello World!"):
- return {"result": request.app.state.EMBEDDING_FUNCTION(text)}
- class BatchProcessFilesForm(BaseModel):
- files: List[FileModel]
- collection_name: str
- class BatchProcessFilesResult(BaseModel):
- file_id: str
- status: str
- error: Optional[str] = None
- class BatchProcessFilesResponse(BaseModel):
- results: List[BatchProcessFilesResult]
- errors: List[BatchProcessFilesResult]
- @router.post("/process/files/batch")
- def process_files_batch(
- request: Request,
- form_data: BatchProcessFilesForm,
- user=Depends(get_verified_user),
- ) -> BatchProcessFilesResponse:
- """
- Process a batch of files and save them to the vector database.
- """
- results: List[BatchProcessFilesResult] = []
- errors: List[BatchProcessFilesResult] = []
- collection_name = form_data.collection_name
- # Prepare all documents first
- all_docs: List[Document] = []
- for file in form_data.files:
- try:
- text_content = file.data.get("content", "")
- docs: List[Document] = [
- Document(
- page_content=text_content.replace("<br/>", "\n"),
- metadata={
- **file.meta,
- "name": file.filename,
- "created_by": file.user_id,
- "file_id": file.id,
- "source": file.filename,
- },
- )
- ]
- hash = calculate_sha256_string(text_content)
- Files.update_file_hash_by_id(file.id, hash)
- Files.update_file_data_by_id(file.id, {"content": text_content})
- all_docs.extend(docs)
- results.append(BatchProcessFilesResult(file_id=file.id, status="prepared"))
- except Exception as e:
- log.error(f"process_files_batch: Error processing file {file.id}: {str(e)}")
- errors.append(
- BatchProcessFilesResult(file_id=file.id, status="failed", error=str(e))
- )
- # Save all documents in one batch
- if all_docs:
- try:
- save_docs_to_vector_db(
- request=request,
- docs=all_docs,
- collection_name=collection_name,
- add=True,
- user=user,
- )
- # Update all files with collection name
- for result in results:
- Files.update_file_metadata_by_id(
- result.file_id, {"collection_name": collection_name}
- )
- result.status = "completed"
- except Exception as e:
- log.error(
- f"process_files_batch: Error saving documents to vector DB: {str(e)}"
- )
- for result in results:
- result.status = "failed"
- errors.append(
- BatchProcessFilesResult(file_id=result.file_id, error=str(e))
- )
- return BatchProcessFilesResponse(results=results, errors=errors)
|