pipelines.py 14 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500
  1. from fastapi import (
  2. Depends,
  3. FastAPI,
  4. File,
  5. Form,
  6. HTTPException,
  7. Request,
  8. UploadFile,
  9. status,
  10. APIRouter,
  11. )
  12. import aiohttp
  13. import os
  14. import logging
  15. import shutil
  16. import requests
  17. from pydantic import BaseModel
  18. from starlette.responses import FileResponse
  19. from typing import Optional
  20. from open_webui.env import SRC_LOG_LEVELS
  21. from open_webui.config import CACHE_DIR
  22. from open_webui.constants import ERROR_MESSAGES
  23. from open_webui.routers.openai import get_all_models_responses
  24. from open_webui.utils.auth import get_admin_user
  25. log = logging.getLogger(__name__)
  26. log.setLevel(SRC_LOG_LEVELS["MAIN"])
  27. ##################################
  28. #
  29. # Pipeline Middleware
  30. #
  31. ##################################
  32. def get_sorted_filters(model_id, models):
  33. filters = [
  34. model
  35. for model in models.values()
  36. if "pipeline" in model
  37. and "type" in model["pipeline"]
  38. and model["pipeline"]["type"] == "filter"
  39. and (
  40. model["pipeline"]["pipelines"] == ["*"]
  41. or any(
  42. model_id == target_model_id
  43. for target_model_id in model["pipeline"]["pipelines"]
  44. )
  45. )
  46. ]
  47. sorted_filters = sorted(filters, key=lambda x: x["pipeline"]["priority"])
  48. return sorted_filters
  49. async def process_pipeline_inlet_filter(request, payload, user, models):
  50. user = {"id": user.id, "email": user.email, "name": user.name, "role": user.role}
  51. model_id = payload["model"]
  52. sorted_filters = get_sorted_filters(model_id, models)
  53. model = models[model_id]
  54. if "pipeline" in model:
  55. sorted_filters.append(model)
  56. async with aiohttp.ClientSession() as session:
  57. for filter in sorted_filters:
  58. urlIdx = filter.get("urlIdx")
  59. if urlIdx is None:
  60. continue
  61. url = request.app.state.config.OPENAI_API_BASE_URLS[urlIdx]
  62. key = request.app.state.config.OPENAI_API_KEYS[urlIdx]
  63. if not key:
  64. continue
  65. headers = {"Authorization": f"Bearer {key}"}
  66. request_data = {
  67. "user": user,
  68. "body": payload,
  69. }
  70. try:
  71. async with session.post(
  72. f"{url}/{filter['id']}/filter/inlet",
  73. headers=headers,
  74. json=request_data,
  75. ) as response:
  76. response.raise_for_status()
  77. payload = await response.json()
  78. except aiohttp.ClientResponseError as e:
  79. res = (
  80. await response.json()
  81. if response.content_type == "application/json"
  82. else {}
  83. )
  84. if "detail" in res:
  85. raise Exception(response.status, res["detail"])
  86. except Exception as e:
  87. log.exception(f"Connection error: {e}")
  88. return payload
  89. async def process_pipeline_outlet_filter(request, payload, user, models):
  90. user = {"id": user.id, "email": user.email, "name": user.name, "role": user.role}
  91. model_id = payload["model"]
  92. sorted_filters = get_sorted_filters(model_id, models)
  93. model = models[model_id]
  94. if "pipeline" in model:
  95. sorted_filters = [model] + sorted_filters
  96. async with aiohttp.ClientSession() as session:
  97. for filter in sorted_filters:
  98. urlIdx = filter.get("urlIdx")
  99. if urlIdx is None:
  100. continue
  101. url = request.app.state.config.OPENAI_API_BASE_URLS[urlIdx]
  102. key = request.app.state.config.OPENAI_API_KEYS[urlIdx]
  103. if not key:
  104. continue
  105. headers = {"Authorization": f"Bearer {key}"}
  106. request_data = {
  107. "user": user,
  108. "body": payload,
  109. }
  110. try:
  111. async with session.post(
  112. f"{url}/{filter['id']}/filter/outlet",
  113. headers=headers,
  114. json=request_data,
  115. ) as response:
  116. response.raise_for_status()
  117. payload = await response.json()
  118. except aiohttp.ClientResponseError as e:
  119. try:
  120. res = (
  121. await response.json()
  122. if "application/json" in response.content_type
  123. else {}
  124. )
  125. if "detail" in res:
  126. raise Exception(response.status, res)
  127. except Exception:
  128. pass
  129. except Exception as e:
  130. log.exception(f"Connection error: {e}")
  131. return payload
  132. ##################################
  133. #
  134. # Pipelines Endpoints
  135. #
  136. ##################################
  137. router = APIRouter()
  138. @router.get("/list")
  139. async def get_pipelines_list(request: Request, user=Depends(get_admin_user)):
  140. responses = await get_all_models_responses(request, user)
  141. log.debug(f"get_pipelines_list: get_openai_models_responses returned {responses}")
  142. urlIdxs = [
  143. idx
  144. for idx, response in enumerate(responses)
  145. if response is not None and "pipelines" in response
  146. ]
  147. return {
  148. "data": [
  149. {
  150. "url": request.app.state.config.OPENAI_API_BASE_URLS[urlIdx],
  151. "idx": urlIdx,
  152. }
  153. for urlIdx in urlIdxs
  154. ]
  155. }
  156. @router.post("/upload")
  157. async def upload_pipeline(
  158. request: Request,
  159. urlIdx: int = Form(...),
  160. file: UploadFile = File(...),
  161. user=Depends(get_admin_user),
  162. ):
  163. log.info(f"upload_pipeline: urlIdx={urlIdx}, filename={file.filename}")
  164. # Check if the uploaded file is a python file
  165. if not (file.filename and file.filename.endswith(".py")):
  166. raise HTTPException(
  167. status_code=status.HTTP_400_BAD_REQUEST,
  168. detail="Only Python (.py) files are allowed.",
  169. )
  170. upload_folder = f"{CACHE_DIR}/pipelines"
  171. os.makedirs(upload_folder, exist_ok=True)
  172. file_path = os.path.join(upload_folder, file.filename)
  173. r = None
  174. try:
  175. # Save the uploaded file
  176. with open(file_path, "wb") as buffer:
  177. shutil.copyfileobj(file.file, buffer)
  178. url = request.app.state.config.OPENAI_API_BASE_URLS[urlIdx]
  179. key = request.app.state.config.OPENAI_API_KEYS[urlIdx]
  180. with open(file_path, "rb") as f:
  181. files = {"file": f}
  182. r = requests.post(
  183. f"{url}/pipelines/upload",
  184. headers={"Authorization": f"Bearer {key}"},
  185. files=files,
  186. )
  187. r.raise_for_status()
  188. data = r.json()
  189. return {**data}
  190. except Exception as e:
  191. # Handle connection error here
  192. log.exception(f"Connection error: {e}")
  193. detail = None
  194. status_code = status.HTTP_404_NOT_FOUND
  195. if r is not None:
  196. status_code = r.status_code
  197. try:
  198. res = r.json()
  199. if "detail" in res:
  200. detail = res["detail"]
  201. except Exception:
  202. pass
  203. raise HTTPException(
  204. status_code=status_code,
  205. detail=detail if detail else "Pipeline not found",
  206. )
  207. finally:
  208. # Ensure the file is deleted after the upload is completed or on failure
  209. if os.path.exists(file_path):
  210. os.remove(file_path)
  211. class AddPipelineForm(BaseModel):
  212. url: str
  213. urlIdx: int
  214. @router.post("/add")
  215. async def add_pipeline(
  216. request: Request, form_data: AddPipelineForm, user=Depends(get_admin_user)
  217. ):
  218. r = None
  219. try:
  220. urlIdx = form_data.urlIdx
  221. url = request.app.state.config.OPENAI_API_BASE_URLS[urlIdx]
  222. key = request.app.state.config.OPENAI_API_KEYS[urlIdx]
  223. r = requests.post(
  224. f"{url}/pipelines/add",
  225. headers={"Authorization": f"Bearer {key}"},
  226. json={"url": form_data.url},
  227. )
  228. r.raise_for_status()
  229. data = r.json()
  230. return {**data}
  231. except Exception as e:
  232. # Handle connection error here
  233. log.exception(f"Connection error: {e}")
  234. detail = None
  235. if r is not None:
  236. try:
  237. res = r.json()
  238. if "detail" in res:
  239. detail = res["detail"]
  240. except Exception:
  241. pass
  242. raise HTTPException(
  243. status_code=(r.status_code if r is not None else status.HTTP_404_NOT_FOUND),
  244. detail=detail if detail else "Pipeline not found",
  245. )
  246. class DeletePipelineForm(BaseModel):
  247. id: str
  248. urlIdx: int
  249. @router.delete("/delete")
  250. async def delete_pipeline(
  251. request: Request, form_data: DeletePipelineForm, user=Depends(get_admin_user)
  252. ):
  253. r = None
  254. try:
  255. urlIdx = form_data.urlIdx
  256. url = request.app.state.config.OPENAI_API_BASE_URLS[urlIdx]
  257. key = request.app.state.config.OPENAI_API_KEYS[urlIdx]
  258. r = requests.delete(
  259. f"{url}/pipelines/delete",
  260. headers={"Authorization": f"Bearer {key}"},
  261. json={"id": form_data.id},
  262. )
  263. r.raise_for_status()
  264. data = r.json()
  265. return {**data}
  266. except Exception as e:
  267. # Handle connection error here
  268. log.exception(f"Connection error: {e}")
  269. detail = None
  270. if r is not None:
  271. try:
  272. res = r.json()
  273. if "detail" in res:
  274. detail = res["detail"]
  275. except Exception:
  276. pass
  277. raise HTTPException(
  278. status_code=(r.status_code if r is not None else status.HTTP_404_NOT_FOUND),
  279. detail=detail if detail else "Pipeline not found",
  280. )
  281. @router.get("/")
  282. async def get_pipelines(
  283. request: Request, urlIdx: Optional[int] = None, user=Depends(get_admin_user)
  284. ):
  285. r = None
  286. try:
  287. url = request.app.state.config.OPENAI_API_BASE_URLS[urlIdx]
  288. key = request.app.state.config.OPENAI_API_KEYS[urlIdx]
  289. r = requests.get(f"{url}/pipelines", headers={"Authorization": f"Bearer {key}"})
  290. r.raise_for_status()
  291. data = r.json()
  292. return {**data}
  293. except Exception as e:
  294. # Handle connection error here
  295. log.exception(f"Connection error: {e}")
  296. detail = None
  297. if r is not None:
  298. try:
  299. res = r.json()
  300. if "detail" in res:
  301. detail = res["detail"]
  302. except Exception:
  303. pass
  304. raise HTTPException(
  305. status_code=(r.status_code if r is not None else status.HTTP_404_NOT_FOUND),
  306. detail=detail if detail else "Pipeline not found",
  307. )
  308. @router.get("/{pipeline_id}/valves")
  309. async def get_pipeline_valves(
  310. request: Request,
  311. urlIdx: Optional[int],
  312. pipeline_id: str,
  313. user=Depends(get_admin_user),
  314. ):
  315. r = None
  316. try:
  317. url = request.app.state.config.OPENAI_API_BASE_URLS[urlIdx]
  318. key = request.app.state.config.OPENAI_API_KEYS[urlIdx]
  319. r = requests.get(
  320. f"{url}/{pipeline_id}/valves", headers={"Authorization": f"Bearer {key}"}
  321. )
  322. r.raise_for_status()
  323. data = r.json()
  324. return {**data}
  325. except Exception as e:
  326. # Handle connection error here
  327. log.exception(f"Connection error: {e}")
  328. detail = None
  329. if r is not None:
  330. try:
  331. res = r.json()
  332. if "detail" in res:
  333. detail = res["detail"]
  334. except Exception:
  335. pass
  336. raise HTTPException(
  337. status_code=(r.status_code if r is not None else status.HTTP_404_NOT_FOUND),
  338. detail=detail if detail else "Pipeline not found",
  339. )
  340. @router.get("/{pipeline_id}/valves/spec")
  341. async def get_pipeline_valves_spec(
  342. request: Request,
  343. urlIdx: Optional[int],
  344. pipeline_id: str,
  345. user=Depends(get_admin_user),
  346. ):
  347. r = None
  348. try:
  349. url = request.app.state.config.OPENAI_API_BASE_URLS[urlIdx]
  350. key = request.app.state.config.OPENAI_API_KEYS[urlIdx]
  351. r = requests.get(
  352. f"{url}/{pipeline_id}/valves/spec",
  353. headers={"Authorization": f"Bearer {key}"},
  354. )
  355. r.raise_for_status()
  356. data = r.json()
  357. return {**data}
  358. except Exception as e:
  359. # Handle connection error here
  360. log.exception(f"Connection error: {e}")
  361. detail = None
  362. if r is not None:
  363. try:
  364. res = r.json()
  365. if "detail" in res:
  366. detail = res["detail"]
  367. except Exception:
  368. pass
  369. raise HTTPException(
  370. status_code=(r.status_code if r is not None else status.HTTP_404_NOT_FOUND),
  371. detail=detail if detail else "Pipeline not found",
  372. )
  373. @router.post("/{pipeline_id}/valves/update")
  374. async def update_pipeline_valves(
  375. request: Request,
  376. urlIdx: Optional[int],
  377. pipeline_id: str,
  378. form_data: dict,
  379. user=Depends(get_admin_user),
  380. ):
  381. r = None
  382. try:
  383. url = request.app.state.config.OPENAI_API_BASE_URLS[urlIdx]
  384. key = request.app.state.config.OPENAI_API_KEYS[urlIdx]
  385. r = requests.post(
  386. f"{url}/{pipeline_id}/valves/update",
  387. headers={"Authorization": f"Bearer {key}"},
  388. json={**form_data},
  389. )
  390. r.raise_for_status()
  391. data = r.json()
  392. return {**data}
  393. except Exception as e:
  394. # Handle connection error here
  395. log.exception(f"Connection error: {e}")
  396. detail = None
  397. if r is not None:
  398. try:
  399. res = r.json()
  400. if "detail" in res:
  401. detail = res["detail"]
  402. except Exception:
  403. pass
  404. raise HTTPException(
  405. status_code=(r.status_code if r is not None else status.HTTP_404_NOT_FOUND),
  406. detail=detail if detail else "Pipeline not found",
  407. )