main.py 32 KB

1234567891011121314151617181920212223242526272829303132333435363738394041424344454647484950515253545556575859606162636465666768697071727374757677787980818283848586878889909192939495969798991001011021031041051061071081091101111121131141151161171181191201211221231241251261271281291301311321331341351361371381391401411421431441451461471481491501511521531541551561571581591601611621631641651661671681691701711721731741751761771781791801811821831841851861871881891901911921931941951961971981992002012022032042052062072082092102112122132142152162172182192202212222232242252262272282292302312322332342352362372382392402412422432442452462472482492502512522532542552562572582592602612622632642652662672682692702712722732742752762772782792802812822832842852862872882892902912922932942952962972982993003013023033043053063073083093103113123133143153163173183193203213223233243253263273283293303313323333343353363373383393403413423433443453463473483493503513523533543553563573583593603613623633643653663673683693703713723733743753763773783793803813823833843853863873883893903913923933943953963973983994004014024034044054064074084094104114124134144154164174184194204214224234244254264274284294304314324334344354364374384394404414424434444454464474484494504514524534544554564574584594604614624634644654664674684694704714724734744754764774784794804814824834844854864874884894904914924934944954964974984995005015025035045055065075085095105115125135145155165175185195205215225235245255265275285295305315325335345355365375385395405415425435445455465475485495505515525535545555565575585595605615625635645655665675685695705715725735745755765775785795805815825835845855865875885895905915925935945955965975985996006016026036046056066076086096106116126136146156166176186196206216226236246256266276286296306316326336346356366376386396406416426436446456466476486496506516526536546556566576586596606616626636646656666676686696706716726736746756766776786796806816826836846856866876886896906916926936946956966976986997007017027037047057067077087097107117127137147157167177187197207217227237247257267277287297307317327337347357367377387397407417427437447457467477487497507517527537547557567577587597607617627637647657667677687697707717727737747757767777787797807817827837847857867877887897907917927937947957967977987998008018028038048058068078088098108118128138148158168178188198208218228238248258268278288298308318328338348358368378388398408418428438448458468478488498508518528538548558568578588598608618628638648658668678688698708718728738748758768778788798808818828838848858868878888898908918928938948958968978988999009019029039049059069079089099109119129139149159169179189199209219229239249259269279289299309319329339349359369379389399409419429439449459469479489499509519529539549559569579589599609619629639649659669679689699709719729739749759769779789799809819829839849859869879889899909919929939949959969979989991000100110021003100410051006100710081009101010111012101310141015101610171018101910201021102210231024102510261027102810291030103110321033103410351036103710381039104010411042104310441045104610471048104910501051105210531054105510561057105810591060106110621063106410651066106710681069107010711072107310741075107610771078107910801081108210831084108510861087108810891090109110921093109410951096109710981099110011011102110311041105
  1. from contextlib import asynccontextmanager
  2. from bs4 import BeautifulSoup
  3. import json
  4. import markdown
  5. import time
  6. import os
  7. import sys
  8. import logging
  9. import aiohttp
  10. import requests
  11. import mimetypes
  12. import shutil
  13. import os
  14. import asyncio
  15. from fastapi import FastAPI, Request, Depends, status, UploadFile, File, Form
  16. from fastapi.staticfiles import StaticFiles
  17. from fastapi.responses import JSONResponse
  18. from fastapi import HTTPException
  19. from fastapi.middleware.wsgi import WSGIMiddleware
  20. from fastapi.middleware.cors import CORSMiddleware
  21. from starlette.exceptions import HTTPException as StarletteHTTPException
  22. from starlette.middleware.base import BaseHTTPMiddleware
  23. from starlette.responses import StreamingResponse, Response
  24. from apps.socket.main import app as socket_app
  25. from apps.ollama.main import (
  26. app as ollama_app,
  27. OpenAIChatCompletionForm,
  28. get_all_models as get_ollama_models,
  29. generate_openai_chat_completion as generate_ollama_chat_completion,
  30. )
  31. from apps.openai.main import (
  32. app as openai_app,
  33. get_all_models as get_openai_models,
  34. generate_chat_completion as generate_openai_chat_completion,
  35. )
  36. from apps.audio.main import app as audio_app
  37. from apps.images.main import app as images_app
  38. from apps.rag.main import app as rag_app
  39. from apps.webui.main import app as webui_app
  40. from pydantic import BaseModel
  41. from typing import List, Optional
  42. from apps.webui.models.models import Models, ModelModel
  43. from utils.utils import (
  44. get_admin_user,
  45. get_verified_user,
  46. get_current_user,
  47. get_http_authorization_cred,
  48. )
  49. from utils.task import title_generation_template
  50. from apps.rag.utils import rag_messages
  51. from config import (
  52. CONFIG_DATA,
  53. WEBUI_NAME,
  54. WEBUI_URL,
  55. WEBUI_AUTH,
  56. ENV,
  57. VERSION,
  58. CHANGELOG,
  59. FRONTEND_BUILD_DIR,
  60. CACHE_DIR,
  61. STATIC_DIR,
  62. ENABLE_OPENAI_API,
  63. ENABLE_OLLAMA_API,
  64. ENABLE_MODEL_FILTER,
  65. MODEL_FILTER_LIST,
  66. GLOBAL_LOG_LEVEL,
  67. SRC_LOG_LEVELS,
  68. WEBHOOK_URL,
  69. ENABLE_ADMIN_EXPORT,
  70. WEBUI_BUILD_HASH,
  71. TITLE_GENERATION_PROMPT_TEMPLATE,
  72. AppConfig,
  73. )
  74. from constants import ERROR_MESSAGES
  75. logging.basicConfig(stream=sys.stdout, level=GLOBAL_LOG_LEVEL)
  76. log = logging.getLogger(__name__)
  77. log.setLevel(SRC_LOG_LEVELS["MAIN"])
  78. class SPAStaticFiles(StaticFiles):
  79. async def get_response(self, path: str, scope):
  80. try:
  81. return await super().get_response(path, scope)
  82. except (HTTPException, StarletteHTTPException) as ex:
  83. if ex.status_code == 404:
  84. return await super().get_response("index.html", scope)
  85. else:
  86. raise ex
  87. print(
  88. rf"""
  89. ___ __ __ _ _ _ ___
  90. / _ \ _ __ ___ _ __ \ \ / /__| |__ | | | |_ _|
  91. | | | | '_ \ / _ \ '_ \ \ \ /\ / / _ \ '_ \| | | || |
  92. | |_| | |_) | __/ | | | \ V V / __/ |_) | |_| || |
  93. \___/| .__/ \___|_| |_| \_/\_/ \___|_.__/ \___/|___|
  94. |_|
  95. v{VERSION} - building the best open-source AI user interface.
  96. {f"Commit: {WEBUI_BUILD_HASH}" if WEBUI_BUILD_HASH != "dev-build" else ""}
  97. https://github.com/open-webui/open-webui
  98. """
  99. )
  100. @asynccontextmanager
  101. async def lifespan(app: FastAPI):
  102. yield
  103. app = FastAPI(
  104. docs_url="/docs" if ENV == "dev" else None, redoc_url=None, lifespan=lifespan
  105. )
  106. app.state.config = AppConfig()
  107. app.state.config.ENABLE_OPENAI_API = ENABLE_OPENAI_API
  108. app.state.config.ENABLE_OLLAMA_API = ENABLE_OLLAMA_API
  109. app.state.config.ENABLE_MODEL_FILTER = ENABLE_MODEL_FILTER
  110. app.state.config.MODEL_FILTER_LIST = MODEL_FILTER_LIST
  111. app.state.config.WEBHOOK_URL = WEBHOOK_URL
  112. app.state.config.TITLE_GENERATION_PROMPT_TEMPLATE = TITLE_GENERATION_PROMPT_TEMPLATE
  113. app.state.MODELS = {}
  114. origins = ["*"]
  115. # Custom middleware to add security headers
  116. # class SecurityHeadersMiddleware(BaseHTTPMiddleware):
  117. # async def dispatch(self, request: Request, call_next):
  118. # response: Response = await call_next(request)
  119. # response.headers["Cross-Origin-Opener-Policy"] = "same-origin"
  120. # response.headers["Cross-Origin-Embedder-Policy"] = "require-corp"
  121. # return response
  122. # app.add_middleware(SecurityHeadersMiddleware)
  123. class RAGMiddleware(BaseHTTPMiddleware):
  124. async def dispatch(self, request: Request, call_next):
  125. return_citations = False
  126. if request.method == "POST" and (
  127. "/ollama/api/chat" in request.url.path
  128. or "/chat/completions" in request.url.path
  129. ):
  130. log.debug(f"request.url.path: {request.url.path}")
  131. # Read the original request body
  132. body = await request.body()
  133. # Decode body to string
  134. body_str = body.decode("utf-8")
  135. # Parse string to JSON
  136. data = json.loads(body_str) if body_str else {}
  137. return_citations = data.get("citations", False)
  138. if "citations" in data:
  139. del data["citations"]
  140. # Example: Add a new key-value pair or modify existing ones
  141. # data["modified"] = True # Example modification
  142. if "docs" in data:
  143. data = {**data}
  144. data["messages"], citations = rag_messages(
  145. docs=data["docs"],
  146. messages=data["messages"],
  147. template=rag_app.state.config.RAG_TEMPLATE,
  148. embedding_function=rag_app.state.EMBEDDING_FUNCTION,
  149. k=rag_app.state.config.TOP_K,
  150. reranking_function=rag_app.state.sentence_transformer_rf,
  151. r=rag_app.state.config.RELEVANCE_THRESHOLD,
  152. hybrid_search=rag_app.state.config.ENABLE_RAG_HYBRID_SEARCH,
  153. )
  154. del data["docs"]
  155. log.debug(
  156. f"data['messages']: {data['messages']}, citations: {citations}"
  157. )
  158. modified_body_bytes = json.dumps(data).encode("utf-8")
  159. # Replace the request body with the modified one
  160. request._body = modified_body_bytes
  161. # Set custom header to ensure content-length matches new body length
  162. request.headers.__dict__["_list"] = [
  163. (b"content-length", str(len(modified_body_bytes)).encode("utf-8")),
  164. *[
  165. (k, v)
  166. for k, v in request.headers.raw
  167. if k.lower() != b"content-length"
  168. ],
  169. ]
  170. response = await call_next(request)
  171. if return_citations:
  172. # Inject the citations into the response
  173. if isinstance(response, StreamingResponse):
  174. # If it's a streaming response, inject it as SSE event or NDJSON line
  175. content_type = response.headers.get("Content-Type")
  176. if "text/event-stream" in content_type:
  177. return StreamingResponse(
  178. self.openai_stream_wrapper(response.body_iterator, citations),
  179. )
  180. if "application/x-ndjson" in content_type:
  181. return StreamingResponse(
  182. self.ollama_stream_wrapper(response.body_iterator, citations),
  183. )
  184. return response
  185. async def _receive(self, body: bytes):
  186. return {"type": "http.request", "body": body, "more_body": False}
  187. async def openai_stream_wrapper(self, original_generator, citations):
  188. yield f"data: {json.dumps({'citations': citations})}\n\n"
  189. async for data in original_generator:
  190. yield data
  191. async def ollama_stream_wrapper(self, original_generator, citations):
  192. yield f"{json.dumps({'citations': citations})}\n"
  193. async for data in original_generator:
  194. yield data
  195. app.add_middleware(RAGMiddleware)
  196. def filter_pipeline(payload, user):
  197. user = {"id": user.id, "name": user.name, "role": user.role}
  198. model_id = payload["model"]
  199. filters = [
  200. model
  201. for model in app.state.MODELS.values()
  202. if "pipeline" in model
  203. and "type" in model["pipeline"]
  204. and model["pipeline"]["type"] == "filter"
  205. and (
  206. model["pipeline"]["pipelines"] == ["*"]
  207. or any(
  208. model_id == target_model_id
  209. for target_model_id in model["pipeline"]["pipelines"]
  210. )
  211. )
  212. ]
  213. sorted_filters = sorted(filters, key=lambda x: x["pipeline"]["priority"])
  214. model = app.state.MODELS[model_id]
  215. if "pipeline" in model:
  216. sorted_filters.append(model)
  217. for filter in sorted_filters:
  218. r = None
  219. try:
  220. urlIdx = filter["urlIdx"]
  221. url = openai_app.state.config.OPENAI_API_BASE_URLS[urlIdx]
  222. key = openai_app.state.config.OPENAI_API_KEYS[urlIdx]
  223. if key != "":
  224. headers = {"Authorization": f"Bearer {key}"}
  225. r = requests.post(
  226. f"{url}/{filter['id']}/filter/inlet",
  227. headers=headers,
  228. json={
  229. "user": user,
  230. "body": payload,
  231. },
  232. )
  233. r.raise_for_status()
  234. payload = r.json()
  235. except Exception as e:
  236. # Handle connection error here
  237. print(f"Connection error: {e}")
  238. if r is not None:
  239. try:
  240. res = r.json()
  241. if "detail" in res:
  242. return JSONResponse(
  243. status_code=r.status_code,
  244. content=res,
  245. )
  246. except:
  247. pass
  248. else:
  249. pass
  250. if "pipeline" not in app.state.MODELS[model_id]:
  251. if "chat_id" in payload:
  252. del payload["chat_id"]
  253. if "title" in payload:
  254. del payload["title"]
  255. return payload
  256. class PipelineMiddleware(BaseHTTPMiddleware):
  257. async def dispatch(self, request: Request, call_next):
  258. if request.method == "POST" and (
  259. "/ollama/api/chat" in request.url.path
  260. or "/chat/completions" in request.url.path
  261. ):
  262. log.debug(f"request.url.path: {request.url.path}")
  263. # Read the original request body
  264. body = await request.body()
  265. # Decode body to string
  266. body_str = body.decode("utf-8")
  267. # Parse string to JSON
  268. data = json.loads(body_str) if body_str else {}
  269. user = get_current_user(
  270. get_http_authorization_cred(request.headers.get("Authorization"))
  271. )
  272. data = filter_pipeline(data, user)
  273. modified_body_bytes = json.dumps(data).encode("utf-8")
  274. # Replace the request body with the modified one
  275. request._body = modified_body_bytes
  276. # Set custom header to ensure content-length matches new body length
  277. request.headers.__dict__["_list"] = [
  278. (b"content-length", str(len(modified_body_bytes)).encode("utf-8")),
  279. *[
  280. (k, v)
  281. for k, v in request.headers.raw
  282. if k.lower() != b"content-length"
  283. ],
  284. ]
  285. response = await call_next(request)
  286. return response
  287. async def _receive(self, body: bytes):
  288. return {"type": "http.request", "body": body, "more_body": False}
  289. app.add_middleware(PipelineMiddleware)
  290. app.add_middleware(
  291. CORSMiddleware,
  292. allow_origins=origins,
  293. allow_credentials=True,
  294. allow_methods=["*"],
  295. allow_headers=["*"],
  296. )
  297. @app.middleware("http")
  298. async def check_url(request: Request, call_next):
  299. if len(app.state.MODELS) == 0:
  300. await get_all_models()
  301. else:
  302. pass
  303. start_time = int(time.time())
  304. response = await call_next(request)
  305. process_time = int(time.time()) - start_time
  306. response.headers["X-Process-Time"] = str(process_time)
  307. return response
  308. @app.middleware("http")
  309. async def update_embedding_function(request: Request, call_next):
  310. response = await call_next(request)
  311. if "/embedding/update" in request.url.path:
  312. webui_app.state.EMBEDDING_FUNCTION = rag_app.state.EMBEDDING_FUNCTION
  313. return response
  314. app.mount("/ws", socket_app)
  315. app.mount("/ollama", ollama_app)
  316. app.mount("/openai", openai_app)
  317. app.mount("/images/api/v1", images_app)
  318. app.mount("/audio/api/v1", audio_app)
  319. app.mount("/rag/api/v1", rag_app)
  320. app.mount("/api/v1", webui_app)
  321. webui_app.state.EMBEDDING_FUNCTION = rag_app.state.EMBEDDING_FUNCTION
  322. async def get_all_models():
  323. openai_models = []
  324. ollama_models = []
  325. if app.state.config.ENABLE_OPENAI_API:
  326. openai_models = await get_openai_models()
  327. openai_models = openai_models["data"]
  328. if app.state.config.ENABLE_OLLAMA_API:
  329. ollama_models = await get_ollama_models()
  330. ollama_models = [
  331. {
  332. "id": model["model"],
  333. "name": model["name"],
  334. "object": "model",
  335. "created": int(time.time()),
  336. "owned_by": "ollama",
  337. "ollama": model,
  338. }
  339. for model in ollama_models["models"]
  340. ]
  341. models = openai_models + ollama_models
  342. custom_models = Models.get_all_models()
  343. for custom_model in custom_models:
  344. if custom_model.base_model_id == None:
  345. for model in models:
  346. if (
  347. custom_model.id == model["id"]
  348. or custom_model.id == model["id"].split(":")[0]
  349. ):
  350. model["name"] = custom_model.name
  351. model["info"] = custom_model.model_dump()
  352. else:
  353. owned_by = "openai"
  354. for model in models:
  355. if (
  356. custom_model.base_model_id == model["id"]
  357. or custom_model.base_model_id == model["id"].split(":")[0]
  358. ):
  359. owned_by = model["owned_by"]
  360. break
  361. models.append(
  362. {
  363. "id": custom_model.id,
  364. "name": custom_model.name,
  365. "object": "model",
  366. "created": custom_model.created_at,
  367. "owned_by": owned_by,
  368. "info": custom_model.model_dump(),
  369. "preset": True,
  370. }
  371. )
  372. app.state.MODELS = {model["id"]: model for model in models}
  373. webui_app.state.MODELS = app.state.MODELS
  374. return models
  375. @app.get("/api/models")
  376. async def get_models(user=Depends(get_verified_user)):
  377. models = await get_all_models()
  378. # Filter out filter pipelines
  379. models = [
  380. model
  381. for model in models
  382. if "pipeline" not in model or model["pipeline"].get("type", None) != "filter"
  383. ]
  384. if app.state.config.ENABLE_MODEL_FILTER:
  385. if user.role == "user":
  386. models = list(
  387. filter(
  388. lambda model: model["id"] in app.state.config.MODEL_FILTER_LIST,
  389. models,
  390. )
  391. )
  392. return {"data": models}
  393. return {"data": models}
  394. @app.post("/api/title/completions")
  395. async def generate_title(form_data: dict, user=Depends(get_verified_user)):
  396. print("generate_title")
  397. model_id = form_data["model"]
  398. if model_id not in app.state.MODELS:
  399. raise HTTPException(
  400. status_code=status.HTTP_404_NOT_FOUND,
  401. detail="Model not found",
  402. )
  403. model = app.state.MODELS[model_id]
  404. template = app.state.config.TITLE_GENERATION_PROMPT_TEMPLATE
  405. content = title_generation_template(
  406. template, form_data["prompt"], user.model_dump()
  407. )
  408. payload = {
  409. "model": model_id,
  410. "messages": [{"role": "user", "content": content}],
  411. "stream": False,
  412. "max_tokens": 50,
  413. "chat_id": form_data.get("chat_id", None),
  414. "title": True,
  415. }
  416. print(payload)
  417. payload = filter_pipeline(payload, user)
  418. if model["owned_by"] == "ollama":
  419. return await generate_ollama_chat_completion(
  420. OpenAIChatCompletionForm(**payload), user=user
  421. )
  422. else:
  423. return await generate_openai_chat_completion(payload, user=user)
  424. @app.post("/api/chat/completions")
  425. async def generate_chat_completions(form_data: dict, user=Depends(get_verified_user)):
  426. model_id = form_data["model"]
  427. if model_id not in app.state.MODELS:
  428. raise HTTPException(
  429. status_code=status.HTTP_404_NOT_FOUND,
  430. detail="Model not found",
  431. )
  432. model = app.state.MODELS[model_id]
  433. print(model)
  434. if model["owned_by"] == "ollama":
  435. return await generate_ollama_chat_completion(
  436. OpenAIChatCompletionForm(**form_data), user=user
  437. )
  438. else:
  439. return await generate_openai_chat_completion(form_data, user=user)
  440. @app.post("/api/chat/completed")
  441. async def chat_completed(form_data: dict, user=Depends(get_verified_user)):
  442. data = form_data
  443. model_id = data["model"]
  444. filters = [
  445. model
  446. for model in app.state.MODELS.values()
  447. if "pipeline" in model
  448. and "type" in model["pipeline"]
  449. and model["pipeline"]["type"] == "filter"
  450. and (
  451. model["pipeline"]["pipelines"] == ["*"]
  452. or any(
  453. model_id == target_model_id
  454. for target_model_id in model["pipeline"]["pipelines"]
  455. )
  456. )
  457. ]
  458. sorted_filters = sorted(filters, key=lambda x: x["pipeline"]["priority"])
  459. print(model_id)
  460. if model_id in app.state.MODELS:
  461. model = app.state.MODELS[model_id]
  462. if "pipeline" in model:
  463. sorted_filters = [model] + sorted_filters
  464. for filter in sorted_filters:
  465. r = None
  466. try:
  467. urlIdx = filter["urlIdx"]
  468. url = openai_app.state.config.OPENAI_API_BASE_URLS[urlIdx]
  469. key = openai_app.state.config.OPENAI_API_KEYS[urlIdx]
  470. if key != "":
  471. headers = {"Authorization": f"Bearer {key}"}
  472. r = requests.post(
  473. f"{url}/{filter['id']}/filter/outlet",
  474. headers=headers,
  475. json={
  476. "user": {"id": user.id, "name": user.name, "role": user.role},
  477. "body": data,
  478. },
  479. )
  480. r.raise_for_status()
  481. data = r.json()
  482. except Exception as e:
  483. # Handle connection error here
  484. print(f"Connection error: {e}")
  485. if r is not None:
  486. try:
  487. res = r.json()
  488. if "detail" in res:
  489. return JSONResponse(
  490. status_code=r.status_code,
  491. content=res,
  492. )
  493. except:
  494. pass
  495. else:
  496. pass
  497. return data
  498. @app.get("/api/pipelines/list")
  499. async def get_pipelines_list(user=Depends(get_admin_user)):
  500. responses = await get_openai_models(raw=True)
  501. print(responses)
  502. urlIdxs = [
  503. idx
  504. for idx, response in enumerate(responses)
  505. if response != None and "pipelines" in response
  506. ]
  507. return {
  508. "data": [
  509. {
  510. "url": openai_app.state.config.OPENAI_API_BASE_URLS[urlIdx],
  511. "idx": urlIdx,
  512. }
  513. for urlIdx in urlIdxs
  514. ]
  515. }
  516. @app.post("/api/pipelines/upload")
  517. async def upload_pipeline(
  518. urlIdx: int = Form(...), file: UploadFile = File(...), user=Depends(get_admin_user)
  519. ):
  520. print("upload_pipeline", urlIdx, file.filename)
  521. # Check if the uploaded file is a python file
  522. if not file.filename.endswith(".py"):
  523. raise HTTPException(
  524. status_code=status.HTTP_400_BAD_REQUEST,
  525. detail="Only Python (.py) files are allowed.",
  526. )
  527. upload_folder = f"{CACHE_DIR}/pipelines"
  528. os.makedirs(upload_folder, exist_ok=True)
  529. file_path = os.path.join(upload_folder, file.filename)
  530. try:
  531. # Save the uploaded file
  532. with open(file_path, "wb") as buffer:
  533. shutil.copyfileobj(file.file, buffer)
  534. url = openai_app.state.config.OPENAI_API_BASE_URLS[urlIdx]
  535. key = openai_app.state.config.OPENAI_API_KEYS[urlIdx]
  536. headers = {"Authorization": f"Bearer {key}"}
  537. with open(file_path, "rb") as f:
  538. files = {"file": f}
  539. r = requests.post(f"{url}/pipelines/upload", headers=headers, files=files)
  540. r.raise_for_status()
  541. data = r.json()
  542. return {**data}
  543. except Exception as e:
  544. # Handle connection error here
  545. print(f"Connection error: {e}")
  546. detail = "Pipeline not found"
  547. if r is not None:
  548. try:
  549. res = r.json()
  550. if "detail" in res:
  551. detail = res["detail"]
  552. except:
  553. pass
  554. raise HTTPException(
  555. status_code=(r.status_code if r is not None else status.HTTP_404_NOT_FOUND),
  556. detail=detail,
  557. )
  558. finally:
  559. # Ensure the file is deleted after the upload is completed or on failure
  560. if os.path.exists(file_path):
  561. os.remove(file_path)
  562. class AddPipelineForm(BaseModel):
  563. url: str
  564. urlIdx: int
  565. @app.post("/api/pipelines/add")
  566. async def add_pipeline(form_data: AddPipelineForm, user=Depends(get_admin_user)):
  567. r = None
  568. try:
  569. urlIdx = form_data.urlIdx
  570. url = openai_app.state.config.OPENAI_API_BASE_URLS[urlIdx]
  571. key = openai_app.state.config.OPENAI_API_KEYS[urlIdx]
  572. headers = {"Authorization": f"Bearer {key}"}
  573. r = requests.post(
  574. f"{url}/pipelines/add", headers=headers, json={"url": form_data.url}
  575. )
  576. r.raise_for_status()
  577. data = r.json()
  578. return {**data}
  579. except Exception as e:
  580. # Handle connection error here
  581. print(f"Connection error: {e}")
  582. detail = "Pipeline not found"
  583. if r is not None:
  584. try:
  585. res = r.json()
  586. if "detail" in res:
  587. detail = res["detail"]
  588. except:
  589. pass
  590. raise HTTPException(
  591. status_code=(r.status_code if r is not None else status.HTTP_404_NOT_FOUND),
  592. detail=detail,
  593. )
  594. class DeletePipelineForm(BaseModel):
  595. id: str
  596. urlIdx: int
  597. @app.delete("/api/pipelines/delete")
  598. async def delete_pipeline(form_data: DeletePipelineForm, user=Depends(get_admin_user)):
  599. r = None
  600. try:
  601. urlIdx = form_data.urlIdx
  602. url = openai_app.state.config.OPENAI_API_BASE_URLS[urlIdx]
  603. key = openai_app.state.config.OPENAI_API_KEYS[urlIdx]
  604. headers = {"Authorization": f"Bearer {key}"}
  605. r = requests.delete(
  606. f"{url}/pipelines/delete", headers=headers, json={"id": form_data.id}
  607. )
  608. r.raise_for_status()
  609. data = r.json()
  610. return {**data}
  611. except Exception as e:
  612. # Handle connection error here
  613. print(f"Connection error: {e}")
  614. detail = "Pipeline not found"
  615. if r is not None:
  616. try:
  617. res = r.json()
  618. if "detail" in res:
  619. detail = res["detail"]
  620. except:
  621. pass
  622. raise HTTPException(
  623. status_code=(r.status_code if r is not None else status.HTTP_404_NOT_FOUND),
  624. detail=detail,
  625. )
  626. @app.get("/api/pipelines")
  627. async def get_pipelines(urlIdx: Optional[int] = None, user=Depends(get_admin_user)):
  628. r = None
  629. try:
  630. urlIdx
  631. url = openai_app.state.config.OPENAI_API_BASE_URLS[urlIdx]
  632. key = openai_app.state.config.OPENAI_API_KEYS[urlIdx]
  633. headers = {"Authorization": f"Bearer {key}"}
  634. r = requests.get(f"{url}/pipelines", headers=headers)
  635. r.raise_for_status()
  636. data = r.json()
  637. return {**data}
  638. except Exception as e:
  639. # Handle connection error here
  640. print(f"Connection error: {e}")
  641. detail = "Pipeline not found"
  642. if r is not None:
  643. try:
  644. res = r.json()
  645. if "detail" in res:
  646. detail = res["detail"]
  647. except:
  648. pass
  649. raise HTTPException(
  650. status_code=(r.status_code if r is not None else status.HTTP_404_NOT_FOUND),
  651. detail=detail,
  652. )
  653. @app.get("/api/pipelines/{pipeline_id}/valves")
  654. async def get_pipeline_valves(
  655. urlIdx: Optional[int], pipeline_id: str, user=Depends(get_admin_user)
  656. ):
  657. models = await get_all_models()
  658. r = None
  659. try:
  660. url = openai_app.state.config.OPENAI_API_BASE_URLS[urlIdx]
  661. key = openai_app.state.config.OPENAI_API_KEYS[urlIdx]
  662. headers = {"Authorization": f"Bearer {key}"}
  663. r = requests.get(f"{url}/{pipeline_id}/valves", headers=headers)
  664. r.raise_for_status()
  665. data = r.json()
  666. return {**data}
  667. except Exception as e:
  668. # Handle connection error here
  669. print(f"Connection error: {e}")
  670. detail = "Pipeline not found"
  671. if r is not None:
  672. try:
  673. res = r.json()
  674. if "detail" in res:
  675. detail = res["detail"]
  676. except:
  677. pass
  678. raise HTTPException(
  679. status_code=(r.status_code if r is not None else status.HTTP_404_NOT_FOUND),
  680. detail=detail,
  681. )
  682. @app.get("/api/pipelines/{pipeline_id}/valves/spec")
  683. async def get_pipeline_valves_spec(
  684. urlIdx: Optional[int], pipeline_id: str, user=Depends(get_admin_user)
  685. ):
  686. models = await get_all_models()
  687. r = None
  688. try:
  689. url = openai_app.state.config.OPENAI_API_BASE_URLS[urlIdx]
  690. key = openai_app.state.config.OPENAI_API_KEYS[urlIdx]
  691. headers = {"Authorization": f"Bearer {key}"}
  692. r = requests.get(f"{url}/{pipeline_id}/valves/spec", headers=headers)
  693. r.raise_for_status()
  694. data = r.json()
  695. return {**data}
  696. except Exception as e:
  697. # Handle connection error here
  698. print(f"Connection error: {e}")
  699. detail = "Pipeline not found"
  700. if r is not None:
  701. try:
  702. res = r.json()
  703. if "detail" in res:
  704. detail = res["detail"]
  705. except:
  706. pass
  707. raise HTTPException(
  708. status_code=(r.status_code if r is not None else status.HTTP_404_NOT_FOUND),
  709. detail=detail,
  710. )
  711. @app.post("/api/pipelines/{pipeline_id}/valves/update")
  712. async def update_pipeline_valves(
  713. urlIdx: Optional[int],
  714. pipeline_id: str,
  715. form_data: dict,
  716. user=Depends(get_admin_user),
  717. ):
  718. models = await get_all_models()
  719. r = None
  720. try:
  721. url = openai_app.state.config.OPENAI_API_BASE_URLS[urlIdx]
  722. key = openai_app.state.config.OPENAI_API_KEYS[urlIdx]
  723. headers = {"Authorization": f"Bearer {key}"}
  724. r = requests.post(
  725. f"{url}/{pipeline_id}/valves/update",
  726. headers=headers,
  727. json={**form_data},
  728. )
  729. r.raise_for_status()
  730. data = r.json()
  731. return {**data}
  732. except Exception as e:
  733. # Handle connection error here
  734. print(f"Connection error: {e}")
  735. detail = "Pipeline not found"
  736. if r is not None:
  737. try:
  738. res = r.json()
  739. if "detail" in res:
  740. detail = res["detail"]
  741. except:
  742. pass
  743. raise HTTPException(
  744. status_code=(r.status_code if r is not None else status.HTTP_404_NOT_FOUND),
  745. detail=detail,
  746. )
  747. @app.get("/api/config")
  748. async def get_app_config():
  749. # Checking and Handling the Absence of 'ui' in CONFIG_DATA
  750. default_locale = "en-US"
  751. if "ui" in CONFIG_DATA:
  752. default_locale = CONFIG_DATA["ui"].get("default_locale", "en-US")
  753. # The Rest of the Function Now Uses the Variables Defined Above
  754. return {
  755. "status": True,
  756. "name": WEBUI_NAME,
  757. "version": VERSION,
  758. "default_locale": default_locale,
  759. "default_models": webui_app.state.config.DEFAULT_MODELS,
  760. "default_prompt_suggestions": webui_app.state.config.DEFAULT_PROMPT_SUGGESTIONS,
  761. "features": {
  762. "auth": WEBUI_AUTH,
  763. "auth_trusted_header": bool(webui_app.state.AUTH_TRUSTED_EMAIL_HEADER),
  764. "enable_signup": webui_app.state.config.ENABLE_SIGNUP,
  765. "enable_web_search": rag_app.state.config.ENABLE_RAG_WEB_SEARCH,
  766. "enable_image_generation": images_app.state.config.ENABLED,
  767. "enable_community_sharing": webui_app.state.config.ENABLE_COMMUNITY_SHARING,
  768. "enable_admin_export": ENABLE_ADMIN_EXPORT,
  769. },
  770. "audio": {
  771. "tts": {
  772. "engine": audio_app.state.config.TTS_ENGINE,
  773. "voice": audio_app.state.config.TTS_VOICE,
  774. },
  775. "stt": {
  776. "engine": audio_app.state.config.STT_ENGINE,
  777. },
  778. },
  779. }
  780. @app.get("/api/config/model/filter")
  781. async def get_model_filter_config(user=Depends(get_admin_user)):
  782. return {
  783. "enabled": app.state.config.ENABLE_MODEL_FILTER,
  784. "models": app.state.config.MODEL_FILTER_LIST,
  785. }
  786. class ModelFilterConfigForm(BaseModel):
  787. enabled: bool
  788. models: List[str]
  789. @app.post("/api/config/model/filter")
  790. async def update_model_filter_config(
  791. form_data: ModelFilterConfigForm, user=Depends(get_admin_user)
  792. ):
  793. app.state.config.ENABLE_MODEL_FILTER = form_data.enabled
  794. app.state.config.MODEL_FILTER_LIST = form_data.models
  795. return {
  796. "enabled": app.state.config.ENABLE_MODEL_FILTER,
  797. "models": app.state.config.MODEL_FILTER_LIST,
  798. }
  799. @app.get("/api/webhook")
  800. async def get_webhook_url(user=Depends(get_admin_user)):
  801. return {
  802. "url": app.state.config.WEBHOOK_URL,
  803. }
  804. class UrlForm(BaseModel):
  805. url: str
  806. @app.post("/api/webhook")
  807. async def update_webhook_url(form_data: UrlForm, user=Depends(get_admin_user)):
  808. app.state.config.WEBHOOK_URL = form_data.url
  809. webui_app.state.WEBHOOK_URL = app.state.config.WEBHOOK_URL
  810. return {"url": app.state.config.WEBHOOK_URL}
  811. @app.get("/api/version")
  812. async def get_app_config():
  813. return {
  814. "version": VERSION,
  815. }
  816. @app.get("/api/changelog")
  817. async def get_app_changelog():
  818. return {key: CHANGELOG[key] for idx, key in enumerate(CHANGELOG) if idx < 5}
  819. @app.get("/api/version/updates")
  820. async def get_app_latest_release_version():
  821. try:
  822. async with aiohttp.ClientSession(trust_env=True) as session:
  823. async with session.get(
  824. "https://api.github.com/repos/open-webui/open-webui/releases/latest"
  825. ) as response:
  826. response.raise_for_status()
  827. data = await response.json()
  828. latest_version = data["tag_name"]
  829. return {"current": VERSION, "latest": latest_version[1:]}
  830. except aiohttp.ClientError as e:
  831. raise HTTPException(
  832. status_code=status.HTTP_503_SERVICE_UNAVAILABLE,
  833. detail=ERROR_MESSAGES.RATE_LIMIT_EXCEEDED,
  834. )
  835. @app.get("/manifest.json")
  836. async def get_manifest_json():
  837. return {
  838. "name": WEBUI_NAME,
  839. "short_name": WEBUI_NAME,
  840. "start_url": "/",
  841. "display": "standalone",
  842. "background_color": "#343541",
  843. "theme_color": "#343541",
  844. "orientation": "portrait-primary",
  845. "icons": [{"src": "/static/logo.png", "type": "image/png", "sizes": "500x500"}],
  846. }
  847. @app.get("/opensearch.xml")
  848. async def get_opensearch_xml():
  849. xml_content = rf"""
  850. <OpenSearchDescription xmlns="http://a9.com/-/spec/opensearch/1.1/" xmlns:moz="http://www.mozilla.org/2006/browser/search/">
  851. <ShortName>{WEBUI_NAME}</ShortName>
  852. <Description>Search {WEBUI_NAME}</Description>
  853. <InputEncoding>UTF-8</InputEncoding>
  854. <Image width="16" height="16" type="image/x-icon">{WEBUI_URL}/favicon.png</Image>
  855. <Url type="text/html" method="get" template="{WEBUI_URL}/?q={"{searchTerms}"}"/>
  856. <moz:SearchForm>{WEBUI_URL}</moz:SearchForm>
  857. </OpenSearchDescription>
  858. """
  859. return Response(content=xml_content, media_type="application/xml")
  860. @app.get("/health")
  861. async def healthcheck():
  862. return {"status": True}
  863. app.mount("/static", StaticFiles(directory=STATIC_DIR), name="static")
  864. app.mount("/cache", StaticFiles(directory=CACHE_DIR), name="cache")
  865. if os.path.exists(FRONTEND_BUILD_DIR):
  866. mimetypes.add_type("text/javascript", ".js")
  867. app.mount(
  868. "/",
  869. SPAStaticFiles(directory=FRONTEND_BUILD_DIR, html=True),
  870. name="spa-static-files",
  871. )
  872. else:
  873. log.warning(
  874. f"Frontend build directory not found at '{FRONTEND_BUILD_DIR}'. Serving API only."
  875. )