main.py 64 KB

1234567891011121314151617181920212223242526272829303132333435363738394041424344454647484950515253545556575859606162636465666768697071727374757677787980818283848586878889909192939495969798991001011021031041051061071081091101111121131141151161171181191201211221231241251261271281291301311321331341351361371381391401411421431441451461471481491501511521531541551561571581591601611621631641651661671681691701711721731741751761771781791801811821831841851861871881891901911921931941951961971981992002012022032042052062072082092102112122132142152162172182192202212222232242252262272282292302312322332342352362372382392402412422432442452462472482492502512522532542552562572582592602612622632642652662672682692702712722732742752762772782792802812822832842852862872882892902912922932942952962972982993003013023033043053063073083093103113123133143153163173183193203213223233243253263273283293303313323333343353363373383393403413423433443453463473483493503513523533543553563573583593603613623633643653663673683693703713723733743753763773783793803813823833843853863873883893903913923933943953963973983994004014024034044054064074084094104114124134144154164174184194204214224234244254264274284294304314324334344354364374384394404414424434444454464474484494504514524534544554564574584594604614624634644654664674684694704714724734744754764774784794804814824834844854864874884894904914924934944954964974984995005015025035045055065075085095105115125135145155165175185195205215225235245255265275285295305315325335345355365375385395405415425435445455465475485495505515525535545555565575585595605615625635645655665675685695705715725735745755765775785795805815825835845855865875885895905915925935945955965975985996006016026036046056066076086096106116126136146156166176186196206216226236246256266276286296306316326336346356366376386396406416426436446456466476486496506516526536546556566576586596606616626636646656666676686696706716726736746756766776786796806816826836846856866876886896906916926936946956966976986997007017027037047057067077087097107117127137147157167177187197207217227237247257267277287297307317327337347357367377387397407417427437447457467477487497507517527537547557567577587597607617627637647657667677687697707717727737747757767777787797807817827837847857867877887897907917927937947957967977987998008018028038048058068078088098108118128138148158168178188198208218228238248258268278288298308318328338348358368378388398408418428438448458468478488498508518528538548558568578588598608618628638648658668678688698708718728738748758768778788798808818828838848858868878888898908918928938948958968978988999009019029039049059069079089099109119129139149159169179189199209219229239249259269279289299309319329339349359369379389399409419429439449459469479489499509519529539549559569579589599609619629639649659669679689699709719729739749759769779789799809819829839849859869879889899909919929939949959969979989991000100110021003100410051006100710081009101010111012101310141015101610171018101910201021102210231024102510261027102810291030103110321033103410351036103710381039104010411042104310441045104610471048104910501051105210531054105510561057105810591060106110621063106410651066106710681069107010711072107310741075107610771078107910801081108210831084108510861087108810891090109110921093109410951096109710981099110011011102110311041105110611071108110911101111111211131114111511161117111811191120112111221123112411251126112711281129113011311132113311341135113611371138113911401141114211431144114511461147114811491150115111521153115411551156115711581159116011611162116311641165116611671168116911701171117211731174117511761177117811791180118111821183118411851186118711881189119011911192119311941195119611971198119912001201120212031204120512061207120812091210121112121213121412151216121712181219122012211222122312241225122612271228122912301231123212331234123512361237123812391240124112421243124412451246124712481249125012511252125312541255125612571258125912601261126212631264126512661267126812691270127112721273127412751276127712781279128012811282128312841285128612871288128912901291129212931294129512961297129812991300130113021303130413051306130713081309131013111312131313141315131613171318131913201321132213231324132513261327132813291330133113321333133413351336133713381339134013411342134313441345134613471348134913501351135213531354135513561357135813591360136113621363136413651366136713681369137013711372137313741375137613771378137913801381138213831384138513861387138813891390139113921393139413951396139713981399140014011402140314041405140614071408140914101411141214131414141514161417141814191420142114221423142414251426142714281429143014311432143314341435143614371438143914401441144214431444144514461447144814491450145114521453145414551456145714581459146014611462146314641465146614671468146914701471147214731474147514761477147814791480148114821483148414851486148714881489149014911492149314941495149614971498149915001501150215031504150515061507150815091510151115121513151415151516151715181519152015211522152315241525152615271528152915301531153215331534153515361537153815391540154115421543154415451546154715481549155015511552155315541555155615571558155915601561156215631564156515661567156815691570157115721573157415751576157715781579158015811582158315841585158615871588158915901591159215931594159515961597159815991600160116021603160416051606160716081609161016111612161316141615161616171618161916201621162216231624162516261627162816291630163116321633163416351636163716381639164016411642164316441645164616471648164916501651165216531654165516561657165816591660166116621663166416651666166716681669167016711672167316741675167616771678167916801681168216831684168516861687168816891690169116921693169416951696169716981699170017011702170317041705170617071708170917101711171217131714171517161717171817191720172117221723172417251726172717281729173017311732173317341735173617371738173917401741174217431744174517461747174817491750175117521753175417551756175717581759176017611762176317641765176617671768176917701771177217731774177517761777177817791780178117821783178417851786178717881789179017911792179317941795179617971798179918001801180218031804180518061807180818091810181118121813181418151816181718181819182018211822182318241825182618271828182918301831183218331834183518361837183818391840184118421843184418451846184718481849185018511852185318541855185618571858185918601861186218631864186518661867186818691870187118721873187418751876187718781879188018811882188318841885188618871888188918901891189218931894189518961897189818991900190119021903190419051906190719081909191019111912191319141915191619171918191919201921192219231924192519261927192819291930193119321933193419351936193719381939194019411942194319441945194619471948194919501951195219531954195519561957195819591960196119621963196419651966196719681969197019711972197319741975197619771978197919801981198219831984198519861987
  1. import uuid
  2. from contextlib import asynccontextmanager
  3. from authlib.integrations.starlette_client import OAuth
  4. from authlib.oidc.core import UserInfo
  5. from bs4 import BeautifulSoup
  6. import json
  7. import markdown
  8. import time
  9. import os
  10. import sys
  11. import logging
  12. import aiohttp
  13. import requests
  14. import mimetypes
  15. import shutil
  16. import os
  17. import uuid
  18. import inspect
  19. import asyncio
  20. from fastapi.concurrency import run_in_threadpool
  21. from fastapi import FastAPI, Request, Depends, status, UploadFile, File, Form
  22. from fastapi.staticfiles import StaticFiles
  23. from fastapi.responses import JSONResponse
  24. from fastapi import HTTPException
  25. from fastapi.middleware.wsgi import WSGIMiddleware
  26. from fastapi.middleware.cors import CORSMiddleware
  27. from starlette.exceptions import HTTPException as StarletteHTTPException
  28. from starlette.middleware.base import BaseHTTPMiddleware
  29. from starlette.middleware.sessions import SessionMiddleware
  30. from starlette.responses import StreamingResponse, Response, RedirectResponse
  31. from apps.socket.main import app as socket_app
  32. from apps.ollama.main import (
  33. app as ollama_app,
  34. OpenAIChatCompletionForm,
  35. get_all_models as get_ollama_models,
  36. generate_openai_chat_completion as generate_ollama_chat_completion,
  37. )
  38. from apps.openai.main import (
  39. app as openai_app,
  40. get_all_models as get_openai_models,
  41. generate_chat_completion as generate_openai_chat_completion,
  42. )
  43. from apps.audio.main import app as audio_app
  44. from apps.images.main import app as images_app
  45. from apps.rag.main import app as rag_app
  46. from apps.webui.main import app as webui_app, get_pipe_models
  47. from pydantic import BaseModel
  48. from typing import List, Optional, Iterator, Generator, Union
  49. from apps.webui.models.auths import Auths
  50. from apps.webui.models.models import Models, ModelModel
  51. from apps.webui.models.tools import Tools
  52. from apps.webui.models.functions import Functions
  53. from apps.webui.models.users import Users
  54. from apps.webui.utils import load_toolkit_module_by_id, load_function_module_by_id
  55. from apps.webui.utils import load_toolkit_module_by_id
  56. from utils.misc import parse_duration
  57. from utils.utils import (
  58. get_admin_user,
  59. get_verified_user,
  60. get_current_user,
  61. get_http_authorization_cred,
  62. get_password_hash,
  63. create_token,
  64. )
  65. from utils.task import (
  66. title_generation_template,
  67. search_query_generation_template,
  68. tools_function_calling_generation_template,
  69. )
  70. from utils.misc import (
  71. get_last_user_message,
  72. add_or_update_system_message,
  73. stream_message_template,
  74. )
  75. from apps.rag.utils import get_rag_context, rag_template
  76. from config import (
  77. CONFIG_DATA,
  78. WEBUI_NAME,
  79. WEBUI_URL,
  80. WEBUI_AUTH,
  81. ENV,
  82. VERSION,
  83. CHANGELOG,
  84. FRONTEND_BUILD_DIR,
  85. UPLOAD_DIR,
  86. CACHE_DIR,
  87. STATIC_DIR,
  88. ENABLE_OPENAI_API,
  89. ENABLE_OLLAMA_API,
  90. ENABLE_MODEL_FILTER,
  91. MODEL_FILTER_LIST,
  92. GLOBAL_LOG_LEVEL,
  93. SRC_LOG_LEVELS,
  94. WEBHOOK_URL,
  95. ENABLE_ADMIN_EXPORT,
  96. WEBUI_BUILD_HASH,
  97. TASK_MODEL,
  98. TASK_MODEL_EXTERNAL,
  99. TITLE_GENERATION_PROMPT_TEMPLATE,
  100. SEARCH_QUERY_GENERATION_PROMPT_TEMPLATE,
  101. SEARCH_QUERY_PROMPT_LENGTH_THRESHOLD,
  102. TOOLS_FUNCTION_CALLING_PROMPT_TEMPLATE,
  103. OAUTH_PROVIDERS,
  104. ENABLE_OAUTH_SIGNUP,
  105. OAUTH_MERGE_ACCOUNTS_BY_EMAIL,
  106. WEBUI_SECRET_KEY,
  107. WEBUI_SESSION_COOKIE_SAME_SITE,
  108. WEBUI_SESSION_COOKIE_SECURE,
  109. AppConfig,
  110. )
  111. from constants import ERROR_MESSAGES, WEBHOOK_MESSAGES
  112. from utils.webhook import post_webhook
  113. logging.basicConfig(stream=sys.stdout, level=GLOBAL_LOG_LEVEL)
  114. log = logging.getLogger(__name__)
  115. log.setLevel(SRC_LOG_LEVELS["MAIN"])
  116. class SPAStaticFiles(StaticFiles):
  117. async def get_response(self, path: str, scope):
  118. try:
  119. return await super().get_response(path, scope)
  120. except (HTTPException, StarletteHTTPException) as ex:
  121. if ex.status_code == 404:
  122. return await super().get_response("index.html", scope)
  123. else:
  124. raise ex
  125. print(
  126. rf"""
  127. ___ __ __ _ _ _ ___
  128. / _ \ _ __ ___ _ __ \ \ / /__| |__ | | | |_ _|
  129. | | | | '_ \ / _ \ '_ \ \ \ /\ / / _ \ '_ \| | | || |
  130. | |_| | |_) | __/ | | | \ V V / __/ |_) | |_| || |
  131. \___/| .__/ \___|_| |_| \_/\_/ \___|_.__/ \___/|___|
  132. |_|
  133. v{VERSION} - building the best open-source AI user interface.
  134. {f"Commit: {WEBUI_BUILD_HASH}" if WEBUI_BUILD_HASH != "dev-build" else ""}
  135. https://github.com/open-webui/open-webui
  136. """
  137. )
  138. @asynccontextmanager
  139. async def lifespan(app: FastAPI):
  140. yield
  141. app = FastAPI(
  142. docs_url="/docs" if ENV == "dev" else None, redoc_url=None, lifespan=lifespan
  143. )
  144. app.state.config = AppConfig()
  145. app.state.config.ENABLE_OPENAI_API = ENABLE_OPENAI_API
  146. app.state.config.ENABLE_OLLAMA_API = ENABLE_OLLAMA_API
  147. app.state.config.ENABLE_MODEL_FILTER = ENABLE_MODEL_FILTER
  148. app.state.config.MODEL_FILTER_LIST = MODEL_FILTER_LIST
  149. app.state.config.WEBHOOK_URL = WEBHOOK_URL
  150. app.state.config.TASK_MODEL = TASK_MODEL
  151. app.state.config.TASK_MODEL_EXTERNAL = TASK_MODEL_EXTERNAL
  152. app.state.config.TITLE_GENERATION_PROMPT_TEMPLATE = TITLE_GENERATION_PROMPT_TEMPLATE
  153. app.state.config.SEARCH_QUERY_GENERATION_PROMPT_TEMPLATE = (
  154. SEARCH_QUERY_GENERATION_PROMPT_TEMPLATE
  155. )
  156. app.state.config.SEARCH_QUERY_PROMPT_LENGTH_THRESHOLD = (
  157. SEARCH_QUERY_PROMPT_LENGTH_THRESHOLD
  158. )
  159. app.state.config.TOOLS_FUNCTION_CALLING_PROMPT_TEMPLATE = (
  160. TOOLS_FUNCTION_CALLING_PROMPT_TEMPLATE
  161. )
  162. app.state.MODELS = {}
  163. origins = ["*"]
  164. ##################################
  165. #
  166. # ChatCompletion Middleware
  167. #
  168. ##################################
  169. async def get_function_call_response(
  170. messages, files, tool_id, template, task_model_id, user
  171. ):
  172. tool = Tools.get_tool_by_id(tool_id)
  173. tools_specs = json.dumps(tool.specs, indent=2)
  174. content = tools_function_calling_generation_template(template, tools_specs)
  175. user_message = get_last_user_message(messages)
  176. prompt = (
  177. "History:\n"
  178. + "\n".join(
  179. [
  180. f"{message['role'].upper()}: \"\"\"{message['content']}\"\"\""
  181. for message in messages[::-1][:4]
  182. ]
  183. )
  184. + f"\nQuery: {user_message}"
  185. )
  186. print(prompt)
  187. payload = {
  188. "model": task_model_id,
  189. "messages": [
  190. {"role": "system", "content": content},
  191. {"role": "user", "content": f"Query: {prompt}"},
  192. ],
  193. "stream": False,
  194. }
  195. try:
  196. payload = filter_pipeline(payload, user)
  197. except Exception as e:
  198. raise e
  199. model = app.state.MODELS[task_model_id]
  200. response = None
  201. try:
  202. if model["owned_by"] == "ollama":
  203. response = await generate_ollama_chat_completion(payload, user=user)
  204. else:
  205. response = await generate_openai_chat_completion(payload, user=user)
  206. content = None
  207. if hasattr(response, "body_iterator"):
  208. async for chunk in response.body_iterator:
  209. data = json.loads(chunk.decode("utf-8"))
  210. content = data["choices"][0]["message"]["content"]
  211. # Cleanup any remaining background tasks if necessary
  212. if response.background is not None:
  213. await response.background()
  214. else:
  215. content = response["choices"][0]["message"]["content"]
  216. # Parse the function response
  217. if content is not None:
  218. print(f"content: {content}")
  219. result = json.loads(content)
  220. print(result)
  221. citation = None
  222. # Call the function
  223. if "name" in result:
  224. if tool_id in webui_app.state.TOOLS:
  225. toolkit_module = webui_app.state.TOOLS[tool_id]
  226. else:
  227. toolkit_module = load_toolkit_module_by_id(tool_id)
  228. webui_app.state.TOOLS[tool_id] = toolkit_module
  229. file_handler = False
  230. # check if toolkit_module has file_handler self variable
  231. if hasattr(toolkit_module, "file_handler"):
  232. file_handler = True
  233. print("file_handler: ", file_handler)
  234. function = getattr(toolkit_module, result["name"])
  235. function_result = None
  236. try:
  237. # Get the signature of the function
  238. sig = inspect.signature(function)
  239. params = result["parameters"]
  240. if "__user__" in sig.parameters:
  241. # Call the function with the '__user__' parameter included
  242. params = {
  243. **params,
  244. "__user__": {
  245. "id": user.id,
  246. "email": user.email,
  247. "name": user.name,
  248. "role": user.role,
  249. },
  250. }
  251. if "__messages__" in sig.parameters:
  252. # Call the function with the '__messages__' parameter included
  253. params = {
  254. **params,
  255. "__messages__": messages,
  256. }
  257. if "__files__" in sig.parameters:
  258. # Call the function with the '__files__' parameter included
  259. params = {
  260. **params,
  261. "__files__": files,
  262. }
  263. if "__model__" in sig.parameters:
  264. # Call the function with the '__model__' parameter included
  265. params = {
  266. **params,
  267. "__model__": model,
  268. }
  269. if "__id__" in sig.parameters:
  270. # Call the function with the '__id__' parameter included
  271. params = {
  272. **params,
  273. "__id__": tool_id,
  274. }
  275. if inspect.iscoroutinefunction(function):
  276. function_result = await function(**params)
  277. else:
  278. function_result = function(**params)
  279. if hasattr(toolkit_module, "citation") and toolkit_module.citation:
  280. citation = {
  281. "source": {"name": f"TOOL:{tool.name}/{result['name']}"},
  282. "document": [function_result],
  283. "metadata": [{"source": result["name"]}],
  284. }
  285. except Exception as e:
  286. print(e)
  287. # Add the function result to the system prompt
  288. if function_result is not None:
  289. return function_result, citation, file_handler
  290. except Exception as e:
  291. print(f"Error: {e}")
  292. return None, None, False
  293. class ChatCompletionMiddleware(BaseHTTPMiddleware):
  294. async def dispatch(self, request: Request, call_next):
  295. data_items = []
  296. show_citations = False
  297. citations = []
  298. if request.method == "POST" and any(
  299. endpoint in request.url.path
  300. for endpoint in ["/ollama/api/chat", "/chat/completions"]
  301. ):
  302. log.debug(f"request.url.path: {request.url.path}")
  303. # Read the original request body
  304. body = await request.body()
  305. body_str = body.decode("utf-8")
  306. data = json.loads(body_str) if body_str else {}
  307. user = get_current_user(
  308. request,
  309. get_http_authorization_cred(request.headers.get("Authorization")),
  310. )
  311. # Flag to skip RAG completions if file_handler is present in tools/functions
  312. skip_files = False
  313. if data.get("citations"):
  314. show_citations = True
  315. del data["citations"]
  316. model_id = data["model"]
  317. if model_id not in app.state.MODELS:
  318. raise HTTPException(
  319. status_code=status.HTTP_404_NOT_FOUND,
  320. detail="Model not found",
  321. )
  322. model = app.state.MODELS[model_id]
  323. # Check if the model has any filters
  324. if "info" in model and "meta" in model["info"]:
  325. for filter_id in model["info"]["meta"].get("filterIds", []):
  326. filter = Functions.get_function_by_id(filter_id)
  327. if filter:
  328. if filter_id in webui_app.state.FUNCTIONS:
  329. function_module = webui_app.state.FUNCTIONS[filter_id]
  330. else:
  331. function_module, function_type = load_function_module_by_id(
  332. filter_id
  333. )
  334. webui_app.state.FUNCTIONS[filter_id] = function_module
  335. # Check if the function has a file_handler variable
  336. if hasattr(function_module, "file_handler"):
  337. skip_files = function_module.file_handler
  338. try:
  339. if hasattr(function_module, "inlet"):
  340. inlet = function_module.inlet
  341. if inspect.iscoroutinefunction(inlet):
  342. data = await inlet(
  343. data,
  344. {
  345. "id": user.id,
  346. "email": user.email,
  347. "name": user.name,
  348. "role": user.role,
  349. },
  350. )
  351. else:
  352. data = inlet(
  353. data,
  354. {
  355. "id": user.id,
  356. "email": user.email,
  357. "name": user.name,
  358. "role": user.role,
  359. },
  360. )
  361. except Exception as e:
  362. print(f"Error: {e}")
  363. return JSONResponse(
  364. status_code=status.HTTP_400_BAD_REQUEST,
  365. content={"detail": str(e)},
  366. )
  367. # Set the task model
  368. task_model_id = data["model"]
  369. # Check if the user has a custom task model and use that model
  370. if app.state.MODELS[task_model_id]["owned_by"] == "ollama":
  371. if (
  372. app.state.config.TASK_MODEL
  373. and app.state.config.TASK_MODEL in app.state.MODELS
  374. ):
  375. task_model_id = app.state.config.TASK_MODEL
  376. else:
  377. if (
  378. app.state.config.TASK_MODEL_EXTERNAL
  379. and app.state.config.TASK_MODEL_EXTERNAL in app.state.MODELS
  380. ):
  381. task_model_id = app.state.config.TASK_MODEL_EXTERNAL
  382. prompt = get_last_user_message(data["messages"])
  383. context = ""
  384. # If tool_ids field is present, call the functions
  385. if "tool_ids" in data:
  386. print(data["tool_ids"])
  387. for tool_id in data["tool_ids"]:
  388. print(tool_id)
  389. try:
  390. response, citation, file_handler = (
  391. await get_function_call_response(
  392. messages=data["messages"],
  393. files=data.get("files", []),
  394. tool_id=tool_id,
  395. template=app.state.config.TOOLS_FUNCTION_CALLING_PROMPT_TEMPLATE,
  396. task_model_id=task_model_id,
  397. user=user,
  398. )
  399. )
  400. print(file_handler)
  401. if isinstance(response, str):
  402. context += ("\n" if context != "" else "") + response
  403. if citation:
  404. citations.append(citation)
  405. show_citations = True
  406. if file_handler:
  407. skip_files = True
  408. except Exception as e:
  409. print(f"Error: {e}")
  410. del data["tool_ids"]
  411. print(f"tool_context: {context}")
  412. # If files field is present, generate RAG completions
  413. # If skip_files is True, skip the RAG completions
  414. if "files" in data:
  415. if not skip_files:
  416. data = {**data}
  417. rag_context, rag_citations = get_rag_context(
  418. files=data["files"],
  419. messages=data["messages"],
  420. embedding_function=rag_app.state.EMBEDDING_FUNCTION,
  421. k=rag_app.state.config.TOP_K,
  422. reranking_function=rag_app.state.sentence_transformer_rf,
  423. r=rag_app.state.config.RELEVANCE_THRESHOLD,
  424. hybrid_search=rag_app.state.config.ENABLE_RAG_HYBRID_SEARCH,
  425. )
  426. if rag_context:
  427. context += ("\n" if context != "" else "") + rag_context
  428. log.debug(f"rag_context: {rag_context}, citations: {citations}")
  429. if rag_citations:
  430. citations.extend(rag_citations)
  431. del data["files"]
  432. if show_citations and len(citations) > 0:
  433. data_items.append({"citations": citations})
  434. if context != "":
  435. system_prompt = rag_template(
  436. rag_app.state.config.RAG_TEMPLATE, context, prompt
  437. )
  438. print(system_prompt)
  439. data["messages"] = add_or_update_system_message(
  440. system_prompt, data["messages"]
  441. )
  442. modified_body_bytes = json.dumps(data).encode("utf-8")
  443. # Replace the request body with the modified one
  444. request._body = modified_body_bytes
  445. # Set custom header to ensure content-length matches new body length
  446. request.headers.__dict__["_list"] = [
  447. (b"content-length", str(len(modified_body_bytes)).encode("utf-8")),
  448. *[
  449. (k, v)
  450. for k, v in request.headers.raw
  451. if k.lower() != b"content-length"
  452. ],
  453. ]
  454. response = await call_next(request)
  455. if isinstance(response, StreamingResponse):
  456. # If it's a streaming response, inject it as SSE event or NDJSON line
  457. content_type = response.headers.get("Content-Type")
  458. if "text/event-stream" in content_type:
  459. return StreamingResponse(
  460. self.openai_stream_wrapper(response.body_iterator, data_items),
  461. )
  462. if "application/x-ndjson" in content_type:
  463. return StreamingResponse(
  464. self.ollama_stream_wrapper(response.body_iterator, data_items),
  465. )
  466. else:
  467. return response
  468. # If it's not a chat completion request, just pass it through
  469. response = await call_next(request)
  470. return response
  471. async def _receive(self, body: bytes):
  472. return {"type": "http.request", "body": body, "more_body": False}
  473. async def openai_stream_wrapper(self, original_generator, data_items):
  474. for item in data_items:
  475. yield f"data: {json.dumps(item)}\n\n"
  476. async for data in original_generator:
  477. yield data
  478. async def ollama_stream_wrapper(self, original_generator, data_items):
  479. for item in data_items:
  480. yield f"{json.dumps(item)}\n"
  481. async for data in original_generator:
  482. yield data
  483. app.add_middleware(ChatCompletionMiddleware)
  484. ##################################
  485. #
  486. # Pipeline Middleware
  487. #
  488. ##################################
  489. def filter_pipeline(payload, user):
  490. user = {"id": user.id, "email": user.email, "name": user.name, "role": user.role}
  491. model_id = payload["model"]
  492. filters = [
  493. model
  494. for model in app.state.MODELS.values()
  495. if "pipeline" in model
  496. and "type" in model["pipeline"]
  497. and model["pipeline"]["type"] == "filter"
  498. and (
  499. model["pipeline"]["pipelines"] == ["*"]
  500. or any(
  501. model_id == target_model_id
  502. for target_model_id in model["pipeline"]["pipelines"]
  503. )
  504. )
  505. ]
  506. sorted_filters = sorted(filters, key=lambda x: x["pipeline"]["priority"])
  507. model = app.state.MODELS[model_id]
  508. if "pipeline" in model:
  509. sorted_filters.append(model)
  510. for filter in sorted_filters:
  511. r = None
  512. try:
  513. urlIdx = filter["urlIdx"]
  514. url = openai_app.state.config.OPENAI_API_BASE_URLS[urlIdx]
  515. key = openai_app.state.config.OPENAI_API_KEYS[urlIdx]
  516. if key != "":
  517. headers = {"Authorization": f"Bearer {key}"}
  518. r = requests.post(
  519. f"{url}/{filter['id']}/filter/inlet",
  520. headers=headers,
  521. json={
  522. "user": user,
  523. "body": payload,
  524. },
  525. )
  526. r.raise_for_status()
  527. payload = r.json()
  528. except Exception as e:
  529. # Handle connection error here
  530. print(f"Connection error: {e}")
  531. if r is not None:
  532. try:
  533. res = r.json()
  534. except:
  535. pass
  536. if "detail" in res:
  537. raise Exception(r.status_code, res["detail"])
  538. else:
  539. pass
  540. if "pipeline" not in app.state.MODELS[model_id]:
  541. if "chat_id" in payload:
  542. del payload["chat_id"]
  543. if "title" in payload:
  544. del payload["title"]
  545. if "task" in payload:
  546. del payload["task"]
  547. return payload
  548. class PipelineMiddleware(BaseHTTPMiddleware):
  549. async def dispatch(self, request: Request, call_next):
  550. if request.method == "POST" and (
  551. "/ollama/api/chat" in request.url.path
  552. or "/chat/completions" in request.url.path
  553. ):
  554. log.debug(f"request.url.path: {request.url.path}")
  555. # Read the original request body
  556. body = await request.body()
  557. # Decode body to string
  558. body_str = body.decode("utf-8")
  559. # Parse string to JSON
  560. data = json.loads(body_str) if body_str else {}
  561. user = get_current_user(
  562. request,
  563. get_http_authorization_cred(request.headers.get("Authorization")),
  564. )
  565. try:
  566. data = filter_pipeline(data, user)
  567. except Exception as e:
  568. return JSONResponse(
  569. status_code=e.args[0],
  570. content={"detail": e.args[1]},
  571. )
  572. modified_body_bytes = json.dumps(data).encode("utf-8")
  573. # Replace the request body with the modified one
  574. request._body = modified_body_bytes
  575. # Set custom header to ensure content-length matches new body length
  576. request.headers.__dict__["_list"] = [
  577. (b"content-length", str(len(modified_body_bytes)).encode("utf-8")),
  578. *[
  579. (k, v)
  580. for k, v in request.headers.raw
  581. if k.lower() != b"content-length"
  582. ],
  583. ]
  584. response = await call_next(request)
  585. return response
  586. async def _receive(self, body: bytes):
  587. return {"type": "http.request", "body": body, "more_body": False}
  588. app.add_middleware(PipelineMiddleware)
  589. app.add_middleware(
  590. CORSMiddleware,
  591. allow_origins=origins,
  592. allow_credentials=True,
  593. allow_methods=["*"],
  594. allow_headers=["*"],
  595. )
  596. @app.middleware("http")
  597. async def check_url(request: Request, call_next):
  598. if len(app.state.MODELS) == 0:
  599. await get_all_models()
  600. else:
  601. pass
  602. start_time = int(time.time())
  603. response = await call_next(request)
  604. process_time = int(time.time()) - start_time
  605. response.headers["X-Process-Time"] = str(process_time)
  606. return response
  607. @app.middleware("http")
  608. async def update_embedding_function(request: Request, call_next):
  609. response = await call_next(request)
  610. if "/embedding/update" in request.url.path:
  611. webui_app.state.EMBEDDING_FUNCTION = rag_app.state.EMBEDDING_FUNCTION
  612. return response
  613. app.mount("/ws", socket_app)
  614. app.mount("/ollama", ollama_app)
  615. app.mount("/openai", openai_app)
  616. app.mount("/images/api/v1", images_app)
  617. app.mount("/audio/api/v1", audio_app)
  618. app.mount("/rag/api/v1", rag_app)
  619. app.mount("/api/v1", webui_app)
  620. webui_app.state.EMBEDDING_FUNCTION = rag_app.state.EMBEDDING_FUNCTION
  621. async def get_all_models():
  622. pipe_models = []
  623. openai_models = []
  624. ollama_models = []
  625. pipe_models = await get_pipe_models()
  626. if app.state.config.ENABLE_OPENAI_API:
  627. openai_models = await get_openai_models()
  628. openai_models = openai_models["data"]
  629. if app.state.config.ENABLE_OLLAMA_API:
  630. ollama_models = await get_ollama_models()
  631. ollama_models = [
  632. {
  633. "id": model["model"],
  634. "name": model["name"],
  635. "object": "model",
  636. "created": int(time.time()),
  637. "owned_by": "ollama",
  638. "ollama": model,
  639. }
  640. for model in ollama_models["models"]
  641. ]
  642. models = pipe_models + openai_models + ollama_models
  643. custom_models = Models.get_all_models()
  644. for custom_model in custom_models:
  645. if custom_model.base_model_id == None:
  646. for model in models:
  647. if (
  648. custom_model.id == model["id"]
  649. or custom_model.id == model["id"].split(":")[0]
  650. ):
  651. model["name"] = custom_model.name
  652. model["info"] = custom_model.model_dump()
  653. else:
  654. owned_by = "openai"
  655. for model in models:
  656. if (
  657. custom_model.base_model_id == model["id"]
  658. or custom_model.base_model_id == model["id"].split(":")[0]
  659. ):
  660. owned_by = model["owned_by"]
  661. break
  662. models.append(
  663. {
  664. "id": custom_model.id,
  665. "name": custom_model.name,
  666. "object": "model",
  667. "created": custom_model.created_at,
  668. "owned_by": owned_by,
  669. "info": custom_model.model_dump(),
  670. "preset": True,
  671. }
  672. )
  673. app.state.MODELS = {model["id"]: model for model in models}
  674. webui_app.state.MODELS = app.state.MODELS
  675. return models
  676. @app.get("/api/models")
  677. async def get_models(user=Depends(get_verified_user)):
  678. models = await get_all_models()
  679. # Filter out filter pipelines
  680. models = [
  681. model
  682. for model in models
  683. if "pipeline" not in model or model["pipeline"].get("type", None) != "filter"
  684. ]
  685. if app.state.config.ENABLE_MODEL_FILTER:
  686. if user.role == "user":
  687. models = list(
  688. filter(
  689. lambda model: model["id"] in app.state.config.MODEL_FILTER_LIST,
  690. models,
  691. )
  692. )
  693. return {"data": models}
  694. return {"data": models}
  695. @app.post("/api/chat/completions")
  696. async def generate_chat_completions(form_data: dict, user=Depends(get_verified_user)):
  697. model_id = form_data["model"]
  698. if model_id not in app.state.MODELS:
  699. raise HTTPException(
  700. status_code=status.HTTP_404_NOT_FOUND,
  701. detail="Model not found",
  702. )
  703. model = app.state.MODELS[model_id]
  704. print(model)
  705. pipe = model.get("pipe")
  706. if pipe:
  707. form_data["user"] = {
  708. "id": user.id,
  709. "email": user.email,
  710. "name": user.name,
  711. "role": user.role,
  712. }
  713. async def job():
  714. pipe_id = form_data["model"]
  715. if "." in pipe_id:
  716. pipe_id, sub_pipe_id = pipe_id.split(".", 1)
  717. print(pipe_id)
  718. pipe = webui_app.state.FUNCTIONS[pipe_id].pipe
  719. if form_data["stream"]:
  720. async def stream_content():
  721. if inspect.iscoroutinefunction(pipe):
  722. res = await pipe(body=form_data)
  723. else:
  724. res = pipe(body=form_data)
  725. if isinstance(res, str):
  726. message = stream_message_template(form_data["model"], res)
  727. yield f"data: {json.dumps(message)}\n\n"
  728. if isinstance(res, Iterator):
  729. for line in res:
  730. if isinstance(line, BaseModel):
  731. line = line.model_dump_json()
  732. line = f"data: {line}"
  733. try:
  734. line = line.decode("utf-8")
  735. except:
  736. pass
  737. if line.startswith("data:"):
  738. yield f"{line}\n\n"
  739. else:
  740. line = stream_message_template(form_data["model"], line)
  741. yield f"data: {json.dumps(line)}\n\n"
  742. if isinstance(res, str) or isinstance(res, Generator):
  743. finish_message = {
  744. "id": f"{form_data['model']}-{str(uuid.uuid4())}",
  745. "object": "chat.completion.chunk",
  746. "created": int(time.time()),
  747. "model": form_data["model"],
  748. "choices": [
  749. {
  750. "index": 0,
  751. "delta": {},
  752. "logprobs": None,
  753. "finish_reason": "stop",
  754. }
  755. ],
  756. }
  757. yield f"data: {json.dumps(finish_message)}\n\n"
  758. yield f"data: [DONE]"
  759. return StreamingResponse(
  760. stream_content(), media_type="text/event-stream"
  761. )
  762. else:
  763. if inspect.iscoroutinefunction(pipe):
  764. res = await pipe(body=form_data)
  765. else:
  766. res = pipe(body=form_data)
  767. if isinstance(res, dict):
  768. return res
  769. elif isinstance(res, BaseModel):
  770. return res.model_dump()
  771. else:
  772. message = ""
  773. if isinstance(res, str):
  774. message = res
  775. if isinstance(res, Generator):
  776. for stream in res:
  777. message = f"{message}{stream}"
  778. return {
  779. "id": f"{form_data['model']}-{str(uuid.uuid4())}",
  780. "object": "chat.completion",
  781. "created": int(time.time()),
  782. "model": form_data["model"],
  783. "choices": [
  784. {
  785. "index": 0,
  786. "message": {
  787. "role": "assistant",
  788. "content": message,
  789. },
  790. "logprobs": None,
  791. "finish_reason": "stop",
  792. }
  793. ],
  794. }
  795. return await job()
  796. if model["owned_by"] == "ollama":
  797. return await generate_ollama_chat_completion(form_data, user=user)
  798. else:
  799. return await generate_openai_chat_completion(form_data, user=user)
  800. @app.post("/api/chat/completed")
  801. async def chat_completed(form_data: dict, user=Depends(get_verified_user)):
  802. data = form_data
  803. model_id = data["model"]
  804. if model_id not in app.state.MODELS:
  805. raise HTTPException(
  806. status_code=status.HTTP_404_NOT_FOUND,
  807. detail="Model not found",
  808. )
  809. model = app.state.MODELS[model_id]
  810. filters = [
  811. model
  812. for model in app.state.MODELS.values()
  813. if "pipeline" in model
  814. and "type" in model["pipeline"]
  815. and model["pipeline"]["type"] == "filter"
  816. and (
  817. model["pipeline"]["pipelines"] == ["*"]
  818. or any(
  819. model_id == target_model_id
  820. for target_model_id in model["pipeline"]["pipelines"]
  821. )
  822. )
  823. ]
  824. sorted_filters = sorted(filters, key=lambda x: x["pipeline"]["priority"])
  825. if "pipeline" in model:
  826. sorted_filters = [model] + sorted_filters
  827. for filter in sorted_filters:
  828. r = None
  829. try:
  830. urlIdx = filter["urlIdx"]
  831. url = openai_app.state.config.OPENAI_API_BASE_URLS[urlIdx]
  832. key = openai_app.state.config.OPENAI_API_KEYS[urlIdx]
  833. if key != "":
  834. headers = {"Authorization": f"Bearer {key}"}
  835. r = requests.post(
  836. f"{url}/{filter['id']}/filter/outlet",
  837. headers=headers,
  838. json={
  839. "user": {"id": user.id, "name": user.name, "role": user.role},
  840. "body": data,
  841. },
  842. )
  843. r.raise_for_status()
  844. data = r.json()
  845. except Exception as e:
  846. # Handle connection error here
  847. print(f"Connection error: {e}")
  848. if r is not None:
  849. try:
  850. res = r.json()
  851. if "detail" in res:
  852. return JSONResponse(
  853. status_code=r.status_code,
  854. content=res,
  855. )
  856. except:
  857. pass
  858. else:
  859. pass
  860. # Check if the model has any filters
  861. if "info" in model and "meta" in model["info"]:
  862. for filter_id in model["info"]["meta"].get("filterIds", []):
  863. filter = Functions.get_function_by_id(filter_id)
  864. if filter:
  865. if filter_id in webui_app.state.FUNCTIONS:
  866. function_module = webui_app.state.FUNCTIONS[filter_id]
  867. else:
  868. function_module, function_type = load_function_module_by_id(
  869. filter_id
  870. )
  871. webui_app.state.FUNCTIONS[filter_id] = function_module
  872. try:
  873. if hasattr(function_module, "outlet"):
  874. outlet = function_module.outlet
  875. if inspect.iscoroutinefunction(outlet):
  876. data = await outlet(
  877. data,
  878. {
  879. "id": user.id,
  880. "email": user.email,
  881. "name": user.name,
  882. "role": user.role,
  883. },
  884. )
  885. else:
  886. data = outlet(
  887. data,
  888. {
  889. "id": user.id,
  890. "email": user.email,
  891. "name": user.name,
  892. "role": user.role,
  893. },
  894. )
  895. except Exception as e:
  896. print(f"Error: {e}")
  897. return JSONResponse(
  898. status_code=status.HTTP_400_BAD_REQUEST,
  899. content={"detail": str(e)},
  900. )
  901. return data
  902. ##################################
  903. #
  904. # Task Endpoints
  905. #
  906. ##################################
  907. # TODO: Refactor task API endpoints below into a separate file
  908. @app.get("/api/task/config")
  909. async def get_task_config(user=Depends(get_verified_user)):
  910. return {
  911. "TASK_MODEL": app.state.config.TASK_MODEL,
  912. "TASK_MODEL_EXTERNAL": app.state.config.TASK_MODEL_EXTERNAL,
  913. "TITLE_GENERATION_PROMPT_TEMPLATE": app.state.config.TITLE_GENERATION_PROMPT_TEMPLATE,
  914. "SEARCH_QUERY_GENERATION_PROMPT_TEMPLATE": app.state.config.SEARCH_QUERY_GENERATION_PROMPT_TEMPLATE,
  915. "SEARCH_QUERY_PROMPT_LENGTH_THRESHOLD": app.state.config.SEARCH_QUERY_PROMPT_LENGTH_THRESHOLD,
  916. "TOOLS_FUNCTION_CALLING_PROMPT_TEMPLATE": app.state.config.TOOLS_FUNCTION_CALLING_PROMPT_TEMPLATE,
  917. }
  918. class TaskConfigForm(BaseModel):
  919. TASK_MODEL: Optional[str]
  920. TASK_MODEL_EXTERNAL: Optional[str]
  921. TITLE_GENERATION_PROMPT_TEMPLATE: str
  922. SEARCH_QUERY_GENERATION_PROMPT_TEMPLATE: str
  923. SEARCH_QUERY_PROMPT_LENGTH_THRESHOLD: int
  924. TOOLS_FUNCTION_CALLING_PROMPT_TEMPLATE: str
  925. @app.post("/api/task/config/update")
  926. async def update_task_config(form_data: TaskConfigForm, user=Depends(get_admin_user)):
  927. app.state.config.TASK_MODEL = form_data.TASK_MODEL
  928. app.state.config.TASK_MODEL_EXTERNAL = form_data.TASK_MODEL_EXTERNAL
  929. app.state.config.TITLE_GENERATION_PROMPT_TEMPLATE = (
  930. form_data.TITLE_GENERATION_PROMPT_TEMPLATE
  931. )
  932. app.state.config.SEARCH_QUERY_GENERATION_PROMPT_TEMPLATE = (
  933. form_data.SEARCH_QUERY_GENERATION_PROMPT_TEMPLATE
  934. )
  935. app.state.config.SEARCH_QUERY_PROMPT_LENGTH_THRESHOLD = (
  936. form_data.SEARCH_QUERY_PROMPT_LENGTH_THRESHOLD
  937. )
  938. app.state.config.TOOLS_FUNCTION_CALLING_PROMPT_TEMPLATE = (
  939. form_data.TOOLS_FUNCTION_CALLING_PROMPT_TEMPLATE
  940. )
  941. return {
  942. "TASK_MODEL": app.state.config.TASK_MODEL,
  943. "TASK_MODEL_EXTERNAL": app.state.config.TASK_MODEL_EXTERNAL,
  944. "TITLE_GENERATION_PROMPT_TEMPLATE": app.state.config.TITLE_GENERATION_PROMPT_TEMPLATE,
  945. "SEARCH_QUERY_GENERATION_PROMPT_TEMPLATE": app.state.config.SEARCH_QUERY_GENERATION_PROMPT_TEMPLATE,
  946. "SEARCH_QUERY_PROMPT_LENGTH_THRESHOLD": app.state.config.SEARCH_QUERY_PROMPT_LENGTH_THRESHOLD,
  947. "TOOLS_FUNCTION_CALLING_PROMPT_TEMPLATE": app.state.config.TOOLS_FUNCTION_CALLING_PROMPT_TEMPLATE,
  948. }
  949. @app.post("/api/task/title/completions")
  950. async def generate_title(form_data: dict, user=Depends(get_verified_user)):
  951. print("generate_title")
  952. model_id = form_data["model"]
  953. if model_id not in app.state.MODELS:
  954. raise HTTPException(
  955. status_code=status.HTTP_404_NOT_FOUND,
  956. detail="Model not found",
  957. )
  958. # Check if the user has a custom task model
  959. # If the user has a custom task model, use that model
  960. if app.state.MODELS[model_id]["owned_by"] == "ollama":
  961. if app.state.config.TASK_MODEL:
  962. task_model_id = app.state.config.TASK_MODEL
  963. if task_model_id in app.state.MODELS:
  964. model_id = task_model_id
  965. else:
  966. if app.state.config.TASK_MODEL_EXTERNAL:
  967. task_model_id = app.state.config.TASK_MODEL_EXTERNAL
  968. if task_model_id in app.state.MODELS:
  969. model_id = task_model_id
  970. print(model_id)
  971. model = app.state.MODELS[model_id]
  972. template = app.state.config.TITLE_GENERATION_PROMPT_TEMPLATE
  973. content = title_generation_template(
  974. template,
  975. form_data["prompt"],
  976. {
  977. "name": user.name,
  978. "location": user.info.get("location") if user.info else None,
  979. },
  980. )
  981. payload = {
  982. "model": model_id,
  983. "messages": [{"role": "user", "content": content}],
  984. "stream": False,
  985. "max_tokens": 50,
  986. "chat_id": form_data.get("chat_id", None),
  987. "title": True,
  988. }
  989. log.debug(payload)
  990. try:
  991. payload = filter_pipeline(payload, user)
  992. except Exception as e:
  993. return JSONResponse(
  994. status_code=e.args[0],
  995. content={"detail": e.args[1]},
  996. )
  997. if model["owned_by"] == "ollama":
  998. return await generate_ollama_chat_completion(payload, user=user)
  999. else:
  1000. return await generate_openai_chat_completion(payload, user=user)
  1001. @app.post("/api/task/query/completions")
  1002. async def generate_search_query(form_data: dict, user=Depends(get_verified_user)):
  1003. print("generate_search_query")
  1004. if len(form_data["prompt"]) < app.state.config.SEARCH_QUERY_PROMPT_LENGTH_THRESHOLD:
  1005. raise HTTPException(
  1006. status_code=status.HTTP_400_BAD_REQUEST,
  1007. detail=f"Skip search query generation for short prompts (< {app.state.config.SEARCH_QUERY_PROMPT_LENGTH_THRESHOLD} characters)",
  1008. )
  1009. model_id = form_data["model"]
  1010. if model_id not in app.state.MODELS:
  1011. raise HTTPException(
  1012. status_code=status.HTTP_404_NOT_FOUND,
  1013. detail="Model not found",
  1014. )
  1015. # Check if the user has a custom task model
  1016. # If the user has a custom task model, use that model
  1017. if app.state.MODELS[model_id]["owned_by"] == "ollama":
  1018. if app.state.config.TASK_MODEL:
  1019. task_model_id = app.state.config.TASK_MODEL
  1020. if task_model_id in app.state.MODELS:
  1021. model_id = task_model_id
  1022. else:
  1023. if app.state.config.TASK_MODEL_EXTERNAL:
  1024. task_model_id = app.state.config.TASK_MODEL_EXTERNAL
  1025. if task_model_id in app.state.MODELS:
  1026. model_id = task_model_id
  1027. print(model_id)
  1028. model = app.state.MODELS[model_id]
  1029. template = app.state.config.SEARCH_QUERY_GENERATION_PROMPT_TEMPLATE
  1030. content = search_query_generation_template(
  1031. template, form_data["prompt"], {"name": user.name}
  1032. )
  1033. payload = {
  1034. "model": model_id,
  1035. "messages": [{"role": "user", "content": content}],
  1036. "stream": False,
  1037. "max_tokens": 30,
  1038. "task": True,
  1039. }
  1040. print(payload)
  1041. try:
  1042. payload = filter_pipeline(payload, user)
  1043. except Exception as e:
  1044. return JSONResponse(
  1045. status_code=e.args[0],
  1046. content={"detail": e.args[1]},
  1047. )
  1048. if model["owned_by"] == "ollama":
  1049. return await generate_ollama_chat_completion(payload, user=user)
  1050. else:
  1051. return await generate_openai_chat_completion(payload, user=user)
  1052. @app.post("/api/task/emoji/completions")
  1053. async def generate_emoji(form_data: dict, user=Depends(get_verified_user)):
  1054. print("generate_emoji")
  1055. model_id = form_data["model"]
  1056. if model_id not in app.state.MODELS:
  1057. raise HTTPException(
  1058. status_code=status.HTTP_404_NOT_FOUND,
  1059. detail="Model not found",
  1060. )
  1061. # Check if the user has a custom task model
  1062. # If the user has a custom task model, use that model
  1063. if app.state.MODELS[model_id]["owned_by"] == "ollama":
  1064. if app.state.config.TASK_MODEL:
  1065. task_model_id = app.state.config.TASK_MODEL
  1066. if task_model_id in app.state.MODELS:
  1067. model_id = task_model_id
  1068. else:
  1069. if app.state.config.TASK_MODEL_EXTERNAL:
  1070. task_model_id = app.state.config.TASK_MODEL_EXTERNAL
  1071. if task_model_id in app.state.MODELS:
  1072. model_id = task_model_id
  1073. print(model_id)
  1074. model = app.state.MODELS[model_id]
  1075. template = '''
  1076. Your task is to reflect the speaker's likely facial expression through a fitting emoji. Interpret emotions from the message and reflect their facial expression using fitting, diverse emojis (e.g., 😊, 😢, 😡, 😱).
  1077. Message: """{{prompt}}"""
  1078. '''
  1079. content = title_generation_template(
  1080. template,
  1081. form_data["prompt"],
  1082. {
  1083. "name": user.name,
  1084. "location": user.info.get("location") if user.info else None,
  1085. },
  1086. )
  1087. payload = {
  1088. "model": model_id,
  1089. "messages": [{"role": "user", "content": content}],
  1090. "stream": False,
  1091. "max_tokens": 4,
  1092. "chat_id": form_data.get("chat_id", None),
  1093. "task": True,
  1094. }
  1095. log.debug(payload)
  1096. try:
  1097. payload = filter_pipeline(payload, user)
  1098. except Exception as e:
  1099. return JSONResponse(
  1100. status_code=e.args[0],
  1101. content={"detail": e.args[1]},
  1102. )
  1103. if model["owned_by"] == "ollama":
  1104. return await generate_ollama_chat_completion(payload, user=user)
  1105. else:
  1106. return await generate_openai_chat_completion(payload, user=user)
  1107. @app.post("/api/task/tools/completions")
  1108. async def get_tools_function_calling(form_data: dict, user=Depends(get_verified_user)):
  1109. print("get_tools_function_calling")
  1110. model_id = form_data["model"]
  1111. if model_id not in app.state.MODELS:
  1112. raise HTTPException(
  1113. status_code=status.HTTP_404_NOT_FOUND,
  1114. detail="Model not found",
  1115. )
  1116. # Check if the user has a custom task model
  1117. # If the user has a custom task model, use that model
  1118. if app.state.MODELS[model_id]["owned_by"] == "ollama":
  1119. if app.state.config.TASK_MODEL:
  1120. task_model_id = app.state.config.TASK_MODEL
  1121. if task_model_id in app.state.MODELS:
  1122. model_id = task_model_id
  1123. else:
  1124. if app.state.config.TASK_MODEL_EXTERNAL:
  1125. task_model_id = app.state.config.TASK_MODEL_EXTERNAL
  1126. if task_model_id in app.state.MODELS:
  1127. model_id = task_model_id
  1128. print(model_id)
  1129. template = app.state.config.TOOLS_FUNCTION_CALLING_PROMPT_TEMPLATE
  1130. try:
  1131. context, citation, file_handler = await get_function_call_response(
  1132. form_data["messages"],
  1133. form_data.get("files", []),
  1134. form_data["tool_id"],
  1135. template,
  1136. model_id,
  1137. user,
  1138. )
  1139. return context
  1140. except Exception as e:
  1141. return JSONResponse(
  1142. status_code=e.args[0],
  1143. content={"detail": e.args[1]},
  1144. )
  1145. ##################################
  1146. #
  1147. # Pipelines Endpoints
  1148. #
  1149. ##################################
  1150. # TODO: Refactor pipelines API endpoints below into a separate file
  1151. @app.get("/api/pipelines/list")
  1152. async def get_pipelines_list(user=Depends(get_admin_user)):
  1153. responses = await get_openai_models(raw=True)
  1154. print(responses)
  1155. urlIdxs = [
  1156. idx
  1157. for idx, response in enumerate(responses)
  1158. if response != None and "pipelines" in response
  1159. ]
  1160. return {
  1161. "data": [
  1162. {
  1163. "url": openai_app.state.config.OPENAI_API_BASE_URLS[urlIdx],
  1164. "idx": urlIdx,
  1165. }
  1166. for urlIdx in urlIdxs
  1167. ]
  1168. }
  1169. @app.post("/api/pipelines/upload")
  1170. async def upload_pipeline(
  1171. urlIdx: int = Form(...), file: UploadFile = File(...), user=Depends(get_admin_user)
  1172. ):
  1173. print("upload_pipeline", urlIdx, file.filename)
  1174. # Check if the uploaded file is a python file
  1175. if not file.filename.endswith(".py"):
  1176. raise HTTPException(
  1177. status_code=status.HTTP_400_BAD_REQUEST,
  1178. detail="Only Python (.py) files are allowed.",
  1179. )
  1180. upload_folder = f"{CACHE_DIR}/pipelines"
  1181. os.makedirs(upload_folder, exist_ok=True)
  1182. file_path = os.path.join(upload_folder, file.filename)
  1183. try:
  1184. # Save the uploaded file
  1185. with open(file_path, "wb") as buffer:
  1186. shutil.copyfileobj(file.file, buffer)
  1187. url = openai_app.state.config.OPENAI_API_BASE_URLS[urlIdx]
  1188. key = openai_app.state.config.OPENAI_API_KEYS[urlIdx]
  1189. headers = {"Authorization": f"Bearer {key}"}
  1190. with open(file_path, "rb") as f:
  1191. files = {"file": f}
  1192. r = requests.post(f"{url}/pipelines/upload", headers=headers, files=files)
  1193. r.raise_for_status()
  1194. data = r.json()
  1195. return {**data}
  1196. except Exception as e:
  1197. # Handle connection error here
  1198. print(f"Connection error: {e}")
  1199. detail = "Pipeline not found"
  1200. if r is not None:
  1201. try:
  1202. res = r.json()
  1203. if "detail" in res:
  1204. detail = res["detail"]
  1205. except:
  1206. pass
  1207. raise HTTPException(
  1208. status_code=(r.status_code if r is not None else status.HTTP_404_NOT_FOUND),
  1209. detail=detail,
  1210. )
  1211. finally:
  1212. # Ensure the file is deleted after the upload is completed or on failure
  1213. if os.path.exists(file_path):
  1214. os.remove(file_path)
  1215. class AddPipelineForm(BaseModel):
  1216. url: str
  1217. urlIdx: int
  1218. @app.post("/api/pipelines/add")
  1219. async def add_pipeline(form_data: AddPipelineForm, user=Depends(get_admin_user)):
  1220. r = None
  1221. try:
  1222. urlIdx = form_data.urlIdx
  1223. url = openai_app.state.config.OPENAI_API_BASE_URLS[urlIdx]
  1224. key = openai_app.state.config.OPENAI_API_KEYS[urlIdx]
  1225. headers = {"Authorization": f"Bearer {key}"}
  1226. r = requests.post(
  1227. f"{url}/pipelines/add", headers=headers, json={"url": form_data.url}
  1228. )
  1229. r.raise_for_status()
  1230. data = r.json()
  1231. return {**data}
  1232. except Exception as e:
  1233. # Handle connection error here
  1234. print(f"Connection error: {e}")
  1235. detail = "Pipeline not found"
  1236. if r is not None:
  1237. try:
  1238. res = r.json()
  1239. if "detail" in res:
  1240. detail = res["detail"]
  1241. except:
  1242. pass
  1243. raise HTTPException(
  1244. status_code=(r.status_code if r is not None else status.HTTP_404_NOT_FOUND),
  1245. detail=detail,
  1246. )
  1247. class DeletePipelineForm(BaseModel):
  1248. id: str
  1249. urlIdx: int
  1250. @app.delete("/api/pipelines/delete")
  1251. async def delete_pipeline(form_data: DeletePipelineForm, user=Depends(get_admin_user)):
  1252. r = None
  1253. try:
  1254. urlIdx = form_data.urlIdx
  1255. url = openai_app.state.config.OPENAI_API_BASE_URLS[urlIdx]
  1256. key = openai_app.state.config.OPENAI_API_KEYS[urlIdx]
  1257. headers = {"Authorization": f"Bearer {key}"}
  1258. r = requests.delete(
  1259. f"{url}/pipelines/delete", headers=headers, json={"id": form_data.id}
  1260. )
  1261. r.raise_for_status()
  1262. data = r.json()
  1263. return {**data}
  1264. except Exception as e:
  1265. # Handle connection error here
  1266. print(f"Connection error: {e}")
  1267. detail = "Pipeline not found"
  1268. if r is not None:
  1269. try:
  1270. res = r.json()
  1271. if "detail" in res:
  1272. detail = res["detail"]
  1273. except:
  1274. pass
  1275. raise HTTPException(
  1276. status_code=(r.status_code if r is not None else status.HTTP_404_NOT_FOUND),
  1277. detail=detail,
  1278. )
  1279. @app.get("/api/pipelines")
  1280. async def get_pipelines(urlIdx: Optional[int] = None, user=Depends(get_admin_user)):
  1281. r = None
  1282. try:
  1283. urlIdx
  1284. url = openai_app.state.config.OPENAI_API_BASE_URLS[urlIdx]
  1285. key = openai_app.state.config.OPENAI_API_KEYS[urlIdx]
  1286. headers = {"Authorization": f"Bearer {key}"}
  1287. r = requests.get(f"{url}/pipelines", headers=headers)
  1288. r.raise_for_status()
  1289. data = r.json()
  1290. return {**data}
  1291. except Exception as e:
  1292. # Handle connection error here
  1293. print(f"Connection error: {e}")
  1294. detail = "Pipeline not found"
  1295. if r is not None:
  1296. try:
  1297. res = r.json()
  1298. if "detail" in res:
  1299. detail = res["detail"]
  1300. except:
  1301. pass
  1302. raise HTTPException(
  1303. status_code=(r.status_code if r is not None else status.HTTP_404_NOT_FOUND),
  1304. detail=detail,
  1305. )
  1306. @app.get("/api/pipelines/{pipeline_id}/valves")
  1307. async def get_pipeline_valves(
  1308. urlIdx: Optional[int], pipeline_id: str, user=Depends(get_admin_user)
  1309. ):
  1310. models = await get_all_models()
  1311. r = None
  1312. try:
  1313. url = openai_app.state.config.OPENAI_API_BASE_URLS[urlIdx]
  1314. key = openai_app.state.config.OPENAI_API_KEYS[urlIdx]
  1315. headers = {"Authorization": f"Bearer {key}"}
  1316. r = requests.get(f"{url}/{pipeline_id}/valves", headers=headers)
  1317. r.raise_for_status()
  1318. data = r.json()
  1319. return {**data}
  1320. except Exception as e:
  1321. # Handle connection error here
  1322. print(f"Connection error: {e}")
  1323. detail = "Pipeline not found"
  1324. if r is not None:
  1325. try:
  1326. res = r.json()
  1327. if "detail" in res:
  1328. detail = res["detail"]
  1329. except:
  1330. pass
  1331. raise HTTPException(
  1332. status_code=(r.status_code if r is not None else status.HTTP_404_NOT_FOUND),
  1333. detail=detail,
  1334. )
  1335. @app.get("/api/pipelines/{pipeline_id}/valves/spec")
  1336. async def get_pipeline_valves_spec(
  1337. urlIdx: Optional[int], pipeline_id: str, user=Depends(get_admin_user)
  1338. ):
  1339. models = await get_all_models()
  1340. r = None
  1341. try:
  1342. url = openai_app.state.config.OPENAI_API_BASE_URLS[urlIdx]
  1343. key = openai_app.state.config.OPENAI_API_KEYS[urlIdx]
  1344. headers = {"Authorization": f"Bearer {key}"}
  1345. r = requests.get(f"{url}/{pipeline_id}/valves/spec", headers=headers)
  1346. r.raise_for_status()
  1347. data = r.json()
  1348. return {**data}
  1349. except Exception as e:
  1350. # Handle connection error here
  1351. print(f"Connection error: {e}")
  1352. detail = "Pipeline not found"
  1353. if r is not None:
  1354. try:
  1355. res = r.json()
  1356. if "detail" in res:
  1357. detail = res["detail"]
  1358. except:
  1359. pass
  1360. raise HTTPException(
  1361. status_code=(r.status_code if r is not None else status.HTTP_404_NOT_FOUND),
  1362. detail=detail,
  1363. )
  1364. @app.post("/api/pipelines/{pipeline_id}/valves/update")
  1365. async def update_pipeline_valves(
  1366. urlIdx: Optional[int],
  1367. pipeline_id: str,
  1368. form_data: dict,
  1369. user=Depends(get_admin_user),
  1370. ):
  1371. models = await get_all_models()
  1372. r = None
  1373. try:
  1374. url = openai_app.state.config.OPENAI_API_BASE_URLS[urlIdx]
  1375. key = openai_app.state.config.OPENAI_API_KEYS[urlIdx]
  1376. headers = {"Authorization": f"Bearer {key}"}
  1377. r = requests.post(
  1378. f"{url}/{pipeline_id}/valves/update",
  1379. headers=headers,
  1380. json={**form_data},
  1381. )
  1382. r.raise_for_status()
  1383. data = r.json()
  1384. return {**data}
  1385. except Exception as e:
  1386. # Handle connection error here
  1387. print(f"Connection error: {e}")
  1388. detail = "Pipeline not found"
  1389. if r is not None:
  1390. try:
  1391. res = r.json()
  1392. if "detail" in res:
  1393. detail = res["detail"]
  1394. except:
  1395. pass
  1396. raise HTTPException(
  1397. status_code=(r.status_code if r is not None else status.HTTP_404_NOT_FOUND),
  1398. detail=detail,
  1399. )
  1400. ##################################
  1401. #
  1402. # Config Endpoints
  1403. #
  1404. ##################################
  1405. @app.get("/api/config")
  1406. async def get_app_config():
  1407. # Checking and Handling the Absence of 'ui' in CONFIG_DATA
  1408. default_locale = "en-US"
  1409. if "ui" in CONFIG_DATA:
  1410. default_locale = CONFIG_DATA["ui"].get("default_locale", "en-US")
  1411. # The Rest of the Function Now Uses the Variables Defined Above
  1412. return {
  1413. "status": True,
  1414. "name": WEBUI_NAME,
  1415. "version": VERSION,
  1416. "default_locale": default_locale,
  1417. "default_models": webui_app.state.config.DEFAULT_MODELS,
  1418. "default_prompt_suggestions": webui_app.state.config.DEFAULT_PROMPT_SUGGESTIONS,
  1419. "features": {
  1420. "auth": WEBUI_AUTH,
  1421. "auth_trusted_header": bool(webui_app.state.AUTH_TRUSTED_EMAIL_HEADER),
  1422. "enable_signup": webui_app.state.config.ENABLE_SIGNUP,
  1423. "enable_web_search": rag_app.state.config.ENABLE_RAG_WEB_SEARCH,
  1424. "enable_image_generation": images_app.state.config.ENABLED,
  1425. "enable_community_sharing": webui_app.state.config.ENABLE_COMMUNITY_SHARING,
  1426. "enable_admin_export": ENABLE_ADMIN_EXPORT,
  1427. },
  1428. "audio": {
  1429. "tts": {
  1430. "engine": audio_app.state.config.TTS_ENGINE,
  1431. "voice": audio_app.state.config.TTS_VOICE,
  1432. },
  1433. "stt": {
  1434. "engine": audio_app.state.config.STT_ENGINE,
  1435. },
  1436. },
  1437. "oauth": {
  1438. "providers": {
  1439. name: config.get("name", name)
  1440. for name, config in OAUTH_PROVIDERS.items()
  1441. }
  1442. },
  1443. }
  1444. @app.get("/api/config/model/filter")
  1445. async def get_model_filter_config(user=Depends(get_admin_user)):
  1446. return {
  1447. "enabled": app.state.config.ENABLE_MODEL_FILTER,
  1448. "models": app.state.config.MODEL_FILTER_LIST,
  1449. }
  1450. class ModelFilterConfigForm(BaseModel):
  1451. enabled: bool
  1452. models: List[str]
  1453. @app.post("/api/config/model/filter")
  1454. async def update_model_filter_config(
  1455. form_data: ModelFilterConfigForm, user=Depends(get_admin_user)
  1456. ):
  1457. app.state.config.ENABLE_MODEL_FILTER = form_data.enabled
  1458. app.state.config.MODEL_FILTER_LIST = form_data.models
  1459. return {
  1460. "enabled": app.state.config.ENABLE_MODEL_FILTER,
  1461. "models": app.state.config.MODEL_FILTER_LIST,
  1462. }
  1463. # TODO: webhook endpoint should be under config endpoints
  1464. @app.get("/api/webhook")
  1465. async def get_webhook_url(user=Depends(get_admin_user)):
  1466. return {
  1467. "url": app.state.config.WEBHOOK_URL,
  1468. }
  1469. class UrlForm(BaseModel):
  1470. url: str
  1471. @app.post("/api/webhook")
  1472. async def update_webhook_url(form_data: UrlForm, user=Depends(get_admin_user)):
  1473. app.state.config.WEBHOOK_URL = form_data.url
  1474. webui_app.state.WEBHOOK_URL = app.state.config.WEBHOOK_URL
  1475. return {"url": app.state.config.WEBHOOK_URL}
  1476. @app.get("/api/version")
  1477. async def get_app_config():
  1478. return {
  1479. "version": VERSION,
  1480. }
  1481. @app.get("/api/changelog")
  1482. async def get_app_changelog():
  1483. return {key: CHANGELOG[key] for idx, key in enumerate(CHANGELOG) if idx < 5}
  1484. @app.get("/api/version/updates")
  1485. async def get_app_latest_release_version():
  1486. try:
  1487. async with aiohttp.ClientSession(trust_env=True) as session:
  1488. async with session.get(
  1489. "https://api.github.com/repos/open-webui/open-webui/releases/latest"
  1490. ) as response:
  1491. response.raise_for_status()
  1492. data = await response.json()
  1493. latest_version = data["tag_name"]
  1494. return {"current": VERSION, "latest": latest_version[1:]}
  1495. except aiohttp.ClientError as e:
  1496. raise HTTPException(
  1497. status_code=status.HTTP_503_SERVICE_UNAVAILABLE,
  1498. detail=ERROR_MESSAGES.RATE_LIMIT_EXCEEDED,
  1499. )
  1500. ############################
  1501. # OAuth Login & Callback
  1502. ############################
  1503. oauth = OAuth()
  1504. for provider_name, provider_config in OAUTH_PROVIDERS.items():
  1505. oauth.register(
  1506. name=provider_name,
  1507. client_id=provider_config["client_id"],
  1508. client_secret=provider_config["client_secret"],
  1509. server_metadata_url=provider_config["server_metadata_url"],
  1510. client_kwargs={
  1511. "scope": provider_config["scope"],
  1512. },
  1513. )
  1514. # SessionMiddleware is used by authlib for oauth
  1515. if len(OAUTH_PROVIDERS) > 0:
  1516. app.add_middleware(
  1517. SessionMiddleware,
  1518. secret_key=WEBUI_SECRET_KEY,
  1519. session_cookie="oui-session",
  1520. same_site=WEBUI_SESSION_COOKIE_SAME_SITE,
  1521. https_only=WEBUI_SESSION_COOKIE_SECURE,
  1522. )
  1523. @app.get("/oauth/{provider}/login")
  1524. async def oauth_login(provider: str, request: Request):
  1525. if provider not in OAUTH_PROVIDERS:
  1526. raise HTTPException(404)
  1527. redirect_uri = request.url_for("oauth_callback", provider=provider)
  1528. return await oauth.create_client(provider).authorize_redirect(request, redirect_uri)
  1529. @app.get("/oauth/{provider}/callback")
  1530. async def oauth_callback(provider: str, request: Request):
  1531. if provider not in OAUTH_PROVIDERS:
  1532. raise HTTPException(404)
  1533. client = oauth.create_client(provider)
  1534. try:
  1535. token = await client.authorize_access_token(request)
  1536. except Exception as e:
  1537. log.error(f"OAuth callback error: {e}")
  1538. raise HTTPException(400, detail=ERROR_MESSAGES.INVALID_CRED)
  1539. user_data: UserInfo = token["userinfo"]
  1540. sub = user_data.get("sub")
  1541. if not sub:
  1542. raise HTTPException(400, detail=ERROR_MESSAGES.INVALID_CRED)
  1543. provider_sub = f"{provider}@{sub}"
  1544. # Check if the user exists
  1545. user = Users.get_user_by_oauth_sub(provider_sub)
  1546. if not user:
  1547. # If the user does not exist, check if merging is enabled
  1548. if OAUTH_MERGE_ACCOUNTS_BY_EMAIL.value:
  1549. # Check if the user exists by email
  1550. email = user_data.get("email", "").lower()
  1551. if not email:
  1552. raise HTTPException(400, detail=ERROR_MESSAGES.INVALID_CRED)
  1553. user = Users.get_user_by_email(user_data.get("email", "").lower(), True)
  1554. if user:
  1555. # Update the user with the new oauth sub
  1556. Users.update_user_oauth_sub_by_id(user.id, provider_sub)
  1557. if not user:
  1558. # If the user does not exist, check if signups are enabled
  1559. if ENABLE_OAUTH_SIGNUP.value:
  1560. user = Auths.insert_new_auth(
  1561. email=user_data.get("email", "").lower(),
  1562. password=get_password_hash(
  1563. str(uuid.uuid4())
  1564. ), # Random password, not used
  1565. name=user_data.get("name", "User"),
  1566. profile_image_url=user_data.get("picture", "/user.png"),
  1567. role=webui_app.state.config.DEFAULT_USER_ROLE,
  1568. oauth_sub=provider_sub,
  1569. )
  1570. if webui_app.state.config.WEBHOOK_URL:
  1571. post_webhook(
  1572. webui_app.state.config.WEBHOOK_URL,
  1573. WEBHOOK_MESSAGES.USER_SIGNUP(user.name),
  1574. {
  1575. "action": "signup",
  1576. "message": WEBHOOK_MESSAGES.USER_SIGNUP(user.name),
  1577. "user": user.model_dump_json(exclude_none=True),
  1578. },
  1579. )
  1580. else:
  1581. raise HTTPException(400, detail=ERROR_MESSAGES.INVALID_CRED)
  1582. jwt_token = create_token(
  1583. data={"id": user.id},
  1584. expires_delta=parse_duration(webui_app.state.config.JWT_EXPIRES_IN),
  1585. )
  1586. # Redirect back to the frontend with the JWT token
  1587. redirect_url = f"{request.base_url}auth#token={jwt_token}"
  1588. return RedirectResponse(url=redirect_url)
  1589. @app.get("/manifest.json")
  1590. async def get_manifest_json():
  1591. return {
  1592. "name": WEBUI_NAME,
  1593. "short_name": WEBUI_NAME,
  1594. "start_url": "/",
  1595. "display": "standalone",
  1596. "background_color": "#343541",
  1597. "theme_color": "#343541",
  1598. "orientation": "portrait-primary",
  1599. "icons": [{"src": "/static/logo.png", "type": "image/png", "sizes": "500x500"}],
  1600. }
  1601. @app.get("/opensearch.xml")
  1602. async def get_opensearch_xml():
  1603. xml_content = rf"""
  1604. <OpenSearchDescription xmlns="http://a9.com/-/spec/opensearch/1.1/" xmlns:moz="http://www.mozilla.org/2006/browser/search/">
  1605. <ShortName>{WEBUI_NAME}</ShortName>
  1606. <Description>Search {WEBUI_NAME}</Description>
  1607. <InputEncoding>UTF-8</InputEncoding>
  1608. <Image width="16" height="16" type="image/x-icon">{WEBUI_URL}/favicon.png</Image>
  1609. <Url type="text/html" method="get" template="{WEBUI_URL}/?q={"{searchTerms}"}"/>
  1610. <moz:SearchForm>{WEBUI_URL}</moz:SearchForm>
  1611. </OpenSearchDescription>
  1612. """
  1613. return Response(content=xml_content, media_type="application/xml")
  1614. @app.get("/health")
  1615. async def healthcheck():
  1616. return {"status": True}
  1617. app.mount("/static", StaticFiles(directory=STATIC_DIR), name="static")
  1618. app.mount("/cache", StaticFiles(directory=CACHE_DIR), name="cache")
  1619. if os.path.exists(FRONTEND_BUILD_DIR):
  1620. mimetypes.add_type("text/javascript", ".js")
  1621. app.mount(
  1622. "/",
  1623. SPAStaticFiles(directory=FRONTEND_BUILD_DIR, html=True),
  1624. name="spa-static-files",
  1625. )
  1626. else:
  1627. log.warning(
  1628. f"Frontend build directory not found at '{FRONTEND_BUILD_DIR}'. Serving API only."
  1629. )