main.py 33 KB

1234567891011121314151617181920212223242526272829303132333435363738394041424344454647484950515253545556575859606162636465666768697071727374757677787980818283848586878889909192939495969798991001011021031041051061071081091101111121131141151161171181191201211221231241251261271281291301311321331341351361371381391401411421431441451461471481491501511521531541551561571581591601611621631641651661671681691701711721731741751761771781791801811821831841851861871881891901911921931941951961971981992002012022032042052062072082092102112122132142152162172182192202212222232242252262272282292302312322332342352362372382392402412422432442452462472482492502512522532542552562572582592602612622632642652662672682692702712722732742752762772782792802812822832842852862872882892902912922932942952962972982993003013023033043053063073083093103113123133143153163173183193203213223233243253263273283293303313323333343353363373383393403413423433443453463473483493503513523533543553563573583593603613623633643653663673683693703713723733743753763773783793803813823833843853863873883893903913923933943953963973983994004014024034044054064074084094104114124134144154164174184194204214224234244254264274284294304314324334344354364374384394404414424434444454464474484494504514524534544554564574584594604614624634644654664674684694704714724734744754764774784794804814824834844854864874884894904914924934944954964974984995005015025035045055065075085095105115125135145155165175185195205215225235245255265275285295305315325335345355365375385395405415425435445455465475485495505515525535545555565575585595605615625635645655665675685695705715725735745755765775785795805815825835845855865875885895905915925935945955965975985996006016026036046056066076086096106116126136146156166176186196206216226236246256266276286296306316326336346356366376386396406416426436446456466476486496506516526536546556566576586596606616626636646656666676686696706716726736746756766776786796806816826836846856866876886896906916926936946956966976986997007017027037047057067077087097107117127137147157167177187197207217227237247257267277287297307317327337347357367377387397407417427437447457467477487497507517527537547557567577587597607617627637647657667677687697707717727737747757767777787797807817827837847857867877887897907917927937947957967977987998008018028038048058068078088098108118128138148158168178188198208218228238248258268278288298308318328338348358368378388398408418428438448458468478488498508518528538548558568578588598608618628638648658668678688698708718728738748758768778788798808818828838848858868878888898908918928938948958968978988999009019029039049059069079089099109119129139149159169179189199209219229239249259269279289299309319329339349359369379389399409419429439449459469479489499509519529539549559569579589599609619629639649659669679689699709719729739749759769779789799809819829839849859869879889899909919929939949959969979989991000100110021003100410051006100710081009101010111012101310141015101610171018101910201021102210231024102510261027102810291030103110321033103410351036103710381039104010411042104310441045104610471048104910501051105210531054105510561057105810591060106110621063106410651066106710681069107010711072107310741075107610771078107910801081108210831084108510861087108810891090109110921093109410951096
  1. import uuid
  2. from contextlib import asynccontextmanager
  3. from authlib.integrations.starlette_client import OAuth
  4. from authlib.oidc.core import UserInfo
  5. from bs4 import BeautifulSoup
  6. import json
  7. import markdown
  8. import time
  9. import os
  10. import sys
  11. import logging
  12. import aiohttp
  13. import requests
  14. import mimetypes
  15. from fastapi import FastAPI, Request, Depends, status
  16. from fastapi.staticfiles import StaticFiles
  17. from fastapi.responses import JSONResponse
  18. from fastapi import HTTPException
  19. from fastapi.middleware.wsgi import WSGIMiddleware
  20. from fastapi.middleware.cors import CORSMiddleware
  21. from starlette.exceptions import HTTPException as StarletteHTTPException
  22. from starlette.middleware.base import BaseHTTPMiddleware
  23. from starlette.middleware.sessions import SessionMiddleware
  24. from starlette.responses import StreamingResponse, Response, RedirectResponse
  25. from apps.socket.main import app as socket_app
  26. from apps.ollama.main import app as ollama_app, get_all_models as get_ollama_models
  27. from apps.openai.main import app as openai_app, get_all_models as get_openai_models
  28. from apps.audio.main import app as audio_app
  29. from apps.images.main import app as images_app
  30. from apps.rag.main import app as rag_app
  31. from apps.webui.main import app as webui_app
  32. import asyncio
  33. from pydantic import BaseModel
  34. from typing import List, Optional
  35. from apps.webui.models.auths import Auths
  36. from apps.webui.models.models import Models
  37. from apps.webui.models.users import Users
  38. from utils.misc import parse_duration
  39. from utils.utils import (
  40. get_admin_user,
  41. get_verified_user,
  42. get_current_user,
  43. get_http_authorization_cred,
  44. get_password_hash,
  45. create_token,
  46. )
  47. from apps.rag.utils import rag_messages
  48. from config import (
  49. CONFIG_DATA,
  50. WEBUI_NAME,
  51. WEBUI_URL,
  52. WEBUI_AUTH,
  53. ENV,
  54. VERSION,
  55. CHANGELOG,
  56. FRONTEND_BUILD_DIR,
  57. CACHE_DIR,
  58. STATIC_DIR,
  59. ENABLE_OPENAI_API,
  60. ENABLE_OLLAMA_API,
  61. ENABLE_MODEL_FILTER,
  62. MODEL_FILTER_LIST,
  63. GLOBAL_LOG_LEVEL,
  64. SRC_LOG_LEVELS,
  65. WEBHOOK_URL,
  66. ENABLE_ADMIN_EXPORT,
  67. AppConfig,
  68. WEBUI_BUILD_HASH,
  69. OAUTH_PROVIDERS,
  70. ENABLE_OAUTH_SIGNUP,
  71. OAUTH_MERGE_ACCOUNTS_BY_EMAIL,
  72. WEBUI_SECRET_KEY,
  73. WEBUI_SESSION_COOKIE_SAME_SITE,
  74. WEBUI_SESSION_COOKIE_SECURE,
  75. )
  76. from constants import ERROR_MESSAGES, WEBHOOK_MESSAGES
  77. from utils.webhook import post_webhook
  78. logging.basicConfig(stream=sys.stdout, level=GLOBAL_LOG_LEVEL)
  79. log = logging.getLogger(__name__)
  80. log.setLevel(SRC_LOG_LEVELS["MAIN"])
  81. class SPAStaticFiles(StaticFiles):
  82. async def get_response(self, path: str, scope):
  83. try:
  84. return await super().get_response(path, scope)
  85. except (HTTPException, StarletteHTTPException) as ex:
  86. if ex.status_code == 404:
  87. return await super().get_response("index.html", scope)
  88. else:
  89. raise ex
  90. print(
  91. rf"""
  92. ___ __ __ _ _ _ ___
  93. / _ \ _ __ ___ _ __ \ \ / /__| |__ | | | |_ _|
  94. | | | | '_ \ / _ \ '_ \ \ \ /\ / / _ \ '_ \| | | || |
  95. | |_| | |_) | __/ | | | \ V V / __/ |_) | |_| || |
  96. \___/| .__/ \___|_| |_| \_/\_/ \___|_.__/ \___/|___|
  97. |_|
  98. v{VERSION} - building the best open-source AI user interface.
  99. {f"Commit: {WEBUI_BUILD_HASH}" if WEBUI_BUILD_HASH != "dev-build" else ""}
  100. https://github.com/open-webui/open-webui
  101. """
  102. )
  103. @asynccontextmanager
  104. async def lifespan(app: FastAPI):
  105. yield
  106. app = FastAPI(
  107. docs_url="/docs" if ENV == "dev" else None, redoc_url=None, lifespan=lifespan
  108. )
  109. app.state.config = AppConfig()
  110. app.state.config.ENABLE_OPENAI_API = ENABLE_OPENAI_API
  111. app.state.config.ENABLE_OLLAMA_API = ENABLE_OLLAMA_API
  112. app.state.config.ENABLE_MODEL_FILTER = ENABLE_MODEL_FILTER
  113. app.state.config.MODEL_FILTER_LIST = MODEL_FILTER_LIST
  114. app.state.config.WEBHOOK_URL = WEBHOOK_URL
  115. app.state.MODELS = {}
  116. origins = ["*"]
  117. # Custom middleware to add security headers
  118. # class SecurityHeadersMiddleware(BaseHTTPMiddleware):
  119. # async def dispatch(self, request: Request, call_next):
  120. # response: Response = await call_next(request)
  121. # response.headers["Cross-Origin-Opener-Policy"] = "same-origin"
  122. # response.headers["Cross-Origin-Embedder-Policy"] = "require-corp"
  123. # return response
  124. # app.add_middleware(SecurityHeadersMiddleware)
  125. class RAGMiddleware(BaseHTTPMiddleware):
  126. async def dispatch(self, request: Request, call_next):
  127. return_citations = False
  128. if request.method == "POST" and (
  129. "/ollama/api/chat" in request.url.path
  130. or "/chat/completions" in request.url.path
  131. ):
  132. log.debug(f"request.url.path: {request.url.path}")
  133. # Read the original request body
  134. body = await request.body()
  135. # Decode body to string
  136. body_str = body.decode("utf-8")
  137. # Parse string to JSON
  138. data = json.loads(body_str) if body_str else {}
  139. return_citations = data.get("citations", False)
  140. if "citations" in data:
  141. del data["citations"]
  142. # Example: Add a new key-value pair or modify existing ones
  143. # data["modified"] = True # Example modification
  144. if "docs" in data:
  145. data = {**data}
  146. data["messages"], citations = rag_messages(
  147. docs=data["docs"],
  148. messages=data["messages"],
  149. template=rag_app.state.config.RAG_TEMPLATE,
  150. embedding_function=rag_app.state.EMBEDDING_FUNCTION,
  151. k=rag_app.state.config.TOP_K,
  152. reranking_function=rag_app.state.sentence_transformer_rf,
  153. r=rag_app.state.config.RELEVANCE_THRESHOLD,
  154. hybrid_search=rag_app.state.config.ENABLE_RAG_HYBRID_SEARCH,
  155. )
  156. del data["docs"]
  157. log.debug(
  158. f"data['messages']: {data['messages']}, citations: {citations}"
  159. )
  160. modified_body_bytes = json.dumps(data).encode("utf-8")
  161. # Replace the request body with the modified one
  162. request._body = modified_body_bytes
  163. # Set custom header to ensure content-length matches new body length
  164. request.headers.__dict__["_list"] = [
  165. (b"content-length", str(len(modified_body_bytes)).encode("utf-8")),
  166. *[
  167. (k, v)
  168. for k, v in request.headers.raw
  169. if k.lower() != b"content-length"
  170. ],
  171. ]
  172. response = await call_next(request)
  173. if return_citations:
  174. # Inject the citations into the response
  175. if isinstance(response, StreamingResponse):
  176. # If it's a streaming response, inject it as SSE event or NDJSON line
  177. content_type = response.headers.get("Content-Type")
  178. if "text/event-stream" in content_type:
  179. return StreamingResponse(
  180. self.openai_stream_wrapper(response.body_iterator, citations),
  181. )
  182. if "application/x-ndjson" in content_type:
  183. return StreamingResponse(
  184. self.ollama_stream_wrapper(response.body_iterator, citations),
  185. )
  186. return response
  187. async def _receive(self, body: bytes):
  188. return {"type": "http.request", "body": body, "more_body": False}
  189. async def openai_stream_wrapper(self, original_generator, citations):
  190. yield f"data: {json.dumps({'citations': citations})}\n\n"
  191. async for data in original_generator:
  192. yield data
  193. async def ollama_stream_wrapper(self, original_generator, citations):
  194. yield f"{json.dumps({'citations': citations})}\n"
  195. async for data in original_generator:
  196. yield data
  197. app.add_middleware(RAGMiddleware)
  198. class PipelineMiddleware(BaseHTTPMiddleware):
  199. async def dispatch(self, request: Request, call_next):
  200. if request.method == "POST" and (
  201. "/ollama/api/chat" in request.url.path
  202. or "/chat/completions" in request.url.path
  203. ):
  204. log.debug(f"request.url.path: {request.url.path}")
  205. # Read the original request body
  206. body = await request.body()
  207. # Decode body to string
  208. body_str = body.decode("utf-8")
  209. # Parse string to JSON
  210. data = json.loads(body_str) if body_str else {}
  211. model_id = data["model"]
  212. filters = [
  213. model
  214. for model in app.state.MODELS.values()
  215. if "pipeline" in model
  216. and "type" in model["pipeline"]
  217. and model["pipeline"]["type"] == "filter"
  218. and (
  219. model["pipeline"]["pipelines"] == ["*"]
  220. or any(
  221. model_id == target_model_id
  222. for target_model_id in model["pipeline"]["pipelines"]
  223. )
  224. )
  225. ]
  226. sorted_filters = sorted(filters, key=lambda x: x["pipeline"]["priority"])
  227. user = None
  228. if len(sorted_filters) > 0:
  229. try:
  230. user = get_current_user(
  231. get_http_authorization_cred(
  232. request.headers.get("Authorization")
  233. )
  234. )
  235. user = {"id": user.id, "name": user.name, "role": user.role}
  236. except:
  237. pass
  238. model = app.state.MODELS[model_id]
  239. if "pipeline" in model:
  240. sorted_filters.append(model)
  241. for filter in sorted_filters:
  242. r = None
  243. try:
  244. urlIdx = filter["urlIdx"]
  245. url = openai_app.state.config.OPENAI_API_BASE_URLS[urlIdx]
  246. key = openai_app.state.config.OPENAI_API_KEYS[urlIdx]
  247. if key != "":
  248. headers = {"Authorization": f"Bearer {key}"}
  249. r = requests.post(
  250. f"{url}/{filter['id']}/filter/inlet",
  251. headers=headers,
  252. json={
  253. "user": user,
  254. "body": data,
  255. },
  256. )
  257. r.raise_for_status()
  258. data = r.json()
  259. except Exception as e:
  260. # Handle connection error here
  261. print(f"Connection error: {e}")
  262. if r is not None:
  263. try:
  264. res = r.json()
  265. if "detail" in res:
  266. return JSONResponse(
  267. status_code=r.status_code,
  268. content=res,
  269. )
  270. except:
  271. pass
  272. else:
  273. pass
  274. if "pipeline" not in app.state.MODELS[model_id]:
  275. if "chat_id" in data:
  276. del data["chat_id"]
  277. if "title" in data:
  278. del data["title"]
  279. modified_body_bytes = json.dumps(data).encode("utf-8")
  280. # Replace the request body with the modified one
  281. request._body = modified_body_bytes
  282. # Set custom header to ensure content-length matches new body length
  283. request.headers.__dict__["_list"] = [
  284. (b"content-length", str(len(modified_body_bytes)).encode("utf-8")),
  285. *[
  286. (k, v)
  287. for k, v in request.headers.raw
  288. if k.lower() != b"content-length"
  289. ],
  290. ]
  291. response = await call_next(request)
  292. return response
  293. async def _receive(self, body: bytes):
  294. return {"type": "http.request", "body": body, "more_body": False}
  295. app.add_middleware(PipelineMiddleware)
  296. app.add_middleware(
  297. CORSMiddleware,
  298. allow_origins=origins,
  299. allow_credentials=True,
  300. allow_methods=["*"],
  301. allow_headers=["*"],
  302. )
  303. @app.middleware("http")
  304. async def check_url(request: Request, call_next):
  305. if len(app.state.MODELS) == 0:
  306. await get_all_models()
  307. else:
  308. pass
  309. start_time = int(time.time())
  310. response = await call_next(request)
  311. process_time = int(time.time()) - start_time
  312. response.headers["X-Process-Time"] = str(process_time)
  313. return response
  314. @app.middleware("http")
  315. async def update_embedding_function(request: Request, call_next):
  316. response = await call_next(request)
  317. if "/embedding/update" in request.url.path:
  318. webui_app.state.EMBEDDING_FUNCTION = rag_app.state.EMBEDDING_FUNCTION
  319. return response
  320. app.mount("/ws", socket_app)
  321. app.mount("/ollama", ollama_app)
  322. app.mount("/openai", openai_app)
  323. app.mount("/images/api/v1", images_app)
  324. app.mount("/audio/api/v1", audio_app)
  325. app.mount("/rag/api/v1", rag_app)
  326. app.mount("/api/v1", webui_app)
  327. webui_app.state.EMBEDDING_FUNCTION = rag_app.state.EMBEDDING_FUNCTION
  328. async def get_all_models():
  329. openai_models = []
  330. ollama_models = []
  331. if app.state.config.ENABLE_OPENAI_API:
  332. openai_models = await get_openai_models()
  333. openai_models = openai_models["data"]
  334. if app.state.config.ENABLE_OLLAMA_API:
  335. ollama_models = await get_ollama_models()
  336. ollama_models = [
  337. {
  338. "id": model["model"],
  339. "name": model["name"],
  340. "object": "model",
  341. "created": int(time.time()),
  342. "owned_by": "ollama",
  343. "ollama": model,
  344. }
  345. for model in ollama_models["models"]
  346. ]
  347. models = openai_models + ollama_models
  348. custom_models = Models.get_all_models()
  349. for custom_model in custom_models:
  350. if custom_model.base_model_id == None:
  351. for model in models:
  352. if (
  353. custom_model.id == model["id"]
  354. or custom_model.id == model["id"].split(":")[0]
  355. ):
  356. model["name"] = custom_model.name
  357. model["info"] = custom_model.model_dump()
  358. else:
  359. owned_by = "openai"
  360. for model in models:
  361. if (
  362. custom_model.base_model_id == model["id"]
  363. or custom_model.base_model_id == model["id"].split(":")[0]
  364. ):
  365. owned_by = model["owned_by"]
  366. break
  367. models.append(
  368. {
  369. "id": custom_model.id,
  370. "name": custom_model.name,
  371. "object": "model",
  372. "created": custom_model.created_at,
  373. "owned_by": owned_by,
  374. "info": custom_model.model_dump(),
  375. "preset": True,
  376. }
  377. )
  378. app.state.MODELS = {model["id"]: model for model in models}
  379. webui_app.state.MODELS = app.state.MODELS
  380. return models
  381. @app.get("/api/models")
  382. async def get_models(user=Depends(get_verified_user)):
  383. models = await get_all_models()
  384. # Filter out filter pipelines
  385. models = [
  386. model
  387. for model in models
  388. if "pipeline" not in model or model["pipeline"].get("type", None) != "filter"
  389. ]
  390. if app.state.config.ENABLE_MODEL_FILTER:
  391. if user.role == "user":
  392. models = list(
  393. filter(
  394. lambda model: model["id"] in app.state.config.MODEL_FILTER_LIST,
  395. models,
  396. )
  397. )
  398. return {"data": models}
  399. return {"data": models}
  400. @app.post("/api/chat/completed")
  401. async def chat_completed(form_data: dict, user=Depends(get_verified_user)):
  402. data = form_data
  403. model_id = data["model"]
  404. filters = [
  405. model
  406. for model in app.state.MODELS.values()
  407. if "pipeline" in model
  408. and "type" in model["pipeline"]
  409. and model["pipeline"]["type"] == "filter"
  410. and (
  411. model["pipeline"]["pipelines"] == ["*"]
  412. or any(
  413. model_id == target_model_id
  414. for target_model_id in model["pipeline"]["pipelines"]
  415. )
  416. )
  417. ]
  418. sorted_filters = sorted(filters, key=lambda x: x["pipeline"]["priority"])
  419. print(model_id)
  420. if model_id in app.state.MODELS:
  421. model = app.state.MODELS[model_id]
  422. if "pipeline" in model:
  423. sorted_filters = [model] + sorted_filters
  424. for filter in sorted_filters:
  425. r = None
  426. try:
  427. urlIdx = filter["urlIdx"]
  428. url = openai_app.state.config.OPENAI_API_BASE_URLS[urlIdx]
  429. key = openai_app.state.config.OPENAI_API_KEYS[urlIdx]
  430. if key != "":
  431. headers = {"Authorization": f"Bearer {key}"}
  432. r = requests.post(
  433. f"{url}/{filter['id']}/filter/outlet",
  434. headers=headers,
  435. json={
  436. "user": {"id": user.id, "name": user.name, "role": user.role},
  437. "body": data,
  438. },
  439. )
  440. r.raise_for_status()
  441. data = r.json()
  442. except Exception as e:
  443. # Handle connection error here
  444. print(f"Connection error: {e}")
  445. if r is not None:
  446. try:
  447. res = r.json()
  448. if "detail" in res:
  449. return JSONResponse(
  450. status_code=r.status_code,
  451. content=res,
  452. )
  453. except:
  454. pass
  455. else:
  456. pass
  457. return data
  458. @app.get("/api/pipelines/list")
  459. async def get_pipelines_list(user=Depends(get_admin_user)):
  460. responses = await get_openai_models(raw=True)
  461. print(responses)
  462. urlIdxs = [
  463. idx
  464. for idx, response in enumerate(responses)
  465. if response != None and "pipelines" in response
  466. ]
  467. return {
  468. "data": [
  469. {
  470. "url": openai_app.state.config.OPENAI_API_BASE_URLS[urlIdx],
  471. "idx": urlIdx,
  472. }
  473. for urlIdx in urlIdxs
  474. ]
  475. }
  476. class AddPipelineForm(BaseModel):
  477. url: str
  478. urlIdx: int
  479. @app.post("/api/pipelines/add")
  480. async def add_pipeline(form_data: AddPipelineForm, user=Depends(get_admin_user)):
  481. r = None
  482. try:
  483. urlIdx = form_data.urlIdx
  484. url = openai_app.state.config.OPENAI_API_BASE_URLS[urlIdx]
  485. key = openai_app.state.config.OPENAI_API_KEYS[urlIdx]
  486. headers = {"Authorization": f"Bearer {key}"}
  487. r = requests.post(
  488. f"{url}/pipelines/add", headers=headers, json={"url": form_data.url}
  489. )
  490. r.raise_for_status()
  491. data = r.json()
  492. return {**data}
  493. except Exception as e:
  494. # Handle connection error here
  495. print(f"Connection error: {e}")
  496. detail = "Pipeline not found"
  497. if r is not None:
  498. try:
  499. res = r.json()
  500. if "detail" in res:
  501. detail = res["detail"]
  502. except:
  503. pass
  504. raise HTTPException(
  505. status_code=(r.status_code if r is not None else status.HTTP_404_NOT_FOUND),
  506. detail=detail,
  507. )
  508. class DeletePipelineForm(BaseModel):
  509. id: str
  510. urlIdx: int
  511. @app.delete("/api/pipelines/delete")
  512. async def delete_pipeline(form_data: DeletePipelineForm, user=Depends(get_admin_user)):
  513. r = None
  514. try:
  515. urlIdx = form_data.urlIdx
  516. url = openai_app.state.config.OPENAI_API_BASE_URLS[urlIdx]
  517. key = openai_app.state.config.OPENAI_API_KEYS[urlIdx]
  518. headers = {"Authorization": f"Bearer {key}"}
  519. r = requests.delete(
  520. f"{url}/pipelines/delete", headers=headers, json={"id": form_data.id}
  521. )
  522. r.raise_for_status()
  523. data = r.json()
  524. return {**data}
  525. except Exception as e:
  526. # Handle connection error here
  527. print(f"Connection error: {e}")
  528. detail = "Pipeline not found"
  529. if r is not None:
  530. try:
  531. res = r.json()
  532. if "detail" in res:
  533. detail = res["detail"]
  534. except:
  535. pass
  536. raise HTTPException(
  537. status_code=(r.status_code if r is not None else status.HTTP_404_NOT_FOUND),
  538. detail=detail,
  539. )
  540. @app.get("/api/pipelines")
  541. async def get_pipelines(urlIdx: Optional[int] = None, user=Depends(get_admin_user)):
  542. r = None
  543. try:
  544. urlIdx
  545. url = openai_app.state.config.OPENAI_API_BASE_URLS[urlIdx]
  546. key = openai_app.state.config.OPENAI_API_KEYS[urlIdx]
  547. headers = {"Authorization": f"Bearer {key}"}
  548. r = requests.get(f"{url}/pipelines", headers=headers)
  549. r.raise_for_status()
  550. data = r.json()
  551. return {**data}
  552. except Exception as e:
  553. # Handle connection error here
  554. print(f"Connection error: {e}")
  555. detail = "Pipeline not found"
  556. if r is not None:
  557. try:
  558. res = r.json()
  559. if "detail" in res:
  560. detail = res["detail"]
  561. except:
  562. pass
  563. raise HTTPException(
  564. status_code=(r.status_code if r is not None else status.HTTP_404_NOT_FOUND),
  565. detail=detail,
  566. )
  567. @app.get("/api/pipelines/{pipeline_id}/valves")
  568. async def get_pipeline_valves(
  569. urlIdx: Optional[int], pipeline_id: str, user=Depends(get_admin_user)
  570. ):
  571. models = await get_all_models()
  572. r = None
  573. try:
  574. url = openai_app.state.config.OPENAI_API_BASE_URLS[urlIdx]
  575. key = openai_app.state.config.OPENAI_API_KEYS[urlIdx]
  576. headers = {"Authorization": f"Bearer {key}"}
  577. r = requests.get(f"{url}/{pipeline_id}/valves", headers=headers)
  578. r.raise_for_status()
  579. data = r.json()
  580. return {**data}
  581. except Exception as e:
  582. # Handle connection error here
  583. print(f"Connection error: {e}")
  584. detail = "Pipeline not found"
  585. if r is not None:
  586. try:
  587. res = r.json()
  588. if "detail" in res:
  589. detail = res["detail"]
  590. except:
  591. pass
  592. raise HTTPException(
  593. status_code=(r.status_code if r is not None else status.HTTP_404_NOT_FOUND),
  594. detail=detail,
  595. )
  596. @app.get("/api/pipelines/{pipeline_id}/valves/spec")
  597. async def get_pipeline_valves_spec(
  598. urlIdx: Optional[int], pipeline_id: str, user=Depends(get_admin_user)
  599. ):
  600. models = await get_all_models()
  601. r = None
  602. try:
  603. url = openai_app.state.config.OPENAI_API_BASE_URLS[urlIdx]
  604. key = openai_app.state.config.OPENAI_API_KEYS[urlIdx]
  605. headers = {"Authorization": f"Bearer {key}"}
  606. r = requests.get(f"{url}/{pipeline_id}/valves/spec", headers=headers)
  607. r.raise_for_status()
  608. data = r.json()
  609. return {**data}
  610. except Exception as e:
  611. # Handle connection error here
  612. print(f"Connection error: {e}")
  613. detail = "Pipeline not found"
  614. if r is not None:
  615. try:
  616. res = r.json()
  617. if "detail" in res:
  618. detail = res["detail"]
  619. except:
  620. pass
  621. raise HTTPException(
  622. status_code=(r.status_code if r is not None else status.HTTP_404_NOT_FOUND),
  623. detail=detail,
  624. )
  625. @app.post("/api/pipelines/{pipeline_id}/valves/update")
  626. async def update_pipeline_valves(
  627. urlIdx: Optional[int],
  628. pipeline_id: str,
  629. form_data: dict,
  630. user=Depends(get_admin_user),
  631. ):
  632. models = await get_all_models()
  633. r = None
  634. try:
  635. url = openai_app.state.config.OPENAI_API_BASE_URLS[urlIdx]
  636. key = openai_app.state.config.OPENAI_API_KEYS[urlIdx]
  637. headers = {"Authorization": f"Bearer {key}"}
  638. r = requests.post(
  639. f"{url}/{pipeline_id}/valves/update",
  640. headers=headers,
  641. json={**form_data},
  642. )
  643. r.raise_for_status()
  644. data = r.json()
  645. return {**data}
  646. except Exception as e:
  647. # Handle connection error here
  648. print(f"Connection error: {e}")
  649. detail = "Pipeline not found"
  650. if r is not None:
  651. try:
  652. res = r.json()
  653. if "detail" in res:
  654. detail = res["detail"]
  655. except:
  656. pass
  657. raise HTTPException(
  658. status_code=(r.status_code if r is not None else status.HTTP_404_NOT_FOUND),
  659. detail=detail,
  660. )
  661. @app.get("/api/config")
  662. async def get_app_config():
  663. # Checking and Handling the Absence of 'ui' in CONFIG_DATA
  664. default_locale = "en-US"
  665. if "ui" in CONFIG_DATA:
  666. default_locale = CONFIG_DATA["ui"].get("default_locale", "en-US")
  667. # The Rest of the Function Now Uses the Variables Defined Above
  668. return {
  669. "status": True,
  670. "name": WEBUI_NAME,
  671. "version": VERSION,
  672. "default_locale": default_locale,
  673. "default_models": webui_app.state.config.DEFAULT_MODELS,
  674. "default_prompt_suggestions": webui_app.state.config.DEFAULT_PROMPT_SUGGESTIONS,
  675. "features": {
  676. "auth": WEBUI_AUTH,
  677. "auth_trusted_header": bool(webui_app.state.AUTH_TRUSTED_EMAIL_HEADER),
  678. "enable_signup": webui_app.state.config.ENABLE_SIGNUP,
  679. "enable_web_search": rag_app.state.config.ENABLE_RAG_WEB_SEARCH,
  680. "enable_image_generation": images_app.state.config.ENABLED,
  681. "enable_community_sharing": webui_app.state.config.ENABLE_COMMUNITY_SHARING,
  682. "enable_admin_export": ENABLE_ADMIN_EXPORT,
  683. },
  684. "oauth": {
  685. "providers": {
  686. name: config.get("name", name)
  687. for name, config in OAUTH_PROVIDERS.items()
  688. }
  689. },
  690. }
  691. @app.get("/api/config/model/filter")
  692. async def get_model_filter_config(user=Depends(get_admin_user)):
  693. return {
  694. "enabled": app.state.config.ENABLE_MODEL_FILTER,
  695. "models": app.state.config.MODEL_FILTER_LIST,
  696. }
  697. class ModelFilterConfigForm(BaseModel):
  698. enabled: bool
  699. models: List[str]
  700. @app.post("/api/config/model/filter")
  701. async def update_model_filter_config(
  702. form_data: ModelFilterConfigForm, user=Depends(get_admin_user)
  703. ):
  704. app.state.config.ENABLE_MODEL_FILTER = form_data.enabled
  705. app.state.config.MODEL_FILTER_LIST = form_data.models
  706. return {
  707. "enabled": app.state.config.ENABLE_MODEL_FILTER,
  708. "models": app.state.config.MODEL_FILTER_LIST,
  709. }
  710. @app.get("/api/webhook")
  711. async def get_webhook_url(user=Depends(get_admin_user)):
  712. return {
  713. "url": app.state.config.WEBHOOK_URL,
  714. }
  715. class UrlForm(BaseModel):
  716. url: str
  717. @app.post("/api/webhook")
  718. async def update_webhook_url(form_data: UrlForm, user=Depends(get_admin_user)):
  719. app.state.config.WEBHOOK_URL = form_data.url
  720. webui_app.state.WEBHOOK_URL = app.state.config.WEBHOOK_URL
  721. return {"url": app.state.config.WEBHOOK_URL}
  722. @app.get("/api/version")
  723. async def get_app_config():
  724. return {
  725. "version": VERSION,
  726. }
  727. @app.get("/api/changelog")
  728. async def get_app_changelog():
  729. return {key: CHANGELOG[key] for idx, key in enumerate(CHANGELOG) if idx < 5}
  730. @app.get("/api/version/updates")
  731. async def get_app_latest_release_version():
  732. try:
  733. async with aiohttp.ClientSession() as session:
  734. async with session.get(
  735. "https://api.github.com/repos/open-webui/open-webui/releases/latest"
  736. ) as response:
  737. response.raise_for_status()
  738. data = await response.json()
  739. latest_version = data["tag_name"]
  740. return {"current": VERSION, "latest": latest_version[1:]}
  741. except aiohttp.ClientError as e:
  742. raise HTTPException(
  743. status_code=status.HTTP_503_SERVICE_UNAVAILABLE,
  744. detail=ERROR_MESSAGES.RATE_LIMIT_EXCEEDED,
  745. )
  746. ############################
  747. # OAuth Login & Callback
  748. ############################
  749. oauth = OAuth()
  750. for provider_name, provider_config in OAUTH_PROVIDERS.items():
  751. oauth.register(
  752. name=provider_name,
  753. client_id=provider_config["client_id"],
  754. client_secret=provider_config["client_secret"],
  755. server_metadata_url=provider_config["server_metadata_url"],
  756. client_kwargs={
  757. "scope": provider_config["scope"],
  758. },
  759. )
  760. # SessionMiddleware is used by authlib for oauth
  761. if len(OAUTH_PROVIDERS) > 0:
  762. app.add_middleware(
  763. SessionMiddleware,
  764. secret_key=WEBUI_SECRET_KEY,
  765. session_cookie="oui-session",
  766. same_site=WEBUI_SESSION_COOKIE_SAME_SITE,
  767. https_only=WEBUI_SESSION_COOKIE_SECURE,
  768. )
  769. @app.get("/oauth/{provider}/login")
  770. async def oauth_login(provider: str, request: Request):
  771. if provider not in OAUTH_PROVIDERS:
  772. raise HTTPException(404)
  773. redirect_uri = request.url_for("oauth_callback", provider=provider)
  774. return await oauth.create_client(provider).authorize_redirect(request, redirect_uri)
  775. @app.get("/oauth/{provider}/callback")
  776. async def oauth_callback(provider: str, request: Request):
  777. if provider not in OAUTH_PROVIDERS:
  778. raise HTTPException(404)
  779. client = oauth.create_client(provider)
  780. try:
  781. token = await client.authorize_access_token(request)
  782. except Exception as e:
  783. log.error(f"OAuth callback error: {e}")
  784. raise HTTPException(400, detail=ERROR_MESSAGES.INVALID_CRED)
  785. user_data: UserInfo = token["userinfo"]
  786. sub = user_data.get("sub")
  787. if not sub:
  788. raise HTTPException(400, detail=ERROR_MESSAGES.INVALID_CRED)
  789. provider_sub = f"{provider}@{sub}"
  790. # Check if the user exists
  791. user = Users.get_user_by_oauth_sub(provider_sub)
  792. if not user:
  793. # If the user does not exist, check if merging is enabled
  794. if OAUTH_MERGE_ACCOUNTS_BY_EMAIL.value:
  795. # Check if the user exists by email
  796. email = user_data.get("email", "").lower()
  797. if not email:
  798. raise HTTPException(400, detail=ERROR_MESSAGES.INVALID_CRED)
  799. user = Users.get_user_by_email(user_data.get("email", "").lower(), True)
  800. if user:
  801. # Update the user with the new oauth sub
  802. Users.update_user_oauth_sub_by_id(user.id, provider_sub)
  803. if not user:
  804. # If the user does not exist, check if signups are enabled
  805. if ENABLE_OAUTH_SIGNUP.value:
  806. user = Auths.insert_new_auth(
  807. email=user_data.get("email", "").lower(),
  808. password=get_password_hash(
  809. str(uuid.uuid4())
  810. ), # Random password, not used
  811. name=user_data.get("name", "User"),
  812. profile_image_url=user_data.get("picture", "/user.png"),
  813. role=webui_app.state.config.DEFAULT_USER_ROLE,
  814. oauth_sub=provider_sub,
  815. )
  816. if webui_app.state.config.WEBHOOK_URL:
  817. post_webhook(
  818. webui_app.state.config.WEBHOOK_URL,
  819. WEBHOOK_MESSAGES.USER_SIGNUP(user.name),
  820. {
  821. "action": "signup",
  822. "message": WEBHOOK_MESSAGES.USER_SIGNUP(user.name),
  823. "user": user.model_dump_json(exclude_none=True),
  824. },
  825. )
  826. else:
  827. raise HTTPException(400, detail=ERROR_MESSAGES.INVALID_CRED)
  828. jwt_token = create_token(
  829. data={"id": user.id},
  830. expires_delta=parse_duration(webui_app.state.config.JWT_EXPIRES_IN),
  831. )
  832. # Redirect back to the frontend with the JWT token
  833. redirect_url = f"{request.base_url}auth#token={jwt_token}"
  834. return RedirectResponse(url=redirect_url)
  835. @app.get("/manifest.json")
  836. async def get_manifest_json():
  837. return {
  838. "name": WEBUI_NAME,
  839. "short_name": WEBUI_NAME,
  840. "start_url": "/",
  841. "display": "standalone",
  842. "background_color": "#343541",
  843. "theme_color": "#343541",
  844. "orientation": "portrait-primary",
  845. "icons": [{"src": "/static/logo.png", "type": "image/png", "sizes": "500x500"}],
  846. }
  847. @app.get("/opensearch.xml")
  848. async def get_opensearch_xml():
  849. xml_content = rf"""
  850. <OpenSearchDescription xmlns="http://a9.com/-/spec/opensearch/1.1/" xmlns:moz="http://www.mozilla.org/2006/browser/search/">
  851. <ShortName>{WEBUI_NAME}</ShortName>
  852. <Description>Search {WEBUI_NAME}</Description>
  853. <InputEncoding>UTF-8</InputEncoding>
  854. <Image width="16" height="16" type="image/x-icon">{WEBUI_URL}/favicon.png</Image>
  855. <Url type="text/html" method="get" template="{WEBUI_URL}/?q={"{searchTerms}"}"/>
  856. <moz:SearchForm>{WEBUI_URL}</moz:SearchForm>
  857. </OpenSearchDescription>
  858. """
  859. return Response(content=xml_content, media_type="application/xml")
  860. @app.get("/health")
  861. async def healthcheck():
  862. return {"status": True}
  863. app.mount("/static", StaticFiles(directory=STATIC_DIR), name="static")
  864. app.mount("/cache", StaticFiles(directory=CACHE_DIR), name="cache")
  865. if os.path.exists(FRONTEND_BUILD_DIR):
  866. mimetypes.add_type("text/javascript", ".js")
  867. app.mount(
  868. "/",
  869. SPAStaticFiles(directory=FRONTEND_BUILD_DIR, html=True),
  870. name="spa-static-files",
  871. )
  872. else:
  873. log.warning(
  874. f"Frontend build directory not found at '{FRONTEND_BUILD_DIR}'. Serving API only."
  875. )