main.py 36 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597598599600601602603604605606607608609610611612613614615616617618619620621622623624625626627628629630631632633634635636637638639640641642643644645646647648649650651652653654655656657658659660661662663664665666667668669670671672673674675676677678679680681682683684685686687688689690691692693694695696697698699700701702703704705706707708709710711712713714715716717718719720721722723724725726727728729730731732733734735736737738739740741742743744745746747748749750751752753754755756757758759760761762763764765766767768769770771772773774775776777778779780781782783784785786787788789790791792793794795796797798799800801802803804805806807808809810811812813814815816817818819820821822823824825826827828829830831832833834835836837838839840841842843844845846847848849850851852853854855856857858859860861862863864865866867868869870871872873874875876877878879880881882883884885886887888889890891892893894895896897898899900901902903904905906907908909910911912913914915916917918919920921922923924925926927928929930931932933934935936937938939940941942943944945946947948949950951952953954955956957958959960961962963964965966967968969970971972973974975976977978979980981982983984985986987988989990991992993994995996997998999100010011002100310041005100610071008100910101011101210131014101510161017101810191020102110221023102410251026102710281029103010311032103310341035103610371038103910401041104210431044104510461047104810491050105110521053105410551056105710581059106010611062106310641065106610671068106910701071107210731074107510761077107810791080108110821083108410851086108710881089109010911092109310941095109610971098109911001101110211031104110511061107110811091110111111121113111411151116111711181119112011211122112311241125112611271128112911301131113211331134113511361137113811391140114111421143114411451146114711481149115011511152115311541155115611571158115911601161116211631164116511661167116811691170117111721173117411751176117711781179118011811182118311841185118611871188118911901191119211931194119511961197119811991200120112021203120412051206120712081209121012111212121312141215
  1. from contextlib import asynccontextmanager
  2. from bs4 import BeautifulSoup
  3. import json
  4. import markdown
  5. import time
  6. import os
  7. import sys
  8. import logging
  9. import aiohttp
  10. import requests
  11. import mimetypes
  12. import shutil
  13. import os
  14. import asyncio
  15. from fastapi import FastAPI, Request, Depends, status, UploadFile, File, Form
  16. from fastapi.staticfiles import StaticFiles
  17. from fastapi.responses import JSONResponse
  18. from fastapi import HTTPException
  19. from fastapi.middleware.wsgi import WSGIMiddleware
  20. from fastapi.middleware.cors import CORSMiddleware
  21. from starlette.exceptions import HTTPException as StarletteHTTPException
  22. from starlette.middleware.base import BaseHTTPMiddleware
  23. from starlette.responses import StreamingResponse, Response
  24. from apps.socket.main import app as socket_app
  25. from apps.ollama.main import (
  26. app as ollama_app,
  27. OpenAIChatCompletionForm,
  28. get_all_models as get_ollama_models,
  29. generate_openai_chat_completion as generate_ollama_chat_completion,
  30. )
  31. from apps.openai.main import (
  32. app as openai_app,
  33. get_all_models as get_openai_models,
  34. generate_chat_completion as generate_openai_chat_completion,
  35. )
  36. from apps.audio.main import app as audio_app
  37. from apps.images.main import app as images_app
  38. from apps.rag.main import app as rag_app
  39. from apps.webui.main import app as webui_app
  40. from pydantic import BaseModel
  41. from typing import List, Optional
  42. from apps.webui.models.models import Models, ModelModel
  43. from utils.utils import (
  44. get_admin_user,
  45. get_verified_user,
  46. get_current_user,
  47. get_http_authorization_cred,
  48. )
  49. from utils.task import title_generation_template, search_query_generation_template
  50. from apps.rag.utils import rag_messages
  51. from config import (
  52. CONFIG_DATA,
  53. WEBUI_NAME,
  54. WEBUI_URL,
  55. WEBUI_AUTH,
  56. ENV,
  57. VERSION,
  58. CHANGELOG,
  59. FRONTEND_BUILD_DIR,
  60. CACHE_DIR,
  61. STATIC_DIR,
  62. ENABLE_OPENAI_API,
  63. ENABLE_OLLAMA_API,
  64. ENABLE_MODEL_FILTER,
  65. MODEL_FILTER_LIST,
  66. GLOBAL_LOG_LEVEL,
  67. SRC_LOG_LEVELS,
  68. WEBHOOK_URL,
  69. ENABLE_ADMIN_EXPORT,
  70. WEBUI_BUILD_HASH,
  71. TASK_MODEL,
  72. TASK_MODEL_EXTERNAL,
  73. TITLE_GENERATION_PROMPT_TEMPLATE,
  74. SEARCH_QUERY_GENERATION_PROMPT_TEMPLATE,
  75. AppConfig,
  76. )
  77. from constants import ERROR_MESSAGES
  78. logging.basicConfig(stream=sys.stdout, level=GLOBAL_LOG_LEVEL)
  79. log = logging.getLogger(__name__)
  80. log.setLevel(SRC_LOG_LEVELS["MAIN"])
  81. class SPAStaticFiles(StaticFiles):
  82. async def get_response(self, path: str, scope):
  83. try:
  84. return await super().get_response(path, scope)
  85. except (HTTPException, StarletteHTTPException) as ex:
  86. if ex.status_code == 404:
  87. return await super().get_response("index.html", scope)
  88. else:
  89. raise ex
  90. print(
  91. rf"""
  92. ___ __ __ _ _ _ ___
  93. / _ \ _ __ ___ _ __ \ \ / /__| |__ | | | |_ _|
  94. | | | | '_ \ / _ \ '_ \ \ \ /\ / / _ \ '_ \| | | || |
  95. | |_| | |_) | __/ | | | \ V V / __/ |_) | |_| || |
  96. \___/| .__/ \___|_| |_| \_/\_/ \___|_.__/ \___/|___|
  97. |_|
  98. v{VERSION} - building the best open-source AI user interface.
  99. {f"Commit: {WEBUI_BUILD_HASH}" if WEBUI_BUILD_HASH != "dev-build" else ""}
  100. https://github.com/open-webui/open-webui
  101. """
  102. )
  103. @asynccontextmanager
  104. async def lifespan(app: FastAPI):
  105. yield
  106. app = FastAPI(
  107. docs_url="/docs" if ENV == "dev" else None, redoc_url=None, lifespan=lifespan
  108. )
  109. app.state.config = AppConfig()
  110. app.state.config.ENABLE_OPENAI_API = ENABLE_OPENAI_API
  111. app.state.config.ENABLE_OLLAMA_API = ENABLE_OLLAMA_API
  112. app.state.config.ENABLE_MODEL_FILTER = ENABLE_MODEL_FILTER
  113. app.state.config.MODEL_FILTER_LIST = MODEL_FILTER_LIST
  114. app.state.config.WEBHOOK_URL = WEBHOOK_URL
  115. app.state.config.TASK_MODEL = TASK_MODEL
  116. app.state.config.TASK_MODEL_EXTERNAL = TASK_MODEL_EXTERNAL
  117. app.state.config.TITLE_GENERATION_PROMPT_TEMPLATE = TITLE_GENERATION_PROMPT_TEMPLATE
  118. app.state.config.SEARCH_QUERY_GENERATION_PROMPT_TEMPLATE = (
  119. SEARCH_QUERY_GENERATION_PROMPT_TEMPLATE
  120. )
  121. app.state.MODELS = {}
  122. origins = ["*"]
  123. # Custom middleware to add security headers
  124. # class SecurityHeadersMiddleware(BaseHTTPMiddleware):
  125. # async def dispatch(self, request: Request, call_next):
  126. # response: Response = await call_next(request)
  127. # response.headers["Cross-Origin-Opener-Policy"] = "same-origin"
  128. # response.headers["Cross-Origin-Embedder-Policy"] = "require-corp"
  129. # return response
  130. # app.add_middleware(SecurityHeadersMiddleware)
  131. class RAGMiddleware(BaseHTTPMiddleware):
  132. async def dispatch(self, request: Request, call_next):
  133. return_citations = False
  134. if request.method == "POST" and (
  135. "/ollama/api/chat" in request.url.path
  136. or "/chat/completions" in request.url.path
  137. ):
  138. log.debug(f"request.url.path: {request.url.path}")
  139. # Read the original request body
  140. body = await request.body()
  141. # Decode body to string
  142. body_str = body.decode("utf-8")
  143. # Parse string to JSON
  144. data = json.loads(body_str) if body_str else {}
  145. return_citations = data.get("citations", False)
  146. if "citations" in data:
  147. del data["citations"]
  148. # Example: Add a new key-value pair or modify existing ones
  149. # data["modified"] = True # Example modification
  150. if "docs" in data:
  151. data = {**data}
  152. data["messages"], citations = rag_messages(
  153. docs=data["docs"],
  154. messages=data["messages"],
  155. template=rag_app.state.config.RAG_TEMPLATE,
  156. embedding_function=rag_app.state.EMBEDDING_FUNCTION,
  157. k=rag_app.state.config.TOP_K,
  158. reranking_function=rag_app.state.sentence_transformer_rf,
  159. r=rag_app.state.config.RELEVANCE_THRESHOLD,
  160. hybrid_search=rag_app.state.config.ENABLE_RAG_HYBRID_SEARCH,
  161. )
  162. del data["docs"]
  163. log.debug(
  164. f"data['messages']: {data['messages']}, citations: {citations}"
  165. )
  166. modified_body_bytes = json.dumps(data).encode("utf-8")
  167. # Replace the request body with the modified one
  168. request._body = modified_body_bytes
  169. # Set custom header to ensure content-length matches new body length
  170. request.headers.__dict__["_list"] = [
  171. (b"content-length", str(len(modified_body_bytes)).encode("utf-8")),
  172. *[
  173. (k, v)
  174. for k, v in request.headers.raw
  175. if k.lower() != b"content-length"
  176. ],
  177. ]
  178. response = await call_next(request)
  179. if return_citations:
  180. # Inject the citations into the response
  181. if isinstance(response, StreamingResponse):
  182. # If it's a streaming response, inject it as SSE event or NDJSON line
  183. content_type = response.headers.get("Content-Type")
  184. if "text/event-stream" in content_type:
  185. return StreamingResponse(
  186. self.openai_stream_wrapper(response.body_iterator, citations),
  187. )
  188. if "application/x-ndjson" in content_type:
  189. return StreamingResponse(
  190. self.ollama_stream_wrapper(response.body_iterator, citations),
  191. )
  192. return response
  193. async def _receive(self, body: bytes):
  194. return {"type": "http.request", "body": body, "more_body": False}
  195. async def openai_stream_wrapper(self, original_generator, citations):
  196. yield f"data: {json.dumps({'citations': citations})}\n\n"
  197. async for data in original_generator:
  198. yield data
  199. async def ollama_stream_wrapper(self, original_generator, citations):
  200. yield f"{json.dumps({'citations': citations})}\n"
  201. async for data in original_generator:
  202. yield data
  203. app.add_middleware(RAGMiddleware)
  204. def filter_pipeline(payload, user):
  205. user = {"id": user.id, "name": user.name, "role": user.role}
  206. model_id = payload["model"]
  207. filters = [
  208. model
  209. for model in app.state.MODELS.values()
  210. if "pipeline" in model
  211. and "type" in model["pipeline"]
  212. and model["pipeline"]["type"] == "filter"
  213. and (
  214. model["pipeline"]["pipelines"] == ["*"]
  215. or any(
  216. model_id == target_model_id
  217. for target_model_id in model["pipeline"]["pipelines"]
  218. )
  219. )
  220. ]
  221. sorted_filters = sorted(filters, key=lambda x: x["pipeline"]["priority"])
  222. model = app.state.MODELS[model_id]
  223. if "pipeline" in model:
  224. sorted_filters.append(model)
  225. for filter in sorted_filters:
  226. r = None
  227. try:
  228. urlIdx = filter["urlIdx"]
  229. url = openai_app.state.config.OPENAI_API_BASE_URLS[urlIdx]
  230. key = openai_app.state.config.OPENAI_API_KEYS[urlIdx]
  231. if key != "":
  232. headers = {"Authorization": f"Bearer {key}"}
  233. r = requests.post(
  234. f"{url}/{filter['id']}/filter/inlet",
  235. headers=headers,
  236. json={
  237. "user": user,
  238. "body": payload,
  239. },
  240. )
  241. r.raise_for_status()
  242. payload = r.json()
  243. except Exception as e:
  244. # Handle connection error here
  245. print(f"Connection error: {e}")
  246. if r is not None:
  247. try:
  248. res = r.json()
  249. if "detail" in res:
  250. return JSONResponse(
  251. status_code=r.status_code,
  252. content=res,
  253. )
  254. except:
  255. pass
  256. else:
  257. pass
  258. if "pipeline" not in app.state.MODELS[model_id]:
  259. if "chat_id" in payload:
  260. del payload["chat_id"]
  261. if "title" in payload:
  262. del payload["title"]
  263. return payload
  264. class PipelineMiddleware(BaseHTTPMiddleware):
  265. async def dispatch(self, request: Request, call_next):
  266. if request.method == "POST" and (
  267. "/ollama/api/chat" in request.url.path
  268. or "/chat/completions" in request.url.path
  269. ):
  270. log.debug(f"request.url.path: {request.url.path}")
  271. # Read the original request body
  272. body = await request.body()
  273. # Decode body to string
  274. body_str = body.decode("utf-8")
  275. # Parse string to JSON
  276. data = json.loads(body_str) if body_str else {}
  277. user = get_current_user(
  278. get_http_authorization_cred(request.headers.get("Authorization"))
  279. )
  280. data = filter_pipeline(data, user)
  281. modified_body_bytes = json.dumps(data).encode("utf-8")
  282. # Replace the request body with the modified one
  283. request._body = modified_body_bytes
  284. # Set custom header to ensure content-length matches new body length
  285. request.headers.__dict__["_list"] = [
  286. (b"content-length", str(len(modified_body_bytes)).encode("utf-8")),
  287. *[
  288. (k, v)
  289. for k, v in request.headers.raw
  290. if k.lower() != b"content-length"
  291. ],
  292. ]
  293. response = await call_next(request)
  294. return response
  295. async def _receive(self, body: bytes):
  296. return {"type": "http.request", "body": body, "more_body": False}
  297. app.add_middleware(PipelineMiddleware)
  298. app.add_middleware(
  299. CORSMiddleware,
  300. allow_origins=origins,
  301. allow_credentials=True,
  302. allow_methods=["*"],
  303. allow_headers=["*"],
  304. )
  305. @app.middleware("http")
  306. async def check_url(request: Request, call_next):
  307. if len(app.state.MODELS) == 0:
  308. await get_all_models()
  309. else:
  310. pass
  311. start_time = int(time.time())
  312. response = await call_next(request)
  313. process_time = int(time.time()) - start_time
  314. response.headers["X-Process-Time"] = str(process_time)
  315. return response
  316. @app.middleware("http")
  317. async def update_embedding_function(request: Request, call_next):
  318. response = await call_next(request)
  319. if "/embedding/update" in request.url.path:
  320. webui_app.state.EMBEDDING_FUNCTION = rag_app.state.EMBEDDING_FUNCTION
  321. return response
  322. app.mount("/ws", socket_app)
  323. app.mount("/ollama", ollama_app)
  324. app.mount("/openai", openai_app)
  325. app.mount("/images/api/v1", images_app)
  326. app.mount("/audio/api/v1", audio_app)
  327. app.mount("/rag/api/v1", rag_app)
  328. app.mount("/api/v1", webui_app)
  329. webui_app.state.EMBEDDING_FUNCTION = rag_app.state.EMBEDDING_FUNCTION
  330. async def get_all_models():
  331. openai_models = []
  332. ollama_models = []
  333. if app.state.config.ENABLE_OPENAI_API:
  334. openai_models = await get_openai_models()
  335. openai_models = openai_models["data"]
  336. if app.state.config.ENABLE_OLLAMA_API:
  337. ollama_models = await get_ollama_models()
  338. ollama_models = [
  339. {
  340. "id": model["model"],
  341. "name": model["name"],
  342. "object": "model",
  343. "created": int(time.time()),
  344. "owned_by": "ollama",
  345. "ollama": model,
  346. }
  347. for model in ollama_models["models"]
  348. ]
  349. models = openai_models + ollama_models
  350. custom_models = Models.get_all_models()
  351. for custom_model in custom_models:
  352. if custom_model.base_model_id == None:
  353. for model in models:
  354. if (
  355. custom_model.id == model["id"]
  356. or custom_model.id == model["id"].split(":")[0]
  357. ):
  358. model["name"] = custom_model.name
  359. model["info"] = custom_model.model_dump()
  360. else:
  361. owned_by = "openai"
  362. for model in models:
  363. if (
  364. custom_model.base_model_id == model["id"]
  365. or custom_model.base_model_id == model["id"].split(":")[0]
  366. ):
  367. owned_by = model["owned_by"]
  368. break
  369. models.append(
  370. {
  371. "id": custom_model.id,
  372. "name": custom_model.name,
  373. "object": "model",
  374. "created": custom_model.created_at,
  375. "owned_by": owned_by,
  376. "info": custom_model.model_dump(),
  377. "preset": True,
  378. }
  379. )
  380. app.state.MODELS = {model["id"]: model for model in models}
  381. webui_app.state.MODELS = app.state.MODELS
  382. return models
  383. @app.get("/api/models")
  384. async def get_models(user=Depends(get_verified_user)):
  385. models = await get_all_models()
  386. # Filter out filter pipelines
  387. models = [
  388. model
  389. for model in models
  390. if "pipeline" not in model or model["pipeline"].get("type", None) != "filter"
  391. ]
  392. if app.state.config.ENABLE_MODEL_FILTER:
  393. if user.role == "user":
  394. models = list(
  395. filter(
  396. lambda model: model["id"] in app.state.config.MODEL_FILTER_LIST,
  397. models,
  398. )
  399. )
  400. return {"data": models}
  401. return {"data": models}
  402. @app.get("/api/task/config")
  403. async def get_task_config(user=Depends(get_verified_user)):
  404. return {
  405. "TASK_MODEL": app.state.config.TASK_MODEL,
  406. "TASK_MODEL_EXTERNAL": app.state.config.TASK_MODEL_EXTERNAL,
  407. "TITLE_GENERATION_PROMPT_TEMPLATE": app.state.config.TITLE_GENERATION_PROMPT_TEMPLATE,
  408. "SEARCH_QUERY_GENERATION_PROMPT_TEMPLATE": app.state.config.SEARCH_QUERY_GENERATION_PROMPT_TEMPLATE,
  409. }
  410. class TaskConfigForm(BaseModel):
  411. TASK_MODEL: Optional[str]
  412. TASK_MODEL_EXTERNAL: Optional[str]
  413. TITLE_GENERATION_PROMPT_TEMPLATE: str
  414. SEARCH_QUERY_GENERATION_PROMPT_TEMPLATE: str
  415. @app.post("/api/task/config/update")
  416. async def update_task_config(form_data: TaskConfigForm, user=Depends(get_admin_user)):
  417. app.state.config.TASK_MODEL = form_data.TASK_MODEL
  418. app.state.config.TASK_MODEL_EXTERNAL = form_data.TASK_MODEL_EXTERNAL
  419. app.state.config.TITLE_GENERATION_PROMPT_TEMPLATE = (
  420. form_data.TITLE_GENERATION_PROMPT_TEMPLATE
  421. )
  422. app.state.config.SEARCH_QUERY_GENERATION_PROMPT_TEMPLATE = (
  423. form_data.SEARCH_QUERY_GENERATION_PROMPT_TEMPLATE
  424. )
  425. return {
  426. "TASK_MODEL": app.state.config.TASK_MODEL,
  427. "TASK_MODEL_EXTERNAL": app.state.config.TASK_MODEL_EXTERNAL,
  428. "TITLE_GENERATION_PROMPT_TEMPLATE": app.state.config.TITLE_GENERATION_PROMPT_TEMPLATE,
  429. "SEARCH_QUERY_GENERATION_PROMPT_TEMPLATE": app.state.config.SEARCH_QUERY_GENERATION_PROMPT_TEMPLATE,
  430. }
  431. @app.post("/api/task/title/completions")
  432. async def generate_title(form_data: dict, user=Depends(get_verified_user)):
  433. print("generate_title")
  434. model_id = form_data["model"]
  435. if model_id not in app.state.MODELS:
  436. raise HTTPException(
  437. status_code=status.HTTP_404_NOT_FOUND,
  438. detail="Model not found",
  439. )
  440. # Check if the user has a custom task model
  441. # If the user has a custom task model, use that model
  442. if app.state.MODELS[model_id]["owned_by"] == "ollama":
  443. if app.state.config.TASK_MODEL:
  444. task_model_id = app.state.config.TASK_MODEL
  445. if task_model_id in app.state.MODELS:
  446. model_id = task_model_id
  447. else:
  448. if app.state.config.TASK_MODEL_EXTERNAL:
  449. task_model_id = app.state.config.TASK_MODEL_EXTERNAL
  450. if task_model_id in app.state.MODELS:
  451. model_id = task_model_id
  452. print(model_id)
  453. model = app.state.MODELS[model_id]
  454. template = app.state.config.TITLE_GENERATION_PROMPT_TEMPLATE
  455. content = title_generation_template(
  456. template, form_data["prompt"], user.model_dump()
  457. )
  458. payload = {
  459. "model": model_id,
  460. "messages": [{"role": "user", "content": content}],
  461. "stream": False,
  462. "max_tokens": 50,
  463. "chat_id": form_data.get("chat_id", None),
  464. "title": True,
  465. }
  466. print(payload)
  467. payload = filter_pipeline(payload, user)
  468. if model["owned_by"] == "ollama":
  469. return await generate_ollama_chat_completion(
  470. OpenAIChatCompletionForm(**payload), user=user
  471. )
  472. else:
  473. return await generate_openai_chat_completion(payload, user=user)
  474. @app.post("/api/task/query/completions")
  475. async def generate_search_query(form_data: dict, user=Depends(get_verified_user)):
  476. print("generate_search_query")
  477. model_id = form_data["model"]
  478. if model_id not in app.state.MODELS:
  479. raise HTTPException(
  480. status_code=status.HTTP_404_NOT_FOUND,
  481. detail="Model not found",
  482. )
  483. # Check if the user has a custom task model
  484. # If the user has a custom task model, use that model
  485. if app.state.MODELS[model_id]["owned_by"] == "ollama":
  486. if app.state.config.TASK_MODEL:
  487. task_model_id = app.state.config.TASK_MODEL
  488. if task_model_id in app.state.MODELS:
  489. model_id = task_model_id
  490. else:
  491. if app.state.config.TASK_MODEL_EXTERNAL:
  492. task_model_id = app.state.config.TASK_MODEL_EXTERNAL
  493. if task_model_id in app.state.MODELS:
  494. model_id = task_model_id
  495. print(model_id)
  496. model = app.state.MODELS[model_id]
  497. template = app.state.config.SEARCH_QUERY_GENERATION_PROMPT_TEMPLATE
  498. content = search_query_generation_template(
  499. template, form_data["prompt"], user.model_dump()
  500. )
  501. payload = {
  502. "model": model_id,
  503. "messages": [{"role": "user", "content": content}],
  504. "stream": False,
  505. "max_tokens": 30,
  506. }
  507. print(payload)
  508. payload = filter_pipeline(payload, user)
  509. if model["owned_by"] == "ollama":
  510. return await generate_ollama_chat_completion(
  511. OpenAIChatCompletionForm(**payload), user=user
  512. )
  513. else:
  514. return await generate_openai_chat_completion(payload, user=user)
  515. @app.post("/api/chat/completions")
  516. async def generate_chat_completions(form_data: dict, user=Depends(get_verified_user)):
  517. model_id = form_data["model"]
  518. if model_id not in app.state.MODELS:
  519. raise HTTPException(
  520. status_code=status.HTTP_404_NOT_FOUND,
  521. detail="Model not found",
  522. )
  523. model = app.state.MODELS[model_id]
  524. print(model)
  525. if model["owned_by"] == "ollama":
  526. return await generate_ollama_chat_completion(
  527. OpenAIChatCompletionForm(**form_data), user=user
  528. )
  529. else:
  530. return await generate_openai_chat_completion(form_data, user=user)
  531. @app.post("/api/chat/completed")
  532. async def chat_completed(form_data: dict, user=Depends(get_verified_user)):
  533. data = form_data
  534. model_id = data["model"]
  535. filters = [
  536. model
  537. for model in app.state.MODELS.values()
  538. if "pipeline" in model
  539. and "type" in model["pipeline"]
  540. and model["pipeline"]["type"] == "filter"
  541. and (
  542. model["pipeline"]["pipelines"] == ["*"]
  543. or any(
  544. model_id == target_model_id
  545. for target_model_id in model["pipeline"]["pipelines"]
  546. )
  547. )
  548. ]
  549. sorted_filters = sorted(filters, key=lambda x: x["pipeline"]["priority"])
  550. print(model_id)
  551. if model_id in app.state.MODELS:
  552. model = app.state.MODELS[model_id]
  553. if "pipeline" in model:
  554. sorted_filters = [model] + sorted_filters
  555. for filter in sorted_filters:
  556. r = None
  557. try:
  558. urlIdx = filter["urlIdx"]
  559. url = openai_app.state.config.OPENAI_API_BASE_URLS[urlIdx]
  560. key = openai_app.state.config.OPENAI_API_KEYS[urlIdx]
  561. if key != "":
  562. headers = {"Authorization": f"Bearer {key}"}
  563. r = requests.post(
  564. f"{url}/{filter['id']}/filter/outlet",
  565. headers=headers,
  566. json={
  567. "user": {"id": user.id, "name": user.name, "role": user.role},
  568. "body": data,
  569. },
  570. )
  571. r.raise_for_status()
  572. data = r.json()
  573. except Exception as e:
  574. # Handle connection error here
  575. print(f"Connection error: {e}")
  576. if r is not None:
  577. try:
  578. res = r.json()
  579. if "detail" in res:
  580. return JSONResponse(
  581. status_code=r.status_code,
  582. content=res,
  583. )
  584. except:
  585. pass
  586. else:
  587. pass
  588. return data
  589. @app.get("/api/pipelines/list")
  590. async def get_pipelines_list(user=Depends(get_admin_user)):
  591. responses = await get_openai_models(raw=True)
  592. print(responses)
  593. urlIdxs = [
  594. idx
  595. for idx, response in enumerate(responses)
  596. if response != None and "pipelines" in response
  597. ]
  598. return {
  599. "data": [
  600. {
  601. "url": openai_app.state.config.OPENAI_API_BASE_URLS[urlIdx],
  602. "idx": urlIdx,
  603. }
  604. for urlIdx in urlIdxs
  605. ]
  606. }
  607. @app.post("/api/pipelines/upload")
  608. async def upload_pipeline(
  609. urlIdx: int = Form(...), file: UploadFile = File(...), user=Depends(get_admin_user)
  610. ):
  611. print("upload_pipeline", urlIdx, file.filename)
  612. # Check if the uploaded file is a python file
  613. if not file.filename.endswith(".py"):
  614. raise HTTPException(
  615. status_code=status.HTTP_400_BAD_REQUEST,
  616. detail="Only Python (.py) files are allowed.",
  617. )
  618. upload_folder = f"{CACHE_DIR}/pipelines"
  619. os.makedirs(upload_folder, exist_ok=True)
  620. file_path = os.path.join(upload_folder, file.filename)
  621. try:
  622. # Save the uploaded file
  623. with open(file_path, "wb") as buffer:
  624. shutil.copyfileobj(file.file, buffer)
  625. url = openai_app.state.config.OPENAI_API_BASE_URLS[urlIdx]
  626. key = openai_app.state.config.OPENAI_API_KEYS[urlIdx]
  627. headers = {"Authorization": f"Bearer {key}"}
  628. with open(file_path, "rb") as f:
  629. files = {"file": f}
  630. r = requests.post(f"{url}/pipelines/upload", headers=headers, files=files)
  631. r.raise_for_status()
  632. data = r.json()
  633. return {**data}
  634. except Exception as e:
  635. # Handle connection error here
  636. print(f"Connection error: {e}")
  637. detail = "Pipeline not found"
  638. if r is not None:
  639. try:
  640. res = r.json()
  641. if "detail" in res:
  642. detail = res["detail"]
  643. except:
  644. pass
  645. raise HTTPException(
  646. status_code=(r.status_code if r is not None else status.HTTP_404_NOT_FOUND),
  647. detail=detail,
  648. )
  649. finally:
  650. # Ensure the file is deleted after the upload is completed or on failure
  651. if os.path.exists(file_path):
  652. os.remove(file_path)
  653. class AddPipelineForm(BaseModel):
  654. url: str
  655. urlIdx: int
  656. @app.post("/api/pipelines/add")
  657. async def add_pipeline(form_data: AddPipelineForm, user=Depends(get_admin_user)):
  658. r = None
  659. try:
  660. urlIdx = form_data.urlIdx
  661. url = openai_app.state.config.OPENAI_API_BASE_URLS[urlIdx]
  662. key = openai_app.state.config.OPENAI_API_KEYS[urlIdx]
  663. headers = {"Authorization": f"Bearer {key}"}
  664. r = requests.post(
  665. f"{url}/pipelines/add", headers=headers, json={"url": form_data.url}
  666. )
  667. r.raise_for_status()
  668. data = r.json()
  669. return {**data}
  670. except Exception as e:
  671. # Handle connection error here
  672. print(f"Connection error: {e}")
  673. detail = "Pipeline not found"
  674. if r is not None:
  675. try:
  676. res = r.json()
  677. if "detail" in res:
  678. detail = res["detail"]
  679. except:
  680. pass
  681. raise HTTPException(
  682. status_code=(r.status_code if r is not None else status.HTTP_404_NOT_FOUND),
  683. detail=detail,
  684. )
  685. class DeletePipelineForm(BaseModel):
  686. id: str
  687. urlIdx: int
  688. @app.delete("/api/pipelines/delete")
  689. async def delete_pipeline(form_data: DeletePipelineForm, user=Depends(get_admin_user)):
  690. r = None
  691. try:
  692. urlIdx = form_data.urlIdx
  693. url = openai_app.state.config.OPENAI_API_BASE_URLS[urlIdx]
  694. key = openai_app.state.config.OPENAI_API_KEYS[urlIdx]
  695. headers = {"Authorization": f"Bearer {key}"}
  696. r = requests.delete(
  697. f"{url}/pipelines/delete", headers=headers, json={"id": form_data.id}
  698. )
  699. r.raise_for_status()
  700. data = r.json()
  701. return {**data}
  702. except Exception as e:
  703. # Handle connection error here
  704. print(f"Connection error: {e}")
  705. detail = "Pipeline not found"
  706. if r is not None:
  707. try:
  708. res = r.json()
  709. if "detail" in res:
  710. detail = res["detail"]
  711. except:
  712. pass
  713. raise HTTPException(
  714. status_code=(r.status_code if r is not None else status.HTTP_404_NOT_FOUND),
  715. detail=detail,
  716. )
  717. @app.get("/api/pipelines")
  718. async def get_pipelines(urlIdx: Optional[int] = None, user=Depends(get_admin_user)):
  719. r = None
  720. try:
  721. urlIdx
  722. url = openai_app.state.config.OPENAI_API_BASE_URLS[urlIdx]
  723. key = openai_app.state.config.OPENAI_API_KEYS[urlIdx]
  724. headers = {"Authorization": f"Bearer {key}"}
  725. r = requests.get(f"{url}/pipelines", headers=headers)
  726. r.raise_for_status()
  727. data = r.json()
  728. return {**data}
  729. except Exception as e:
  730. # Handle connection error here
  731. print(f"Connection error: {e}")
  732. detail = "Pipeline not found"
  733. if r is not None:
  734. try:
  735. res = r.json()
  736. if "detail" in res:
  737. detail = res["detail"]
  738. except:
  739. pass
  740. raise HTTPException(
  741. status_code=(r.status_code if r is not None else status.HTTP_404_NOT_FOUND),
  742. detail=detail,
  743. )
  744. @app.get("/api/pipelines/{pipeline_id}/valves")
  745. async def get_pipeline_valves(
  746. urlIdx: Optional[int], pipeline_id: str, user=Depends(get_admin_user)
  747. ):
  748. models = await get_all_models()
  749. r = None
  750. try:
  751. url = openai_app.state.config.OPENAI_API_BASE_URLS[urlIdx]
  752. key = openai_app.state.config.OPENAI_API_KEYS[urlIdx]
  753. headers = {"Authorization": f"Bearer {key}"}
  754. r = requests.get(f"{url}/{pipeline_id}/valves", headers=headers)
  755. r.raise_for_status()
  756. data = r.json()
  757. return {**data}
  758. except Exception as e:
  759. # Handle connection error here
  760. print(f"Connection error: {e}")
  761. detail = "Pipeline not found"
  762. if r is not None:
  763. try:
  764. res = r.json()
  765. if "detail" in res:
  766. detail = res["detail"]
  767. except:
  768. pass
  769. raise HTTPException(
  770. status_code=(r.status_code if r is not None else status.HTTP_404_NOT_FOUND),
  771. detail=detail,
  772. )
  773. @app.get("/api/pipelines/{pipeline_id}/valves/spec")
  774. async def get_pipeline_valves_spec(
  775. urlIdx: Optional[int], pipeline_id: str, user=Depends(get_admin_user)
  776. ):
  777. models = await get_all_models()
  778. r = None
  779. try:
  780. url = openai_app.state.config.OPENAI_API_BASE_URLS[urlIdx]
  781. key = openai_app.state.config.OPENAI_API_KEYS[urlIdx]
  782. headers = {"Authorization": f"Bearer {key}"}
  783. r = requests.get(f"{url}/{pipeline_id}/valves/spec", headers=headers)
  784. r.raise_for_status()
  785. data = r.json()
  786. return {**data}
  787. except Exception as e:
  788. # Handle connection error here
  789. print(f"Connection error: {e}")
  790. detail = "Pipeline not found"
  791. if r is not None:
  792. try:
  793. res = r.json()
  794. if "detail" in res:
  795. detail = res["detail"]
  796. except:
  797. pass
  798. raise HTTPException(
  799. status_code=(r.status_code if r is not None else status.HTTP_404_NOT_FOUND),
  800. detail=detail,
  801. )
  802. @app.post("/api/pipelines/{pipeline_id}/valves/update")
  803. async def update_pipeline_valves(
  804. urlIdx: Optional[int],
  805. pipeline_id: str,
  806. form_data: dict,
  807. user=Depends(get_admin_user),
  808. ):
  809. models = await get_all_models()
  810. r = None
  811. try:
  812. url = openai_app.state.config.OPENAI_API_BASE_URLS[urlIdx]
  813. key = openai_app.state.config.OPENAI_API_KEYS[urlIdx]
  814. headers = {"Authorization": f"Bearer {key}"}
  815. r = requests.post(
  816. f"{url}/{pipeline_id}/valves/update",
  817. headers=headers,
  818. json={**form_data},
  819. )
  820. r.raise_for_status()
  821. data = r.json()
  822. return {**data}
  823. except Exception as e:
  824. # Handle connection error here
  825. print(f"Connection error: {e}")
  826. detail = "Pipeline not found"
  827. if r is not None:
  828. try:
  829. res = r.json()
  830. if "detail" in res:
  831. detail = res["detail"]
  832. except:
  833. pass
  834. raise HTTPException(
  835. status_code=(r.status_code if r is not None else status.HTTP_404_NOT_FOUND),
  836. detail=detail,
  837. )
  838. @app.get("/api/config")
  839. async def get_app_config():
  840. # Checking and Handling the Absence of 'ui' in CONFIG_DATA
  841. default_locale = "en-US"
  842. if "ui" in CONFIG_DATA:
  843. default_locale = CONFIG_DATA["ui"].get("default_locale", "en-US")
  844. # The Rest of the Function Now Uses the Variables Defined Above
  845. return {
  846. "status": True,
  847. "name": WEBUI_NAME,
  848. "version": VERSION,
  849. "default_locale": default_locale,
  850. "default_models": webui_app.state.config.DEFAULT_MODELS,
  851. "default_prompt_suggestions": webui_app.state.config.DEFAULT_PROMPT_SUGGESTIONS,
  852. "features": {
  853. "auth": WEBUI_AUTH,
  854. "auth_trusted_header": bool(webui_app.state.AUTH_TRUSTED_EMAIL_HEADER),
  855. "enable_signup": webui_app.state.config.ENABLE_SIGNUP,
  856. "enable_web_search": rag_app.state.config.ENABLE_RAG_WEB_SEARCH,
  857. "enable_image_generation": images_app.state.config.ENABLED,
  858. "enable_community_sharing": webui_app.state.config.ENABLE_COMMUNITY_SHARING,
  859. "enable_admin_export": ENABLE_ADMIN_EXPORT,
  860. },
  861. "audio": {
  862. "tts": {
  863. "engine": audio_app.state.config.TTS_ENGINE,
  864. "voice": audio_app.state.config.TTS_VOICE,
  865. },
  866. "stt": {
  867. "engine": audio_app.state.config.STT_ENGINE,
  868. },
  869. },
  870. }
  871. @app.get("/api/config/model/filter")
  872. async def get_model_filter_config(user=Depends(get_admin_user)):
  873. return {
  874. "enabled": app.state.config.ENABLE_MODEL_FILTER,
  875. "models": app.state.config.MODEL_FILTER_LIST,
  876. }
  877. class ModelFilterConfigForm(BaseModel):
  878. enabled: bool
  879. models: List[str]
  880. @app.post("/api/config/model/filter")
  881. async def update_model_filter_config(
  882. form_data: ModelFilterConfigForm, user=Depends(get_admin_user)
  883. ):
  884. app.state.config.ENABLE_MODEL_FILTER = form_data.enabled
  885. app.state.config.MODEL_FILTER_LIST = form_data.models
  886. return {
  887. "enabled": app.state.config.ENABLE_MODEL_FILTER,
  888. "models": app.state.config.MODEL_FILTER_LIST,
  889. }
  890. @app.get("/api/webhook")
  891. async def get_webhook_url(user=Depends(get_admin_user)):
  892. return {
  893. "url": app.state.config.WEBHOOK_URL,
  894. }
  895. class UrlForm(BaseModel):
  896. url: str
  897. @app.post("/api/webhook")
  898. async def update_webhook_url(form_data: UrlForm, user=Depends(get_admin_user)):
  899. app.state.config.WEBHOOK_URL = form_data.url
  900. webui_app.state.WEBHOOK_URL = app.state.config.WEBHOOK_URL
  901. return {"url": app.state.config.WEBHOOK_URL}
  902. @app.get("/api/version")
  903. async def get_app_config():
  904. return {
  905. "version": VERSION,
  906. }
  907. @app.get("/api/changelog")
  908. async def get_app_changelog():
  909. return {key: CHANGELOG[key] for idx, key in enumerate(CHANGELOG) if idx < 5}
  910. @app.get("/api/version/updates")
  911. async def get_app_latest_release_version():
  912. try:
  913. async with aiohttp.ClientSession(trust_env=True) as session:
  914. async with session.get(
  915. "https://api.github.com/repos/open-webui/open-webui/releases/latest"
  916. ) as response:
  917. response.raise_for_status()
  918. data = await response.json()
  919. latest_version = data["tag_name"]
  920. return {"current": VERSION, "latest": latest_version[1:]}
  921. except aiohttp.ClientError as e:
  922. raise HTTPException(
  923. status_code=status.HTTP_503_SERVICE_UNAVAILABLE,
  924. detail=ERROR_MESSAGES.RATE_LIMIT_EXCEEDED,
  925. )
  926. @app.get("/manifest.json")
  927. async def get_manifest_json():
  928. return {
  929. "name": WEBUI_NAME,
  930. "short_name": WEBUI_NAME,
  931. "start_url": "/",
  932. "display": "standalone",
  933. "background_color": "#343541",
  934. "theme_color": "#343541",
  935. "orientation": "portrait-primary",
  936. "icons": [{"src": "/static/logo.png", "type": "image/png", "sizes": "500x500"}],
  937. }
  938. @app.get("/opensearch.xml")
  939. async def get_opensearch_xml():
  940. xml_content = rf"""
  941. <OpenSearchDescription xmlns="http://a9.com/-/spec/opensearch/1.1/" xmlns:moz="http://www.mozilla.org/2006/browser/search/">
  942. <ShortName>{WEBUI_NAME}</ShortName>
  943. <Description>Search {WEBUI_NAME}</Description>
  944. <InputEncoding>UTF-8</InputEncoding>
  945. <Image width="16" height="16" type="image/x-icon">{WEBUI_URL}/favicon.png</Image>
  946. <Url type="text/html" method="get" template="{WEBUI_URL}/?q={"{searchTerms}"}"/>
  947. <moz:SearchForm>{WEBUI_URL}</moz:SearchForm>
  948. </OpenSearchDescription>
  949. """
  950. return Response(content=xml_content, media_type="application/xml")
  951. @app.get("/health")
  952. async def healthcheck():
  953. return {"status": True}
  954. app.mount("/static", StaticFiles(directory=STATIC_DIR), name="static")
  955. app.mount("/cache", StaticFiles(directory=CACHE_DIR), name="cache")
  956. if os.path.exists(FRONTEND_BUILD_DIR):
  957. mimetypes.add_type("text/javascript", ".js")
  958. app.mount(
  959. "/",
  960. SPAStaticFiles(directory=FRONTEND_BUILD_DIR, html=True),
  961. name="spa-static-files",
  962. )
  963. else:
  964. log.warning(
  965. f"Frontend build directory not found at '{FRONTEND_BUILD_DIR}'. Serving API only."
  966. )