pipelines.py 13 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496
  1. from fastapi import (
  2. Depends,
  3. FastAPI,
  4. File,
  5. Form,
  6. HTTPException,
  7. Request,
  8. UploadFile,
  9. status,
  10. APIRouter,
  11. )
  12. import os
  13. import logging
  14. import shutil
  15. import requests
  16. from pydantic import BaseModel
  17. from starlette.responses import FileResponse
  18. from typing import Optional
  19. from open_webui.env import SRC_LOG_LEVELS
  20. from open_webui.config import CACHE_DIR
  21. from open_webui.constants import ERROR_MESSAGES
  22. from open_webui.routers.openai import get_all_models_responses
  23. from open_webui.utils.auth import get_admin_user
  24. log = logging.getLogger(__name__)
  25. log.setLevel(SRC_LOG_LEVELS["MAIN"])
  26. ##################################
  27. #
  28. # Pipeline Middleware
  29. #
  30. ##################################
  31. def get_sorted_filters(model_id, models):
  32. filters = [
  33. model
  34. for model in models.values()
  35. if "pipeline" in model
  36. and "type" in model["pipeline"]
  37. and model["pipeline"]["type"] == "filter"
  38. and (
  39. model["pipeline"]["pipelines"] == ["*"]
  40. or any(
  41. model_id == target_model_id
  42. for target_model_id in model["pipeline"]["pipelines"]
  43. )
  44. )
  45. ]
  46. sorted_filters = sorted(filters, key=lambda x: x["pipeline"]["priority"])
  47. return sorted_filters
  48. def process_pipeline_inlet_filter(request, payload, user, models):
  49. user = {"id": user.id, "email": user.email, "name": user.name, "role": user.role}
  50. model_id = payload["model"]
  51. sorted_filters = get_sorted_filters(model_id, models)
  52. model = models[model_id]
  53. if "pipeline" in model:
  54. sorted_filters.append(model)
  55. for filter in sorted_filters:
  56. r = None
  57. try:
  58. urlIdx = filter["urlIdx"]
  59. url = request.app.state.config.OPENAI_API_BASE_URLS[urlIdx]
  60. key = request.app.state.config.OPENAI_API_KEYS[urlIdx]
  61. if key == "":
  62. continue
  63. headers = {"Authorization": f"Bearer {key}"}
  64. r = requests.post(
  65. f"{url}/{filter['id']}/filter/inlet",
  66. headers=headers,
  67. json={
  68. "user": user,
  69. "body": payload,
  70. },
  71. )
  72. r.raise_for_status()
  73. payload = r.json()
  74. except Exception as e:
  75. # Handle connection error here
  76. print(f"Connection error: {e}")
  77. if r is not None:
  78. res = r.json()
  79. if "detail" in res:
  80. raise Exception(r.status_code, res["detail"])
  81. return payload
  82. def process_pipeline_outlet_filter(request, payload, user, models):
  83. user = {"id": user.id, "email": user.email, "name": user.name, "role": user.role}
  84. model_id = payload["model"]
  85. sorted_filters = get_sorted_filters(model_id, models)
  86. model = models[model_id]
  87. if "pipeline" in model:
  88. sorted_filters = [model] + sorted_filters
  89. for filter in sorted_filters:
  90. r = None
  91. try:
  92. urlIdx = filter["urlIdx"]
  93. url = request.app.state.config.OPENAI_API_BASE_URLS[urlIdx]
  94. key = request.app.state.config.OPENAI_API_KEYS[urlIdx]
  95. if key != "":
  96. r = requests.post(
  97. f"{url}/{filter['id']}/filter/outlet",
  98. headers={"Authorization": f"Bearer {key}"},
  99. json={
  100. "user": {
  101. "id": user.id,
  102. "name": user.name,
  103. "email": user.email,
  104. "role": user.role,
  105. },
  106. "body": data,
  107. },
  108. )
  109. r.raise_for_status()
  110. data = r.json()
  111. except Exception as e:
  112. # Handle connection error here
  113. print(f"Connection error: {e}")
  114. if r is not None:
  115. try:
  116. res = r.json()
  117. if "detail" in res:
  118. return Exception(r.status_code, res)
  119. except Exception:
  120. pass
  121. else:
  122. pass
  123. return payload
  124. ##################################
  125. #
  126. # Pipelines Endpoints
  127. #
  128. ##################################
  129. router = APIRouter()
  130. @router.get("/list")
  131. async def get_pipelines_list(request: Request, user=Depends(get_admin_user)):
  132. responses = await get_all_models_responses(request)
  133. log.debug(f"get_pipelines_list: get_openai_models_responses returned {responses}")
  134. urlIdxs = [
  135. idx
  136. for idx, response in enumerate(responses)
  137. if response is not None and "pipelines" in response
  138. ]
  139. return {
  140. "data": [
  141. {
  142. "url": request.app.state.config.OPENAI_API_BASE_URLS[urlIdx],
  143. "idx": urlIdx,
  144. }
  145. for urlIdx in urlIdxs
  146. ]
  147. }
  148. @router.post("/upload")
  149. async def upload_pipeline(
  150. request: Request,
  151. urlIdx: int = Form(...),
  152. file: UploadFile = File(...),
  153. user=Depends(get_admin_user),
  154. ):
  155. print("upload_pipeline", urlIdx, file.filename)
  156. # Check if the uploaded file is a python file
  157. if not (file.filename and file.filename.endswith(".py")):
  158. raise HTTPException(
  159. status_code=status.HTTP_400_BAD_REQUEST,
  160. detail="Only Python (.py) files are allowed.",
  161. )
  162. upload_folder = f"{CACHE_DIR}/pipelines"
  163. os.makedirs(upload_folder, exist_ok=True)
  164. file_path = os.path.join(upload_folder, file.filename)
  165. r = None
  166. try:
  167. # Save the uploaded file
  168. with open(file_path, "wb") as buffer:
  169. shutil.copyfileobj(file.file, buffer)
  170. url = request.app.state.config.OPENAI_API_BASE_URLS[urlIdx]
  171. key = request.app.state.config.OPENAI_API_KEYS[urlIdx]
  172. with open(file_path, "rb") as f:
  173. files = {"file": f}
  174. r = requests.post(
  175. f"{url}/pipelines/upload",
  176. headers={"Authorization": f"Bearer {key}"},
  177. files=files,
  178. )
  179. r.raise_for_status()
  180. data = r.json()
  181. return {**data}
  182. except Exception as e:
  183. # Handle connection error here
  184. print(f"Connection error: {e}")
  185. detail = None
  186. status_code = status.HTTP_404_NOT_FOUND
  187. if r is not None:
  188. status_code = r.status_code
  189. try:
  190. res = r.json()
  191. if "detail" in res:
  192. detail = res["detail"]
  193. except Exception:
  194. pass
  195. raise HTTPException(
  196. status_code=status_code,
  197. detail=detail if detail else "Pipeline not found",
  198. )
  199. finally:
  200. # Ensure the file is deleted after the upload is completed or on failure
  201. if os.path.exists(file_path):
  202. os.remove(file_path)
  203. class AddPipelineForm(BaseModel):
  204. url: str
  205. urlIdx: int
  206. @router.post("/add")
  207. async def add_pipeline(
  208. request: Request, form_data: AddPipelineForm, user=Depends(get_admin_user)
  209. ):
  210. r = None
  211. try:
  212. urlIdx = form_data.urlIdx
  213. url = request.app.state.config.OPENAI_API_BASE_URLS[urlIdx]
  214. key = request.app.state.config.OPENAI_API_KEYS[urlIdx]
  215. r = requests.post(
  216. f"{url}/pipelines/add",
  217. headers={"Authorization": f"Bearer {key}"},
  218. json={"url": form_data.url},
  219. )
  220. r.raise_for_status()
  221. data = r.json()
  222. return {**data}
  223. except Exception as e:
  224. # Handle connection error here
  225. print(f"Connection error: {e}")
  226. detail = None
  227. if r is not None:
  228. try:
  229. res = r.json()
  230. if "detail" in res:
  231. detail = res["detail"]
  232. except Exception:
  233. pass
  234. raise HTTPException(
  235. status_code=(r.status_code if r is not None else status.HTTP_404_NOT_FOUND),
  236. detail=detail if detail else "Pipeline not found",
  237. )
  238. class DeletePipelineForm(BaseModel):
  239. id: str
  240. urlIdx: int
  241. @router.delete("/delete")
  242. async def delete_pipeline(
  243. request: Request, form_data: DeletePipelineForm, user=Depends(get_admin_user)
  244. ):
  245. r = None
  246. try:
  247. urlIdx = form_data.urlIdx
  248. url = request.app.state.config.OPENAI_API_BASE_URLS[urlIdx]
  249. key = request.app.state.config.OPENAI_API_KEYS[urlIdx]
  250. r = requests.delete(
  251. f"{url}/pipelines/delete",
  252. headers={"Authorization": f"Bearer {key}"},
  253. json={"id": form_data.id},
  254. )
  255. r.raise_for_status()
  256. data = r.json()
  257. return {**data}
  258. except Exception as e:
  259. # Handle connection error here
  260. print(f"Connection error: {e}")
  261. detail = None
  262. if r is not None:
  263. try:
  264. res = r.json()
  265. if "detail" in res:
  266. detail = res["detail"]
  267. except Exception:
  268. pass
  269. raise HTTPException(
  270. status_code=(r.status_code if r is not None else status.HTTP_404_NOT_FOUND),
  271. detail=detail if detail else "Pipeline not found",
  272. )
  273. @router.get("/")
  274. async def get_pipelines(
  275. request: Request, urlIdx: Optional[int] = None, user=Depends(get_admin_user)
  276. ):
  277. r = None
  278. try:
  279. url = request.app.state.config.OPENAI_API_BASE_URLS[urlIdx]
  280. key = request.app.state.config.OPENAI_API_KEYS[urlIdx]
  281. r = requests.get(f"{url}/pipelines", headers={"Authorization": f"Bearer {key}"})
  282. r.raise_for_status()
  283. data = r.json()
  284. return {**data}
  285. except Exception as e:
  286. # Handle connection error here
  287. print(f"Connection error: {e}")
  288. detail = None
  289. if r is not None:
  290. try:
  291. res = r.json()
  292. if "detail" in res:
  293. detail = res["detail"]
  294. except Exception:
  295. pass
  296. raise HTTPException(
  297. status_code=(r.status_code if r is not None else status.HTTP_404_NOT_FOUND),
  298. detail=detail if detail else "Pipeline not found",
  299. )
  300. @router.get("/{pipeline_id}/valves")
  301. async def get_pipeline_valves(
  302. request: Request,
  303. urlIdx: Optional[int],
  304. pipeline_id: str,
  305. user=Depends(get_admin_user),
  306. ):
  307. r = None
  308. try:
  309. url = request.app.state.config.OPENAI_API_BASE_URLS[urlIdx]
  310. key = request.app.state.config.OPENAI_API_KEYS[urlIdx]
  311. r = requests.get(
  312. f"{url}/{pipeline_id}/valves", headers={"Authorization": f"Bearer {key}"}
  313. )
  314. r.raise_for_status()
  315. data = r.json()
  316. return {**data}
  317. except Exception as e:
  318. # Handle connection error here
  319. print(f"Connection error: {e}")
  320. detail = None
  321. if r is not None:
  322. try:
  323. res = r.json()
  324. if "detail" in res:
  325. detail = res["detail"]
  326. except Exception:
  327. pass
  328. raise HTTPException(
  329. status_code=(r.status_code if r is not None else status.HTTP_404_NOT_FOUND),
  330. detail=detail if detail else "Pipeline not found",
  331. )
  332. @router.get("/{pipeline_id}/valves/spec")
  333. async def get_pipeline_valves_spec(
  334. request: Request,
  335. urlIdx: Optional[int],
  336. pipeline_id: str,
  337. user=Depends(get_admin_user),
  338. ):
  339. r = None
  340. try:
  341. url = request.app.state.config.OPENAI_API_BASE_URLS[urlIdx]
  342. key = request.app.state.config.OPENAI_API_KEYS[urlIdx]
  343. r = requests.get(
  344. f"{url}/{pipeline_id}/valves/spec",
  345. headers={"Authorization": f"Bearer {key}"},
  346. )
  347. r.raise_for_status()
  348. data = r.json()
  349. return {**data}
  350. except Exception as e:
  351. # Handle connection error here
  352. print(f"Connection error: {e}")
  353. detail = None
  354. if r is not None:
  355. try:
  356. res = r.json()
  357. if "detail" in res:
  358. detail = res["detail"]
  359. except Exception:
  360. pass
  361. raise HTTPException(
  362. status_code=(r.status_code if r is not None else status.HTTP_404_NOT_FOUND),
  363. detail=detail if detail else "Pipeline not found",
  364. )
  365. @router.post("/{pipeline_id}/valves/update")
  366. async def update_pipeline_valves(
  367. request: Request,
  368. urlIdx: Optional[int],
  369. pipeline_id: str,
  370. form_data: dict,
  371. user=Depends(get_admin_user),
  372. ):
  373. r = None
  374. try:
  375. url = request.app.state.config.OPENAI_API_BASE_URLS[urlIdx]
  376. key = request.app.state.config.OPENAI_API_KEYS[urlIdx]
  377. r = requests.post(
  378. f"{url}/{pipeline_id}/valves/update",
  379. headers={"Authorization": f"Bearer {key}"},
  380. json={**form_data},
  381. )
  382. r.raise_for_status()
  383. data = r.json()
  384. return {**data}
  385. except Exception as e:
  386. # Handle connection error here
  387. print(f"Connection error: {e}")
  388. detail = None
  389. if r is not None:
  390. try:
  391. res = r.json()
  392. if "detail" in res:
  393. detail = res["detail"]
  394. except Exception:
  395. pass
  396. raise HTTPException(
  397. status_code=(r.status_code if r is not None else status.HTTP_404_NOT_FOUND),
  398. detail=detail if detail else "Pipeline not found",
  399. )