main.py 49 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597598599600601602603604605606607608609610611612613614615616617618619620621622623624625626627628629630631632633634635636637638639640641642643644645646647648649650651652653654655656657658659660661662663664665666667668669670671672673674675676677678679680681682683684685686687688689690691692693694695696697698699700701702703704705706707708709710711712713714715716717718719720721722723724725726727728729730731732733734735736737738739740741742743744745746747748749750751752753754755756757758759760761762763764765766767768769770771772773774775776777778779780781782783784785786787788789790791792793794795796797798799800801802803804805806807808809810811812813814815816817818819820821822823824825826827828829830831832833834835836837838839840841842843844845846847848849850851852853854855856857858859860861862863864865866867868869870871872873874875876877878879880881882883884885886887888889890891892893894895896897898899900901902903904905906907908909910911912913914915916917918919920921922923924925926927928929930931932933934935936937938939940941942943944945946947948949950951952953954955956957958959960961962963964965966967968969970971972973974975976977978979980981982983984985986987988989990991992993994995996997998999100010011002100310041005100610071008100910101011101210131014101510161017101810191020102110221023102410251026102710281029103010311032103310341035103610371038103910401041104210431044104510461047104810491050105110521053105410551056105710581059106010611062106310641065106610671068106910701071107210731074107510761077107810791080108110821083108410851086108710881089109010911092109310941095109610971098109911001101110211031104110511061107110811091110111111121113111411151116111711181119112011211122112311241125112611271128112911301131113211331134113511361137113811391140114111421143114411451146114711481149115011511152115311541155115611571158115911601161116211631164116511661167116811691170117111721173117411751176117711781179118011811182118311841185118611871188118911901191119211931194119511961197119811991200120112021203120412051206120712081209121012111212121312141215121612171218121912201221122212231224122512261227122812291230123112321233123412351236123712381239124012411242124312441245124612471248124912501251125212531254125512561257125812591260126112621263126412651266126712681269127012711272127312741275127612771278127912801281128212831284128512861287128812891290129112921293129412951296129712981299130013011302130313041305130613071308130913101311131213131314131513161317131813191320132113221323132413251326132713281329133013311332133313341335133613371338133913401341134213431344134513461347134813491350135113521353135413551356135713581359136013611362136313641365136613671368136913701371137213731374137513761377137813791380138113821383138413851386138713881389139013911392139313941395139613971398139914001401140214031404140514061407140814091410141114121413141414151416141714181419142014211422142314241425142614271428142914301431143214331434143514361437143814391440144114421443144414451446144714481449145014511452145314541455145614571458145914601461146214631464146514661467146814691470147114721473147414751476147714781479148014811482148314841485148614871488148914901491149214931494149514961497149814991500150115021503150415051506150715081509151015111512151315141515151615171518151915201521152215231524152515261527152815291530153115321533153415351536153715381539154015411542154315441545154615471548154915501551155215531554155515561557155815591560156115621563156415651566156715681569
  1. from contextlib import asynccontextmanager
  2. from bs4 import BeautifulSoup
  3. import json
  4. import markdown
  5. import time
  6. import os
  7. import sys
  8. import logging
  9. import aiohttp
  10. import requests
  11. import mimetypes
  12. import shutil
  13. import os
  14. import uuid
  15. import inspect
  16. import asyncio
  17. from fastapi import FastAPI, Request, Depends, status, UploadFile, File, Form
  18. from fastapi.staticfiles import StaticFiles
  19. from fastapi.responses import JSONResponse
  20. from fastapi import HTTPException
  21. from fastapi.middleware.wsgi import WSGIMiddleware
  22. from fastapi.middleware.cors import CORSMiddleware
  23. from starlette.exceptions import HTTPException as StarletteHTTPException
  24. from starlette.middleware.base import BaseHTTPMiddleware
  25. from starlette.responses import StreamingResponse, Response
  26. from apps.socket.main import app as socket_app
  27. from apps.ollama.main import (
  28. app as ollama_app,
  29. OpenAIChatCompletionForm,
  30. get_all_models as get_ollama_models,
  31. generate_openai_chat_completion as generate_ollama_chat_completion,
  32. )
  33. from apps.openai.main import (
  34. app as openai_app,
  35. get_all_models as get_openai_models,
  36. generate_chat_completion as generate_openai_chat_completion,
  37. )
  38. from apps.audio.main import app as audio_app
  39. from apps.images.main import app as images_app
  40. from apps.rag.main import app as rag_app
  41. from apps.webui.main import app as webui_app
  42. from pydantic import BaseModel
  43. from typing import List, Optional
  44. from apps.webui.models.models import Models, ModelModel
  45. from apps.webui.models.tools import Tools
  46. from apps.webui.utils import load_toolkit_module_by_id
  47. from utils.utils import (
  48. get_admin_user,
  49. get_verified_user,
  50. get_current_user,
  51. get_http_authorization_cred,
  52. )
  53. from utils.task import (
  54. title_generation_template,
  55. search_query_generation_template,
  56. tools_function_calling_generation_template,
  57. )
  58. from utils.misc import get_last_user_message, add_or_update_system_message
  59. from apps.rag.utils import get_rag_context, rag_template
  60. from config import (
  61. CONFIG_DATA,
  62. WEBUI_NAME,
  63. WEBUI_URL,
  64. WEBUI_AUTH,
  65. ENV,
  66. VERSION,
  67. CHANGELOG,
  68. FRONTEND_BUILD_DIR,
  69. UPLOAD_DIR,
  70. CACHE_DIR,
  71. STATIC_DIR,
  72. ENABLE_OPENAI_API,
  73. ENABLE_OLLAMA_API,
  74. ENABLE_MODEL_FILTER,
  75. MODEL_FILTER_LIST,
  76. GLOBAL_LOG_LEVEL,
  77. SRC_LOG_LEVELS,
  78. WEBHOOK_URL,
  79. ENABLE_ADMIN_EXPORT,
  80. WEBUI_BUILD_HASH,
  81. TASK_MODEL,
  82. TASK_MODEL_EXTERNAL,
  83. TITLE_GENERATION_PROMPT_TEMPLATE,
  84. SEARCH_QUERY_GENERATION_PROMPT_TEMPLATE,
  85. SEARCH_QUERY_PROMPT_LENGTH_THRESHOLD,
  86. TOOLS_FUNCTION_CALLING_PROMPT_TEMPLATE,
  87. AppConfig,
  88. )
  89. from constants import ERROR_MESSAGES
  90. logging.basicConfig(stream=sys.stdout, level=GLOBAL_LOG_LEVEL)
  91. log = logging.getLogger(__name__)
  92. log.setLevel(SRC_LOG_LEVELS["MAIN"])
  93. class SPAStaticFiles(StaticFiles):
  94. async def get_response(self, path: str, scope):
  95. try:
  96. return await super().get_response(path, scope)
  97. except (HTTPException, StarletteHTTPException) as ex:
  98. if ex.status_code == 404:
  99. return await super().get_response("index.html", scope)
  100. else:
  101. raise ex
  102. print(
  103. rf"""
  104. ___ __ __ _ _ _ ___
  105. / _ \ _ __ ___ _ __ \ \ / /__| |__ | | | |_ _|
  106. | | | | '_ \ / _ \ '_ \ \ \ /\ / / _ \ '_ \| | | || |
  107. | |_| | |_) | __/ | | | \ V V / __/ |_) | |_| || |
  108. \___/| .__/ \___|_| |_| \_/\_/ \___|_.__/ \___/|___|
  109. |_|
  110. v{VERSION} - building the best open-source AI user interface.
  111. {f"Commit: {WEBUI_BUILD_HASH}" if WEBUI_BUILD_HASH != "dev-build" else ""}
  112. https://github.com/open-webui/open-webui
  113. """
  114. )
  115. @asynccontextmanager
  116. async def lifespan(app: FastAPI):
  117. yield
  118. app = FastAPI(
  119. docs_url="/docs" if ENV == "dev" else None, redoc_url=None, lifespan=lifespan
  120. )
  121. app.state.config = AppConfig()
  122. app.state.config.ENABLE_OPENAI_API = ENABLE_OPENAI_API
  123. app.state.config.ENABLE_OLLAMA_API = ENABLE_OLLAMA_API
  124. app.state.config.ENABLE_MODEL_FILTER = ENABLE_MODEL_FILTER
  125. app.state.config.MODEL_FILTER_LIST = MODEL_FILTER_LIST
  126. app.state.config.WEBHOOK_URL = WEBHOOK_URL
  127. app.state.config.TASK_MODEL = TASK_MODEL
  128. app.state.config.TASK_MODEL_EXTERNAL = TASK_MODEL_EXTERNAL
  129. app.state.config.TITLE_GENERATION_PROMPT_TEMPLATE = TITLE_GENERATION_PROMPT_TEMPLATE
  130. app.state.config.SEARCH_QUERY_GENERATION_PROMPT_TEMPLATE = (
  131. SEARCH_QUERY_GENERATION_PROMPT_TEMPLATE
  132. )
  133. app.state.config.SEARCH_QUERY_PROMPT_LENGTH_THRESHOLD = (
  134. SEARCH_QUERY_PROMPT_LENGTH_THRESHOLD
  135. )
  136. app.state.config.TOOLS_FUNCTION_CALLING_PROMPT_TEMPLATE = (
  137. TOOLS_FUNCTION_CALLING_PROMPT_TEMPLATE
  138. )
  139. app.state.MODELS = {}
  140. origins = ["*"]
  141. async def get_function_call_response(
  142. messages, files, tool_id, template, task_model_id, user
  143. ):
  144. tool = Tools.get_tool_by_id(tool_id)
  145. tools_specs = json.dumps(tool.specs, indent=2)
  146. content = tools_function_calling_generation_template(template, tools_specs)
  147. user_message = get_last_user_message(messages)
  148. prompt = (
  149. "History:\n"
  150. + "\n".join(
  151. [
  152. f"{message['role'].upper()}: \"\"\"{message['content']}\"\"\""
  153. for message in messages[::-1][:4]
  154. ]
  155. )
  156. + f"\nQuery: {user_message}"
  157. )
  158. print(prompt)
  159. payload = {
  160. "model": task_model_id,
  161. "messages": [
  162. {"role": "system", "content": content},
  163. {"role": "user", "content": f"Query: {prompt}"},
  164. ],
  165. "stream": False,
  166. }
  167. try:
  168. payload = filter_pipeline(payload, user)
  169. except Exception as e:
  170. raise e
  171. model = app.state.MODELS[task_model_id]
  172. response = None
  173. try:
  174. if model["owned_by"] == "ollama":
  175. response = await generate_ollama_chat_completion(payload, user=user)
  176. else:
  177. response = await generate_openai_chat_completion(payload, user=user)
  178. content = None
  179. if hasattr(response, "body_iterator"):
  180. async for chunk in response.body_iterator:
  181. data = json.loads(chunk.decode("utf-8"))
  182. content = data["choices"][0]["message"]["content"]
  183. # Cleanup any remaining background tasks if necessary
  184. if response.background is not None:
  185. await response.background()
  186. else:
  187. content = response["choices"][0]["message"]["content"]
  188. # Parse the function response
  189. if content is not None:
  190. print(f"content: {content}")
  191. result = json.loads(content)
  192. print(result)
  193. # Call the function
  194. if "name" in result:
  195. if tool_id in webui_app.state.TOOLS:
  196. toolkit_module = webui_app.state.TOOLS[tool_id]
  197. else:
  198. toolkit_module = load_toolkit_module_by_id(tool_id)
  199. webui_app.state.TOOLS[tool_id] = toolkit_module
  200. file_handler = False
  201. # check if toolkit_module has file_handler self variable
  202. if hasattr(toolkit_module, "file_handler"):
  203. file_handler = True
  204. print("file_handler: ", file_handler)
  205. function = getattr(toolkit_module, result["name"])
  206. function_result = None
  207. try:
  208. # Get the signature of the function
  209. sig = inspect.signature(function)
  210. params = result["parameters"]
  211. if "__user__" in sig.parameters:
  212. # Call the function with the '__user__' parameter included
  213. params = {
  214. **params,
  215. "__user__": {
  216. "id": user.id,
  217. "email": user.email,
  218. "name": user.name,
  219. "role": user.role,
  220. },
  221. }
  222. if "__messages__" in sig.parameters:
  223. # Call the function with the '__messages__' parameter included
  224. params = {
  225. **params,
  226. "__messages__": messages,
  227. }
  228. if "__files__" in sig.parameters:
  229. # Call the function with the '__files__' parameter included
  230. params = {
  231. **params,
  232. "__files__": files,
  233. }
  234. function_result = function(**params)
  235. except Exception as e:
  236. print(e)
  237. # Add the function result to the system prompt
  238. if function_result is not None:
  239. return function_result, file_handler
  240. except Exception as e:
  241. print(f"Error: {e}")
  242. return None, False
  243. class ChatCompletionMiddleware(BaseHTTPMiddleware):
  244. async def dispatch(self, request: Request, call_next):
  245. return_citations = False
  246. if request.method == "POST" and (
  247. "/ollama/api/chat" in request.url.path
  248. or "/chat/completions" in request.url.path
  249. ):
  250. log.debug(f"request.url.path: {request.url.path}")
  251. # Read the original request body
  252. body = await request.body()
  253. # Decode body to string
  254. body_str = body.decode("utf-8")
  255. # Parse string to JSON
  256. data = json.loads(body_str) if body_str else {}
  257. user = get_current_user(
  258. get_http_authorization_cred(request.headers.get("Authorization"))
  259. )
  260. # Remove the citations from the body
  261. return_citations = data.get("citations", False)
  262. if "citations" in data:
  263. del data["citations"]
  264. # Set the task model
  265. task_model_id = data["model"]
  266. if task_model_id not in app.state.MODELS:
  267. raise HTTPException(
  268. status_code=status.HTTP_404_NOT_FOUND,
  269. detail="Model not found",
  270. )
  271. # Check if the user has a custom task model
  272. # If the user has a custom task model, use that model
  273. if app.state.MODELS[task_model_id]["owned_by"] == "ollama":
  274. if (
  275. app.state.config.TASK_MODEL
  276. and app.state.config.TASK_MODEL in app.state.MODELS
  277. ):
  278. task_model_id = app.state.config.TASK_MODEL
  279. else:
  280. if (
  281. app.state.config.TASK_MODEL_EXTERNAL
  282. and app.state.config.TASK_MODEL_EXTERNAL in app.state.MODELS
  283. ):
  284. task_model_id = app.state.config.TASK_MODEL_EXTERNAL
  285. prompt = get_last_user_message(data["messages"])
  286. context = ""
  287. # If tool_ids field is present, call the functions
  288. skip_files = False
  289. if "tool_ids" in data:
  290. print(data["tool_ids"])
  291. for tool_id in data["tool_ids"]:
  292. print(tool_id)
  293. try:
  294. response, file_handler = await get_function_call_response(
  295. messages=data["messages"],
  296. files=data.get("files", []),
  297. tool_id=tool_id,
  298. template=app.state.config.TOOLS_FUNCTION_CALLING_PROMPT_TEMPLATE,
  299. task_model_id=task_model_id,
  300. user=user,
  301. )
  302. print(file_handler)
  303. if isinstance(response, str):
  304. context += ("\n" if context != "" else "") + response
  305. if file_handler:
  306. skip_files = True
  307. except Exception as e:
  308. print(f"Error: {e}")
  309. del data["tool_ids"]
  310. print(f"tool_context: {context}")
  311. # If files field is present, generate RAG completions
  312. # If skip_files is True, skip the RAG completions
  313. if "files" in data:
  314. if not skip_files:
  315. data = {**data}
  316. rag_context, citations = get_rag_context(
  317. files=data["files"],
  318. messages=data["messages"],
  319. embedding_function=rag_app.state.EMBEDDING_FUNCTION,
  320. k=rag_app.state.config.TOP_K,
  321. reranking_function=rag_app.state.sentence_transformer_rf,
  322. r=rag_app.state.config.RELEVANCE_THRESHOLD,
  323. hybrid_search=rag_app.state.config.ENABLE_RAG_HYBRID_SEARCH,
  324. )
  325. if rag_context:
  326. context += ("\n" if context != "" else "") + rag_context
  327. log.debug(f"rag_context: {rag_context}, citations: {citations}")
  328. else:
  329. return_citations = False
  330. del data["files"]
  331. if context != "":
  332. system_prompt = rag_template(
  333. rag_app.state.config.RAG_TEMPLATE, context, prompt
  334. )
  335. print(system_prompt)
  336. data["messages"] = add_or_update_system_message(
  337. f"\n{system_prompt}", data["messages"]
  338. )
  339. modified_body_bytes = json.dumps(data).encode("utf-8")
  340. # Replace the request body with the modified one
  341. request._body = modified_body_bytes
  342. # Set custom header to ensure content-length matches new body length
  343. request.headers.__dict__["_list"] = [
  344. (b"content-length", str(len(modified_body_bytes)).encode("utf-8")),
  345. *[
  346. (k, v)
  347. for k, v in request.headers.raw
  348. if k.lower() != b"content-length"
  349. ],
  350. ]
  351. response = await call_next(request)
  352. if return_citations:
  353. # Inject the citations into the response
  354. if isinstance(response, StreamingResponse):
  355. # If it's a streaming response, inject it as SSE event or NDJSON line
  356. content_type = response.headers.get("Content-Type")
  357. if "text/event-stream" in content_type:
  358. return StreamingResponse(
  359. self.openai_stream_wrapper(response.body_iterator, citations),
  360. )
  361. if "application/x-ndjson" in content_type:
  362. return StreamingResponse(
  363. self.ollama_stream_wrapper(response.body_iterator, citations),
  364. )
  365. return response
  366. async def _receive(self, body: bytes):
  367. return {"type": "http.request", "body": body, "more_body": False}
  368. async def openai_stream_wrapper(self, original_generator, citations):
  369. yield f"data: {json.dumps({'citations': citations})}\n\n"
  370. async for data in original_generator:
  371. yield data
  372. async def ollama_stream_wrapper(self, original_generator, citations):
  373. yield f"{json.dumps({'citations': citations})}\n"
  374. async for data in original_generator:
  375. yield data
  376. app.add_middleware(ChatCompletionMiddleware)
  377. def filter_pipeline(payload, user):
  378. user = {"id": user.id, "name": user.name, "role": user.role}
  379. model_id = payload["model"]
  380. filters = [
  381. model
  382. for model in app.state.MODELS.values()
  383. if "pipeline" in model
  384. and "type" in model["pipeline"]
  385. and model["pipeline"]["type"] == "filter"
  386. and (
  387. model["pipeline"]["pipelines"] == ["*"]
  388. or any(
  389. model_id == target_model_id
  390. for target_model_id in model["pipeline"]["pipelines"]
  391. )
  392. )
  393. ]
  394. sorted_filters = sorted(filters, key=lambda x: x["pipeline"]["priority"])
  395. model = app.state.MODELS[model_id]
  396. if "pipeline" in model:
  397. sorted_filters.append(model)
  398. for filter in sorted_filters:
  399. r = None
  400. try:
  401. urlIdx = filter["urlIdx"]
  402. url = openai_app.state.config.OPENAI_API_BASE_URLS[urlIdx]
  403. key = openai_app.state.config.OPENAI_API_KEYS[urlIdx]
  404. if key != "":
  405. headers = {"Authorization": f"Bearer {key}"}
  406. r = requests.post(
  407. f"{url}/{filter['id']}/filter/inlet",
  408. headers=headers,
  409. json={
  410. "user": user,
  411. "body": payload,
  412. },
  413. )
  414. r.raise_for_status()
  415. payload = r.json()
  416. except Exception as e:
  417. # Handle connection error here
  418. print(f"Connection error: {e}")
  419. if r is not None:
  420. try:
  421. res = r.json()
  422. except:
  423. pass
  424. if "detail" in res:
  425. raise Exception(r.status_code, res["detail"])
  426. else:
  427. pass
  428. if "pipeline" not in app.state.MODELS[model_id]:
  429. if "chat_id" in payload:
  430. del payload["chat_id"]
  431. if "title" in payload:
  432. del payload["title"]
  433. if "task" in payload:
  434. del payload["task"]
  435. return payload
  436. class PipelineMiddleware(BaseHTTPMiddleware):
  437. async def dispatch(self, request: Request, call_next):
  438. if request.method == "POST" and (
  439. "/ollama/api/chat" in request.url.path
  440. or "/chat/completions" in request.url.path
  441. ):
  442. log.debug(f"request.url.path: {request.url.path}")
  443. # Read the original request body
  444. body = await request.body()
  445. # Decode body to string
  446. body_str = body.decode("utf-8")
  447. # Parse string to JSON
  448. data = json.loads(body_str) if body_str else {}
  449. user = get_current_user(
  450. get_http_authorization_cred(request.headers.get("Authorization"))
  451. )
  452. try:
  453. data = filter_pipeline(data, user)
  454. except Exception as e:
  455. return JSONResponse(
  456. status_code=e.args[0],
  457. content={"detail": e.args[1]},
  458. )
  459. modified_body_bytes = json.dumps(data).encode("utf-8")
  460. # Replace the request body with the modified one
  461. request._body = modified_body_bytes
  462. # Set custom header to ensure content-length matches new body length
  463. request.headers.__dict__["_list"] = [
  464. (b"content-length", str(len(modified_body_bytes)).encode("utf-8")),
  465. *[
  466. (k, v)
  467. for k, v in request.headers.raw
  468. if k.lower() != b"content-length"
  469. ],
  470. ]
  471. response = await call_next(request)
  472. return response
  473. async def _receive(self, body: bytes):
  474. return {"type": "http.request", "body": body, "more_body": False}
  475. app.add_middleware(PipelineMiddleware)
  476. app.add_middleware(
  477. CORSMiddleware,
  478. allow_origins=origins,
  479. allow_credentials=True,
  480. allow_methods=["*"],
  481. allow_headers=["*"],
  482. )
  483. @app.middleware("http")
  484. async def check_url(request: Request, call_next):
  485. if len(app.state.MODELS) == 0:
  486. await get_all_models()
  487. else:
  488. pass
  489. start_time = int(time.time())
  490. response = await call_next(request)
  491. process_time = int(time.time()) - start_time
  492. response.headers["X-Process-Time"] = str(process_time)
  493. return response
  494. @app.middleware("http")
  495. async def update_embedding_function(request: Request, call_next):
  496. response = await call_next(request)
  497. if "/embedding/update" in request.url.path:
  498. webui_app.state.EMBEDDING_FUNCTION = rag_app.state.EMBEDDING_FUNCTION
  499. return response
  500. app.mount("/ws", socket_app)
  501. app.mount("/ollama", ollama_app)
  502. app.mount("/openai", openai_app)
  503. app.mount("/images/api/v1", images_app)
  504. app.mount("/audio/api/v1", audio_app)
  505. app.mount("/rag/api/v1", rag_app)
  506. app.mount("/api/v1", webui_app)
  507. webui_app.state.EMBEDDING_FUNCTION = rag_app.state.EMBEDDING_FUNCTION
  508. async def get_all_models():
  509. openai_models = []
  510. ollama_models = []
  511. if app.state.config.ENABLE_OPENAI_API:
  512. openai_models = await get_openai_models()
  513. openai_models = openai_models["data"]
  514. if app.state.config.ENABLE_OLLAMA_API:
  515. ollama_models = await get_ollama_models()
  516. ollama_models = [
  517. {
  518. "id": model["model"],
  519. "name": model["name"],
  520. "object": "model",
  521. "created": int(time.time()),
  522. "owned_by": "ollama",
  523. "ollama": model,
  524. }
  525. for model in ollama_models["models"]
  526. ]
  527. models = openai_models + ollama_models
  528. custom_models = Models.get_all_models()
  529. for custom_model in custom_models:
  530. if custom_model.base_model_id == None:
  531. for model in models:
  532. if (
  533. custom_model.id == model["id"]
  534. or custom_model.id == model["id"].split(":")[0]
  535. ):
  536. model["name"] = custom_model.name
  537. model["info"] = custom_model.model_dump()
  538. else:
  539. owned_by = "openai"
  540. for model in models:
  541. if (
  542. custom_model.base_model_id == model["id"]
  543. or custom_model.base_model_id == model["id"].split(":")[0]
  544. ):
  545. owned_by = model["owned_by"]
  546. break
  547. models.append(
  548. {
  549. "id": custom_model.id,
  550. "name": custom_model.name,
  551. "object": "model",
  552. "created": custom_model.created_at,
  553. "owned_by": owned_by,
  554. "info": custom_model.model_dump(),
  555. "preset": True,
  556. }
  557. )
  558. app.state.MODELS = {model["id"]: model for model in models}
  559. webui_app.state.MODELS = app.state.MODELS
  560. return models
  561. @app.get("/api/models")
  562. async def get_models(user=Depends(get_verified_user)):
  563. models = await get_all_models()
  564. # Filter out filter pipelines
  565. models = [
  566. model
  567. for model in models
  568. if "pipeline" not in model or model["pipeline"].get("type", None) != "filter"
  569. ]
  570. if app.state.config.ENABLE_MODEL_FILTER:
  571. if user.role == "user":
  572. models = list(
  573. filter(
  574. lambda model: model["id"] in app.state.config.MODEL_FILTER_LIST,
  575. models,
  576. )
  577. )
  578. return {"data": models}
  579. return {"data": models}
  580. @app.get("/api/task/config")
  581. async def get_task_config(user=Depends(get_verified_user)):
  582. return {
  583. "TASK_MODEL": app.state.config.TASK_MODEL,
  584. "TASK_MODEL_EXTERNAL": app.state.config.TASK_MODEL_EXTERNAL,
  585. "TITLE_GENERATION_PROMPT_TEMPLATE": app.state.config.TITLE_GENERATION_PROMPT_TEMPLATE,
  586. "SEARCH_QUERY_GENERATION_PROMPT_TEMPLATE": app.state.config.SEARCH_QUERY_GENERATION_PROMPT_TEMPLATE,
  587. "SEARCH_QUERY_PROMPT_LENGTH_THRESHOLD": app.state.config.SEARCH_QUERY_PROMPT_LENGTH_THRESHOLD,
  588. "TOOLS_FUNCTION_CALLING_PROMPT_TEMPLATE": app.state.config.TOOLS_FUNCTION_CALLING_PROMPT_TEMPLATE,
  589. }
  590. class TaskConfigForm(BaseModel):
  591. TASK_MODEL: Optional[str]
  592. TASK_MODEL_EXTERNAL: Optional[str]
  593. TITLE_GENERATION_PROMPT_TEMPLATE: str
  594. SEARCH_QUERY_GENERATION_PROMPT_TEMPLATE: str
  595. SEARCH_QUERY_PROMPT_LENGTH_THRESHOLD: int
  596. TOOLS_FUNCTION_CALLING_PROMPT_TEMPLATE: str
  597. @app.post("/api/task/config/update")
  598. async def update_task_config(form_data: TaskConfigForm, user=Depends(get_admin_user)):
  599. app.state.config.TASK_MODEL = form_data.TASK_MODEL
  600. app.state.config.TASK_MODEL_EXTERNAL = form_data.TASK_MODEL_EXTERNAL
  601. app.state.config.TITLE_GENERATION_PROMPT_TEMPLATE = (
  602. form_data.TITLE_GENERATION_PROMPT_TEMPLATE
  603. )
  604. app.state.config.SEARCH_QUERY_GENERATION_PROMPT_TEMPLATE = (
  605. form_data.SEARCH_QUERY_GENERATION_PROMPT_TEMPLATE
  606. )
  607. app.state.config.SEARCH_QUERY_PROMPT_LENGTH_THRESHOLD = (
  608. form_data.SEARCH_QUERY_PROMPT_LENGTH_THRESHOLD
  609. )
  610. app.state.config.TOOLS_FUNCTION_CALLING_PROMPT_TEMPLATE = (
  611. form_data.TOOLS_FUNCTION_CALLING_PROMPT_TEMPLATE
  612. )
  613. return {
  614. "TASK_MODEL": app.state.config.TASK_MODEL,
  615. "TASK_MODEL_EXTERNAL": app.state.config.TASK_MODEL_EXTERNAL,
  616. "TITLE_GENERATION_PROMPT_TEMPLATE": app.state.config.TITLE_GENERATION_PROMPT_TEMPLATE,
  617. "SEARCH_QUERY_GENERATION_PROMPT_TEMPLATE": app.state.config.SEARCH_QUERY_GENERATION_PROMPT_TEMPLATE,
  618. "SEARCH_QUERY_PROMPT_LENGTH_THRESHOLD": app.state.config.SEARCH_QUERY_PROMPT_LENGTH_THRESHOLD,
  619. "TOOLS_FUNCTION_CALLING_PROMPT_TEMPLATE": app.state.config.TOOLS_FUNCTION_CALLING_PROMPT_TEMPLATE,
  620. }
  621. @app.post("/api/task/title/completions")
  622. async def generate_title(form_data: dict, user=Depends(get_verified_user)):
  623. print("generate_title")
  624. model_id = form_data["model"]
  625. if model_id not in app.state.MODELS:
  626. raise HTTPException(
  627. status_code=status.HTTP_404_NOT_FOUND,
  628. detail="Model not found",
  629. )
  630. # Check if the user has a custom task model
  631. # If the user has a custom task model, use that model
  632. if app.state.MODELS[model_id]["owned_by"] == "ollama":
  633. if app.state.config.TASK_MODEL:
  634. task_model_id = app.state.config.TASK_MODEL
  635. if task_model_id in app.state.MODELS:
  636. model_id = task_model_id
  637. else:
  638. if app.state.config.TASK_MODEL_EXTERNAL:
  639. task_model_id = app.state.config.TASK_MODEL_EXTERNAL
  640. if task_model_id in app.state.MODELS:
  641. model_id = task_model_id
  642. print(model_id)
  643. model = app.state.MODELS[model_id]
  644. template = app.state.config.TITLE_GENERATION_PROMPT_TEMPLATE
  645. content = title_generation_template(
  646. template,
  647. form_data["prompt"],
  648. {
  649. "name": user.name,
  650. "location": user.info.get("location") if user.info else None,
  651. },
  652. )
  653. payload = {
  654. "model": model_id,
  655. "messages": [{"role": "user", "content": content}],
  656. "stream": False,
  657. "max_tokens": 50,
  658. "chat_id": form_data.get("chat_id", None),
  659. "title": True,
  660. }
  661. log.debug(payload)
  662. try:
  663. payload = filter_pipeline(payload, user)
  664. except Exception as e:
  665. return JSONResponse(
  666. status_code=e.args[0],
  667. content={"detail": e.args[1]},
  668. )
  669. if model["owned_by"] == "ollama":
  670. return await generate_ollama_chat_completion(payload, user=user)
  671. else:
  672. return await generate_openai_chat_completion(payload, user=user)
  673. @app.post("/api/task/query/completions")
  674. async def generate_search_query(form_data: dict, user=Depends(get_verified_user)):
  675. print("generate_search_query")
  676. if len(form_data["prompt"]) < app.state.config.SEARCH_QUERY_PROMPT_LENGTH_THRESHOLD:
  677. raise HTTPException(
  678. status_code=status.HTTP_400_BAD_REQUEST,
  679. detail=f"Skip search query generation for short prompts (< {app.state.config.SEARCH_QUERY_PROMPT_LENGTH_THRESHOLD} characters)",
  680. )
  681. model_id = form_data["model"]
  682. if model_id not in app.state.MODELS:
  683. raise HTTPException(
  684. status_code=status.HTTP_404_NOT_FOUND,
  685. detail="Model not found",
  686. )
  687. # Check if the user has a custom task model
  688. # If the user has a custom task model, use that model
  689. if app.state.MODELS[model_id]["owned_by"] == "ollama":
  690. if app.state.config.TASK_MODEL:
  691. task_model_id = app.state.config.TASK_MODEL
  692. if task_model_id in app.state.MODELS:
  693. model_id = task_model_id
  694. else:
  695. if app.state.config.TASK_MODEL_EXTERNAL:
  696. task_model_id = app.state.config.TASK_MODEL_EXTERNAL
  697. if task_model_id in app.state.MODELS:
  698. model_id = task_model_id
  699. print(model_id)
  700. model = app.state.MODELS[model_id]
  701. template = app.state.config.SEARCH_QUERY_GENERATION_PROMPT_TEMPLATE
  702. content = search_query_generation_template(
  703. template, form_data["prompt"], {"name": user.name}
  704. )
  705. payload = {
  706. "model": model_id,
  707. "messages": [{"role": "user", "content": content}],
  708. "stream": False,
  709. "max_tokens": 30,
  710. "task": True,
  711. }
  712. print(payload)
  713. try:
  714. payload = filter_pipeline(payload, user)
  715. except Exception as e:
  716. return JSONResponse(
  717. status_code=e.args[0],
  718. content={"detail": e.args[1]},
  719. )
  720. if model["owned_by"] == "ollama":
  721. return await generate_ollama_chat_completion(payload, user=user)
  722. else:
  723. return await generate_openai_chat_completion(payload, user=user)
  724. @app.post("/api/task/emoji/completions")
  725. async def generate_emoji(form_data: dict, user=Depends(get_verified_user)):
  726. print("generate_emoji")
  727. model_id = form_data["model"]
  728. if model_id not in app.state.MODELS:
  729. raise HTTPException(
  730. status_code=status.HTTP_404_NOT_FOUND,
  731. detail="Model not found",
  732. )
  733. # Check if the user has a custom task model
  734. # If the user has a custom task model, use that model
  735. if app.state.MODELS[model_id]["owned_by"] == "ollama":
  736. if app.state.config.TASK_MODEL:
  737. task_model_id = app.state.config.TASK_MODEL
  738. if task_model_id in app.state.MODELS:
  739. model_id = task_model_id
  740. else:
  741. if app.state.config.TASK_MODEL_EXTERNAL:
  742. task_model_id = app.state.config.TASK_MODEL_EXTERNAL
  743. if task_model_id in app.state.MODELS:
  744. model_id = task_model_id
  745. print(model_id)
  746. model = app.state.MODELS[model_id]
  747. template = '''
  748. Your task is to reflect the speaker's likely facial expression through a fitting emoji. Interpret emotions from the message and reflect their facial expression using fitting, diverse emojis (e.g., 😊, 😢, 😡, 😱).
  749. Message: """{{prompt}}"""
  750. '''
  751. content = title_generation_template(
  752. template,
  753. form_data["prompt"],
  754. {
  755. "name": user.name,
  756. "location": user.info.get("location") if user.info else None,
  757. },
  758. )
  759. payload = {
  760. "model": model_id,
  761. "messages": [{"role": "user", "content": content}],
  762. "stream": False,
  763. "max_tokens": 4,
  764. "chat_id": form_data.get("chat_id", None),
  765. "task": True,
  766. }
  767. log.debug(payload)
  768. try:
  769. payload = filter_pipeline(payload, user)
  770. except Exception as e:
  771. return JSONResponse(
  772. status_code=e.args[0],
  773. content={"detail": e.args[1]},
  774. )
  775. if model["owned_by"] == "ollama":
  776. return await generate_ollama_chat_completion(payload, user=user)
  777. else:
  778. return await generate_openai_chat_completion(payload, user=user)
  779. @app.post("/api/task/tools/completions")
  780. async def get_tools_function_calling(form_data: dict, user=Depends(get_verified_user)):
  781. print("get_tools_function_calling")
  782. model_id = form_data["model"]
  783. if model_id not in app.state.MODELS:
  784. raise HTTPException(
  785. status_code=status.HTTP_404_NOT_FOUND,
  786. detail="Model not found",
  787. )
  788. # Check if the user has a custom task model
  789. # If the user has a custom task model, use that model
  790. if app.state.MODELS[model_id]["owned_by"] == "ollama":
  791. if app.state.config.TASK_MODEL:
  792. task_model_id = app.state.config.TASK_MODEL
  793. if task_model_id in app.state.MODELS:
  794. model_id = task_model_id
  795. else:
  796. if app.state.config.TASK_MODEL_EXTERNAL:
  797. task_model_id = app.state.config.TASK_MODEL_EXTERNAL
  798. if task_model_id in app.state.MODELS:
  799. model_id = task_model_id
  800. print(model_id)
  801. template = app.state.config.TOOLS_FUNCTION_CALLING_PROMPT_TEMPLATE
  802. try:
  803. context, file_handler = await get_function_call_response(
  804. form_data["messages"],
  805. form_data.get("files", []),
  806. form_data["tool_id"],
  807. template,
  808. model_id,
  809. user,
  810. )
  811. return context
  812. except Exception as e:
  813. return JSONResponse(
  814. status_code=e.args[0],
  815. content={"detail": e.args[1]},
  816. )
  817. @app.post("/api/chat/completions")
  818. async def generate_chat_completions(form_data: dict, user=Depends(get_verified_user)):
  819. model_id = form_data["model"]
  820. if model_id not in app.state.MODELS:
  821. raise HTTPException(
  822. status_code=status.HTTP_404_NOT_FOUND,
  823. detail="Model not found",
  824. )
  825. model = app.state.MODELS[model_id]
  826. print(model)
  827. if model["owned_by"] == "ollama":
  828. return await generate_ollama_chat_completion(form_data, user=user)
  829. else:
  830. return await generate_openai_chat_completion(form_data, user=user)
  831. @app.post("/api/chat/completed")
  832. async def chat_completed(form_data: dict, user=Depends(get_verified_user)):
  833. data = form_data
  834. model_id = data["model"]
  835. filters = [
  836. model
  837. for model in app.state.MODELS.values()
  838. if "pipeline" in model
  839. and "type" in model["pipeline"]
  840. and model["pipeline"]["type"] == "filter"
  841. and (
  842. model["pipeline"]["pipelines"] == ["*"]
  843. or any(
  844. model_id == target_model_id
  845. for target_model_id in model["pipeline"]["pipelines"]
  846. )
  847. )
  848. ]
  849. sorted_filters = sorted(filters, key=lambda x: x["pipeline"]["priority"])
  850. print(model_id)
  851. if model_id in app.state.MODELS:
  852. model = app.state.MODELS[model_id]
  853. if "pipeline" in model:
  854. sorted_filters = [model] + sorted_filters
  855. for filter in sorted_filters:
  856. r = None
  857. try:
  858. urlIdx = filter["urlIdx"]
  859. url = openai_app.state.config.OPENAI_API_BASE_URLS[urlIdx]
  860. key = openai_app.state.config.OPENAI_API_KEYS[urlIdx]
  861. if key != "":
  862. headers = {"Authorization": f"Bearer {key}"}
  863. r = requests.post(
  864. f"{url}/{filter['id']}/filter/outlet",
  865. headers=headers,
  866. json={
  867. "user": {"id": user.id, "name": user.name, "role": user.role},
  868. "body": data,
  869. },
  870. )
  871. r.raise_for_status()
  872. data = r.json()
  873. except Exception as e:
  874. # Handle connection error here
  875. print(f"Connection error: {e}")
  876. if r is not None:
  877. try:
  878. res = r.json()
  879. if "detail" in res:
  880. return JSONResponse(
  881. status_code=r.status_code,
  882. content=res,
  883. )
  884. except:
  885. pass
  886. else:
  887. pass
  888. return data
  889. @app.get("/api/pipelines/list")
  890. async def get_pipelines_list(user=Depends(get_admin_user)):
  891. responses = await get_openai_models(raw=True)
  892. print(responses)
  893. urlIdxs = [
  894. idx
  895. for idx, response in enumerate(responses)
  896. if response != None and "pipelines" in response
  897. ]
  898. return {
  899. "data": [
  900. {
  901. "url": openai_app.state.config.OPENAI_API_BASE_URLS[urlIdx],
  902. "idx": urlIdx,
  903. }
  904. for urlIdx in urlIdxs
  905. ]
  906. }
  907. @app.post("/api/pipelines/upload")
  908. async def upload_pipeline(
  909. urlIdx: int = Form(...), file: UploadFile = File(...), user=Depends(get_admin_user)
  910. ):
  911. print("upload_pipeline", urlIdx, file.filename)
  912. # Check if the uploaded file is a python file
  913. if not file.filename.endswith(".py"):
  914. raise HTTPException(
  915. status_code=status.HTTP_400_BAD_REQUEST,
  916. detail="Only Python (.py) files are allowed.",
  917. )
  918. upload_folder = f"{CACHE_DIR}/pipelines"
  919. os.makedirs(upload_folder, exist_ok=True)
  920. file_path = os.path.join(upload_folder, file.filename)
  921. try:
  922. # Save the uploaded file
  923. with open(file_path, "wb") as buffer:
  924. shutil.copyfileobj(file.file, buffer)
  925. url = openai_app.state.config.OPENAI_API_BASE_URLS[urlIdx]
  926. key = openai_app.state.config.OPENAI_API_KEYS[urlIdx]
  927. headers = {"Authorization": f"Bearer {key}"}
  928. with open(file_path, "rb") as f:
  929. files = {"file": f}
  930. r = requests.post(f"{url}/pipelines/upload", headers=headers, files=files)
  931. r.raise_for_status()
  932. data = r.json()
  933. return {**data}
  934. except Exception as e:
  935. # Handle connection error here
  936. print(f"Connection error: {e}")
  937. detail = "Pipeline not found"
  938. if r is not None:
  939. try:
  940. res = r.json()
  941. if "detail" in res:
  942. detail = res["detail"]
  943. except:
  944. pass
  945. raise HTTPException(
  946. status_code=(r.status_code if r is not None else status.HTTP_404_NOT_FOUND),
  947. detail=detail,
  948. )
  949. finally:
  950. # Ensure the file is deleted after the upload is completed or on failure
  951. if os.path.exists(file_path):
  952. os.remove(file_path)
  953. class AddPipelineForm(BaseModel):
  954. url: str
  955. urlIdx: int
  956. @app.post("/api/pipelines/add")
  957. async def add_pipeline(form_data: AddPipelineForm, user=Depends(get_admin_user)):
  958. r = None
  959. try:
  960. urlIdx = form_data.urlIdx
  961. url = openai_app.state.config.OPENAI_API_BASE_URLS[urlIdx]
  962. key = openai_app.state.config.OPENAI_API_KEYS[urlIdx]
  963. headers = {"Authorization": f"Bearer {key}"}
  964. r = requests.post(
  965. f"{url}/pipelines/add", headers=headers, json={"url": form_data.url}
  966. )
  967. r.raise_for_status()
  968. data = r.json()
  969. return {**data}
  970. except Exception as e:
  971. # Handle connection error here
  972. print(f"Connection error: {e}")
  973. detail = "Pipeline not found"
  974. if r is not None:
  975. try:
  976. res = r.json()
  977. if "detail" in res:
  978. detail = res["detail"]
  979. except:
  980. pass
  981. raise HTTPException(
  982. status_code=(r.status_code if r is not None else status.HTTP_404_NOT_FOUND),
  983. detail=detail,
  984. )
  985. class DeletePipelineForm(BaseModel):
  986. id: str
  987. urlIdx: int
  988. @app.delete("/api/pipelines/delete")
  989. async def delete_pipeline(form_data: DeletePipelineForm, user=Depends(get_admin_user)):
  990. r = None
  991. try:
  992. urlIdx = form_data.urlIdx
  993. url = openai_app.state.config.OPENAI_API_BASE_URLS[urlIdx]
  994. key = openai_app.state.config.OPENAI_API_KEYS[urlIdx]
  995. headers = {"Authorization": f"Bearer {key}"}
  996. r = requests.delete(
  997. f"{url}/pipelines/delete", headers=headers, json={"id": form_data.id}
  998. )
  999. r.raise_for_status()
  1000. data = r.json()
  1001. return {**data}
  1002. except Exception as e:
  1003. # Handle connection error here
  1004. print(f"Connection error: {e}")
  1005. detail = "Pipeline not found"
  1006. if r is not None:
  1007. try:
  1008. res = r.json()
  1009. if "detail" in res:
  1010. detail = res["detail"]
  1011. except:
  1012. pass
  1013. raise HTTPException(
  1014. status_code=(r.status_code if r is not None else status.HTTP_404_NOT_FOUND),
  1015. detail=detail,
  1016. )
  1017. @app.get("/api/pipelines")
  1018. async def get_pipelines(urlIdx: Optional[int] = None, user=Depends(get_admin_user)):
  1019. r = None
  1020. try:
  1021. urlIdx
  1022. url = openai_app.state.config.OPENAI_API_BASE_URLS[urlIdx]
  1023. key = openai_app.state.config.OPENAI_API_KEYS[urlIdx]
  1024. headers = {"Authorization": f"Bearer {key}"}
  1025. r = requests.get(f"{url}/pipelines", headers=headers)
  1026. r.raise_for_status()
  1027. data = r.json()
  1028. return {**data}
  1029. except Exception as e:
  1030. # Handle connection error here
  1031. print(f"Connection error: {e}")
  1032. detail = "Pipeline not found"
  1033. if r is not None:
  1034. try:
  1035. res = r.json()
  1036. if "detail" in res:
  1037. detail = res["detail"]
  1038. except:
  1039. pass
  1040. raise HTTPException(
  1041. status_code=(r.status_code if r is not None else status.HTTP_404_NOT_FOUND),
  1042. detail=detail,
  1043. )
  1044. @app.get("/api/pipelines/{pipeline_id}/valves")
  1045. async def get_pipeline_valves(
  1046. urlIdx: Optional[int], pipeline_id: str, user=Depends(get_admin_user)
  1047. ):
  1048. models = await get_all_models()
  1049. r = None
  1050. try:
  1051. url = openai_app.state.config.OPENAI_API_BASE_URLS[urlIdx]
  1052. key = openai_app.state.config.OPENAI_API_KEYS[urlIdx]
  1053. headers = {"Authorization": f"Bearer {key}"}
  1054. r = requests.get(f"{url}/{pipeline_id}/valves", headers=headers)
  1055. r.raise_for_status()
  1056. data = r.json()
  1057. return {**data}
  1058. except Exception as e:
  1059. # Handle connection error here
  1060. print(f"Connection error: {e}")
  1061. detail = "Pipeline not found"
  1062. if r is not None:
  1063. try:
  1064. res = r.json()
  1065. if "detail" in res:
  1066. detail = res["detail"]
  1067. except:
  1068. pass
  1069. raise HTTPException(
  1070. status_code=(r.status_code if r is not None else status.HTTP_404_NOT_FOUND),
  1071. detail=detail,
  1072. )
  1073. @app.get("/api/pipelines/{pipeline_id}/valves/spec")
  1074. async def get_pipeline_valves_spec(
  1075. urlIdx: Optional[int], pipeline_id: str, user=Depends(get_admin_user)
  1076. ):
  1077. models = await get_all_models()
  1078. r = None
  1079. try:
  1080. url = openai_app.state.config.OPENAI_API_BASE_URLS[urlIdx]
  1081. key = openai_app.state.config.OPENAI_API_KEYS[urlIdx]
  1082. headers = {"Authorization": f"Bearer {key}"}
  1083. r = requests.get(f"{url}/{pipeline_id}/valves/spec", headers=headers)
  1084. r.raise_for_status()
  1085. data = r.json()
  1086. return {**data}
  1087. except Exception as e:
  1088. # Handle connection error here
  1089. print(f"Connection error: {e}")
  1090. detail = "Pipeline not found"
  1091. if r is not None:
  1092. try:
  1093. res = r.json()
  1094. if "detail" in res:
  1095. detail = res["detail"]
  1096. except:
  1097. pass
  1098. raise HTTPException(
  1099. status_code=(r.status_code if r is not None else status.HTTP_404_NOT_FOUND),
  1100. detail=detail,
  1101. )
  1102. @app.post("/api/pipelines/{pipeline_id}/valves/update")
  1103. async def update_pipeline_valves(
  1104. urlIdx: Optional[int],
  1105. pipeline_id: str,
  1106. form_data: dict,
  1107. user=Depends(get_admin_user),
  1108. ):
  1109. models = await get_all_models()
  1110. r = None
  1111. try:
  1112. url = openai_app.state.config.OPENAI_API_BASE_URLS[urlIdx]
  1113. key = openai_app.state.config.OPENAI_API_KEYS[urlIdx]
  1114. headers = {"Authorization": f"Bearer {key}"}
  1115. r = requests.post(
  1116. f"{url}/{pipeline_id}/valves/update",
  1117. headers=headers,
  1118. json={**form_data},
  1119. )
  1120. r.raise_for_status()
  1121. data = r.json()
  1122. return {**data}
  1123. except Exception as e:
  1124. # Handle connection error here
  1125. print(f"Connection error: {e}")
  1126. detail = "Pipeline not found"
  1127. if r is not None:
  1128. try:
  1129. res = r.json()
  1130. if "detail" in res:
  1131. detail = res["detail"]
  1132. except:
  1133. pass
  1134. raise HTTPException(
  1135. status_code=(r.status_code if r is not None else status.HTTP_404_NOT_FOUND),
  1136. detail=detail,
  1137. )
  1138. @app.get("/api/config")
  1139. async def get_app_config():
  1140. # Checking and Handling the Absence of 'ui' in CONFIG_DATA
  1141. default_locale = "en-US"
  1142. if "ui" in CONFIG_DATA:
  1143. default_locale = CONFIG_DATA["ui"].get("default_locale", "en-US")
  1144. # The Rest of the Function Now Uses the Variables Defined Above
  1145. return {
  1146. "status": True,
  1147. "name": WEBUI_NAME,
  1148. "version": VERSION,
  1149. "default_locale": default_locale,
  1150. "default_models": webui_app.state.config.DEFAULT_MODELS,
  1151. "default_prompt_suggestions": webui_app.state.config.DEFAULT_PROMPT_SUGGESTIONS,
  1152. "features": {
  1153. "auth": WEBUI_AUTH,
  1154. "auth_trusted_header": bool(webui_app.state.AUTH_TRUSTED_EMAIL_HEADER),
  1155. "enable_signup": webui_app.state.config.ENABLE_SIGNUP,
  1156. "enable_web_search": rag_app.state.config.ENABLE_RAG_WEB_SEARCH,
  1157. "enable_image_generation": images_app.state.config.ENABLED,
  1158. "enable_community_sharing": webui_app.state.config.ENABLE_COMMUNITY_SHARING,
  1159. "enable_admin_export": ENABLE_ADMIN_EXPORT,
  1160. },
  1161. "audio": {
  1162. "tts": {
  1163. "engine": audio_app.state.config.TTS_ENGINE,
  1164. "voice": audio_app.state.config.TTS_VOICE,
  1165. },
  1166. "stt": {
  1167. "engine": audio_app.state.config.STT_ENGINE,
  1168. },
  1169. },
  1170. }
  1171. @app.get("/api/config/model/filter")
  1172. async def get_model_filter_config(user=Depends(get_admin_user)):
  1173. return {
  1174. "enabled": app.state.config.ENABLE_MODEL_FILTER,
  1175. "models": app.state.config.MODEL_FILTER_LIST,
  1176. }
  1177. class ModelFilterConfigForm(BaseModel):
  1178. enabled: bool
  1179. models: List[str]
  1180. @app.post("/api/config/model/filter")
  1181. async def update_model_filter_config(
  1182. form_data: ModelFilterConfigForm, user=Depends(get_admin_user)
  1183. ):
  1184. app.state.config.ENABLE_MODEL_FILTER = form_data.enabled
  1185. app.state.config.MODEL_FILTER_LIST = form_data.models
  1186. return {
  1187. "enabled": app.state.config.ENABLE_MODEL_FILTER,
  1188. "models": app.state.config.MODEL_FILTER_LIST,
  1189. }
  1190. @app.get("/api/webhook")
  1191. async def get_webhook_url(user=Depends(get_admin_user)):
  1192. return {
  1193. "url": app.state.config.WEBHOOK_URL,
  1194. }
  1195. class UrlForm(BaseModel):
  1196. url: str
  1197. @app.post("/api/webhook")
  1198. async def update_webhook_url(form_data: UrlForm, user=Depends(get_admin_user)):
  1199. app.state.config.WEBHOOK_URL = form_data.url
  1200. webui_app.state.WEBHOOK_URL = app.state.config.WEBHOOK_URL
  1201. return {"url": app.state.config.WEBHOOK_URL}
  1202. @app.get("/api/version")
  1203. async def get_app_config():
  1204. return {
  1205. "version": VERSION,
  1206. }
  1207. @app.get("/api/changelog")
  1208. async def get_app_changelog():
  1209. return {key: CHANGELOG[key] for idx, key in enumerate(CHANGELOG) if idx < 5}
  1210. @app.get("/api/version/updates")
  1211. async def get_app_latest_release_version():
  1212. try:
  1213. async with aiohttp.ClientSession(trust_env=True) as session:
  1214. async with session.get(
  1215. "https://api.github.com/repos/open-webui/open-webui/releases/latest"
  1216. ) as response:
  1217. response.raise_for_status()
  1218. data = await response.json()
  1219. latest_version = data["tag_name"]
  1220. return {"current": VERSION, "latest": latest_version[1:]}
  1221. except aiohttp.ClientError as e:
  1222. raise HTTPException(
  1223. status_code=status.HTTP_503_SERVICE_UNAVAILABLE,
  1224. detail=ERROR_MESSAGES.RATE_LIMIT_EXCEEDED,
  1225. )
  1226. @app.get("/manifest.json")
  1227. async def get_manifest_json():
  1228. return {
  1229. "name": WEBUI_NAME,
  1230. "short_name": WEBUI_NAME,
  1231. "start_url": "/",
  1232. "display": "standalone",
  1233. "background_color": "#343541",
  1234. "theme_color": "#343541",
  1235. "orientation": "portrait-primary",
  1236. "icons": [{"src": "/static/logo.png", "type": "image/png", "sizes": "500x500"}],
  1237. }
  1238. @app.get("/opensearch.xml")
  1239. async def get_opensearch_xml():
  1240. xml_content = rf"""
  1241. <OpenSearchDescription xmlns="http://a9.com/-/spec/opensearch/1.1/" xmlns:moz="http://www.mozilla.org/2006/browser/search/">
  1242. <ShortName>{WEBUI_NAME}</ShortName>
  1243. <Description>Search {WEBUI_NAME}</Description>
  1244. <InputEncoding>UTF-8</InputEncoding>
  1245. <Image width="16" height="16" type="image/x-icon">{WEBUI_URL}/favicon.png</Image>
  1246. <Url type="text/html" method="get" template="{WEBUI_URL}/?q={"{searchTerms}"}"/>
  1247. <moz:SearchForm>{WEBUI_URL}</moz:SearchForm>
  1248. </OpenSearchDescription>
  1249. """
  1250. return Response(content=xml_content, media_type="application/xml")
  1251. @app.get("/health")
  1252. async def healthcheck():
  1253. return {"status": True}
  1254. app.mount("/static", StaticFiles(directory=STATIC_DIR), name="static")
  1255. app.mount("/cache", StaticFiles(directory=CACHE_DIR), name="cache")
  1256. if os.path.exists(FRONTEND_BUILD_DIR):
  1257. mimetypes.add_type("text/javascript", ".js")
  1258. app.mount(
  1259. "/",
  1260. SPAStaticFiles(directory=FRONTEND_BUILD_DIR, html=True),
  1261. name="spa-static-files",
  1262. )
  1263. else:
  1264. log.warning(
  1265. f"Frontend build directory not found at '{FRONTEND_BUILD_DIR}'. Serving API only."
  1266. )