1234567891011121314151617181920212223242526272829303132333435363738394041424344454647484950515253545556575859606162636465666768697071727374757677787980818283848586878889909192939495969798991001011021031041051061071081091101111121131141151161171181191201211221231241251261271281291301311321331341351361371381391401411421431441451461471481491501511521531541551561571581591601611621631641651661671681691701711721731741751761771781791801811821831841851861871881891901911921931941951961971981992002012022032042052062072082092102112122132142152162172182192202212222232242252262272282292302312322332342352362372382392402412422432442452462472482492502512522532542552562572582592602612622632642652662672682692702712722732742752762772782792802812822832842852862872882892902912922932942952962972982993003013023033043053063073083093103113123133143153163173183193203213223233243253263273283293303313323333343353363373383393403413423433443453463473483493503513523533543553563573583593603613623633643653663673683693703713723733743753763773783793803813823833843853863873883893903913923933943953963973983994004014024034044054064074084094104114124134144154164174184194204214224234244254264274284294304314324334344354364374384394404414424434444454464474484494504514524534544554564574584594604614624634644654664674684694704714724734744754764774784794804814824834844854864874884894904914924934944954964974984995005015025035045055065075085095105115125135145155165175185195205215225235245255265275285295305315325335345355365375385395405415425435445455465475485495505515525535545555565575585595605615625635645655665675685695705715725735745755765775785795805815825835845855865875885895905915925935945955965975985996006016026036046056066076086096106116126136146156166176186196206216226236246256266276286296306316326336346356366376386396406416426436446456466476486496506516526536546556566576586596606616626636646656666676686696706716726736746756766776786796806816826836846856866876886896906916926936946956966976986997007017027037047057067077087097107117127137147157167177187197207217227237247257267277287297307317327337347357367377387397407417427437447457467477487497507517527537547557567577587597607617627637647657667677687697707717727737747757767777787797807817827837847857867877887897907917927937947957967977987998008018028038048058068078088098108118128138148158168178188198208218228238248258268278288298308318328338348358368378388398408418428438448458468478488498508518528538548558568578588598608618628638648658668678688698708718728738748758768778788798808818828838848858868878888898908918928938948958968978988999009019029039049059069079089099109119129139149159169179189199209219229239249259269279289299309319329339349359369379389399409419429439449459469479489499509519529539549559569579589599609619629639649659669679689699709719729739749759769779789799809819829839849859869879889899909919929939949959969979989991000100110021003100410051006100710081009101010111012101310141015101610171018101910201021102210231024102510261027102810291030103110321033103410351036103710381039104010411042104310441045104610471048104910501051105210531054105510561057105810591060106110621063106410651066106710681069107010711072107310741075107610771078107910801081108210831084108510861087108810891090109110921093109410951096109710981099110011011102110311041105110611071108110911101111111211131114111511161117111811191120112111221123112411251126112711281129113011311132113311341135113611371138113911401141114211431144114511461147114811491150115111521153115411551156115711581159116011611162116311641165116611671168116911701171117211731174117511761177117811791180118111821183118411851186118711881189119011911192119311941195119611971198119912001201120212031204120512061207120812091210121112121213121412151216121712181219122012211222122312241225122612271228122912301231123212331234123512361237123812391240124112421243124412451246124712481249125012511252125312541255125612571258125912601261126212631264126512661267126812691270127112721273127412751276127712781279128012811282128312841285128612871288128912901291129212931294129512961297129812991300130113021303130413051306130713081309131013111312131313141315131613171318131913201321132213231324132513261327132813291330133113321333133413351336133713381339134013411342134313441345134613471348134913501351135213531354135513561357135813591360136113621363136413651366136713681369137013711372137313741375137613771378137913801381138213831384138513861387138813891390139113921393139413951396139713981399140014011402140314041405140614071408140914101411141214131414141514161417141814191420142114221423142414251426142714281429143014311432143314341435143614371438143914401441144214431444144514461447144814491450145114521453145414551456145714581459146014611462146314641465146614671468146914701471147214731474147514761477147814791480148114821483148414851486148714881489149014911492149314941495149614971498149915001501150215031504150515061507150815091510151115121513151415151516151715181519152015211522152315241525152615271528152915301531153215331534153515361537153815391540154115421543154415451546154715481549155015511552155315541555155615571558155915601561156215631564156515661567156815691570157115721573157415751576157715781579158015811582158315841585158615871588158915901591159215931594159515961597159815991600160116021603160416051606160716081609161016111612161316141615161616171618161916201621162216231624162516261627162816291630163116321633163416351636163716381639164016411642164316441645164616471648164916501651165216531654165516561657165816591660166116621663166416651666166716681669167016711672167316741675167616771678167916801681168216831684168516861687168816891690169116921693169416951696169716981699170017011702170317041705170617071708170917101711171217131714171517161717171817191720172117221723172417251726172717281729173017311732173317341735173617371738173917401741174217431744174517461747174817491750175117521753175417551756175717581759176017611762176317641765176617671768176917701771177217731774177517761777177817791780178117821783178417851786178717881789179017911792179317941795179617971798179918001801180218031804180518061807180818091810181118121813181418151816181718181819182018211822182318241825182618271828182918301831183218331834183518361837183818391840184118421843184418451846184718481849185018511852185318541855185618571858185918601861186218631864186518661867186818691870187118721873187418751876187718781879188018811882188318841885188618871888188918901891189218931894189518961897189818991900190119021903190419051906190719081909191019111912191319141915191619171918191919201921192219231924192519261927192819291930193119321933193419351936193719381939194019411942194319441945194619471948194919501951195219531954195519561957195819591960196119621963196419651966196719681969197019711972197319741975197619771978197919801981198219831984198519861987198819891990199119921993199419951996199719981999200020012002200320042005200620072008200920102011201220132014201520162017201820192020202120222023202420252026202720282029203020312032203320342035203620372038203920402041204220432044204520462047204820492050205120522053205420552056205720582059206020612062206320642065206620672068206920702071207220732074207520762077207820792080208120822083208420852086208720882089209020912092209320942095209620972098209921002101210221032104210521062107210821092110211121122113211421152116211721182119212021212122212321242125212621272128212921302131213221332134213521362137213821392140214121422143214421452146214721482149215021512152215321542155215621572158215921602161216221632164216521662167216821692170217121722173217421752176217721782179218021812182218321842185218621872188218921902191219221932194219521962197219821992200220122022203220422052206220722082209221022112212221322142215221622172218221922202221222222232224222522262227222822292230223122322233223422352236223722382239224022412242224322442245224622472248224922502251225222532254225522562257225822592260226122622263226422652266226722682269227022712272227322742275227622772278227922802281228222832284228522862287228822892290229122922293229422952296229722982299230023012302230323042305230623072308 |
- import base64
- import uuid
- from contextlib import asynccontextmanager
- from authlib.integrations.starlette_client import OAuth
- from authlib.oidc.core import UserInfo
- import json
- import time
- import os
- import sys
- import logging
- import aiohttp
- import requests
- import mimetypes
- import shutil
- import os
- import uuid
- import inspect
- from fastapi import FastAPI, Request, Depends, status, UploadFile, File, Form
- from fastapi.staticfiles import StaticFiles
- from fastapi.responses import JSONResponse
- from fastapi import HTTPException
- from fastapi.middleware.cors import CORSMiddleware
- from sqlalchemy import text
- from starlette.exceptions import HTTPException as StarletteHTTPException
- from starlette.middleware.base import BaseHTTPMiddleware
- from starlette.middleware.sessions import SessionMiddleware
- from starlette.responses import StreamingResponse, Response, RedirectResponse
- from apps.socket.main import sio, app as socket_app, get_event_emitter, get_event_call
- from apps.ollama.main import (
- app as ollama_app,
- get_all_models as get_ollama_models,
- generate_openai_chat_completion as generate_ollama_chat_completion,
- )
- from apps.openai.main import (
- app as openai_app,
- get_all_models as get_openai_models,
- generate_chat_completion as generate_openai_chat_completion,
- )
- from apps.audio.main import app as audio_app
- from apps.images.main import app as images_app
- from apps.rag.main import app as rag_app
- from apps.webui.main import (
- app as webui_app,
- get_pipe_models,
- generate_function_chat_completion,
- )
- from apps.webui.internal.db import Session
- from pydantic import BaseModel
- from typing import List, Optional
- from apps.webui.models.auths import Auths
- from apps.webui.models.models import Models
- from apps.webui.models.tools import Tools
- from apps.webui.models.functions import Functions
- from apps.webui.models.users import Users
- from apps.webui.utils import load_toolkit_module_by_id, load_function_module_by_id
- from utils.utils import (
- get_admin_user,
- get_verified_user,
- get_current_user,
- get_http_authorization_cred,
- get_password_hash,
- create_token,
- )
- from utils.task import (
- title_generation_template,
- search_query_generation_template,
- tools_function_calling_generation_template,
- )
- from utils.misc import (
- get_last_user_message,
- add_or_update_system_message,
- prepend_to_first_user_message_content,
- parse_duration,
- )
- from apps.rag.utils import get_rag_context, rag_template
- from config import (
- WEBUI_NAME,
- WEBUI_URL,
- WEBUI_AUTH,
- ENV,
- VERSION,
- CHANGELOG,
- FRONTEND_BUILD_DIR,
- CACHE_DIR,
- STATIC_DIR,
- DEFAULT_LOCALE,
- ENABLE_OPENAI_API,
- ENABLE_OLLAMA_API,
- ENABLE_MODEL_FILTER,
- MODEL_FILTER_LIST,
- GLOBAL_LOG_LEVEL,
- SRC_LOG_LEVELS,
- WEBHOOK_URL,
- ENABLE_ADMIN_EXPORT,
- WEBUI_BUILD_HASH,
- TASK_MODEL,
- TASK_MODEL_EXTERNAL,
- TITLE_GENERATION_PROMPT_TEMPLATE,
- SEARCH_QUERY_GENERATION_PROMPT_TEMPLATE,
- SEARCH_QUERY_PROMPT_LENGTH_THRESHOLD,
- TOOLS_FUNCTION_CALLING_PROMPT_TEMPLATE,
- SAFE_MODE,
- OAUTH_PROVIDERS,
- ENABLE_OAUTH_SIGNUP,
- OAUTH_MERGE_ACCOUNTS_BY_EMAIL,
- WEBUI_SECRET_KEY,
- WEBUI_SESSION_COOKIE_SAME_SITE,
- WEBUI_SESSION_COOKIE_SECURE,
- AppConfig,
- )
- from constants import ERROR_MESSAGES, WEBHOOK_MESSAGES, TASKS
- from utils.webhook import post_webhook
- if SAFE_MODE:
- print("SAFE MODE ENABLED")
- Functions.deactivate_all_functions()
- logging.basicConfig(stream=sys.stdout, level=GLOBAL_LOG_LEVEL)
- log = logging.getLogger(__name__)
- log.setLevel(SRC_LOG_LEVELS["MAIN"])
- class SPAStaticFiles(StaticFiles):
- async def get_response(self, path: str, scope):
- try:
- return await super().get_response(path, scope)
- except (HTTPException, StarletteHTTPException) as ex:
- if ex.status_code == 404:
- return await super().get_response("index.html", scope)
- else:
- raise ex
- print(
- rf"""
- ___ __ __ _ _ _ ___
- / _ \ _ __ ___ _ __ \ \ / /__| |__ | | | |_ _|
- | | | | '_ \ / _ \ '_ \ \ \ /\ / / _ \ '_ \| | | || |
- | |_| | |_) | __/ | | | \ V V / __/ |_) | |_| || |
- \___/| .__/ \___|_| |_| \_/\_/ \___|_.__/ \___/|___|
- |_|
-
- v{VERSION} - building the best open-source AI user interface.
- {f"Commit: {WEBUI_BUILD_HASH}" if WEBUI_BUILD_HASH != "dev-build" else ""}
- https://github.com/open-webui/open-webui
- """
- )
- def run_migrations():
- try:
- from alembic.config import Config
- from alembic import command
- alembic_cfg = Config("alembic.ini")
- command.upgrade(alembic_cfg, "head")
- except Exception as e:
- print(f"Error: {e}")
- @asynccontextmanager
- async def lifespan(app: FastAPI):
- run_migrations()
- yield
- app = FastAPI(
- docs_url="/docs" if ENV == "dev" else None, redoc_url=None, lifespan=lifespan
- )
- app.state.config = AppConfig()
- app.state.config.ENABLE_OPENAI_API = ENABLE_OPENAI_API
- app.state.config.ENABLE_OLLAMA_API = ENABLE_OLLAMA_API
- app.state.config.ENABLE_MODEL_FILTER = ENABLE_MODEL_FILTER
- app.state.config.MODEL_FILTER_LIST = MODEL_FILTER_LIST
- app.state.config.WEBHOOK_URL = WEBHOOK_URL
- app.state.config.TASK_MODEL = TASK_MODEL
- app.state.config.TASK_MODEL_EXTERNAL = TASK_MODEL_EXTERNAL
- app.state.config.TITLE_GENERATION_PROMPT_TEMPLATE = TITLE_GENERATION_PROMPT_TEMPLATE
- app.state.config.SEARCH_QUERY_GENERATION_PROMPT_TEMPLATE = (
- SEARCH_QUERY_GENERATION_PROMPT_TEMPLATE
- )
- app.state.config.SEARCH_QUERY_PROMPT_LENGTH_THRESHOLD = (
- SEARCH_QUERY_PROMPT_LENGTH_THRESHOLD
- )
- app.state.config.TOOLS_FUNCTION_CALLING_PROMPT_TEMPLATE = (
- TOOLS_FUNCTION_CALLING_PROMPT_TEMPLATE
- )
- app.state.MODELS = {}
- origins = ["*"]
- ##################################
- #
- # ChatCompletion Middleware
- #
- ##################################
- async def get_body_and_model_and_user(request):
- # Read the original request body
- body = await request.body()
- body_str = body.decode("utf-8")
- body = json.loads(body_str) if body_str else {}
- model_id = body["model"]
- if model_id not in app.state.MODELS:
- raise Exception("Model not found")
- model = app.state.MODELS[model_id]
- user = get_current_user(
- request,
- get_http_authorization_cred(request.headers.get("Authorization")),
- )
- return body, model, user
- def get_task_model_id(default_model_id):
- # Set the task model
- task_model_id = default_model_id
- # Check if the user has a custom task model and use that model
- if app.state.MODELS[task_model_id]["owned_by"] == "ollama":
- if (
- app.state.config.TASK_MODEL
- and app.state.config.TASK_MODEL in app.state.MODELS
- ):
- task_model_id = app.state.config.TASK_MODEL
- else:
- if (
- app.state.config.TASK_MODEL_EXTERNAL
- and app.state.config.TASK_MODEL_EXTERNAL in app.state.MODELS
- ):
- task_model_id = app.state.config.TASK_MODEL_EXTERNAL
- return task_model_id
- def get_filter_function_ids(model):
- def get_priority(function_id):
- function = Functions.get_function_by_id(function_id)
- if function is not None and hasattr(function, "valves"):
- return (function.valves if function.valves else {}).get("priority", 0)
- return 0
- filter_ids = [function.id for function in Functions.get_global_filter_functions()]
- if "info" in model and "meta" in model["info"]:
- filter_ids.extend(model["info"]["meta"].get("filterIds", []))
- filter_ids = list(set(filter_ids))
- enabled_filter_ids = [
- function.id
- for function in Functions.get_functions_by_type("filter", active_only=True)
- ]
- filter_ids = [
- filter_id for filter_id in filter_ids if filter_id in enabled_filter_ids
- ]
- filter_ids.sort(key=get_priority)
- return filter_ids
- async def get_function_call_response(
- messages,
- files,
- tool_id,
- template,
- task_model_id,
- user,
- __event_emitter__=None,
- __event_call__=None,
- ):
- tool = Tools.get_tool_by_id(tool_id)
- tools_specs = json.dumps(tool.specs, indent=2)
- content = tools_function_calling_generation_template(template, tools_specs)
- user_message = get_last_user_message(messages)
- prompt = (
- "History:\n"
- + "\n".join(
- [
- f"{message['role'].upper()}: \"\"\"{message['content']}\"\"\""
- for message in messages[::-1][:4]
- ]
- )
- + f"\nQuery: {user_message}"
- )
- print(prompt)
- payload = {
- "model": task_model_id,
- "messages": [
- {"role": "system", "content": content},
- {"role": "user", "content": f"Query: {prompt}"},
- ],
- "stream": False,
- "task": str(TASKS.FUNCTION_CALLING),
- }
- try:
- payload = filter_pipeline(payload, user)
- except Exception as e:
- raise e
- model = app.state.MODELS[task_model_id]
- response = None
- try:
- response = await generate_chat_completions(form_data=payload, user=user)
- content = None
- if hasattr(response, "body_iterator"):
- async for chunk in response.body_iterator:
- data = json.loads(chunk.decode("utf-8"))
- content = data["choices"][0]["message"]["content"]
- # Cleanup any remaining background tasks if necessary
- if response.background is not None:
- await response.background()
- else:
- content = response["choices"][0]["message"]["content"]
- if content is None:
- return None, None, False
- # Parse the function response
- print(f"content: {content}")
- result = json.loads(content)
- print(result)
- citation = None
- if "name" not in result:
- return None, None, False
- # Call the function
- if tool_id in webui_app.state.TOOLS:
- toolkit_module = webui_app.state.TOOLS[tool_id]
- else:
- toolkit_module, _ = load_toolkit_module_by_id(tool_id)
- webui_app.state.TOOLS[tool_id] = toolkit_module
- file_handler = False
- # check if toolkit_module has file_handler self variable
- if hasattr(toolkit_module, "file_handler"):
- file_handler = True
- print("file_handler: ", file_handler)
- if hasattr(toolkit_module, "valves") and hasattr(toolkit_module, "Valves"):
- valves = Tools.get_tool_valves_by_id(tool_id)
- toolkit_module.valves = toolkit_module.Valves(**(valves if valves else {}))
- function = getattr(toolkit_module, result["name"])
- function_result = None
- try:
- # Get the signature of the function
- sig = inspect.signature(function)
- params = result["parameters"]
- # Extra parameters to be passed to the function
- extra_params = {
- "__model__": model,
- "__id__": tool_id,
- "__messages__": messages,
- "__files__": files,
- "__event_emitter__": __event_emitter__,
- "__event_call__": __event_call__,
- }
- # Add extra params in contained in function signature
- for key, value in extra_params.items():
- if key in sig.parameters:
- params[key] = value
- if "__user__" in sig.parameters:
- # Call the function with the '__user__' parameter included
- __user__ = {
- "id": user.id,
- "email": user.email,
- "name": user.name,
- "role": user.role,
- }
- try:
- if hasattr(toolkit_module, "UserValves"):
- __user__["valves"] = toolkit_module.UserValves(
- **Tools.get_user_valves_by_id_and_user_id(tool_id, user.id)
- )
- except Exception as e:
- print(e)
- params = {**params, "__user__": __user__}
- if inspect.iscoroutinefunction(function):
- function_result = await function(**params)
- else:
- function_result = function(**params)
- if hasattr(toolkit_module, "citation") and toolkit_module.citation:
- citation = {
- "source": {"name": f"TOOL:{tool.name}/{result['name']}"},
- "document": [function_result],
- "metadata": [{"source": result["name"]}],
- }
- except Exception as e:
- print(e)
- # Add the function result to the system prompt
- if function_result is not None:
- return function_result, citation, file_handler
- except Exception as e:
- print(f"Error: {e}")
- return None, None, False
- async def chat_completion_functions_handler(
- body, model, user, __event_emitter__, __event_call__
- ):
- skip_files = None
- filter_ids = get_filter_function_ids(model)
- for filter_id in filter_ids:
- filter = Functions.get_function_by_id(filter_id)
- if not filter:
- continue
- if filter_id in webui_app.state.FUNCTIONS:
- function_module = webui_app.state.FUNCTIONS[filter_id]
- else:
- function_module, _, _ = load_function_module_by_id(filter_id)
- webui_app.state.FUNCTIONS[filter_id] = function_module
- # Check if the function has a file_handler variable
- if hasattr(function_module, "file_handler"):
- skip_files = function_module.file_handler
- if hasattr(function_module, "valves") and hasattr(function_module, "Valves"):
- valves = Functions.get_function_valves_by_id(filter_id)
- function_module.valves = function_module.Valves(
- **(valves if valves else {})
- )
- if not hasattr(function_module, "inlet"):
- continue
- try:
- inlet = function_module.inlet
- # Get the signature of the function
- sig = inspect.signature(inlet)
- params = {"body": body}
- # Extra parameters to be passed to the function
- extra_params = {
- "__model__": model,
- "__id__": filter_id,
- "__event_emitter__": __event_emitter__,
- "__event_call__": __event_call__,
- }
- # Add extra params in contained in function signature
- for key, value in extra_params.items():
- if key in sig.parameters:
- params[key] = value
- if "__user__" in sig.parameters:
- __user__ = {
- "id": user.id,
- "email": user.email,
- "name": user.name,
- "role": user.role,
- }
- try:
- if hasattr(function_module, "UserValves"):
- __user__["valves"] = function_module.UserValves(
- **Functions.get_user_valves_by_id_and_user_id(
- filter_id, user.id
- )
- )
- except Exception as e:
- print(e)
- params = {**params, "__user__": __user__}
- if inspect.iscoroutinefunction(inlet):
- body = await inlet(**params)
- else:
- body = inlet(**params)
- except Exception as e:
- print(f"Error: {e}")
- raise e
- if skip_files:
- if "files" in body:
- del body["files"]
- return body, {}
- async def chat_completion_tools_handler(body, user, __event_emitter__, __event_call__):
- skip_files = None
- contexts = []
- citations = None
- task_model_id = get_task_model_id(body["model"])
- # If tool_ids field is present, call the functions
- if "tool_ids" in body:
- print(body["tool_ids"])
- for tool_id in body["tool_ids"]:
- print(tool_id)
- try:
- response, citation, file_handler = await get_function_call_response(
- messages=body["messages"],
- files=body.get("files", []),
- tool_id=tool_id,
- template=app.state.config.TOOLS_FUNCTION_CALLING_PROMPT_TEMPLATE,
- task_model_id=task_model_id,
- user=user,
- __event_emitter__=__event_emitter__,
- __event_call__=__event_call__,
- )
- print(file_handler)
- if isinstance(response, str):
- contexts.append(response)
- if citation:
- if citations is None:
- citations = [citation]
- else:
- citations.append(citation)
- if file_handler:
- skip_files = True
- except Exception as e:
- print(f"Error: {e}")
- del body["tool_ids"]
- print(f"tool_contexts: {contexts}")
- if skip_files:
- if "files" in body:
- del body["files"]
- return body, {
- **({"contexts": contexts} if contexts is not None else {}),
- **({"citations": citations} if citations is not None else {}),
- }
- async def chat_completion_files_handler(body):
- contexts = []
- citations = None
- if "files" in body:
- files = body["files"]
- del body["files"]
- contexts, citations = get_rag_context(
- files=files,
- messages=body["messages"],
- embedding_function=rag_app.state.EMBEDDING_FUNCTION,
- k=rag_app.state.config.TOP_K,
- reranking_function=rag_app.state.sentence_transformer_rf,
- r=rag_app.state.config.RELEVANCE_THRESHOLD,
- hybrid_search=rag_app.state.config.ENABLE_RAG_HYBRID_SEARCH,
- )
- log.debug(f"rag_contexts: {contexts}, citations: {citations}")
- return body, {
- **({"contexts": contexts} if contexts is not None else {}),
- **({"citations": citations} if citations is not None else {}),
- }
- class ChatCompletionMiddleware(BaseHTTPMiddleware):
- async def dispatch(self, request: Request, call_next):
- if request.method == "POST" and any(
- endpoint in request.url.path
- for endpoint in ["/ollama/api/chat", "/chat/completions"]
- ):
- log.debug(f"request.url.path: {request.url.path}")
- try:
- body, model, user = await get_body_and_model_and_user(request)
- except Exception as e:
- return JSONResponse(
- status_code=status.HTTP_400_BAD_REQUEST,
- content={"detail": str(e)},
- )
- # Extract valves from the request body
- valves = None
- if "valves" in body:
- valves = body["valves"]
- del body["valves"]
- # Extract session_id, chat_id and message_id from the request body
- session_id = None
- if "session_id" in body:
- session_id = body["session_id"]
- del body["session_id"]
- chat_id = None
- if "chat_id" in body:
- chat_id = body["chat_id"]
- del body["chat_id"]
- message_id = None
- if "id" in body:
- message_id = body["id"]
- del body["id"]
- __event_emitter__ = await get_event_emitter(
- {"chat_id": chat_id, "message_id": message_id, "session_id": session_id}
- )
- __event_call__ = await get_event_call(
- {"chat_id": chat_id, "message_id": message_id, "session_id": session_id}
- )
- # Initialize data_items to store additional data to be sent to the client
- data_items = []
- # Initialize context, and citations
- contexts = []
- citations = []
- try:
- body, flags = await chat_completion_functions_handler(
- body, model, user, __event_emitter__, __event_call__
- )
- except Exception as e:
- return JSONResponse(
- status_code=status.HTTP_400_BAD_REQUEST,
- content={"detail": str(e)},
- )
- try:
- body, flags = await chat_completion_tools_handler(
- body, user, __event_emitter__, __event_call__
- )
- contexts.extend(flags.get("contexts", []))
- citations.extend(flags.get("citations", []))
- except Exception as e:
- print(e)
- pass
- try:
- body, flags = await chat_completion_files_handler(body)
- contexts.extend(flags.get("contexts", []))
- citations.extend(flags.get("citations", []))
- except Exception as e:
- print(e)
- pass
- # If context is not empty, insert it into the messages
- if len(contexts) > 0:
- context_string = "/n".join(contexts).strip()
- prompt = get_last_user_message(body["messages"])
- # Workaround for Ollama 2.0+ system prompt issue
- # TODO: replace with add_or_update_system_message
- if model["owned_by"] == "ollama":
- body["messages"] = prepend_to_first_user_message_content(
- rag_template(
- rag_app.state.config.RAG_TEMPLATE, context_string, prompt
- ),
- body["messages"],
- )
- else:
- body["messages"] = add_or_update_system_message(
- rag_template(
- rag_app.state.config.RAG_TEMPLATE, context_string, prompt
- ),
- body["messages"],
- )
- # If there are citations, add them to the data_items
- if len(citations) > 0:
- data_items.append({"citations": citations})
- body["metadata"] = {
- "session_id": session_id,
- "chat_id": chat_id,
- "message_id": message_id,
- "valves": valves,
- }
- modified_body_bytes = json.dumps(body).encode("utf-8")
- # Replace the request body with the modified one
- request._body = modified_body_bytes
- # Set custom header to ensure content-length matches new body length
- request.headers.__dict__["_list"] = [
- (b"content-length", str(len(modified_body_bytes)).encode("utf-8")),
- *[
- (k, v)
- for k, v in request.headers.raw
- if k.lower() != b"content-length"
- ],
- ]
- response = await call_next(request)
- if isinstance(response, StreamingResponse):
- # If it's a streaming response, inject it as SSE event or NDJSON line
- content_type = response.headers.get("Content-Type")
- if "text/event-stream" in content_type:
- return StreamingResponse(
- self.openai_stream_wrapper(response.body_iterator, data_items),
- )
- if "application/x-ndjson" in content_type:
- return StreamingResponse(
- self.ollama_stream_wrapper(response.body_iterator, data_items),
- )
- return response
- else:
- return response
- # If it's not a chat completion request, just pass it through
- response = await call_next(request)
- return response
- async def _receive(self, body: bytes):
- return {"type": "http.request", "body": body, "more_body": False}
- async def openai_stream_wrapper(self, original_generator, data_items):
- for item in data_items:
- yield f"data: {json.dumps(item)}\n\n"
- async for data in original_generator:
- yield data
- async def ollama_stream_wrapper(self, original_generator, data_items):
- for item in data_items:
- yield f"{json.dumps(item)}\n"
- async for data in original_generator:
- yield data
- app.add_middleware(ChatCompletionMiddleware)
- ##################################
- #
- # Pipeline Middleware
- #
- ##################################
- def get_sorted_filters(model_id):
- filters = [
- model
- for model in app.state.MODELS.values()
- if "pipeline" in model
- and "type" in model["pipeline"]
- and model["pipeline"]["type"] == "filter"
- and (
- model["pipeline"]["pipelines"] == ["*"]
- or any(
- model_id == target_model_id
- for target_model_id in model["pipeline"]["pipelines"]
- )
- )
- ]
- sorted_filters = sorted(filters, key=lambda x: x["pipeline"]["priority"])
- return sorted_filters
- def filter_pipeline(payload, user):
- user = {"id": user.id, "email": user.email, "name": user.name, "role": user.role}
- model_id = payload["model"]
- sorted_filters = get_sorted_filters(model_id)
- model = app.state.MODELS[model_id]
- if "pipeline" in model:
- sorted_filters.append(model)
- for filter in sorted_filters:
- r = None
- try:
- urlIdx = filter["urlIdx"]
- url = openai_app.state.config.OPENAI_API_BASE_URLS[urlIdx]
- key = openai_app.state.config.OPENAI_API_KEYS[urlIdx]
- if key != "":
- headers = {"Authorization": f"Bearer {key}"}
- r = requests.post(
- f"{url}/{filter['id']}/filter/inlet",
- headers=headers,
- json={
- "user": user,
- "body": payload,
- },
- )
- r.raise_for_status()
- payload = r.json()
- except Exception as e:
- # Handle connection error here
- print(f"Connection error: {e}")
- if r is not None:
- res = r.json()
- if "detail" in res:
- raise Exception(r.status_code, res["detail"])
- return payload
- class PipelineMiddleware(BaseHTTPMiddleware):
- async def dispatch(self, request: Request, call_next):
- if request.method == "POST" and (
- "/ollama/api/chat" in request.url.path
- or "/chat/completions" in request.url.path
- ):
- log.debug(f"request.url.path: {request.url.path}")
- # Read the original request body
- body = await request.body()
- # Decode body to string
- body_str = body.decode("utf-8")
- # Parse string to JSON
- data = json.loads(body_str) if body_str else {}
- user = get_current_user(
- request,
- get_http_authorization_cred(request.headers.get("Authorization")),
- )
- try:
- data = filter_pipeline(data, user)
- except Exception as e:
- return JSONResponse(
- status_code=e.args[0],
- content={"detail": e.args[1]},
- )
- modified_body_bytes = json.dumps(data).encode("utf-8")
- # Replace the request body with the modified one
- request._body = modified_body_bytes
- # Set custom header to ensure content-length matches new body length
- request.headers.__dict__["_list"] = [
- (b"content-length", str(len(modified_body_bytes)).encode("utf-8")),
- *[
- (k, v)
- for k, v in request.headers.raw
- if k.lower() != b"content-length"
- ],
- ]
- response = await call_next(request)
- return response
- async def _receive(self, body: bytes):
- return {"type": "http.request", "body": body, "more_body": False}
- app.add_middleware(PipelineMiddleware)
- app.add_middleware(
- CORSMiddleware,
- allow_origins=origins,
- allow_credentials=True,
- allow_methods=["*"],
- allow_headers=["*"],
- )
- @app.middleware("http")
- async def commit_session_after_request(request: Request, call_next):
- response = await call_next(request)
- log.debug("Commit session after request")
- Session.commit()
- return response
- @app.middleware("http")
- async def check_url(request: Request, call_next):
- if len(app.state.MODELS) == 0:
- await get_all_models()
- else:
- pass
- start_time = int(time.time())
- response = await call_next(request)
- process_time = int(time.time()) - start_time
- response.headers["X-Process-Time"] = str(process_time)
- return response
- @app.middleware("http")
- async def update_embedding_function(request: Request, call_next):
- response = await call_next(request)
- if "/embedding/update" in request.url.path:
- webui_app.state.EMBEDDING_FUNCTION = rag_app.state.EMBEDDING_FUNCTION
- return response
- app.mount("/ws", socket_app)
- app.mount("/ollama", ollama_app)
- app.mount("/openai", openai_app)
- app.mount("/images/api/v1", images_app)
- app.mount("/audio/api/v1", audio_app)
- app.mount("/rag/api/v1", rag_app)
- app.mount("/api/v1", webui_app)
- webui_app.state.EMBEDDING_FUNCTION = rag_app.state.EMBEDDING_FUNCTION
- async def get_all_models():
- # TODO: Optimize this function
- pipe_models = []
- openai_models = []
- ollama_models = []
- pipe_models = await get_pipe_models()
- if app.state.config.ENABLE_OPENAI_API:
- openai_models = await get_openai_models()
- openai_models = openai_models["data"]
- if app.state.config.ENABLE_OLLAMA_API:
- ollama_models = await get_ollama_models()
- ollama_models = [
- {
- "id": model["model"],
- "name": model["name"],
- "object": "model",
- "created": int(time.time()),
- "owned_by": "ollama",
- "ollama": model,
- }
- for model in ollama_models["models"]
- ]
- models = pipe_models + openai_models + ollama_models
- global_action_ids = [
- function.id for function in Functions.get_global_action_functions()
- ]
- enabled_action_ids = [
- function.id
- for function in Functions.get_functions_by_type("action", active_only=True)
- ]
- custom_models = Models.get_all_models()
- for custom_model in custom_models:
- if custom_model.base_model_id == None:
- for model in models:
- if (
- custom_model.id == model["id"]
- or custom_model.id == model["id"].split(":")[0]
- ):
- model["name"] = custom_model.name
- model["info"] = custom_model.model_dump()
- action_ids = []
- if "info" in model and "meta" in model["info"]:
- action_ids.extend(model["info"]["meta"].get("actionIds", []))
- model["action_ids"] = action_ids
- else:
- owned_by = "openai"
- pipe = None
- action_ids = []
- for model in models:
- if (
- custom_model.base_model_id == model["id"]
- or custom_model.base_model_id == model["id"].split(":")[0]
- ):
- owned_by = model["owned_by"]
- if "pipe" in model:
- pipe = model["pipe"]
- if "info" in model and "meta" in model["info"]:
- action_ids.extend(model["info"]["meta"].get("actionIds", []))
- break
- models.append(
- {
- "id": custom_model.id,
- "name": custom_model.name,
- "object": "model",
- "created": custom_model.created_at,
- "owned_by": owned_by,
- "info": custom_model.model_dump(),
- "preset": True,
- **({"pipe": pipe} if pipe is not None else {}),
- "action_ids": action_ids,
- }
- )
- for model in models:
- action_ids = []
- if "action_ids" in model:
- action_ids = model["action_ids"]
- del model["action_ids"]
- action_ids = action_ids + global_action_ids
- action_ids = list(set(action_ids))
- action_ids = [
- action_id for action_id in action_ids if action_id in enabled_action_ids
- ]
- model["actions"] = []
- for action_id in action_ids:
- action = Functions.get_function_by_id(action_id)
- model["actions"].append(
- {
- "id": action_id,
- "name": action.name,
- "description": action.meta.description,
- "icon_url": action.meta.manifest.get("icon_url", None),
- }
- )
- app.state.MODELS = {model["id"]: model for model in models}
- webui_app.state.MODELS = app.state.MODELS
- return models
- @app.get("/api/models")
- async def get_models(user=Depends(get_verified_user)):
- models = await get_all_models()
- # Filter out filter pipelines
- models = [
- model
- for model in models
- if "pipeline" not in model or model["pipeline"].get("type", None) != "filter"
- ]
- if app.state.config.ENABLE_MODEL_FILTER:
- if user.role == "user":
- models = list(
- filter(
- lambda model: model["id"] in app.state.config.MODEL_FILTER_LIST,
- models,
- )
- )
- return {"data": models}
- return {"data": models}
- @app.post("/api/chat/completions")
- async def generate_chat_completions(form_data: dict, user=Depends(get_verified_user)):
- model_id = form_data["model"]
- if model_id not in app.state.MODELS:
- raise HTTPException(
- status_code=status.HTTP_404_NOT_FOUND,
- detail="Model not found",
- )
- model = app.state.MODELS[model_id]
- # `task` field is used to determine the type of the request, e.g. `title_generation`, `query_generation`, etc.
- task = None
- if "task" in form_data:
- task = form_data["task"]
- del form_data["task"]
- if task:
- if "metadata" in form_data:
- form_data["metadata"]["task"] = task
- else:
- form_data["metadata"] = {"task": task}
- if model.get("pipe"):
- return await generate_function_chat_completion(form_data, user=user)
- if model["owned_by"] == "ollama":
- print("generate_ollama_chat_completion")
- return await generate_ollama_chat_completion(form_data, user=user)
- else:
- return await generate_openai_chat_completion(form_data, user=user)
- @app.post("/api/chat/completed")
- async def chat_completed(form_data: dict, user=Depends(get_verified_user)):
- data = form_data
- model_id = data["model"]
- if model_id not in app.state.MODELS:
- raise HTTPException(
- status_code=status.HTTP_404_NOT_FOUND,
- detail="Model not found",
- )
- model = app.state.MODELS[model_id]
- sorted_filters = get_sorted_filters(model_id)
- if "pipeline" in model:
- sorted_filters = [model] + sorted_filters
- for filter in sorted_filters:
- r = None
- try:
- urlIdx = filter["urlIdx"]
- url = openai_app.state.config.OPENAI_API_BASE_URLS[urlIdx]
- key = openai_app.state.config.OPENAI_API_KEYS[urlIdx]
- if key != "":
- headers = {"Authorization": f"Bearer {key}"}
- r = requests.post(
- f"{url}/{filter['id']}/filter/outlet",
- headers=headers,
- json={
- "user": {
- "id": user.id,
- "name": user.name,
- "email": user.email,
- "role": user.role,
- },
- "body": data,
- },
- )
- r.raise_for_status()
- data = r.json()
- except Exception as e:
- # Handle connection error here
- print(f"Connection error: {e}")
- if r is not None:
- try:
- res = r.json()
- if "detail" in res:
- return JSONResponse(
- status_code=r.status_code,
- content=res,
- )
- except:
- pass
- else:
- pass
- __event_emitter__ = await get_event_emitter(
- {
- "chat_id": data["chat_id"],
- "message_id": data["id"],
- "session_id": data["session_id"],
- }
- )
- __event_call__ = await get_event_call(
- {
- "chat_id": data["chat_id"],
- "message_id": data["id"],
- "session_id": data["session_id"],
- }
- )
- def get_priority(function_id):
- function = Functions.get_function_by_id(function_id)
- if function is not None and hasattr(function, "valves"):
- return (function.valves if function.valves else {}).get("priority", 0)
- return 0
- filter_ids = [function.id for function in Functions.get_global_filter_functions()]
- if "info" in model and "meta" in model["info"]:
- filter_ids.extend(model["info"]["meta"].get("filterIds", []))
- filter_ids = list(set(filter_ids))
- enabled_filter_ids = [
- function.id
- for function in Functions.get_functions_by_type("filter", active_only=True)
- ]
- filter_ids = [
- filter_id for filter_id in filter_ids if filter_id in enabled_filter_ids
- ]
- # Sort filter_ids by priority, using the get_priority function
- filter_ids.sort(key=get_priority)
- for filter_id in filter_ids:
- filter = Functions.get_function_by_id(filter_id)
- if not filter:
- continue
- if filter_id in webui_app.state.FUNCTIONS:
- function_module = webui_app.state.FUNCTIONS[filter_id]
- else:
- function_module, _, _ = load_function_module_by_id(filter_id)
- webui_app.state.FUNCTIONS[filter_id] = function_module
- if hasattr(function_module, "valves") and hasattr(function_module, "Valves"):
- valves = Functions.get_function_valves_by_id(filter_id)
- function_module.valves = function_module.Valves(
- **(valves if valves else {})
- )
- if not hasattr(function_module, "outlet"):
- continue
- try:
- outlet = function_module.outlet
- # Get the signature of the function
- sig = inspect.signature(outlet)
- params = {"body": data}
- # Extra parameters to be passed to the function
- extra_params = {
- "__model__": model,
- "__id__": filter_id,
- "__event_emitter__": __event_emitter__,
- "__event_call__": __event_call__,
- }
- # Add extra params in contained in function signature
- for key, value in extra_params.items():
- if key in sig.parameters:
- params[key] = value
- if "__user__" in sig.parameters:
- __user__ = {
- "id": user.id,
- "email": user.email,
- "name": user.name,
- "role": user.role,
- }
- try:
- if hasattr(function_module, "UserValves"):
- __user__["valves"] = function_module.UserValves(
- **Functions.get_user_valves_by_id_and_user_id(
- filter_id, user.id
- )
- )
- except Exception as e:
- print(e)
- params = {**params, "__user__": __user__}
- if inspect.iscoroutinefunction(outlet):
- data = await outlet(**params)
- else:
- data = outlet(**params)
- except Exception as e:
- print(f"Error: {e}")
- return JSONResponse(
- status_code=status.HTTP_400_BAD_REQUEST,
- content={"detail": str(e)},
- )
- return data
- @app.post("/api/chat/actions/{action_id}")
- async def chat_completed(
- action_id: str, form_data: dict, user=Depends(get_verified_user)
- ):
- action = Functions.get_function_by_id(action_id)
- if not action:
- raise HTTPException(
- status_code=status.HTTP_404_NOT_FOUND,
- detail="Action not found",
- )
- data = form_data
- model_id = data["model"]
- if model_id not in app.state.MODELS:
- raise HTTPException(
- status_code=status.HTTP_404_NOT_FOUND,
- detail="Model not found",
- )
- model = app.state.MODELS[model_id]
- __event_emitter__ = await get_event_emitter(
- {
- "chat_id": data["chat_id"],
- "message_id": data["id"],
- "session_id": data["session_id"],
- }
- )
- __event_call__ = await get_event_call(
- {
- "chat_id": data["chat_id"],
- "message_id": data["id"],
- "session_id": data["session_id"],
- }
- )
- if action_id in webui_app.state.FUNCTIONS:
- function_module = webui_app.state.FUNCTIONS[action_id]
- else:
- function_module, _, _ = load_function_module_by_id(action_id)
- webui_app.state.FUNCTIONS[action_id] = function_module
- if hasattr(function_module, "valves") and hasattr(function_module, "Valves"):
- valves = Functions.get_function_valves_by_id(action_id)
- function_module.valves = function_module.Valves(**(valves if valves else {}))
- if hasattr(function_module, "action"):
- try:
- action = function_module.action
- # Get the signature of the function
- sig = inspect.signature(action)
- params = {"body": data}
- # Extra parameters to be passed to the function
- extra_params = {
- "__model__": model,
- "__id__": action_id,
- "__event_emitter__": __event_emitter__,
- "__event_call__": __event_call__,
- }
- # Add extra params in contained in function signature
- for key, value in extra_params.items():
- if key in sig.parameters:
- params[key] = value
- if "__user__" in sig.parameters:
- __user__ = {
- "id": user.id,
- "email": user.email,
- "name": user.name,
- "role": user.role,
- }
- try:
- if hasattr(function_module, "UserValves"):
- __user__["valves"] = function_module.UserValves(
- **Functions.get_user_valves_by_id_and_user_id(
- action_id, user.id
- )
- )
- except Exception as e:
- print(e)
- params = {**params, "__user__": __user__}
- if inspect.iscoroutinefunction(action):
- data = await action(**params)
- else:
- data = action(**params)
- except Exception as e:
- print(f"Error: {e}")
- return JSONResponse(
- status_code=status.HTTP_400_BAD_REQUEST,
- content={"detail": str(e)},
- )
- return data
- ##################################
- #
- # Task Endpoints
- #
- ##################################
- # TODO: Refactor task API endpoints below into a separate file
- @app.get("/api/task/config")
- async def get_task_config(user=Depends(get_verified_user)):
- return {
- "TASK_MODEL": app.state.config.TASK_MODEL,
- "TASK_MODEL_EXTERNAL": app.state.config.TASK_MODEL_EXTERNAL,
- "TITLE_GENERATION_PROMPT_TEMPLATE": app.state.config.TITLE_GENERATION_PROMPT_TEMPLATE,
- "SEARCH_QUERY_GENERATION_PROMPT_TEMPLATE": app.state.config.SEARCH_QUERY_GENERATION_PROMPT_TEMPLATE,
- "SEARCH_QUERY_PROMPT_LENGTH_THRESHOLD": app.state.config.SEARCH_QUERY_PROMPT_LENGTH_THRESHOLD,
- "TOOLS_FUNCTION_CALLING_PROMPT_TEMPLATE": app.state.config.TOOLS_FUNCTION_CALLING_PROMPT_TEMPLATE,
- }
- class TaskConfigForm(BaseModel):
- TASK_MODEL: Optional[str]
- TASK_MODEL_EXTERNAL: Optional[str]
- TITLE_GENERATION_PROMPT_TEMPLATE: str
- SEARCH_QUERY_GENERATION_PROMPT_TEMPLATE: str
- SEARCH_QUERY_PROMPT_LENGTH_THRESHOLD: int
- TOOLS_FUNCTION_CALLING_PROMPT_TEMPLATE: str
- @app.post("/api/task/config/update")
- async def update_task_config(form_data: TaskConfigForm, user=Depends(get_admin_user)):
- app.state.config.TASK_MODEL = form_data.TASK_MODEL
- app.state.config.TASK_MODEL_EXTERNAL = form_data.TASK_MODEL_EXTERNAL
- app.state.config.TITLE_GENERATION_PROMPT_TEMPLATE = (
- form_data.TITLE_GENERATION_PROMPT_TEMPLATE
- )
- app.state.config.SEARCH_QUERY_GENERATION_PROMPT_TEMPLATE = (
- form_data.SEARCH_QUERY_GENERATION_PROMPT_TEMPLATE
- )
- app.state.config.SEARCH_QUERY_PROMPT_LENGTH_THRESHOLD = (
- form_data.SEARCH_QUERY_PROMPT_LENGTH_THRESHOLD
- )
- app.state.config.TOOLS_FUNCTION_CALLING_PROMPT_TEMPLATE = (
- form_data.TOOLS_FUNCTION_CALLING_PROMPT_TEMPLATE
- )
- return {
- "TASK_MODEL": app.state.config.TASK_MODEL,
- "TASK_MODEL_EXTERNAL": app.state.config.TASK_MODEL_EXTERNAL,
- "TITLE_GENERATION_PROMPT_TEMPLATE": app.state.config.TITLE_GENERATION_PROMPT_TEMPLATE,
- "SEARCH_QUERY_GENERATION_PROMPT_TEMPLATE": app.state.config.SEARCH_QUERY_GENERATION_PROMPT_TEMPLATE,
- "SEARCH_QUERY_PROMPT_LENGTH_THRESHOLD": app.state.config.SEARCH_QUERY_PROMPT_LENGTH_THRESHOLD,
- "TOOLS_FUNCTION_CALLING_PROMPT_TEMPLATE": app.state.config.TOOLS_FUNCTION_CALLING_PROMPT_TEMPLATE,
- }
- @app.post("/api/task/title/completions")
- async def generate_title(form_data: dict, user=Depends(get_verified_user)):
- print("generate_title")
- model_id = form_data["model"]
- if model_id not in app.state.MODELS:
- raise HTTPException(
- status_code=status.HTTP_404_NOT_FOUND,
- detail="Model not found",
- )
- # Check if the user has a custom task model
- # If the user has a custom task model, use that model
- model_id = get_task_model_id(model_id)
- print(model_id)
- template = app.state.config.TITLE_GENERATION_PROMPT_TEMPLATE
- content = title_generation_template(
- template,
- form_data["prompt"],
- {
- "name": user.name,
- "location": user.info.get("location") if user.info else None,
- },
- )
- payload = {
- "model": model_id,
- "messages": [{"role": "user", "content": content}],
- "stream": False,
- "max_tokens": 50,
- "chat_id": form_data.get("chat_id", None),
- "task": str(TASKS.TITLE_GENERATION),
- }
- log.debug(payload)
- try:
- payload = filter_pipeline(payload, user)
- except Exception as e:
- return JSONResponse(
- status_code=e.args[0],
- content={"detail": e.args[1]},
- )
- if "chat_id" in payload:
- del payload["chat_id"]
- return await generate_chat_completions(form_data=payload, user=user)
- @app.post("/api/task/query/completions")
- async def generate_search_query(form_data: dict, user=Depends(get_verified_user)):
- print("generate_search_query")
- if len(form_data["prompt"]) < app.state.config.SEARCH_QUERY_PROMPT_LENGTH_THRESHOLD:
- raise HTTPException(
- status_code=status.HTTP_400_BAD_REQUEST,
- detail=f"Skip search query generation for short prompts (< {app.state.config.SEARCH_QUERY_PROMPT_LENGTH_THRESHOLD} characters)",
- )
- model_id = form_data["model"]
- if model_id not in app.state.MODELS:
- raise HTTPException(
- status_code=status.HTTP_404_NOT_FOUND,
- detail="Model not found",
- )
- # Check if the user has a custom task model
- # If the user has a custom task model, use that model
- model_id = get_task_model_id(model_id)
- print(model_id)
- template = app.state.config.SEARCH_QUERY_GENERATION_PROMPT_TEMPLATE
- content = search_query_generation_template(
- template, form_data["prompt"], {"name": user.name}
- )
- payload = {
- "model": model_id,
- "messages": [{"role": "user", "content": content}],
- "stream": False,
- "max_tokens": 30,
- "task": str(TASKS.QUERY_GENERATION),
- }
- print(payload)
- try:
- payload = filter_pipeline(payload, user)
- except Exception as e:
- return JSONResponse(
- status_code=e.args[0],
- content={"detail": e.args[1]},
- )
- if "chat_id" in payload:
- del payload["chat_id"]
- return await generate_chat_completions(form_data=payload, user=user)
- @app.post("/api/task/emoji/completions")
- async def generate_emoji(form_data: dict, user=Depends(get_verified_user)):
- print("generate_emoji")
- model_id = form_data["model"]
- if model_id not in app.state.MODELS:
- raise HTTPException(
- status_code=status.HTTP_404_NOT_FOUND,
- detail="Model not found",
- )
- # Check if the user has a custom task model
- # If the user has a custom task model, use that model
- model_id = get_task_model_id(model_id)
- print(model_id)
- template = '''
- Your task is to reflect the speaker's likely facial expression through a fitting emoji. Interpret emotions from the message and reflect their facial expression using fitting, diverse emojis (e.g., 😊, 😢, 😡, 😱).
- Message: """{{prompt}}"""
- '''
- content = title_generation_template(
- template,
- form_data["prompt"],
- {
- "name": user.name,
- "location": user.info.get("location") if user.info else None,
- },
- )
- payload = {
- "model": model_id,
- "messages": [{"role": "user", "content": content}],
- "stream": False,
- "max_tokens": 4,
- "chat_id": form_data.get("chat_id", None),
- "task": str(TASKS.EMOJI_GENERATION),
- }
- log.debug(payload)
- try:
- payload = filter_pipeline(payload, user)
- except Exception as e:
- return JSONResponse(
- status_code=e.args[0],
- content={"detail": e.args[1]},
- )
- if "chat_id" in payload:
- del payload["chat_id"]
- return await generate_chat_completions(form_data=payload, user=user)
- @app.post("/api/task/tools/completions")
- async def get_tools_function_calling(form_data: dict, user=Depends(get_verified_user)):
- print("get_tools_function_calling")
- model_id = form_data["model"]
- if model_id not in app.state.MODELS:
- raise HTTPException(
- status_code=status.HTTP_404_NOT_FOUND,
- detail="Model not found",
- )
- # Check if the user has a custom task model
- # If the user has a custom task model, use that model
- model_id = get_task_model_id(model_id)
- print(model_id)
- template = app.state.config.TOOLS_FUNCTION_CALLING_PROMPT_TEMPLATE
- try:
- context, _, _ = await get_function_call_response(
- form_data["messages"],
- form_data.get("files", []),
- form_data["tool_id"],
- template,
- model_id,
- user,
- )
- return context
- except Exception as e:
- return JSONResponse(
- status_code=e.args[0],
- content={"detail": e.args[1]},
- )
- ##################################
- #
- # Pipelines Endpoints
- #
- ##################################
- # TODO: Refactor pipelines API endpoints below into a separate file
- @app.get("/api/pipelines/list")
- async def get_pipelines_list(user=Depends(get_admin_user)):
- responses = await get_openai_models(raw=True)
- print(responses)
- urlIdxs = [
- idx
- for idx, response in enumerate(responses)
- if response != None and "pipelines" in response
- ]
- return {
- "data": [
- {
- "url": openai_app.state.config.OPENAI_API_BASE_URLS[urlIdx],
- "idx": urlIdx,
- }
- for urlIdx in urlIdxs
- ]
- }
- @app.post("/api/pipelines/upload")
- async def upload_pipeline(
- urlIdx: int = Form(...), file: UploadFile = File(...), user=Depends(get_admin_user)
- ):
- print("upload_pipeline", urlIdx, file.filename)
- # Check if the uploaded file is a python file
- if not file.filename.endswith(".py"):
- raise HTTPException(
- status_code=status.HTTP_400_BAD_REQUEST,
- detail="Only Python (.py) files are allowed.",
- )
- upload_folder = f"{CACHE_DIR}/pipelines"
- os.makedirs(upload_folder, exist_ok=True)
- file_path = os.path.join(upload_folder, file.filename)
- r = None
- try:
- # Save the uploaded file
- with open(file_path, "wb") as buffer:
- shutil.copyfileobj(file.file, buffer)
- url = openai_app.state.config.OPENAI_API_BASE_URLS[urlIdx]
- key = openai_app.state.config.OPENAI_API_KEYS[urlIdx]
- headers = {"Authorization": f"Bearer {key}"}
- with open(file_path, "rb") as f:
- files = {"file": f}
- r = requests.post(f"{url}/pipelines/upload", headers=headers, files=files)
- r.raise_for_status()
- data = r.json()
- return {**data}
- except Exception as e:
- # Handle connection error here
- print(f"Connection error: {e}")
- detail = "Pipeline not found"
- status_code = status.HTTP_404_NOT_FOUND
- if r is not None:
- status_code = r.status_code
- try:
- res = r.json()
- if "detail" in res:
- detail = res["detail"]
- except:
- pass
- raise HTTPException(
- status_code=status_code,
- detail=detail,
- )
- finally:
- # Ensure the file is deleted after the upload is completed or on failure
- if os.path.exists(file_path):
- os.remove(file_path)
- class AddPipelineForm(BaseModel):
- url: str
- urlIdx: int
- @app.post("/api/pipelines/add")
- async def add_pipeline(form_data: AddPipelineForm, user=Depends(get_admin_user)):
- r = None
- try:
- urlIdx = form_data.urlIdx
- url = openai_app.state.config.OPENAI_API_BASE_URLS[urlIdx]
- key = openai_app.state.config.OPENAI_API_KEYS[urlIdx]
- headers = {"Authorization": f"Bearer {key}"}
- r = requests.post(
- f"{url}/pipelines/add", headers=headers, json={"url": form_data.url}
- )
- r.raise_for_status()
- data = r.json()
- return {**data}
- except Exception as e:
- # Handle connection error here
- print(f"Connection error: {e}")
- detail = "Pipeline not found"
- if r is not None:
- try:
- res = r.json()
- if "detail" in res:
- detail = res["detail"]
- except:
- pass
- raise HTTPException(
- status_code=(r.status_code if r is not None else status.HTTP_404_NOT_FOUND),
- detail=detail,
- )
- class DeletePipelineForm(BaseModel):
- id: str
- urlIdx: int
- @app.delete("/api/pipelines/delete")
- async def delete_pipeline(form_data: DeletePipelineForm, user=Depends(get_admin_user)):
- r = None
- try:
- urlIdx = form_data.urlIdx
- url = openai_app.state.config.OPENAI_API_BASE_URLS[urlIdx]
- key = openai_app.state.config.OPENAI_API_KEYS[urlIdx]
- headers = {"Authorization": f"Bearer {key}"}
- r = requests.delete(
- f"{url}/pipelines/delete", headers=headers, json={"id": form_data.id}
- )
- r.raise_for_status()
- data = r.json()
- return {**data}
- except Exception as e:
- # Handle connection error here
- print(f"Connection error: {e}")
- detail = "Pipeline not found"
- if r is not None:
- try:
- res = r.json()
- if "detail" in res:
- detail = res["detail"]
- except:
- pass
- raise HTTPException(
- status_code=(r.status_code if r is not None else status.HTTP_404_NOT_FOUND),
- detail=detail,
- )
- @app.get("/api/pipelines")
- async def get_pipelines(urlIdx: Optional[int] = None, user=Depends(get_admin_user)):
- r = None
- try:
- url = openai_app.state.config.OPENAI_API_BASE_URLS[urlIdx]
- key = openai_app.state.config.OPENAI_API_KEYS[urlIdx]
- headers = {"Authorization": f"Bearer {key}"}
- r = requests.get(f"{url}/pipelines", headers=headers)
- r.raise_for_status()
- data = r.json()
- return {**data}
- except Exception as e:
- # Handle connection error here
- print(f"Connection error: {e}")
- detail = "Pipeline not found"
- if r is not None:
- try:
- res = r.json()
- if "detail" in res:
- detail = res["detail"]
- except:
- pass
- raise HTTPException(
- status_code=(r.status_code if r is not None else status.HTTP_404_NOT_FOUND),
- detail=detail,
- )
- @app.get("/api/pipelines/{pipeline_id}/valves")
- async def get_pipeline_valves(
- urlIdx: Optional[int],
- pipeline_id: str,
- user=Depends(get_admin_user),
- ):
- models = await get_all_models()
- r = None
- try:
- url = openai_app.state.config.OPENAI_API_BASE_URLS[urlIdx]
- key = openai_app.state.config.OPENAI_API_KEYS[urlIdx]
- headers = {"Authorization": f"Bearer {key}"}
- r = requests.get(f"{url}/{pipeline_id}/valves", headers=headers)
- r.raise_for_status()
- data = r.json()
- return {**data}
- except Exception as e:
- # Handle connection error here
- print(f"Connection error: {e}")
- detail = "Pipeline not found"
- if r is not None:
- try:
- res = r.json()
- if "detail" in res:
- detail = res["detail"]
- except:
- pass
- raise HTTPException(
- status_code=(r.status_code if r is not None else status.HTTP_404_NOT_FOUND),
- detail=detail,
- )
- @app.get("/api/pipelines/{pipeline_id}/valves/spec")
- async def get_pipeline_valves_spec(
- urlIdx: Optional[int],
- pipeline_id: str,
- user=Depends(get_admin_user),
- ):
- models = await get_all_models()
- r = None
- try:
- url = openai_app.state.config.OPENAI_API_BASE_URLS[urlIdx]
- key = openai_app.state.config.OPENAI_API_KEYS[urlIdx]
- headers = {"Authorization": f"Bearer {key}"}
- r = requests.get(f"{url}/{pipeline_id}/valves/spec", headers=headers)
- r.raise_for_status()
- data = r.json()
- return {**data}
- except Exception as e:
- # Handle connection error here
- print(f"Connection error: {e}")
- detail = "Pipeline not found"
- if r is not None:
- try:
- res = r.json()
- if "detail" in res:
- detail = res["detail"]
- except:
- pass
- raise HTTPException(
- status_code=(r.status_code if r is not None else status.HTTP_404_NOT_FOUND),
- detail=detail,
- )
- @app.post("/api/pipelines/{pipeline_id}/valves/update")
- async def update_pipeline_valves(
- urlIdx: Optional[int],
- pipeline_id: str,
- form_data: dict,
- user=Depends(get_admin_user),
- ):
- models = await get_all_models()
- r = None
- try:
- url = openai_app.state.config.OPENAI_API_BASE_URLS[urlIdx]
- key = openai_app.state.config.OPENAI_API_KEYS[urlIdx]
- headers = {"Authorization": f"Bearer {key}"}
- r = requests.post(
- f"{url}/{pipeline_id}/valves/update",
- headers=headers,
- json={**form_data},
- )
- r.raise_for_status()
- data = r.json()
- return {**data}
- except Exception as e:
- # Handle connection error here
- print(f"Connection error: {e}")
- detail = "Pipeline not found"
- if r is not None:
- try:
- res = r.json()
- if "detail" in res:
- detail = res["detail"]
- except:
- pass
- raise HTTPException(
- status_code=(r.status_code if r is not None else status.HTTP_404_NOT_FOUND),
- detail=detail,
- )
- ##################################
- #
- # Config Endpoints
- #
- ##################################
- @app.get("/api/config")
- async def get_app_config():
- return {
- "status": True,
- "name": WEBUI_NAME,
- "version": VERSION,
- "default_locale": str(DEFAULT_LOCALE),
- "default_models": webui_app.state.config.DEFAULT_MODELS,
- "default_prompt_suggestions": webui_app.state.config.DEFAULT_PROMPT_SUGGESTIONS,
- "features": {
- "auth": WEBUI_AUTH,
- "auth_trusted_header": bool(webui_app.state.AUTH_TRUSTED_EMAIL_HEADER),
- "enable_signup": webui_app.state.config.ENABLE_SIGNUP,
- "enable_login_form": webui_app.state.config.ENABLE_LOGIN_FORM,
- "enable_web_search": rag_app.state.config.ENABLE_RAG_WEB_SEARCH,
- "enable_image_generation": images_app.state.config.ENABLED,
- "enable_community_sharing": webui_app.state.config.ENABLE_COMMUNITY_SHARING,
- "enable_admin_export": ENABLE_ADMIN_EXPORT,
- },
- "audio": {
- "tts": {
- "engine": audio_app.state.config.TTS_ENGINE,
- "voice": audio_app.state.config.TTS_VOICE,
- },
- "stt": {
- "engine": audio_app.state.config.STT_ENGINE,
- },
- },
- "oauth": {
- "providers": {
- name: config.get("name", name)
- for name, config in OAUTH_PROVIDERS.items()
- }
- },
- }
- @app.get("/api/config/model/filter")
- async def get_model_filter_config(user=Depends(get_admin_user)):
- return {
- "enabled": app.state.config.ENABLE_MODEL_FILTER,
- "models": app.state.config.MODEL_FILTER_LIST,
- }
- class ModelFilterConfigForm(BaseModel):
- enabled: bool
- models: List[str]
- @app.post("/api/config/model/filter")
- async def update_model_filter_config(
- form_data: ModelFilterConfigForm, user=Depends(get_admin_user)
- ):
- app.state.config.ENABLE_MODEL_FILTER = form_data.enabled
- app.state.config.MODEL_FILTER_LIST = form_data.models
- return {
- "enabled": app.state.config.ENABLE_MODEL_FILTER,
- "models": app.state.config.MODEL_FILTER_LIST,
- }
- # TODO: webhook endpoint should be under config endpoints
- @app.get("/api/webhook")
- async def get_webhook_url(user=Depends(get_admin_user)):
- return {
- "url": app.state.config.WEBHOOK_URL,
- }
- class UrlForm(BaseModel):
- url: str
- @app.post("/api/webhook")
- async def update_webhook_url(form_data: UrlForm, user=Depends(get_admin_user)):
- app.state.config.WEBHOOK_URL = form_data.url
- webui_app.state.WEBHOOK_URL = app.state.config.WEBHOOK_URL
- return {"url": app.state.config.WEBHOOK_URL}
- @app.get("/api/version")
- async def get_app_config():
- return {
- "version": VERSION,
- }
- @app.get("/api/changelog")
- async def get_app_changelog():
- return {key: CHANGELOG[key] for idx, key in enumerate(CHANGELOG) if idx < 5}
- @app.get("/api/version/updates")
- async def get_app_latest_release_version():
- try:
- async with aiohttp.ClientSession(trust_env=True) as session:
- async with session.get(
- "https://api.github.com/repos/open-webui/open-webui/releases/latest"
- ) as response:
- response.raise_for_status()
- data = await response.json()
- latest_version = data["tag_name"]
- return {"current": VERSION, "latest": latest_version[1:]}
- except aiohttp.ClientError as e:
- raise HTTPException(
- status_code=status.HTTP_503_SERVICE_UNAVAILABLE,
- detail=ERROR_MESSAGES.RATE_LIMIT_EXCEEDED,
- )
- ############################
- # OAuth Login & Callback
- ############################
- oauth = OAuth()
- for provider_name, provider_config in OAUTH_PROVIDERS.items():
- oauth.register(
- name=provider_name,
- client_id=provider_config["client_id"],
- client_secret=provider_config["client_secret"],
- server_metadata_url=provider_config["server_metadata_url"],
- client_kwargs={
- "scope": provider_config["scope"],
- },
- )
- # SessionMiddleware is used by authlib for oauth
- if len(OAUTH_PROVIDERS) > 0:
- app.add_middleware(
- SessionMiddleware,
- secret_key=WEBUI_SECRET_KEY,
- session_cookie="oui-session",
- same_site=WEBUI_SESSION_COOKIE_SAME_SITE,
- https_only=WEBUI_SESSION_COOKIE_SECURE,
- )
- @app.get("/oauth/{provider}/login")
- async def oauth_login(provider: str, request: Request):
- if provider not in OAUTH_PROVIDERS:
- raise HTTPException(404)
- redirect_uri = request.url_for("oauth_callback", provider=provider)
- return await oauth.create_client(provider).authorize_redirect(request, redirect_uri)
- # OAuth login logic is as follows:
- # 1. Attempt to find a user with matching subject ID, tied to the provider
- # 2. If OAUTH_MERGE_ACCOUNTS_BY_EMAIL is true, find a user with the email address provided via OAuth
- # - This is considered insecure in general, as OAuth providers do not always verify email addresses
- # 3. If there is no user, and ENABLE_OAUTH_SIGNUP is true, create a user
- # - Email addresses are considered unique, so we fail registration if the email address is alreayd taken
- @app.get("/oauth/{provider}/callback")
- async def oauth_callback(provider: str, request: Request, response: Response):
- if provider not in OAUTH_PROVIDERS:
- raise HTTPException(404)
- client = oauth.create_client(provider)
- try:
- token = await client.authorize_access_token(request)
- except Exception as e:
- log.warning(f"OAuth callback error: {e}")
- raise HTTPException(400, detail=ERROR_MESSAGES.INVALID_CRED)
- user_data: UserInfo = token["userinfo"]
- sub = user_data.get("sub")
- if not sub:
- log.warning(f"OAuth callback failed, sub is missing: {user_data}")
- raise HTTPException(400, detail=ERROR_MESSAGES.INVALID_CRED)
- provider_sub = f"{provider}@{sub}"
- email = user_data.get("email", "").lower()
- # We currently mandate that email addresses are provided
- if not email:
- log.warning(f"OAuth callback failed, email is missing: {user_data}")
- raise HTTPException(400, detail=ERROR_MESSAGES.INVALID_CRED)
- # Check if the user exists
- user = Users.get_user_by_oauth_sub(provider_sub)
- if not user:
- # If the user does not exist, check if merging is enabled
- if OAUTH_MERGE_ACCOUNTS_BY_EMAIL.value:
- # Check if the user exists by email
- user = Users.get_user_by_email(email)
- if user:
- # Update the user with the new oauth sub
- Users.update_user_oauth_sub_by_id(user.id, provider_sub)
- if not user:
- # If the user does not exist, check if signups are enabled
- if ENABLE_OAUTH_SIGNUP.value:
- # Check if an existing user with the same email already exists
- existing_user = Users.get_user_by_email(user_data.get("email", "").lower())
- if existing_user:
- raise HTTPException(400, detail=ERROR_MESSAGES.EMAIL_TAKEN)
- picture_claim = webui_app.state.config.OAUTH_PICTURE_CLAIM
- picture_url = user_data.get(picture_claim, "")
- if picture_url:
- # Download the profile image into a base64 string
- try:
- async with aiohttp.ClientSession() as session:
- async with session.get(picture_url) as resp:
- picture = await resp.read()
- base64_encoded_picture = base64.b64encode(picture).decode(
- "utf-8"
- )
- guessed_mime_type = mimetypes.guess_type(picture_url)[0]
- if guessed_mime_type is None:
- # assume JPG, browsers are tolerant enough of image formats
- guessed_mime_type = "image/jpeg"
- picture_url = f"data:{guessed_mime_type};base64,{base64_encoded_picture}"
- except Exception as e:
- log.error(f"Error downloading profile image '{picture_url}': {e}")
- picture_url = ""
- if not picture_url:
- picture_url = "/user.png"
- username_claim = webui_app.state.config.OAUTH_USERNAME_CLAIM
- role = (
- "admin"
- if Users.get_num_users() == 0
- else webui_app.state.config.DEFAULT_USER_ROLE
- )
- user = Auths.insert_new_auth(
- email=email,
- password=get_password_hash(
- str(uuid.uuid4())
- ), # Random password, not used
- name=user_data.get(username_claim, "User"),
- profile_image_url=picture_url,
- role=role,
- oauth_sub=provider_sub,
- )
- if webui_app.state.config.WEBHOOK_URL:
- post_webhook(
- webui_app.state.config.WEBHOOK_URL,
- WEBHOOK_MESSAGES.USER_SIGNUP(user.name),
- {
- "action": "signup",
- "message": WEBHOOK_MESSAGES.USER_SIGNUP(user.name),
- "user": user.model_dump_json(exclude_none=True),
- },
- )
- else:
- raise HTTPException(
- status.HTTP_403_FORBIDDEN, detail=ERROR_MESSAGES.ACCESS_PROHIBITED
- )
- jwt_token = create_token(
- data={"id": user.id},
- expires_delta=parse_duration(webui_app.state.config.JWT_EXPIRES_IN),
- )
- # Set the cookie token
- response.set_cookie(
- key="token",
- value=jwt_token,
- httponly=True, # Ensures the cookie is not accessible via JavaScript
- )
- # Redirect back to the frontend with the JWT token
- redirect_url = f"{request.base_url}auth#token={jwt_token}"
- return RedirectResponse(url=redirect_url)
- @app.get("/manifest.json")
- async def get_manifest_json():
- return {
- "name": WEBUI_NAME,
- "short_name": WEBUI_NAME,
- "start_url": "/",
- "display": "standalone",
- "background_color": "#343541",
- "orientation": "portrait-primary",
- "icons": [{"src": "/static/logo.png", "type": "image/png", "sizes": "500x500"}],
- }
- @app.get("/opensearch.xml")
- async def get_opensearch_xml():
- xml_content = rf"""
- <OpenSearchDescription xmlns="http://a9.com/-/spec/opensearch/1.1/" xmlns:moz="http://www.mozilla.org/2006/browser/search/">
- <ShortName>{WEBUI_NAME}</ShortName>
- <Description>Search {WEBUI_NAME}</Description>
- <InputEncoding>UTF-8</InputEncoding>
- <Image width="16" height="16" type="image/x-icon">{WEBUI_URL}/static/favicon.png</Image>
- <Url type="text/html" method="get" template="{WEBUI_URL}/?q={"{searchTerms}"}"/>
- <moz:SearchForm>{WEBUI_URL}</moz:SearchForm>
- </OpenSearchDescription>
- """
- return Response(content=xml_content, media_type="application/xml")
- @app.get("/health")
- async def healthcheck():
- return {"status": True}
- @app.get("/health/db")
- async def healthcheck_with_db():
- Session.execute(text("SELECT 1;")).all()
- return {"status": True}
- app.mount("/static", StaticFiles(directory=STATIC_DIR), name="static")
- app.mount("/cache", StaticFiles(directory=CACHE_DIR), name="cache")
- if os.path.exists(FRONTEND_BUILD_DIR):
- mimetypes.add_type("text/javascript", ".js")
- app.mount(
- "/",
- SPAStaticFiles(directory=FRONTEND_BUILD_DIR, html=True),
- name="spa-static-files",
- )
- else:
- log.warning(
- f"Frontend build directory not found at '{FRONTEND_BUILD_DIR}'. Serving API only."
- )
|