12345678910111213141516171819202122232425262728293031323334353637383940414243444546474849505152535455565758596061626364656667686970717273747576777879808182838485868788899091929394959697989910010110210310410510610710810911011111211311411511611711811912012112212312412512612712812913013113213313413513613713813914014114214314414514614714814915015115215315415515615715815916016116216316416516616716816917017117217317417517617717817918018118218318418518618718818919019119219319419519619719819920020120220320420520620720820921021121221321421521621721821922022122222322422522622722822923023123223323423523623723823924024124224324424524624724824925025125225325425525625725825926026126226326426526626726826927027127227327427527627727827928028128228328428528628728828929029129229329429529629729829930030130230330430530630730830931031131231331431531631731831932032132232332432532632732832933033133233333433533633733833934034134234334434534634734834935035135235335435535635735835936036136236336436536636736836937037137237337437537637737837938038138238338438538638738838939039139239339439539639739839940040140240340440540640740840941041141241341441541641741841942042142242342442542642742842943043143243343443543643743843944044144244344444544644744844945045145245345445545645745845946046146246346446546646746846947047147247347447547647747847948048148248348448548648748848949049149249349449549649749849950050150250350450550650750850951051151251351451551651751851952052152252352452552652752852953053153253353453553653753853954054154254354454554654754854955055155255355455555655755855956056156256356456556656756856957057157257357457557657757857958058158258358458558658758858959059159259359459559659759859960060160260360460560660760860961061161261361461561661761861962062162262362462562662762862963063163263363463563663763863964064164264364464564664764864965065165265365465565665765865966066166266366466566666766866967067167267367467567667767867968068168268368468568668768868969069169269369469569669769869970070170270370470570670770870971071171271371471571671771871972072172272372472572672772872973073173273373473573673773873974074174274374474574674774874975075175275375475575675775875976076176276376476576676776876977077177277377477577677777877978078178278378478578678778878979079179279379479579679779879980080180280380480580680780880981081181281381481581681781881982082182282382482582682782882983083183283383483583683783883984084184284384484584684784884985085185285385485585685785885986086186286386486586686786886987087187287387487587687787887988088188288388488588688788888989089189289389489589689789889990090190290390490590690790890991091191291391491591691791891992092192292392492592692792892993093193293393493593693793893994094194294394494594694794894995095195295395495595695795895996096196296396496596696796896997097197297397497597697797897998098198298398498598698798898999099199299399499599699799899910001001100210031004100510061007100810091010101110121013101410151016101710181019102010211022102310241025102610271028102910301031103210331034103510361037103810391040104110421043104410451046104710481049105010511052105310541055105610571058105910601061106210631064106510661067106810691070107110721073107410751076107710781079108010811082108310841085108610871088108910901091109210931094109510961097109810991100110111021103110411051106110711081109111011111112111311141115111611171118111911201121112211231124112511261127112811291130113111321133113411351136113711381139114011411142114311441145114611471148114911501151115211531154115511561157115811591160116111621163116411651166116711681169117011711172117311741175117611771178117911801181118211831184118511861187118811891190119111921193119411951196119711981199120012011202120312041205120612071208120912101211121212131214121512161217121812191220122112221223122412251226122712281229123012311232123312341235123612371238123912401241124212431244124512461247124812491250125112521253125412551256125712581259126012611262126312641265126612671268126912701271127212731274127512761277127812791280128112821283128412851286128712881289129012911292129312941295129612971298129913001301130213031304130513061307130813091310131113121313131413151316131713181319132013211322132313241325132613271328132913301331133213331334133513361337133813391340134113421343134413451346134713481349135013511352135313541355135613571358135913601361136213631364136513661367136813691370137113721373137413751376137713781379138013811382138313841385138613871388138913901391139213931394139513961397139813991400140114021403140414051406140714081409141014111412141314141415141614171418141914201421142214231424142514261427142814291430143114321433143414351436143714381439144014411442144314441445144614471448144914501451145214531454145514561457145814591460146114621463146414651466146714681469147014711472147314741475147614771478147914801481148214831484148514861487148814891490149114921493149414951496149714981499150015011502150315041505150615071508150915101511151215131514151515161517151815191520152115221523152415251526152715281529153015311532153315341535153615371538153915401541154215431544154515461547154815491550155115521553155415551556155715581559156015611562156315641565156615671568156915701571157215731574157515761577157815791580158115821583158415851586158715881589159015911592159315941595159615971598159916001601160216031604160516061607160816091610161116121613161416151616161716181619162016211622162316241625162616271628162916301631163216331634163516361637163816391640164116421643164416451646164716481649165016511652165316541655165616571658165916601661166216631664166516661667166816691670167116721673167416751676167716781679168016811682168316841685168616871688168916901691169216931694169516961697169816991700170117021703170417051706170717081709171017111712171317141715171617171718171917201721172217231724172517261727172817291730173117321733173417351736173717381739174017411742174317441745174617471748174917501751175217531754175517561757175817591760176117621763176417651766176717681769177017711772177317741775177617771778177917801781178217831784178517861787178817891790179117921793179417951796179717981799180018011802180318041805180618071808180918101811181218131814181518161817181818191820182118221823182418251826182718281829183018311832183318341835183618371838183918401841184218431844184518461847184818491850185118521853185418551856185718581859186018611862186318641865186618671868186918701871187218731874187518761877187818791880188118821883188418851886188718881889189018911892189318941895189618971898189919001901190219031904190519061907190819091910191119121913191419151916191719181919192019211922192319241925192619271928192919301931193219331934193519361937193819391940194119421943194419451946194719481949195019511952195319541955195619571958195919601961196219631964196519661967196819691970197119721973197419751976197719781979198019811982198319841985198619871988198919901991199219931994199519961997199819992000200120022003200420052006200720082009201020112012201320142015201620172018201920202021202220232024202520262027202820292030203120322033203420352036203720382039204020412042204320442045204620472048204920502051205220532054205520562057205820592060206120622063206420652066206720682069207020712072207320742075207620772078207920802081208220832084208520862087208820892090209120922093209420952096209720982099210021012102210321042105210621072108210921102111211221132114211521162117211821192120212121222123212421252126212721282129213021312132213321342135213621372138213921402141214221432144214521462147214821492150215121522153215421552156215721582159216021612162216321642165216621672168216921702171217221732174217521762177217821792180218121822183218421852186218721882189219021912192219321942195219621972198219922002201220222032204220522062207220822092210221122122213221422152216221722182219222022212222222322242225222622272228222922302231223222332234223522362237223822392240224122422243224422452246224722482249225022512252225322542255225622572258225922602261226222632264226522662267226822692270227122722273227422752276227722782279228022812282228322842285228622872288228922902291229222932294229522962297229822992300230123022303230423052306230723082309231023112312231323142315231623172318231923202321232223232324232523262327232823292330233123322333233423352336233723382339234023412342234323442345234623472348234923502351235223532354235523562357235823592360236123622363236423652366236723682369237023712372237323742375237623772378237923802381238223832384238523862387238823892390239123922393239423952396239723982399240024012402240324042405240624072408240924102411241224132414241524162417241824192420242124222423242424252426242724282429243024312432243324342435243624372438243924402441244224432444244524462447244824492450245124522453245424552456245724582459246024612462246324642465246624672468246924702471247224732474247524762477247824792480248124822483248424852486248724882489249024912492249324942495249624972498249925002501250225032504250525062507250825092510251125122513251425152516251725182519252025212522252325242525252625272528252925302531253225332534253525362537253825392540254125422543 |
- import asyncio
- import inspect
- import json
- import logging
- import mimetypes
- import os
- import shutil
- import sys
- import time
- import random
- from typing import AsyncGenerator, Generator, Iterator
- from contextlib import asynccontextmanager
- from urllib.parse import urlencode, parse_qs, urlparse
- from pydantic import BaseModel
- from sqlalchemy import text
- from typing import Optional
- from aiocache import cached
- import aiohttp
- import requests
- from fastapi import (
- Depends,
- FastAPI,
- File,
- Form,
- HTTPException,
- Request,
- UploadFile,
- status,
- )
- from fastapi.middleware.cors import CORSMiddleware
- from fastapi.responses import JSONResponse, RedirectResponse
- from fastapi.staticfiles import StaticFiles
- from starlette.exceptions import HTTPException as StarletteHTTPException
- from starlette.middleware.base import BaseHTTPMiddleware
- from starlette.middleware.sessions import SessionMiddleware
- from starlette.responses import Response, StreamingResponse
- from open_webui.socket.main import (
- app as socket_app,
- periodic_usage_pool_cleanup,
- get_event_call,
- get_event_emitter,
- )
- from open_webui.routers import (
- audio,
- images,
- ollama,
- openai,
- retrieval,
- pipelines,
- tasks,
- auths,
- chats,
- folders,
- configs,
- groups,
- files,
- functions,
- memories,
- models,
- knowledge,
- prompts,
- evaluations,
- tools,
- users,
- utils,
- )
- from open_webui.routers.openai import (
- generate_chat_completion as generate_openai_chat_completion,
- )
- from open_webui.routers.ollama import (
- generate_chat_completion as generate_ollama_chat_completion,
- )
- from open_webui.routers.retrieval import (
- get_embedding_function,
- get_ef,
- get_rf,
- )
- from open_webui.routers.pipelines import (
- process_pipeline_inlet_filter,
- process_pipeline_outlet_filter,
- )
- from open_webui.retrieval.utils import get_sources_from_files
- from open_webui.internal.db import Session
- from open_webui.models.functions import Functions
- from open_webui.models.models import Models
- from open_webui.models.users import UserModel, Users
- from open_webui.constants import TASKS
- from open_webui.config import (
- # Ollama
- ENABLE_OLLAMA_API,
- OLLAMA_BASE_URLS,
- OLLAMA_API_CONFIGS,
- # OpenAI
- ENABLE_OPENAI_API,
- OPENAI_API_BASE_URLS,
- OPENAI_API_KEYS,
- OPENAI_API_CONFIGS,
- # Image
- AUTOMATIC1111_API_AUTH,
- AUTOMATIC1111_BASE_URL,
- AUTOMATIC1111_CFG_SCALE,
- AUTOMATIC1111_SAMPLER,
- AUTOMATIC1111_SCHEDULER,
- COMFYUI_BASE_URL,
- COMFYUI_WORKFLOW,
- COMFYUI_WORKFLOW_NODES,
- ENABLE_IMAGE_GENERATION,
- IMAGE_GENERATION_ENGINE,
- IMAGE_GENERATION_MODEL,
- IMAGE_SIZE,
- IMAGE_STEPS,
- IMAGES_OPENAI_API_BASE_URL,
- IMAGES_OPENAI_API_KEY,
- # Audio
- AUDIO_STT_ENGINE,
- AUDIO_STT_MODEL,
- AUDIO_STT_OPENAI_API_BASE_URL,
- AUDIO_STT_OPENAI_API_KEY,
- AUDIO_TTS_API_KEY,
- AUDIO_TTS_ENGINE,
- AUDIO_TTS_MODEL,
- AUDIO_TTS_OPENAI_API_BASE_URL,
- AUDIO_TTS_OPENAI_API_KEY,
- AUDIO_TTS_SPLIT_ON,
- AUDIO_TTS_VOICE,
- AUDIO_TTS_AZURE_SPEECH_REGION,
- AUDIO_TTS_AZURE_SPEECH_OUTPUT_FORMAT,
- WHISPER_MODEL,
- WHISPER_MODEL_AUTO_UPDATE,
- WHISPER_MODEL_DIR,
- # Retrieval
- RAG_TEMPLATE,
- DEFAULT_RAG_TEMPLATE,
- RAG_EMBEDDING_MODEL,
- RAG_EMBEDDING_MODEL_AUTO_UPDATE,
- RAG_EMBEDDING_MODEL_TRUST_REMOTE_CODE,
- RAG_RERANKING_MODEL,
- RAG_RERANKING_MODEL_AUTO_UPDATE,
- RAG_RERANKING_MODEL_TRUST_REMOTE_CODE,
- RAG_EMBEDDING_ENGINE,
- RAG_EMBEDDING_BATCH_SIZE,
- RAG_RELEVANCE_THRESHOLD,
- RAG_FILE_MAX_COUNT,
- RAG_FILE_MAX_SIZE,
- RAG_OPENAI_API_BASE_URL,
- RAG_OPENAI_API_KEY,
- RAG_OLLAMA_BASE_URL,
- RAG_OLLAMA_API_KEY,
- CHUNK_OVERLAP,
- CHUNK_SIZE,
- CONTENT_EXTRACTION_ENGINE,
- TIKA_SERVER_URL,
- RAG_TOP_K,
- RAG_TEXT_SPLITTER,
- TIKTOKEN_ENCODING_NAME,
- PDF_EXTRACT_IMAGES,
- YOUTUBE_LOADER_LANGUAGE,
- YOUTUBE_LOADER_PROXY_URL,
- # Retrieval (Web Search)
- RAG_WEB_SEARCH_ENGINE,
- RAG_WEB_SEARCH_RESULT_COUNT,
- RAG_WEB_SEARCH_CONCURRENT_REQUESTS,
- RAG_WEB_SEARCH_DOMAIN_FILTER_LIST,
- JINA_API_KEY,
- SEARCHAPI_API_KEY,
- SEARCHAPI_ENGINE,
- SEARXNG_QUERY_URL,
- SERPER_API_KEY,
- SERPLY_API_KEY,
- SERPSTACK_API_KEY,
- SERPSTACK_HTTPS,
- TAVILY_API_KEY,
- BING_SEARCH_V7_ENDPOINT,
- BING_SEARCH_V7_SUBSCRIPTION_KEY,
- BRAVE_SEARCH_API_KEY,
- KAGI_SEARCH_API_KEY,
- MOJEEK_SEARCH_API_KEY,
- GOOGLE_PSE_API_KEY,
- GOOGLE_PSE_ENGINE_ID,
- ENABLE_RAG_HYBRID_SEARCH,
- ENABLE_RAG_LOCAL_WEB_FETCH,
- ENABLE_RAG_WEB_LOADER_SSL_VERIFICATION,
- ENABLE_RAG_WEB_SEARCH,
- UPLOAD_DIR,
- # WebUI
- WEBUI_AUTH,
- WEBUI_NAME,
- WEBUI_BANNERS,
- WEBHOOK_URL,
- ADMIN_EMAIL,
- SHOW_ADMIN_DETAILS,
- JWT_EXPIRES_IN,
- ENABLE_SIGNUP,
- ENABLE_LOGIN_FORM,
- ENABLE_API_KEY,
- ENABLE_COMMUNITY_SHARING,
- ENABLE_MESSAGE_RATING,
- ENABLE_EVALUATION_ARENA_MODELS,
- USER_PERMISSIONS,
- DEFAULT_USER_ROLE,
- DEFAULT_PROMPT_SUGGESTIONS,
- DEFAULT_MODELS,
- DEFAULT_ARENA_MODEL,
- MODEL_ORDER_LIST,
- EVALUATION_ARENA_MODELS,
- # WebUI (OAuth)
- ENABLE_OAUTH_ROLE_MANAGEMENT,
- OAUTH_ROLES_CLAIM,
- OAUTH_EMAIL_CLAIM,
- OAUTH_PICTURE_CLAIM,
- OAUTH_USERNAME_CLAIM,
- OAUTH_ALLOWED_ROLES,
- OAUTH_ADMIN_ROLES,
- # WebUI (LDAP)
- ENABLE_LDAP,
- LDAP_SERVER_LABEL,
- LDAP_SERVER_HOST,
- LDAP_SERVER_PORT,
- LDAP_ATTRIBUTE_FOR_USERNAME,
- LDAP_SEARCH_FILTERS,
- LDAP_SEARCH_BASE,
- LDAP_APP_DN,
- LDAP_APP_PASSWORD,
- LDAP_USE_TLS,
- LDAP_CA_CERT_FILE,
- LDAP_CIPHERS,
- # Misc
- ENV,
- CACHE_DIR,
- STATIC_DIR,
- FRONTEND_BUILD_DIR,
- CORS_ALLOW_ORIGIN,
- DEFAULT_LOCALE,
- OAUTH_PROVIDERS,
- # Admin
- ENABLE_ADMIN_CHAT_ACCESS,
- ENABLE_ADMIN_EXPORT,
- # Tasks
- TASK_MODEL,
- TASK_MODEL_EXTERNAL,
- ENABLE_TAGS_GENERATION,
- ENABLE_SEARCH_QUERY_GENERATION,
- ENABLE_RETRIEVAL_QUERY_GENERATION,
- ENABLE_AUTOCOMPLETE_GENERATION,
- TITLE_GENERATION_PROMPT_TEMPLATE,
- TAGS_GENERATION_PROMPT_TEMPLATE,
- TOOLS_FUNCTION_CALLING_PROMPT_TEMPLATE,
- QUERY_GENERATION_PROMPT_TEMPLATE,
- AUTOCOMPLETE_GENERATION_PROMPT_TEMPLATE,
- AUTOCOMPLETE_GENERATION_INPUT_MAX_LENGTH,
- AppConfig,
- reset_config,
- )
- from open_webui.env import (
- CHANGELOG,
- GLOBAL_LOG_LEVEL,
- SAFE_MODE,
- SRC_LOG_LEVELS,
- VERSION,
- WEBUI_URL,
- WEBUI_BUILD_HASH,
- WEBUI_SECRET_KEY,
- WEBUI_SESSION_COOKIE_SAME_SITE,
- WEBUI_SESSION_COOKIE_SECURE,
- WEBUI_AUTH_TRUSTED_EMAIL_HEADER,
- WEBUI_AUTH_TRUSTED_NAME_HEADER,
- BYPASS_MODEL_ACCESS_CONTROL,
- RESET_CONFIG_ON_START,
- OFFLINE_MODE,
- )
- from open_webui.utils.plugin import load_function_module_by_id
- from open_webui.utils.misc import (
- add_or_update_system_message,
- get_last_user_message,
- prepend_to_first_user_message_content,
- openai_chat_chunk_message_template,
- openai_chat_completion_message_template,
- )
- from open_webui.utils.payload import (
- apply_model_params_to_body_openai,
- apply_model_system_prompt_to_body,
- )
- from open_webui.utils.payload import convert_payload_openai_to_ollama
- from open_webui.utils.response import (
- convert_response_ollama_to_openai,
- convert_streaming_response_ollama_to_openai,
- )
- from open_webui.utils.task import (
- get_task_model_id,
- rag_template,
- tools_function_calling_generation_template,
- )
- from open_webui.utils.tools import get_tools
- from open_webui.utils.access_control import has_access
- from open_webui.utils.auth import (
- decode_token,
- get_admin_user,
- get_current_user,
- get_http_authorization_cred,
- get_verified_user,
- )
- from open_webui.utils.oauth import oauth_manager
- from open_webui.utils.security_headers import SecurityHeadersMiddleware
- if SAFE_MODE:
- print("SAFE MODE ENABLED")
- Functions.deactivate_all_functions()
- logging.basicConfig(stream=sys.stdout, level=GLOBAL_LOG_LEVEL)
- log = logging.getLogger(__name__)
- log.setLevel(SRC_LOG_LEVELS["MAIN"])
- class SPAStaticFiles(StaticFiles):
- async def get_response(self, path: str, scope):
- try:
- return await super().get_response(path, scope)
- except (HTTPException, StarletteHTTPException) as ex:
- if ex.status_code == 404:
- return await super().get_response("index.html", scope)
- else:
- raise ex
- print(
- rf"""
- ___ __ __ _ _ _ ___
- / _ \ _ __ ___ _ __ \ \ / /__| |__ | | | |_ _|
- | | | | '_ \ / _ \ '_ \ \ \ /\ / / _ \ '_ \| | | || |
- | |_| | |_) | __/ | | | \ V V / __/ |_) | |_| || |
- \___/| .__/ \___|_| |_| \_/\_/ \___|_.__/ \___/|___|
- |_|
- v{VERSION} - building the best open-source AI user interface.
- {f"Commit: {WEBUI_BUILD_HASH}" if WEBUI_BUILD_HASH != "dev-build" else ""}
- https://github.com/open-webui/open-webui
- """
- )
- @asynccontextmanager
- async def lifespan(app: FastAPI):
- if RESET_CONFIG_ON_START:
- reset_config()
- asyncio.create_task(periodic_usage_pool_cleanup())
- yield
- app = FastAPI(
- docs_url="/docs" if ENV == "dev" else None,
- openapi_url="/openapi.json" if ENV == "dev" else None,
- redoc_url=None,
- lifespan=lifespan,
- )
- app.state.config = AppConfig()
- ########################################
- #
- # OLLAMA
- #
- ########################################
- app.state.config.ENABLE_OLLAMA_API = ENABLE_OLLAMA_API
- app.state.config.OLLAMA_BASE_URLS = OLLAMA_BASE_URLS
- app.state.config.OLLAMA_API_CONFIGS = OLLAMA_API_CONFIGS
- app.state.OLLAMA_MODELS = {}
- ########################################
- #
- # OPENAI
- #
- ########################################
- app.state.config.ENABLE_OPENAI_API = ENABLE_OPENAI_API
- app.state.config.OPENAI_API_BASE_URLS = OPENAI_API_BASE_URLS
- app.state.config.OPENAI_API_KEYS = OPENAI_API_KEYS
- app.state.config.OPENAI_API_CONFIGS = OPENAI_API_CONFIGS
- app.state.OPENAI_MODELS = {}
- ########################################
- #
- # WEBUI
- #
- ########################################
- app.state.config.ENABLE_SIGNUP = ENABLE_SIGNUP
- app.state.config.ENABLE_LOGIN_FORM = ENABLE_LOGIN_FORM
- app.state.config.ENABLE_API_KEY = ENABLE_API_KEY
- app.state.config.JWT_EXPIRES_IN = JWT_EXPIRES_IN
- app.state.config.SHOW_ADMIN_DETAILS = SHOW_ADMIN_DETAILS
- app.state.config.ADMIN_EMAIL = ADMIN_EMAIL
- app.state.config.DEFAULT_MODELS = DEFAULT_MODELS
- app.state.config.DEFAULT_PROMPT_SUGGESTIONS = DEFAULT_PROMPT_SUGGESTIONS
- app.state.config.DEFAULT_USER_ROLE = DEFAULT_USER_ROLE
- app.state.config.USER_PERMISSIONS = USER_PERMISSIONS
- app.state.config.WEBHOOK_URL = WEBHOOK_URL
- app.state.config.BANNERS = WEBUI_BANNERS
- app.state.config.MODEL_ORDER_LIST = MODEL_ORDER_LIST
- app.state.config.ENABLE_COMMUNITY_SHARING = ENABLE_COMMUNITY_SHARING
- app.state.config.ENABLE_MESSAGE_RATING = ENABLE_MESSAGE_RATING
- app.state.config.ENABLE_EVALUATION_ARENA_MODELS = ENABLE_EVALUATION_ARENA_MODELS
- app.state.config.EVALUATION_ARENA_MODELS = EVALUATION_ARENA_MODELS
- app.state.config.OAUTH_USERNAME_CLAIM = OAUTH_USERNAME_CLAIM
- app.state.config.OAUTH_PICTURE_CLAIM = OAUTH_PICTURE_CLAIM
- app.state.config.OAUTH_EMAIL_CLAIM = OAUTH_EMAIL_CLAIM
- app.state.config.ENABLE_OAUTH_ROLE_MANAGEMENT = ENABLE_OAUTH_ROLE_MANAGEMENT
- app.state.config.OAUTH_ROLES_CLAIM = OAUTH_ROLES_CLAIM
- app.state.config.OAUTH_ALLOWED_ROLES = OAUTH_ALLOWED_ROLES
- app.state.config.OAUTH_ADMIN_ROLES = OAUTH_ADMIN_ROLES
- app.state.config.ENABLE_LDAP = ENABLE_LDAP
- app.state.config.LDAP_SERVER_LABEL = LDAP_SERVER_LABEL
- app.state.config.LDAP_SERVER_HOST = LDAP_SERVER_HOST
- app.state.config.LDAP_SERVER_PORT = LDAP_SERVER_PORT
- app.state.config.LDAP_ATTRIBUTE_FOR_USERNAME = LDAP_ATTRIBUTE_FOR_USERNAME
- app.state.config.LDAP_APP_DN = LDAP_APP_DN
- app.state.config.LDAP_APP_PASSWORD = LDAP_APP_PASSWORD
- app.state.config.LDAP_SEARCH_BASE = LDAP_SEARCH_BASE
- app.state.config.LDAP_SEARCH_FILTERS = LDAP_SEARCH_FILTERS
- app.state.config.LDAP_USE_TLS = LDAP_USE_TLS
- app.state.config.LDAP_CA_CERT_FILE = LDAP_CA_CERT_FILE
- app.state.config.LDAP_CIPHERS = LDAP_CIPHERS
- app.state.AUTH_TRUSTED_EMAIL_HEADER = WEBUI_AUTH_TRUSTED_EMAIL_HEADER
- app.state.AUTH_TRUSTED_NAME_HEADER = WEBUI_AUTH_TRUSTED_NAME_HEADER
- app.state.TOOLS = {}
- app.state.FUNCTIONS = {}
- ########################################
- #
- # RETRIEVAL
- #
- ########################################
- app.state.config.TOP_K = RAG_TOP_K
- app.state.config.RELEVANCE_THRESHOLD = RAG_RELEVANCE_THRESHOLD
- app.state.config.FILE_MAX_SIZE = RAG_FILE_MAX_SIZE
- app.state.config.FILE_MAX_COUNT = RAG_FILE_MAX_COUNT
- app.state.config.ENABLE_RAG_HYBRID_SEARCH = ENABLE_RAG_HYBRID_SEARCH
- app.state.config.ENABLE_RAG_WEB_LOADER_SSL_VERIFICATION = (
- ENABLE_RAG_WEB_LOADER_SSL_VERIFICATION
- )
- app.state.config.CONTENT_EXTRACTION_ENGINE = CONTENT_EXTRACTION_ENGINE
- app.state.config.TIKA_SERVER_URL = TIKA_SERVER_URL
- app.state.config.TEXT_SPLITTER = RAG_TEXT_SPLITTER
- app.state.config.TIKTOKEN_ENCODING_NAME = TIKTOKEN_ENCODING_NAME
- app.state.config.CHUNK_SIZE = CHUNK_SIZE
- app.state.config.CHUNK_OVERLAP = CHUNK_OVERLAP
- app.state.config.RAG_EMBEDDING_ENGINE = RAG_EMBEDDING_ENGINE
- app.state.config.RAG_EMBEDDING_MODEL = RAG_EMBEDDING_MODEL
- app.state.config.RAG_EMBEDDING_BATCH_SIZE = RAG_EMBEDDING_BATCH_SIZE
- app.state.config.RAG_RERANKING_MODEL = RAG_RERANKING_MODEL
- app.state.config.RAG_TEMPLATE = RAG_TEMPLATE
- app.state.config.RAG_OPENAI_API_BASE_URL = RAG_OPENAI_API_BASE_URL
- app.state.config.RAG_OPENAI_API_KEY = RAG_OPENAI_API_KEY
- app.state.config.RAG_OLLAMA_BASE_URL = RAG_OLLAMA_BASE_URL
- app.state.config.RAG_OLLAMA_API_KEY = RAG_OLLAMA_API_KEY
- app.state.config.PDF_EXTRACT_IMAGES = PDF_EXTRACT_IMAGES
- app.state.config.YOUTUBE_LOADER_LANGUAGE = YOUTUBE_LOADER_LANGUAGE
- app.state.config.YOUTUBE_LOADER_PROXY_URL = YOUTUBE_LOADER_PROXY_URL
- app.state.config.ENABLE_RAG_WEB_SEARCH = ENABLE_RAG_WEB_SEARCH
- app.state.config.RAG_WEB_SEARCH_ENGINE = RAG_WEB_SEARCH_ENGINE
- app.state.config.RAG_WEB_SEARCH_DOMAIN_FILTER_LIST = RAG_WEB_SEARCH_DOMAIN_FILTER_LIST
- app.state.config.SEARXNG_QUERY_URL = SEARXNG_QUERY_URL
- app.state.config.GOOGLE_PSE_API_KEY = GOOGLE_PSE_API_KEY
- app.state.config.GOOGLE_PSE_ENGINE_ID = GOOGLE_PSE_ENGINE_ID
- app.state.config.BRAVE_SEARCH_API_KEY = BRAVE_SEARCH_API_KEY
- app.state.config.KAGI_SEARCH_API_KEY = KAGI_SEARCH_API_KEY
- app.state.config.MOJEEK_SEARCH_API_KEY = MOJEEK_SEARCH_API_KEY
- app.state.config.SERPSTACK_API_KEY = SERPSTACK_API_KEY
- app.state.config.SERPSTACK_HTTPS = SERPSTACK_HTTPS
- app.state.config.SERPER_API_KEY = SERPER_API_KEY
- app.state.config.SERPLY_API_KEY = SERPLY_API_KEY
- app.state.config.TAVILY_API_KEY = TAVILY_API_KEY
- app.state.config.SEARCHAPI_API_KEY = SEARCHAPI_API_KEY
- app.state.config.SEARCHAPI_ENGINE = SEARCHAPI_ENGINE
- app.state.config.JINA_API_KEY = JINA_API_KEY
- app.state.config.BING_SEARCH_V7_ENDPOINT = BING_SEARCH_V7_ENDPOINT
- app.state.config.BING_SEARCH_V7_SUBSCRIPTION_KEY = BING_SEARCH_V7_SUBSCRIPTION_KEY
- app.state.config.RAG_WEB_SEARCH_RESULT_COUNT = RAG_WEB_SEARCH_RESULT_COUNT
- app.state.config.RAG_WEB_SEARCH_CONCURRENT_REQUESTS = RAG_WEB_SEARCH_CONCURRENT_REQUESTS
- app.state.EMBEDDING_FUNCTION = None
- app.state.ef = None
- app.state.rf = None
- app.state.YOUTUBE_LOADER_TRANSLATION = None
- app.state.EMBEDDING_FUNCTION = get_embedding_function(
- app.state.config.RAG_EMBEDDING_ENGINE,
- app.state.config.RAG_EMBEDDING_MODEL,
- app.state.ef,
- (
- app.state.config.RAG_OPENAI_API_BASE_URL
- if app.state.config.RAG_EMBEDDING_ENGINE == "openai"
- else app.state.config.RAG_OLLAMA_BASE_URL
- ),
- (
- app.state.config.RAG_OPENAI_API_KEY
- if app.state.config.RAG_EMBEDDING_ENGINE == "openai"
- else app.state.config.RAG_OLLAMA_API_KEY
- ),
- app.state.config.RAG_EMBEDDING_BATCH_SIZE,
- )
- try:
- app.state.ef = get_ef(
- app.state.config.RAG_EMBEDDING_ENGINE,
- app.state.config.RAG_EMBEDDING_MODEL,
- RAG_EMBEDDING_MODEL_AUTO_UPDATE,
- )
- app.state.rf = get_rf(
- app.state.config.RAG_RERANKING_MODEL,
- RAG_RERANKING_MODEL_AUTO_UPDATE,
- )
- except Exception as e:
- log.error(f"Error updating models: {e}")
- pass
- ########################################
- #
- # IMAGES
- #
- ########################################
- app.state.config.IMAGE_GENERATION_ENGINE = IMAGE_GENERATION_ENGINE
- app.state.config.ENABLE_IMAGE_GENERATION = ENABLE_IMAGE_GENERATION
- app.state.config.IMAGES_OPENAI_API_BASE_URL = IMAGES_OPENAI_API_BASE_URL
- app.state.config.IMAGES_OPENAI_API_KEY = IMAGES_OPENAI_API_KEY
- app.state.config.IMAGE_GENERATION_MODEL = IMAGE_GENERATION_MODEL
- app.state.config.AUTOMATIC1111_BASE_URL = AUTOMATIC1111_BASE_URL
- app.state.config.AUTOMATIC1111_API_AUTH = AUTOMATIC1111_API_AUTH
- app.state.config.AUTOMATIC1111_CFG_SCALE = AUTOMATIC1111_CFG_SCALE
- app.state.config.AUTOMATIC1111_SAMPLER = AUTOMATIC1111_SAMPLER
- app.state.config.AUTOMATIC1111_SCHEDULER = AUTOMATIC1111_SCHEDULER
- app.state.config.COMFYUI_BASE_URL = COMFYUI_BASE_URL
- app.state.config.COMFYUI_WORKFLOW = COMFYUI_WORKFLOW
- app.state.config.COMFYUI_WORKFLOW_NODES = COMFYUI_WORKFLOW_NODES
- app.state.config.IMAGE_SIZE = IMAGE_SIZE
- app.state.config.IMAGE_STEPS = IMAGE_STEPS
- ########################################
- #
- # AUDIO
- #
- ########################################
- app.state.config.STT_OPENAI_API_BASE_URL = AUDIO_STT_OPENAI_API_BASE_URL
- app.state.config.STT_OPENAI_API_KEY = AUDIO_STT_OPENAI_API_KEY
- app.state.config.STT_ENGINE = AUDIO_STT_ENGINE
- app.state.config.STT_MODEL = AUDIO_STT_MODEL
- app.state.config.WHISPER_MODEL = WHISPER_MODEL
- app.state.config.TTS_OPENAI_API_BASE_URL = AUDIO_TTS_OPENAI_API_BASE_URL
- app.state.config.TTS_OPENAI_API_KEY = AUDIO_TTS_OPENAI_API_KEY
- app.state.config.TTS_ENGINE = AUDIO_TTS_ENGINE
- app.state.config.TTS_MODEL = AUDIO_TTS_MODEL
- app.state.config.TTS_VOICE = AUDIO_TTS_VOICE
- app.state.config.TTS_API_KEY = AUDIO_TTS_API_KEY
- app.state.config.TTS_SPLIT_ON = AUDIO_TTS_SPLIT_ON
- app.state.config.TTS_AZURE_SPEECH_REGION = AUDIO_TTS_AZURE_SPEECH_REGION
- app.state.config.TTS_AZURE_SPEECH_OUTPUT_FORMAT = AUDIO_TTS_AZURE_SPEECH_OUTPUT_FORMAT
- app.state.faster_whisper_model = None
- app.state.speech_synthesiser = None
- app.state.speech_speaker_embeddings_dataset = None
- ########################################
- #
- # TASKS
- #
- ########################################
- app.state.config.TASK_MODEL = TASK_MODEL
- app.state.config.TASK_MODEL_EXTERNAL = TASK_MODEL_EXTERNAL
- app.state.config.ENABLE_SEARCH_QUERY_GENERATION = ENABLE_SEARCH_QUERY_GENERATION
- app.state.config.ENABLE_RETRIEVAL_QUERY_GENERATION = ENABLE_RETRIEVAL_QUERY_GENERATION
- app.state.config.ENABLE_AUTOCOMPLETE_GENERATION = ENABLE_AUTOCOMPLETE_GENERATION
- app.state.config.ENABLE_TAGS_GENERATION = ENABLE_TAGS_GENERATION
- app.state.config.TITLE_GENERATION_PROMPT_TEMPLATE = TITLE_GENERATION_PROMPT_TEMPLATE
- app.state.config.TAGS_GENERATION_PROMPT_TEMPLATE = TAGS_GENERATION_PROMPT_TEMPLATE
- app.state.config.TOOLS_FUNCTION_CALLING_PROMPT_TEMPLATE = (
- TOOLS_FUNCTION_CALLING_PROMPT_TEMPLATE
- )
- app.state.config.QUERY_GENERATION_PROMPT_TEMPLATE = QUERY_GENERATION_PROMPT_TEMPLATE
- app.state.config.AUTOCOMPLETE_GENERATION_PROMPT_TEMPLATE = (
- AUTOCOMPLETE_GENERATION_PROMPT_TEMPLATE
- )
- app.state.config.AUTOCOMPLETE_GENERATION_INPUT_MAX_LENGTH = (
- AUTOCOMPLETE_GENERATION_INPUT_MAX_LENGTH
- )
- ########################################
- #
- # WEBUI
- #
- ########################################
- app.state.MODELS = {}
- ##################################
- #
- # ChatCompletion Middleware
- #
- ##################################
- async def chat_completion_filter_functions_handler(body, model, extra_params):
- skip_files = None
- def get_filter_function_ids(model):
- def get_priority(function_id):
- function = Functions.get_function_by_id(function_id)
- if function is not None and hasattr(function, "valves"):
- # TODO: Fix FunctionModel
- return (function.valves if function.valves else {}).get("priority", 0)
- return 0
- filter_ids = [
- function.id for function in Functions.get_global_filter_functions()
- ]
- if "info" in model and "meta" in model["info"]:
- filter_ids.extend(model["info"]["meta"].get("filterIds", []))
- filter_ids = list(set(filter_ids))
- enabled_filter_ids = [
- function.id
- for function in Functions.get_functions_by_type("filter", active_only=True)
- ]
- filter_ids = [
- filter_id for filter_id in filter_ids if filter_id in enabled_filter_ids
- ]
- filter_ids.sort(key=get_priority)
- return filter_ids
- filter_ids = get_filter_function_ids(model)
- for filter_id in filter_ids:
- filter = Functions.get_function_by_id(filter_id)
- if not filter:
- continue
- if filter_id in app.state.FUNCTIONS:
- function_module = app.state.FUNCTIONS[filter_id]
- else:
- function_module, _, _ = load_function_module_by_id(filter_id)
- app.state.FUNCTIONS[filter_id] = function_module
- # Check if the function has a file_handler variable
- if hasattr(function_module, "file_handler"):
- skip_files = function_module.file_handler
- if hasattr(function_module, "valves") and hasattr(function_module, "Valves"):
- valves = Functions.get_function_valves_by_id(filter_id)
- function_module.valves = function_module.Valves(
- **(valves if valves else {})
- )
- if not hasattr(function_module, "inlet"):
- continue
- try:
- inlet = function_module.inlet
- # Get the signature of the function
- sig = inspect.signature(inlet)
- params = {"body": body} | {
- k: v
- for k, v in {
- **extra_params,
- "__model__": model,
- "__id__": filter_id,
- }.items()
- if k in sig.parameters
- }
- if "__user__" in params and hasattr(function_module, "UserValves"):
- try:
- params["__user__"]["valves"] = function_module.UserValves(
- **Functions.get_user_valves_by_id_and_user_id(
- filter_id, params["__user__"]["id"]
- )
- )
- except Exception as e:
- print(e)
- if inspect.iscoroutinefunction(inlet):
- body = await inlet(**params)
- else:
- body = inlet(**params)
- except Exception as e:
- print(f"Error: {e}")
- raise e
- if skip_files and "files" in body.get("metadata", {}):
- del body["metadata"]["files"]
- return body, {}
- def get_tools_function_calling_payload(messages, task_model_id, content):
- user_message = get_last_user_message(messages)
- history = "\n".join(
- f"{message['role'].upper()}: \"\"\"{message['content']}\"\"\""
- for message in messages[::-1][:4]
- )
- prompt = f"History:\n{history}\nQuery: {user_message}"
- return {
- "model": task_model_id,
- "messages": [
- {"role": "system", "content": content},
- {"role": "user", "content": f"Query: {prompt}"},
- ],
- "stream": False,
- "metadata": {"task": str(TASKS.FUNCTION_CALLING)},
- }
- async def get_content_from_response(response) -> Optional[str]:
- content = None
- if hasattr(response, "body_iterator"):
- async for chunk in response.body_iterator:
- data = json.loads(chunk.decode("utf-8"))
- content = data["choices"][0]["message"]["content"]
- # Cleanup any remaining background tasks if necessary
- if response.background is not None:
- await response.background()
- else:
- content = response["choices"][0]["message"]["content"]
- return content
- async def chat_completion_tools_handler(
- body: dict, user: UserModel, models, extra_params: dict
- ) -> tuple[dict, dict]:
- # If tool_ids field is present, call the functions
- metadata = body.get("metadata", {})
- tool_ids = metadata.get("tool_ids", None)
- log.debug(f"{tool_ids=}")
- if not tool_ids:
- return body, {}
- skip_files = False
- sources = []
- task_model_id = get_task_model_id(
- body["model"],
- app.state.config.TASK_MODEL,
- app.state.config.TASK_MODEL_EXTERNAL,
- models,
- )
- tools = get_tools(
- app,
- tool_ids,
- user,
- {
- **extra_params,
- "__model__": models[task_model_id],
- "__messages__": body["messages"],
- "__files__": metadata.get("files", []),
- },
- )
- log.info(f"{tools=}")
- specs = [tool["spec"] for tool in tools.values()]
- tools_specs = json.dumps(specs)
- if app.state.config.TOOLS_FUNCTION_CALLING_PROMPT_TEMPLATE != "":
- template = app.state.config.TOOLS_FUNCTION_CALLING_PROMPT_TEMPLATE
- else:
- template = """Available Tools: {{TOOLS}}\nReturn an empty string if no tools match the query. If a function tool matches, construct and return a JSON object in the format {\"name\": \"functionName\", \"parameters\": {\"requiredFunctionParamKey\": \"requiredFunctionParamValue\"}} using the appropriate tool and its parameters. Only return the object and limit the response to the JSON object without additional text."""
- tools_function_calling_prompt = tools_function_calling_generation_template(
- template, tools_specs
- )
- log.info(f"{tools_function_calling_prompt=}")
- payload = get_tools_function_calling_payload(
- body["messages"], task_model_id, tools_function_calling_prompt
- )
- try:
- payload = process_pipeline_inlet_filter(request, payload, user, models)
- except Exception as e:
- raise e
- try:
- response = await generate_chat_completions(form_data=payload, user=user)
- log.debug(f"{response=}")
- content = await get_content_from_response(response)
- log.debug(f"{content=}")
- if not content:
- return body, {}
- try:
- content = content[content.find("{") : content.rfind("}") + 1]
- if not content:
- raise Exception("No JSON object found in the response")
- result = json.loads(content)
- tool_function_name = result.get("name", None)
- if tool_function_name not in tools:
- return body, {}
- tool_function_params = result.get("parameters", {})
- try:
- required_params = (
- tools[tool_function_name]
- .get("spec", {})
- .get("parameters", {})
- .get("required", [])
- )
- tool_function = tools[tool_function_name]["callable"]
- tool_function_params = {
- k: v
- for k, v in tool_function_params.items()
- if k in required_params
- }
- tool_output = await tool_function(**tool_function_params)
- except Exception as e:
- tool_output = str(e)
- if isinstance(tool_output, str):
- if tools[tool_function_name]["citation"]:
- sources.append(
- {
- "source": {
- "name": f"TOOL:{tools[tool_function_name]['toolkit_id']}/{tool_function_name}"
- },
- "document": [tool_output],
- "metadata": [
- {
- "source": f"TOOL:{tools[tool_function_name]['toolkit_id']}/{tool_function_name}"
- }
- ],
- }
- )
- else:
- sources.append(
- {
- "source": {},
- "document": [tool_output],
- "metadata": [
- {
- "source": f"TOOL:{tools[tool_function_name]['toolkit_id']}/{tool_function_name}"
- }
- ],
- }
- )
- if tools[tool_function_name]["file_handler"]:
- skip_files = True
- except Exception as e:
- log.exception(f"Error: {e}")
- content = None
- except Exception as e:
- log.exception(f"Error: {e}")
- content = None
- log.debug(f"tool_contexts: {sources}")
- if skip_files and "files" in body.get("metadata", {}):
- del body["metadata"]["files"]
- return body, {"sources": sources}
- async def chat_completion_files_handler(
- body: dict, user: UserModel
- ) -> tuple[dict, dict[str, list]]:
- sources = []
- if files := body.get("metadata", {}).get("files", None):
- try:
- queries_response = await generate_queries(
- {
- "model": body["model"],
- "messages": body["messages"],
- "type": "retrieval",
- },
- user,
- )
- queries_response = queries_response["choices"][0]["message"]["content"]
- try:
- bracket_start = queries_response.find("{")
- bracket_end = queries_response.rfind("}") + 1
- if bracket_start == -1 or bracket_end == -1:
- raise Exception("No JSON object found in the response")
- queries_response = queries_response[bracket_start:bracket_end]
- queries_response = json.loads(queries_response)
- except Exception as e:
- queries_response = {"queries": [queries_response]}
- queries = queries_response.get("queries", [])
- except Exception as e:
- queries = []
- if len(queries) == 0:
- queries = [get_last_user_message(body["messages"])]
- sources = get_sources_from_files(
- files=files,
- queries=queries,
- embedding_function=app.state.EMBEDDING_FUNCTION,
- k=app.state.config.TOP_K,
- reranking_function=app.state.rf,
- r=app.state.config.RELEVANCE_THRESHOLD,
- hybrid_search=app.state.config.ENABLE_RAG_HYBRID_SEARCH,
- )
- log.debug(f"rag_contexts:sources: {sources}")
- return body, {"sources": sources}
- async def get_body_and_model_and_user(request, models):
- # Read the original request body
- body = await request.body()
- body_str = body.decode("utf-8")
- body = json.loads(body_str) if body_str else {}
- model_id = body["model"]
- if model_id not in models:
- raise Exception("Model not found")
- model = models[model_id]
- user = get_current_user(
- request,
- get_http_authorization_cred(request.headers.get("Authorization")),
- )
- return body, model, user
- class ChatCompletionMiddleware(BaseHTTPMiddleware):
- async def dispatch(self, request: Request, call_next):
- if not (
- request.method == "POST"
- and any(
- endpoint in request.url.path
- for endpoint in ["/ollama/api/chat", "/chat/completions"]
- )
- ):
- return await call_next(request)
- log.debug(f"request.url.path: {request.url.path}")
- await get_all_models(request)
- models = app.state.MODELS
- try:
- body, model, user = await get_body_and_model_and_user(request, models)
- except Exception as e:
- return JSONResponse(
- status_code=status.HTTP_400_BAD_REQUEST,
- content={"detail": str(e)},
- )
- model_info = Models.get_model_by_id(model["id"])
- if user.role == "user" and not BYPASS_MODEL_ACCESS_CONTROL:
- if model.get("arena"):
- if not has_access(
- user.id,
- type="read",
- access_control=model.get("info", {})
- .get("meta", {})
- .get("access_control", {}),
- ):
- raise HTTPException(
- status_code=403,
- detail="Model not found",
- )
- else:
- if not model_info:
- return JSONResponse(
- status_code=status.HTTP_404_NOT_FOUND,
- content={"detail": "Model not found"},
- )
- elif not (
- user.id == model_info.user_id
- or has_access(
- user.id, type="read", access_control=model_info.access_control
- )
- ):
- return JSONResponse(
- status_code=status.HTTP_403_FORBIDDEN,
- content={"detail": "User does not have access to the model"},
- )
- metadata = {
- "chat_id": body.pop("chat_id", None),
- "message_id": body.pop("id", None),
- "session_id": body.pop("session_id", None),
- "tool_ids": body.get("tool_ids", None),
- "files": body.get("files", None),
- }
- body["metadata"] = metadata
- extra_params = {
- "__event_emitter__": get_event_emitter(metadata),
- "__event_call__": get_event_call(metadata),
- "__user__": {
- "id": user.id,
- "email": user.email,
- "name": user.name,
- "role": user.role,
- },
- "__metadata__": metadata,
- }
- # Initialize data_items to store additional data to be sent to the client
- # Initialize contexts and citation
- data_items = []
- sources = []
- try:
- body, flags = await chat_completion_filter_functions_handler(
- body, model, extra_params
- )
- except Exception as e:
- return JSONResponse(
- status_code=status.HTTP_400_BAD_REQUEST,
- content={"detail": str(e)},
- )
- tool_ids = body.pop("tool_ids", None)
- files = body.pop("files", None)
- metadata = {
- **metadata,
- "tool_ids": tool_ids,
- "files": files,
- }
- body["metadata"] = metadata
- try:
- body, flags = await chat_completion_tools_handler(
- body, user, models, extra_params
- )
- sources.extend(flags.get("sources", []))
- except Exception as e:
- log.exception(e)
- try:
- body, flags = await chat_completion_files_handler(body, user)
- sources.extend(flags.get("sources", []))
- except Exception as e:
- log.exception(e)
- # If context is not empty, insert it into the messages
- if len(sources) > 0:
- context_string = ""
- for source_idx, source in enumerate(sources):
- source_id = source.get("source", {}).get("name", "")
- if "document" in source:
- for doc_idx, doc_context in enumerate(source["document"]):
- metadata = source.get("metadata")
- doc_source_id = None
- if metadata:
- doc_source_id = metadata[doc_idx].get("source", source_id)
- if source_id:
- context_string += f"<source><source_id>{doc_source_id if doc_source_id is not None else source_id}</source_id><source_context>{doc_context}</source_context></source>\n"
- else:
- # If there is no source_id, then do not include the source_id tag
- context_string += f"<source><source_context>{doc_context}</source_context></source>\n"
- context_string = context_string.strip()
- prompt = get_last_user_message(body["messages"])
- if prompt is None:
- raise Exception("No user message found")
- if (
- app.state.config.RELEVANCE_THRESHOLD == 0
- and context_string.strip() == ""
- ):
- log.debug(
- f"With a 0 relevancy threshold for RAG, the context cannot be empty"
- )
- # Workaround for Ollama 2.0+ system prompt issue
- # TODO: replace with add_or_update_system_message
- if model["owned_by"] == "ollama":
- body["messages"] = prepend_to_first_user_message_content(
- rag_template(app.state.config.RAG_TEMPLATE, context_string, prompt),
- body["messages"],
- )
- else:
- body["messages"] = add_or_update_system_message(
- rag_template(app.state.config.RAG_TEMPLATE, context_string, prompt),
- body["messages"],
- )
- # If there are citations, add them to the data_items
- sources = [
- source for source in sources if source.get("source", {}).get("name", "")
- ]
- if len(sources) > 0:
- data_items.append({"sources": sources})
- modified_body_bytes = json.dumps(body).encode("utf-8")
- # Replace the request body with the modified one
- request._body = modified_body_bytes
- # Set custom header to ensure content-length matches new body length
- request.headers.__dict__["_list"] = [
- (b"content-length", str(len(modified_body_bytes)).encode("utf-8")),
- *[(k, v) for k, v in request.headers.raw if k.lower() != b"content-length"],
- ]
- response = await call_next(request)
- if not isinstance(response, StreamingResponse):
- return response
- content_type = response.headers["Content-Type"]
- is_openai = "text/event-stream" in content_type
- is_ollama = "application/x-ndjson" in content_type
- if not is_openai and not is_ollama:
- return response
- def wrap_item(item):
- return f"data: {item}\n\n" if is_openai else f"{item}\n"
- async def stream_wrapper(original_generator, data_items):
- for item in data_items:
- yield wrap_item(json.dumps(item))
- async for data in original_generator:
- yield data
- return StreamingResponse(
- stream_wrapper(response.body_iterator, data_items),
- headers=dict(response.headers),
- )
- async def _receive(self, body: bytes):
- return {"type": "http.request", "body": body, "more_body": False}
- app.add_middleware(ChatCompletionMiddleware)
- class PipelineMiddleware(BaseHTTPMiddleware):
- async def dispatch(self, request: Request, call_next):
- if not (
- request.method == "POST"
- and any(
- endpoint in request.url.path
- for endpoint in ["/ollama/api/chat", "/chat/completions"]
- )
- ):
- return await call_next(request)
- log.debug(f"request.url.path: {request.url.path}")
- # Read the original request body
- body = await request.body()
- # Decode body to string
- body_str = body.decode("utf-8")
- # Parse string to JSON
- data = json.loads(body_str) if body_str else {}
- try:
- user = get_current_user(
- request,
- get_http_authorization_cred(request.headers["Authorization"]),
- )
- except KeyError as e:
- if len(e.args) > 1:
- return JSONResponse(
- status_code=e.args[0],
- content={"detail": e.args[1]},
- )
- else:
- return JSONResponse(
- status_code=status.HTTP_401_UNAUTHORIZED,
- content={"detail": "Not authenticated"},
- )
- except HTTPException as e:
- return JSONResponse(
- status_code=e.status_code,
- content={"detail": e.detail},
- )
- await get_all_models(request)
- models = app.state.MODELS
- try:
- data = process_pipeline_inlet_filter(request, data, user, models)
- except Exception as e:
- if len(e.args) > 1:
- return JSONResponse(
- status_code=e.args[0],
- content={"detail": e.args[1]},
- )
- else:
- return JSONResponse(
- status_code=status.HTTP_400_BAD_REQUEST,
- content={"detail": str(e)},
- )
- modified_body_bytes = json.dumps(data).encode("utf-8")
- # Replace the request body with the modified one
- request._body = modified_body_bytes
- # Set custom header to ensure content-length matches new body length
- request.headers.__dict__["_list"] = [
- (b"content-length", str(len(modified_body_bytes)).encode("utf-8")),
- *[(k, v) for k, v in request.headers.raw if k.lower() != b"content-length"],
- ]
- response = await call_next(request)
- return response
- async def _receive(self, body: bytes):
- return {"type": "http.request", "body": body, "more_body": False}
- app.add_middleware(PipelineMiddleware)
- class RedirectMiddleware(BaseHTTPMiddleware):
- async def dispatch(self, request: Request, call_next):
- # Check if the request is a GET request
- if request.method == "GET":
- path = request.url.path
- query_params = dict(parse_qs(urlparse(str(request.url)).query))
- # Check for the specific watch path and the presence of 'v' parameter
- if path.endswith("/watch") and "v" in query_params:
- video_id = query_params["v"][0] # Extract the first 'v' parameter
- encoded_video_id = urlencode({"youtube": video_id})
- redirect_url = f"/?{encoded_video_id}"
- return RedirectResponse(url=redirect_url)
- # Proceed with the normal flow of other requests
- response = await call_next(request)
- return response
- # Add the middleware to the app
- app.add_middleware(RedirectMiddleware)
- app.add_middleware(SecurityHeadersMiddleware)
- @app.middleware("http")
- async def commit_session_after_request(request: Request, call_next):
- response = await call_next(request)
- # log.debug("Commit session after request")
- Session.commit()
- return response
- @app.middleware("http")
- async def check_url(request: Request, call_next):
- start_time = int(time.time())
- request.state.enable_api_key = app.state.config.ENABLE_API_KEY
- response = await call_next(request)
- process_time = int(time.time()) - start_time
- response.headers["X-Process-Time"] = str(process_time)
- return response
- @app.middleware("http")
- async def inspect_websocket(request: Request, call_next):
- if (
- "/ws/socket.io" in request.url.path
- and request.query_params.get("transport") == "websocket"
- ):
- upgrade = (request.headers.get("Upgrade") or "").lower()
- connection = (request.headers.get("Connection") or "").lower().split(",")
- # Check that there's the correct headers for an upgrade, else reject the connection
- # This is to work around this upstream issue: https://github.com/miguelgrinberg/python-engineio/issues/367
- if upgrade != "websocket" or "upgrade" not in connection:
- return JSONResponse(
- status_code=status.HTTP_400_BAD_REQUEST,
- content={"detail": "Invalid WebSocket upgrade request"},
- )
- return await call_next(request)
- app.add_middleware(
- CORSMiddleware,
- allow_origins=CORS_ALLOW_ORIGIN,
- allow_credentials=True,
- allow_methods=["*"],
- allow_headers=["*"],
- )
- app.mount("/ws", socket_app)
- app.include_router(ollama.router, prefix="/ollama", tags=["ollama"])
- app.include_router(openai.router, prefix="/openai", tags=["openai"])
- app.include_router(pipelines.router, prefix="/api/pipelines", tags=["pipelines"])
- app.include_router(tasks.router, prefix="/api/tasks", tags=["tasks"])
- app.include_router(images.router, prefix="/api/v1/images", tags=["images"])
- app.include_router(audio.router, prefix="/api/v1/audio", tags=["audio"])
- app.include_router(retrieval.router, prefix="/api/v1/retrieval", tags=["retrieval"])
- app.include_router(configs.router, prefix="/api/v1/configs", tags=["configs"])
- app.include_router(auths.router, prefix="/api/v1/auths", tags=["auths"])
- app.include_router(users.router, prefix="/api/v1/users", tags=["users"])
- app.include_router(chats.router, prefix="/api/v1/chats", tags=["chats"])
- app.include_router(models.router, prefix="/api/v1/models", tags=["models"])
- app.include_router(knowledge.router, prefix="/api/v1/knowledge", tags=["knowledge"])
- app.include_router(prompts.router, prefix="/api/v1/prompts", tags=["prompts"])
- app.include_router(tools.router, prefix="/api/v1/tools", tags=["tools"])
- app.include_router(memories.router, prefix="/api/v1/memories", tags=["memories"])
- app.include_router(folders.router, prefix="/api/v1/folders", tags=["folders"])
- app.include_router(groups.router, prefix="/api/v1/groups", tags=["groups"])
- app.include_router(files.router, prefix="/api/v1/files", tags=["files"])
- app.include_router(functions.router, prefix="/api/v1/functions", tags=["functions"])
- app.include_router(
- evaluations.router, prefix="/api/v1/evaluations", tags=["evaluations"]
- )
- app.include_router(utils.router, prefix="/api/v1/utils", tags=["utils"])
- ##################################
- #
- # Chat Endpoints
- #
- ##################################
- def get_function_module(pipe_id: str):
- # Check if function is already loaded
- if pipe_id not in app.state.FUNCTIONS:
- function_module, _, _ = load_function_module_by_id(pipe_id)
- app.state.FUNCTIONS[pipe_id] = function_module
- else:
- function_module = app.state.FUNCTIONS[pipe_id]
- if hasattr(function_module, "valves") and hasattr(function_module, "Valves"):
- valves = Functions.get_function_valves_by_id(pipe_id)
- function_module.valves = function_module.Valves(**(valves if valves else {}))
- return function_module
- async def get_function_models():
- pipes = Functions.get_functions_by_type("pipe", active_only=True)
- pipe_models = []
- for pipe in pipes:
- function_module = get_function_module(pipe.id)
- # Check if function is a manifold
- if hasattr(function_module, "pipes"):
- sub_pipes = []
- # Check if pipes is a function or a list
- try:
- if callable(function_module.pipes):
- sub_pipes = function_module.pipes()
- else:
- sub_pipes = function_module.pipes
- except Exception as e:
- log.exception(e)
- sub_pipes = []
- log.debug(
- f"get_function_models: function '{pipe.id}' is a manifold of {sub_pipes}"
- )
- for p in sub_pipes:
- sub_pipe_id = f'{pipe.id}.{p["id"]}'
- sub_pipe_name = p["name"]
- if hasattr(function_module, "name"):
- sub_pipe_name = f"{function_module.name}{sub_pipe_name}"
- pipe_flag = {"type": pipe.type}
- pipe_models.append(
- {
- "id": sub_pipe_id,
- "name": sub_pipe_name,
- "object": "model",
- "created": pipe.created_at,
- "owned_by": "openai",
- "pipe": pipe_flag,
- }
- )
- else:
- pipe_flag = {"type": "pipe"}
- log.debug(
- f"get_function_models: function '{pipe.id}' is a single pipe {{ 'id': {pipe.id}, 'name': {pipe.name} }}"
- )
- pipe_models.append(
- {
- "id": pipe.id,
- "name": pipe.name,
- "object": "model",
- "created": pipe.created_at,
- "owned_by": "openai",
- "pipe": pipe_flag,
- }
- )
- return pipe_models
- async def generate_function_chat_completion(form_data, user, models: dict = {}):
- async def execute_pipe(pipe, params):
- if inspect.iscoroutinefunction(pipe):
- return await pipe(**params)
- else:
- return pipe(**params)
- async def get_message_content(res: str | Generator | AsyncGenerator) -> str:
- if isinstance(res, str):
- return res
- if isinstance(res, Generator):
- return "".join(map(str, res))
- if isinstance(res, AsyncGenerator):
- return "".join([str(stream) async for stream in res])
- def process_line(form_data: dict, line):
- if isinstance(line, BaseModel):
- line = line.model_dump_json()
- line = f"data: {line}"
- if isinstance(line, dict):
- line = f"data: {json.dumps(line)}"
- try:
- line = line.decode("utf-8")
- except Exception:
- pass
- if line.startswith("data:"):
- return f"{line}\n\n"
- else:
- line = openai_chat_chunk_message_template(form_data["model"], line)
- return f"data: {json.dumps(line)}\n\n"
- def get_pipe_id(form_data: dict) -> str:
- pipe_id = form_data["model"]
- if "." in pipe_id:
- pipe_id, _ = pipe_id.split(".", 1)
- return pipe_id
- def get_function_params(function_module, form_data, user, extra_params=None):
- if extra_params is None:
- extra_params = {}
- pipe_id = get_pipe_id(form_data)
- # Get the signature of the function
- sig = inspect.signature(function_module.pipe)
- params = {"body": form_data} | {
- k: v for k, v in extra_params.items() if k in sig.parameters
- }
- if "__user__" in params and hasattr(function_module, "UserValves"):
- user_valves = Functions.get_user_valves_by_id_and_user_id(pipe_id, user.id)
- try:
- params["__user__"]["valves"] = function_module.UserValves(**user_valves)
- except Exception as e:
- log.exception(e)
- params["__user__"]["valves"] = function_module.UserValves()
- return params
- model_id = form_data.get("model")
- model_info = Models.get_model_by_id(model_id)
- metadata = form_data.pop("metadata", {})
- files = metadata.get("files", [])
- tool_ids = metadata.get("tool_ids", [])
- # Check if tool_ids is None
- if tool_ids is None:
- tool_ids = []
- __event_emitter__ = None
- __event_call__ = None
- __task__ = None
- __task_body__ = None
- if metadata:
- if all(k in metadata for k in ("session_id", "chat_id", "message_id")):
- __event_emitter__ = get_event_emitter(metadata)
- __event_call__ = get_event_call(metadata)
- __task__ = metadata.get("task", None)
- __task_body__ = metadata.get("task_body", None)
- extra_params = {
- "__event_emitter__": __event_emitter__,
- "__event_call__": __event_call__,
- "__task__": __task__,
- "__task_body__": __task_body__,
- "__files__": files,
- "__user__": {
- "id": user.id,
- "email": user.email,
- "name": user.name,
- "role": user.role,
- },
- "__metadata__": metadata,
- }
- extra_params["__tools__"] = get_tools(
- app,
- tool_ids,
- user,
- {
- **extra_params,
- "__model__": models.get(form_data["model"], None),
- "__messages__": form_data["messages"],
- "__files__": files,
- },
- )
- if model_info:
- if model_info.base_model_id:
- form_data["model"] = model_info.base_model_id
- params = model_info.params.model_dump()
- form_data = apply_model_params_to_body_openai(params, form_data)
- form_data = apply_model_system_prompt_to_body(params, form_data, user)
- pipe_id = get_pipe_id(form_data)
- function_module = get_function_module(pipe_id)
- pipe = function_module.pipe
- params = get_function_params(function_module, form_data, user, extra_params)
- if form_data.get("stream", False):
- async def stream_content():
- try:
- res = await execute_pipe(pipe, params)
- # Directly return if the response is a StreamingResponse
- if isinstance(res, StreamingResponse):
- async for data in res.body_iterator:
- yield data
- return
- if isinstance(res, dict):
- yield f"data: {json.dumps(res)}\n\n"
- return
- except Exception as e:
- log.error(f"Error: {e}")
- yield f"data: {json.dumps({'error': {'detail':str(e)}})}\n\n"
- return
- if isinstance(res, str):
- message = openai_chat_chunk_message_template(form_data["model"], res)
- yield f"data: {json.dumps(message)}\n\n"
- if isinstance(res, Iterator):
- for line in res:
- yield process_line(form_data, line)
- if isinstance(res, AsyncGenerator):
- async for line in res:
- yield process_line(form_data, line)
- if isinstance(res, str) or isinstance(res, Generator):
- finish_message = openai_chat_chunk_message_template(
- form_data["model"], ""
- )
- finish_message["choices"][0]["finish_reason"] = "stop"
- yield f"data: {json.dumps(finish_message)}\n\n"
- yield "data: [DONE]"
- return StreamingResponse(stream_content(), media_type="text/event-stream")
- else:
- try:
- res = await execute_pipe(pipe, params)
- except Exception as e:
- log.error(f"Error: {e}")
- return {"error": {"detail": str(e)}}
- if isinstance(res, StreamingResponse) or isinstance(res, dict):
- return res
- if isinstance(res, BaseModel):
- return res.model_dump()
- message = await get_message_content(res)
- return openai_chat_completion_message_template(form_data["model"], message)
- async def get_all_base_models(request):
- function_models = []
- openai_models = []
- ollama_models = []
- if app.state.config.ENABLE_OPENAI_API:
- openai_models = await openai.get_all_models(request)
- openai_models = openai_models["data"]
- if app.state.config.ENABLE_OLLAMA_API:
- ollama_models = await ollama.get_all_models(request)
- ollama_models = [
- {
- "id": model["model"],
- "name": model["name"],
- "object": "model",
- "created": int(time.time()),
- "owned_by": "ollama",
- "ollama": model,
- }
- for model in ollama_models["models"]
- ]
- function_models = await get_function_models()
- models = function_models + openai_models + ollama_models
- # Add arena models
- if app.state.config.ENABLE_EVALUATION_ARENA_MODELS:
- arena_models = []
- if len(app.state.config.EVALUATION_ARENA_MODELS) > 0:
- arena_models = [
- {
- "id": model["id"],
- "name": model["name"],
- "info": {
- "meta": model["meta"],
- },
- "object": "model",
- "created": int(time.time()),
- "owned_by": "arena",
- "arena": True,
- }
- for model in app.state.config.EVALUATION_ARENA_MODELS
- ]
- else:
- # Add default arena model
- arena_models = [
- {
- "id": DEFAULT_ARENA_MODEL["id"],
- "name": DEFAULT_ARENA_MODEL["name"],
- "info": {
- "meta": DEFAULT_ARENA_MODEL["meta"],
- },
- "object": "model",
- "created": int(time.time()),
- "owned_by": "arena",
- "arena": True,
- }
- ]
- models = models + arena_models
- return models
- @cached(ttl=3)
- async def get_all_models(request):
- models = await get_all_base_models(request)
- # If there are no models, return an empty list
- if len([model for model in models if not model.get("arena", False)]) == 0:
- return []
- global_action_ids = [
- function.id for function in Functions.get_global_action_functions()
- ]
- enabled_action_ids = [
- function.id
- for function in Functions.get_functions_by_type("action", active_only=True)
- ]
- custom_models = Models.get_all_models()
- for custom_model in custom_models:
- if custom_model.base_model_id is None:
- for model in models:
- if (
- custom_model.id == model["id"]
- or custom_model.id == model["id"].split(":")[0]
- ):
- if custom_model.is_active:
- model["name"] = custom_model.name
- model["info"] = custom_model.model_dump()
- action_ids = []
- if "info" in model and "meta" in model["info"]:
- action_ids.extend(
- model["info"]["meta"].get("actionIds", [])
- )
- model["action_ids"] = action_ids
- else:
- models.remove(model)
- elif custom_model.is_active and (
- custom_model.id not in [model["id"] for model in models]
- ):
- owned_by = "openai"
- pipe = None
- action_ids = []
- for model in models:
- if (
- custom_model.base_model_id == model["id"]
- or custom_model.base_model_id == model["id"].split(":")[0]
- ):
- owned_by = model["owned_by"]
- if "pipe" in model:
- pipe = model["pipe"]
- break
- if custom_model.meta:
- meta = custom_model.meta.model_dump()
- if "actionIds" in meta:
- action_ids.extend(meta["actionIds"])
- models.append(
- {
- "id": f"{custom_model.id}",
- "name": custom_model.name,
- "object": "model",
- "created": custom_model.created_at,
- "owned_by": owned_by,
- "info": custom_model.model_dump(),
- "preset": True,
- **({"pipe": pipe} if pipe is not None else {}),
- "action_ids": action_ids,
- }
- )
- # Process action_ids to get the actions
- def get_action_items_from_module(function, module):
- actions = []
- if hasattr(module, "actions"):
- actions = module.actions
- return [
- {
- "id": f"{function.id}.{action['id']}",
- "name": action.get("name", f"{function.name} ({action['id']})"),
- "description": function.meta.description,
- "icon_url": action.get(
- "icon_url", function.meta.manifest.get("icon_url", None)
- ),
- }
- for action in actions
- ]
- else:
- return [
- {
- "id": function.id,
- "name": function.name,
- "description": function.meta.description,
- "icon_url": function.meta.manifest.get("icon_url", None),
- }
- ]
- def get_function_module_by_id(function_id):
- if function_id in app.state.FUNCTIONS:
- function_module = app.state.FUNCTIONS[function_id]
- else:
- function_module, _, _ = load_function_module_by_id(function_id)
- app.state.FUNCTIONS[function_id] = function_module
- for model in models:
- action_ids = [
- action_id
- for action_id in list(set(model.pop("action_ids", []) + global_action_ids))
- if action_id in enabled_action_ids
- ]
- model["actions"] = []
- for action_id in action_ids:
- action_function = Functions.get_function_by_id(action_id)
- if action_function is None:
- raise Exception(f"Action not found: {action_id}")
- function_module = get_function_module_by_id(action_id)
- model["actions"].extend(
- get_action_items_from_module(action_function, function_module)
- )
- log.debug(f"get_all_models() returned {len(models)} models")
- app.state.MODELS = {model["id"]: model for model in models}
- return models
- @app.get("/api/models")
- async def get_models(request: Request, user=Depends(get_verified_user)):
- models = await get_all_models(request)
- # Filter out filter pipelines
- models = [
- model
- for model in models
- if "pipeline" not in model or model["pipeline"].get("type", None) != "filter"
- ]
- model_order_list = app.state.config.MODEL_ORDER_LIST
- if model_order_list:
- model_order_dict = {model_id: i for i, model_id in enumerate(model_order_list)}
- # Sort models by order list priority, with fallback for those not in the list
- models.sort(
- key=lambda x: (model_order_dict.get(x["id"], float("inf")), x["name"])
- )
- # Filter out models that the user does not have access to
- if user.role == "user" and not BYPASS_MODEL_ACCESS_CONTROL:
- filtered_models = []
- for model in models:
- if model.get("arena"):
- if has_access(
- user.id,
- type="read",
- access_control=model.get("info", {})
- .get("meta", {})
- .get("access_control", {}),
- ):
- filtered_models.append(model)
- continue
- model_info = Models.get_model_by_id(model["id"])
- if model_info:
- if user.id == model_info.user_id or has_access(
- user.id, type="read", access_control=model_info.access_control
- ):
- filtered_models.append(model)
- models = filtered_models
- log.debug(
- f"/api/models returned filtered models accessible to the user: {json.dumps([model['id'] for model in models])}"
- )
- return {"data": models}
- @app.get("/api/models/base")
- async def get_base_models(request: Request, user=Depends(get_admin_user)):
- models = await get_all_base_models(request)
- # Filter out arena models
- models = [model for model in models if not model.get("arena", False)]
- return {"data": models}
- @app.post("/api/chat/completions")
- async def generate_chat_completions(
- request: Request,
- form_data: dict,
- user=Depends(get_verified_user),
- bypass_filter: bool = False,
- ):
- if BYPASS_MODEL_ACCESS_CONTROL:
- bypass_filter = True
- models = app.state.MODELS
- model_id = form_data["model"]
- if model_id not in models:
- raise HTTPException(
- status_code=status.HTTP_404_NOT_FOUND,
- detail="Model not found",
- )
- model = models[model_id]
- # Check if user has access to the model
- if not bypass_filter and user.role == "user":
- if model.get("arena"):
- if not has_access(
- user.id,
- type="read",
- access_control=model.get("info", {})
- .get("meta", {})
- .get("access_control", {}),
- ):
- raise HTTPException(
- status_code=403,
- detail="Model not found",
- )
- else:
- model_info = Models.get_model_by_id(model_id)
- if not model_info:
- raise HTTPException(
- status_code=404,
- detail="Model not found",
- )
- elif not (
- user.id == model_info.user_id
- or has_access(
- user.id, type="read", access_control=model_info.access_control
- )
- ):
- raise HTTPException(
- status_code=403,
- detail="Model not found",
- )
- if model["owned_by"] == "arena":
- model_ids = model.get("info", {}).get("meta", {}).get("model_ids")
- filter_mode = model.get("info", {}).get("meta", {}).get("filter_mode")
- if model_ids and filter_mode == "exclude":
- model_ids = [
- model["id"]
- for model in await get_all_models(request)
- if model.get("owned_by") != "arena" and model["id"] not in model_ids
- ]
- selected_model_id = None
- if isinstance(model_ids, list) and model_ids:
- selected_model_id = random.choice(model_ids)
- else:
- model_ids = [
- model["id"]
- for model in await get_all_models(request)
- if model.get("owned_by") != "arena"
- ]
- selected_model_id = random.choice(model_ids)
- form_data["model"] = selected_model_id
- if form_data.get("stream") == True:
- async def stream_wrapper(stream):
- yield f"data: {json.dumps({'selected_model_id': selected_model_id})}\n\n"
- async for chunk in stream:
- yield chunk
- response = await generate_chat_completions(
- form_data, user, bypass_filter=True
- )
- return StreamingResponse(
- stream_wrapper(response.body_iterator), media_type="text/event-stream"
- )
- else:
- return {
- **(
- await generate_chat_completions(form_data, user, bypass_filter=True)
- ),
- "selected_model_id": selected_model_id,
- }
- if model.get("pipe"):
- # Below does not require bypass_filter because this is the only route the uses this function and it is already bypassing the filter
- return await generate_function_chat_completion(
- form_data, user=user, models=models
- )
- if model["owned_by"] == "ollama":
- # Using /ollama/api/chat endpoint
- form_data = convert_payload_openai_to_ollama(form_data)
- response = await generate_ollama_chat_completion(
- request=request,
- form_data=form_data, user=user, bypass_filter=bypass_filter
- )
- if form_data.stream:
- response.headers["content-type"] = "text/event-stream"
- return StreamingResponse(
- convert_streaming_response_ollama_to_openai(response),
- headers=dict(response.headers),
- )
- else:
- return convert_response_ollama_to_openai(response)
- else:
- return await generate_openai_chat_completion(
- request=request,
- form_data, user=user, bypass_filter=bypass_filter
- )
- @app.post("/api/chat/completed")
- async def chat_completed(
- request: Request, form_data: dict, user=Depends(get_verified_user)
- ):
- await get_all_models(request)
- models = app.state.MODELS
- data = form_data
- model_id = data["model"]
- if model_id not in models:
- raise HTTPException(
- status_code=status.HTTP_404_NOT_FOUND,
- detail="Model not found",
- )
- model = models[model_id]
- try:
- data = process_pipeline_outlet_filter(request, data, user, models)
- except Exception as e:
- return HTTPException(
- status_code=status.HTTP_400_BAD_REQUEST,
- detail=str(e),
- )
- __event_emitter__ = get_event_emitter(
- {
- "chat_id": data["chat_id"],
- "message_id": data["id"],
- "session_id": data["session_id"],
- }
- )
- __event_call__ = get_event_call(
- {
- "chat_id": data["chat_id"],
- "message_id": data["id"],
- "session_id": data["session_id"],
- }
- )
- def get_priority(function_id):
- function = Functions.get_function_by_id(function_id)
- if function is not None and hasattr(function, "valves"):
- # TODO: Fix FunctionModel to include vavles
- return (function.valves if function.valves else {}).get("priority", 0)
- return 0
- filter_ids = [function.id for function in Functions.get_global_filter_functions()]
- if "info" in model and "meta" in model["info"]:
- filter_ids.extend(model["info"]["meta"].get("filterIds", []))
- filter_ids = list(set(filter_ids))
- enabled_filter_ids = [
- function.id
- for function in Functions.get_functions_by_type("filter", active_only=True)
- ]
- filter_ids = [
- filter_id for filter_id in filter_ids if filter_id in enabled_filter_ids
- ]
- # Sort filter_ids by priority, using the get_priority function
- filter_ids.sort(key=get_priority)
- for filter_id in filter_ids:
- filter = Functions.get_function_by_id(filter_id)
- if not filter:
- continue
- if filter_id in app.state.FUNCTIONS:
- function_module = app.state.FUNCTIONS[filter_id]
- else:
- function_module, _, _ = load_function_module_by_id(filter_id)
- app.state.FUNCTIONS[filter_id] = function_module
- if hasattr(function_module, "valves") and hasattr(function_module, "Valves"):
- valves = Functions.get_function_valves_by_id(filter_id)
- function_module.valves = function_module.Valves(
- **(valves if valves else {})
- )
- if not hasattr(function_module, "outlet"):
- continue
- try:
- outlet = function_module.outlet
- # Get the signature of the function
- sig = inspect.signature(outlet)
- params = {"body": data}
- # Extra parameters to be passed to the function
- extra_params = {
- "__model__": model,
- "__id__": filter_id,
- "__event_emitter__": __event_emitter__,
- "__event_call__": __event_call__,
- }
- # Add extra params in contained in function signature
- for key, value in extra_params.items():
- if key in sig.parameters:
- params[key] = value
- if "__user__" in sig.parameters:
- __user__ = {
- "id": user.id,
- "email": user.email,
- "name": user.name,
- "role": user.role,
- }
- try:
- if hasattr(function_module, "UserValves"):
- __user__["valves"] = function_module.UserValves(
- **Functions.get_user_valves_by_id_and_user_id(
- filter_id, user.id
- )
- )
- except Exception as e:
- print(e)
- params = {**params, "__user__": __user__}
- if inspect.iscoroutinefunction(outlet):
- data = await outlet(**params)
- else:
- data = outlet(**params)
- except Exception as e:
- print(f"Error: {e}")
- return JSONResponse(
- status_code=status.HTTP_400_BAD_REQUEST,
- content={"detail": str(e)},
- )
- return data
- @app.post("/api/chat/actions/{action_id}")
- async def chat_action(
- request: Request, action_id: str, form_data: dict, user=Depends(get_verified_user)
- ):
- if "." in action_id:
- action_id, sub_action_id = action_id.split(".")
- else:
- sub_action_id = None
- action = Functions.get_function_by_id(action_id)
- if not action:
- raise HTTPException(
- status_code=status.HTTP_404_NOT_FOUND,
- detail="Action not found",
- )
- await get_all_models(request)
- models = app.state.MODELS
- data = form_data
- model_id = data["model"]
- if model_id not in models:
- raise HTTPException(
- status_code=status.HTTP_404_NOT_FOUND,
- detail="Model not found",
- )
- model = models[model_id]
- __event_emitter__ = get_event_emitter(
- {
- "chat_id": data["chat_id"],
- "message_id": data["id"],
- "session_id": data["session_id"],
- }
- )
- __event_call__ = get_event_call(
- {
- "chat_id": data["chat_id"],
- "message_id": data["id"],
- "session_id": data["session_id"],
- }
- )
- if action_id in app.state.FUNCTIONS:
- function_module = app.state.FUNCTIONS[action_id]
- else:
- function_module, _, _ = load_function_module_by_id(action_id)
- app.state.FUNCTIONS[action_id] = function_module
- if hasattr(function_module, "valves") and hasattr(function_module, "Valves"):
- valves = Functions.get_function_valves_by_id(action_id)
- function_module.valves = function_module.Valves(**(valves if valves else {}))
- if hasattr(function_module, "action"):
- try:
- action = function_module.action
- # Get the signature of the function
- sig = inspect.signature(action)
- params = {"body": data}
- # Extra parameters to be passed to the function
- extra_params = {
- "__model__": model,
- "__id__": sub_action_id if sub_action_id is not None else action_id,
- "__event_emitter__": __event_emitter__,
- "__event_call__": __event_call__,
- }
- # Add extra params in contained in function signature
- for key, value in extra_params.items():
- if key in sig.parameters:
- params[key] = value
- if "__user__" in sig.parameters:
- __user__ = {
- "id": user.id,
- "email": user.email,
- "name": user.name,
- "role": user.role,
- }
- try:
- if hasattr(function_module, "UserValves"):
- __user__["valves"] = function_module.UserValves(
- **Functions.get_user_valves_by_id_and_user_id(
- action_id, user.id
- )
- )
- except Exception as e:
- print(e)
- params = {**params, "__user__": __user__}
- if inspect.iscoroutinefunction(action):
- data = await action(**params)
- else:
- data = action(**params)
- except Exception as e:
- print(f"Error: {e}")
- return JSONResponse(
- status_code=status.HTTP_400_BAD_REQUEST,
- content={"detail": str(e)},
- )
- return data
- ##################################
- #
- # Config Endpoints
- #
- ##################################
- @app.get("/api/config")
- async def get_app_config(request: Request):
- user = None
- if "token" in request.cookies:
- token = request.cookies.get("token")
- try:
- data = decode_token(token)
- except Exception as e:
- log.debug(e)
- raise HTTPException(
- status_code=status.HTTP_401_UNAUTHORIZED,
- detail="Invalid token",
- )
- if data is not None and "id" in data:
- user = Users.get_user_by_id(data["id"])
- onboarding = False
- if user is None:
- user_count = Users.get_num_users()
- onboarding = user_count == 0
- return {
- **({"onboarding": True} if onboarding else {}),
- "status": True,
- "name": WEBUI_NAME,
- "version": VERSION,
- "default_locale": str(DEFAULT_LOCALE),
- "oauth": {
- "providers": {
- name: config.get("name", name)
- for name, config in OAUTH_PROVIDERS.items()
- }
- },
- "features": {
- "auth": WEBUI_AUTH,
- "auth_trusted_header": bool(app.state.AUTH_TRUSTED_EMAIL_HEADER),
- "enable_ldap": app.state.config.ENABLE_LDAP,
- "enable_api_key": app.state.config.ENABLE_API_KEY,
- "enable_signup": app.state.config.ENABLE_SIGNUP,
- "enable_login_form": app.state.config.ENABLE_LOGIN_FORM,
- **(
- {
- "enable_web_search": app.state.config.ENABLE_RAG_WEB_SEARCH,
- "enable_image_generation": app.state.config.ENABLE_IMAGE_GENERATION,
- "enable_community_sharing": app.state.config.ENABLE_COMMUNITY_SHARING,
- "enable_message_rating": app.state.config.ENABLE_MESSAGE_RATING,
- "enable_admin_export": ENABLE_ADMIN_EXPORT,
- "enable_admin_chat_access": ENABLE_ADMIN_CHAT_ACCESS,
- }
- if user is not None
- else {}
- ),
- },
- **(
- {
- "default_models": app.state.config.DEFAULT_MODELS,
- "default_prompt_suggestions": app.state.config.DEFAULT_PROMPT_SUGGESTIONS,
- "audio": {
- "tts": {
- "engine": app.state.config.TTS_ENGINE,
- "voice": app.state.config.TTS_VOICE,
- "split_on": app.state.config.TTS_SPLIT_ON,
- },
- "stt": {
- "engine": app.state.config.STT_ENGINE,
- },
- },
- "file": {
- "max_size": app.state.config.FILE_MAX_SIZE,
- "max_count": app.state.config.FILE_MAX_COUNT,
- },
- "permissions": {**app.state.config.USER_PERMISSIONS},
- }
- if user is not None
- else {}
- ),
- }
- class UrlForm(BaseModel):
- url: str
- @app.get("/api/webhook")
- async def get_webhook_url(user=Depends(get_admin_user)):
- return {
- "url": app.state.config.WEBHOOK_URL,
- }
- @app.post("/api/webhook")
- async def update_webhook_url(form_data: UrlForm, user=Depends(get_admin_user)):
- app.state.config.WEBHOOK_URL = form_data.url
- app.state.WEBHOOK_URL = app.state.config.WEBHOOK_URL
- return {"url": app.state.config.WEBHOOK_URL}
- @app.get("/api/version")
- async def get_app_version():
- return {
- "version": VERSION,
- }
- @app.get("/api/version/updates")
- async def get_app_latest_release_version():
- if OFFLINE_MODE:
- log.debug(
- f"Offline mode is enabled, returning current version as latest version"
- )
- return {"current": VERSION, "latest": VERSION}
- try:
- timeout = aiohttp.ClientTimeout(total=1)
- async with aiohttp.ClientSession(timeout=timeout, trust_env=True) as session:
- async with session.get(
- "https://api.github.com/repos/open-webui/open-webui/releases/latest"
- ) as response:
- response.raise_for_status()
- data = await response.json()
- latest_version = data["tag_name"]
- return {"current": VERSION, "latest": latest_version[1:]}
- except Exception as e:
- log.debug(e)
- return {"current": VERSION, "latest": VERSION}
- @app.get("/api/changelog")
- async def get_app_changelog():
- return {key: CHANGELOG[key] for idx, key in enumerate(CHANGELOG) if idx < 5}
- ############################
- # OAuth Login & Callback
- ############################
- # SessionMiddleware is used by authlib for oauth
- if len(OAUTH_PROVIDERS) > 0:
- app.add_middleware(
- SessionMiddleware,
- secret_key=WEBUI_SECRET_KEY,
- session_cookie="oui-session",
- same_site=WEBUI_SESSION_COOKIE_SAME_SITE,
- https_only=WEBUI_SESSION_COOKIE_SECURE,
- )
- @app.get("/oauth/{provider}/login")
- async def oauth_login(provider: str, request: Request):
- return await oauth_manager.handle_login(provider, request)
- # OAuth login logic is as follows:
- # 1. Attempt to find a user with matching subject ID, tied to the provider
- # 2. If OAUTH_MERGE_ACCOUNTS_BY_EMAIL is true, find a user with the email address provided via OAuth
- # - This is considered insecure in general, as OAuth providers do not always verify email addresses
- # 3. If there is no user, and ENABLE_OAUTH_SIGNUP is true, create a user
- # - Email addresses are considered unique, so we fail registration if the email address is already taken
- @app.get("/oauth/{provider}/callback")
- async def oauth_callback(provider: str, request: Request, response: Response):
- return await oauth_manager.handle_callback(provider, request, response)
- @app.get("/manifest.json")
- async def get_manifest_json():
- return {
- "name": WEBUI_NAME,
- "short_name": WEBUI_NAME,
- "description": "Open WebUI is an open, extensible, user-friendly interface for AI that adapts to your workflow.",
- "start_url": "/",
- "display": "standalone",
- "background_color": "#343541",
- "orientation": "natural",
- "icons": [
- {
- "src": "/static/logo.png",
- "type": "image/png",
- "sizes": "500x500",
- "purpose": "any",
- },
- {
- "src": "/static/logo.png",
- "type": "image/png",
- "sizes": "500x500",
- "purpose": "maskable",
- },
- ],
- }
- @app.get("/opensearch.xml")
- async def get_opensearch_xml():
- xml_content = rf"""
- <OpenSearchDescription xmlns="http://a9.com/-/spec/opensearch/1.1/" xmlns:moz="http://www.mozilla.org/2006/browser/search/">
- <ShortName>{WEBUI_NAME}</ShortName>
- <Description>Search {WEBUI_NAME}</Description>
- <InputEncoding>UTF-8</InputEncoding>
- <Image width="16" height="16" type="image/x-icon">{WEBUI_URL}/static/favicon.png</Image>
- <Url type="text/html" method="get" template="{WEBUI_URL}/?q={"{searchTerms}"}"/>
- <moz:SearchForm>{WEBUI_URL}</moz:SearchForm>
- </OpenSearchDescription>
- """
- return Response(content=xml_content, media_type="application/xml")
- @app.get("/health")
- async def healthcheck():
- return {"status": True}
- @app.get("/health/db")
- async def healthcheck_with_db():
- Session.execute(text("SELECT 1;")).all()
- return {"status": True}
- app.mount("/static", StaticFiles(directory=STATIC_DIR), name="static")
- app.mount("/cache", StaticFiles(directory=CACHE_DIR), name="cache")
- if os.path.exists(FRONTEND_BUILD_DIR):
- mimetypes.add_type("text/javascript", ".js")
- app.mount(
- "/",
- SPAStaticFiles(directory=FRONTEND_BUILD_DIR, html=True),
- name="spa-static-files",
- )
- else:
- log.warning(
- f"Frontend build directory not found at '{FRONTEND_BUILD_DIR}'. Serving API only."
- )
|