pipelines.py 13 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492
  1. from fastapi import (
  2. Depends,
  3. FastAPI,
  4. File,
  5. Form,
  6. HTTPException,
  7. Request,
  8. UploadFile,
  9. status,
  10. APIRouter,
  11. )
  12. import os
  13. import logging
  14. import shutil
  15. import requests
  16. from pydantic import BaseModel
  17. from starlette.responses import FileResponse
  18. from typing import Optional
  19. from open_webui.env import SRC_LOG_LEVELS
  20. from open_webui.config import CACHE_DIR
  21. from open_webui.constants import ERROR_MESSAGES
  22. from open_webui.routers.openai import get_all_models_responses
  23. from open_webui.utils.auth import get_admin_user
  24. log = logging.getLogger(__name__)
  25. log.setLevel(SRC_LOG_LEVELS["MAIN"])
  26. ##################################
  27. #
  28. # Pipeline Middleware
  29. #
  30. ##################################
  31. def get_sorted_filters(model_id, models):
  32. filters = [
  33. model
  34. for model in models.values()
  35. if "pipeline" in model
  36. and "type" in model["pipeline"]
  37. and model["pipeline"]["type"] == "filter"
  38. and (
  39. model["pipeline"]["pipelines"] == ["*"]
  40. or any(
  41. model_id == target_model_id
  42. for target_model_id in model["pipeline"]["pipelines"]
  43. )
  44. )
  45. ]
  46. sorted_filters = sorted(filters, key=lambda x: x["pipeline"]["priority"])
  47. return sorted_filters
  48. def process_pipeline_inlet_filter(request, payload, user, models):
  49. user = {"id": user.id, "email": user.email, "name": user.name, "role": user.role}
  50. model_id = payload["model"]
  51. sorted_filters = get_sorted_filters(model_id, models)
  52. model = models[model_id]
  53. if "pipeline" in model:
  54. sorted_filters.append(model)
  55. for filter in sorted_filters:
  56. r = None
  57. try:
  58. urlIdx = filter["urlIdx"]
  59. url = request.app.state.config.OPENAI_API_BASE_URLS[urlIdx]
  60. key = request.app.state.config.OPENAI_API_KEYS[urlIdx]
  61. if key == "":
  62. continue
  63. headers = {"Authorization": f"Bearer {key}"}
  64. r = requests.post(
  65. f"{url}/{filter['id']}/filter/inlet",
  66. headers=headers,
  67. json={
  68. "user": user,
  69. "body": payload,
  70. },
  71. )
  72. r.raise_for_status()
  73. payload = r.json()
  74. except Exception as e:
  75. # Handle connection error here
  76. print(f"Connection error: {e}")
  77. if r is not None:
  78. res = r.json()
  79. if "detail" in res:
  80. raise Exception(r.status_code, res["detail"])
  81. return payload
  82. def process_pipeline_outlet_filter(request, payload, user, models):
  83. user = {"id": user.id, "email": user.email, "name": user.name, "role": user.role}
  84. model_id = payload["model"]
  85. sorted_filters = get_sorted_filters(model_id, models)
  86. model = models[model_id]
  87. if "pipeline" in model:
  88. sorted_filters = [model] + sorted_filters
  89. for filter in sorted_filters:
  90. r = None
  91. try:
  92. urlIdx = filter["urlIdx"]
  93. url = request.app.state.config.OPENAI_API_BASE_URLS[urlIdx]
  94. key = request.app.state.config.OPENAI_API_KEYS[urlIdx]
  95. if key != "":
  96. r = requests.post(
  97. f"{url}/{filter['id']}/filter/outlet",
  98. headers={"Authorization": f"Bearer {key}"},
  99. json={
  100. "user": user,
  101. "body": payload,
  102. },
  103. )
  104. r.raise_for_status()
  105. data = r.json()
  106. payload = data
  107. except Exception as e:
  108. # Handle connection error here
  109. print(f"Connection error: {e}")
  110. if r is not None:
  111. try:
  112. res = r.json()
  113. if "detail" in res:
  114. return Exception(r.status_code, res)
  115. except Exception:
  116. pass
  117. else:
  118. pass
  119. return payload
  120. ##################################
  121. #
  122. # Pipelines Endpoints
  123. #
  124. ##################################
  125. router = APIRouter()
  126. @router.get("/list")
  127. async def get_pipelines_list(request: Request, user=Depends(get_admin_user)):
  128. responses = await get_all_models_responses(request)
  129. log.debug(f"get_pipelines_list: get_openai_models_responses returned {responses}")
  130. urlIdxs = [
  131. idx
  132. for idx, response in enumerate(responses)
  133. if response is not None and "pipelines" in response
  134. ]
  135. return {
  136. "data": [
  137. {
  138. "url": request.app.state.config.OPENAI_API_BASE_URLS[urlIdx],
  139. "idx": urlIdx,
  140. }
  141. for urlIdx in urlIdxs
  142. ]
  143. }
  144. @router.post("/upload")
  145. async def upload_pipeline(
  146. request: Request,
  147. urlIdx: int = Form(...),
  148. file: UploadFile = File(...),
  149. user=Depends(get_admin_user),
  150. ):
  151. print("upload_pipeline", urlIdx, file.filename)
  152. # Check if the uploaded file is a python file
  153. if not (file.filename and file.filename.endswith(".py")):
  154. raise HTTPException(
  155. status_code=status.HTTP_400_BAD_REQUEST,
  156. detail="Only Python (.py) files are allowed.",
  157. )
  158. upload_folder = f"{CACHE_DIR}/pipelines"
  159. os.makedirs(upload_folder, exist_ok=True)
  160. file_path = os.path.join(upload_folder, file.filename)
  161. r = None
  162. try:
  163. # Save the uploaded file
  164. with open(file_path, "wb") as buffer:
  165. shutil.copyfileobj(file.file, buffer)
  166. url = request.app.state.config.OPENAI_API_BASE_URLS[urlIdx]
  167. key = request.app.state.config.OPENAI_API_KEYS[urlIdx]
  168. with open(file_path, "rb") as f:
  169. files = {"file": f}
  170. r = requests.post(
  171. f"{url}/pipelines/upload",
  172. headers={"Authorization": f"Bearer {key}"},
  173. files=files,
  174. )
  175. r.raise_for_status()
  176. data = r.json()
  177. return {**data}
  178. except Exception as e:
  179. # Handle connection error here
  180. print(f"Connection error: {e}")
  181. detail = None
  182. status_code = status.HTTP_404_NOT_FOUND
  183. if r is not None:
  184. status_code = r.status_code
  185. try:
  186. res = r.json()
  187. if "detail" in res:
  188. detail = res["detail"]
  189. except Exception:
  190. pass
  191. raise HTTPException(
  192. status_code=status_code,
  193. detail=detail if detail else "Pipeline not found",
  194. )
  195. finally:
  196. # Ensure the file is deleted after the upload is completed or on failure
  197. if os.path.exists(file_path):
  198. os.remove(file_path)
  199. class AddPipelineForm(BaseModel):
  200. url: str
  201. urlIdx: int
  202. @router.post("/add")
  203. async def add_pipeline(
  204. request: Request, form_data: AddPipelineForm, user=Depends(get_admin_user)
  205. ):
  206. r = None
  207. try:
  208. urlIdx = form_data.urlIdx
  209. url = request.app.state.config.OPENAI_API_BASE_URLS[urlIdx]
  210. key = request.app.state.config.OPENAI_API_KEYS[urlIdx]
  211. r = requests.post(
  212. f"{url}/pipelines/add",
  213. headers={"Authorization": f"Bearer {key}"},
  214. json={"url": form_data.url},
  215. )
  216. r.raise_for_status()
  217. data = r.json()
  218. return {**data}
  219. except Exception as e:
  220. # Handle connection error here
  221. print(f"Connection error: {e}")
  222. detail = None
  223. if r is not None:
  224. try:
  225. res = r.json()
  226. if "detail" in res:
  227. detail = res["detail"]
  228. except Exception:
  229. pass
  230. raise HTTPException(
  231. status_code=(r.status_code if r is not None else status.HTTP_404_NOT_FOUND),
  232. detail=detail if detail else "Pipeline not found",
  233. )
  234. class DeletePipelineForm(BaseModel):
  235. id: str
  236. urlIdx: int
  237. @router.delete("/delete")
  238. async def delete_pipeline(
  239. request: Request, form_data: DeletePipelineForm, user=Depends(get_admin_user)
  240. ):
  241. r = None
  242. try:
  243. urlIdx = form_data.urlIdx
  244. url = request.app.state.config.OPENAI_API_BASE_URLS[urlIdx]
  245. key = request.app.state.config.OPENAI_API_KEYS[urlIdx]
  246. r = requests.delete(
  247. f"{url}/pipelines/delete",
  248. headers={"Authorization": f"Bearer {key}"},
  249. json={"id": form_data.id},
  250. )
  251. r.raise_for_status()
  252. data = r.json()
  253. return {**data}
  254. except Exception as e:
  255. # Handle connection error here
  256. print(f"Connection error: {e}")
  257. detail = None
  258. if r is not None:
  259. try:
  260. res = r.json()
  261. if "detail" in res:
  262. detail = res["detail"]
  263. except Exception:
  264. pass
  265. raise HTTPException(
  266. status_code=(r.status_code if r is not None else status.HTTP_404_NOT_FOUND),
  267. detail=detail if detail else "Pipeline not found",
  268. )
  269. @router.get("/")
  270. async def get_pipelines(
  271. request: Request, urlIdx: Optional[int] = None, user=Depends(get_admin_user)
  272. ):
  273. r = None
  274. try:
  275. url = request.app.state.config.OPENAI_API_BASE_URLS[urlIdx]
  276. key = request.app.state.config.OPENAI_API_KEYS[urlIdx]
  277. r = requests.get(f"{url}/pipelines", headers={"Authorization": f"Bearer {key}"})
  278. r.raise_for_status()
  279. data = r.json()
  280. return {**data}
  281. except Exception as e:
  282. # Handle connection error here
  283. print(f"Connection error: {e}")
  284. detail = None
  285. if r is not None:
  286. try:
  287. res = r.json()
  288. if "detail" in res:
  289. detail = res["detail"]
  290. except Exception:
  291. pass
  292. raise HTTPException(
  293. status_code=(r.status_code if r is not None else status.HTTP_404_NOT_FOUND),
  294. detail=detail if detail else "Pipeline not found",
  295. )
  296. @router.get("/{pipeline_id}/valves")
  297. async def get_pipeline_valves(
  298. request: Request,
  299. urlIdx: Optional[int],
  300. pipeline_id: str,
  301. user=Depends(get_admin_user),
  302. ):
  303. r = None
  304. try:
  305. url = request.app.state.config.OPENAI_API_BASE_URLS[urlIdx]
  306. key = request.app.state.config.OPENAI_API_KEYS[urlIdx]
  307. r = requests.get(
  308. f"{url}/{pipeline_id}/valves", headers={"Authorization": f"Bearer {key}"}
  309. )
  310. r.raise_for_status()
  311. data = r.json()
  312. return {**data}
  313. except Exception as e:
  314. # Handle connection error here
  315. print(f"Connection error: {e}")
  316. detail = None
  317. if r is not None:
  318. try:
  319. res = r.json()
  320. if "detail" in res:
  321. detail = res["detail"]
  322. except Exception:
  323. pass
  324. raise HTTPException(
  325. status_code=(r.status_code if r is not None else status.HTTP_404_NOT_FOUND),
  326. detail=detail if detail else "Pipeline not found",
  327. )
  328. @router.get("/{pipeline_id}/valves/spec")
  329. async def get_pipeline_valves_spec(
  330. request: Request,
  331. urlIdx: Optional[int],
  332. pipeline_id: str,
  333. user=Depends(get_admin_user),
  334. ):
  335. r = None
  336. try:
  337. url = request.app.state.config.OPENAI_API_BASE_URLS[urlIdx]
  338. key = request.app.state.config.OPENAI_API_KEYS[urlIdx]
  339. r = requests.get(
  340. f"{url}/{pipeline_id}/valves/spec",
  341. headers={"Authorization": f"Bearer {key}"},
  342. )
  343. r.raise_for_status()
  344. data = r.json()
  345. return {**data}
  346. except Exception as e:
  347. # Handle connection error here
  348. print(f"Connection error: {e}")
  349. detail = None
  350. if r is not None:
  351. try:
  352. res = r.json()
  353. if "detail" in res:
  354. detail = res["detail"]
  355. except Exception:
  356. pass
  357. raise HTTPException(
  358. status_code=(r.status_code if r is not None else status.HTTP_404_NOT_FOUND),
  359. detail=detail if detail else "Pipeline not found",
  360. )
  361. @router.post("/{pipeline_id}/valves/update")
  362. async def update_pipeline_valves(
  363. request: Request,
  364. urlIdx: Optional[int],
  365. pipeline_id: str,
  366. form_data: dict,
  367. user=Depends(get_admin_user),
  368. ):
  369. r = None
  370. try:
  371. url = request.app.state.config.OPENAI_API_BASE_URLS[urlIdx]
  372. key = request.app.state.config.OPENAI_API_KEYS[urlIdx]
  373. r = requests.post(
  374. f"{url}/{pipeline_id}/valves/update",
  375. headers={"Authorization": f"Bearer {key}"},
  376. json={**form_data},
  377. )
  378. r.raise_for_status()
  379. data = r.json()
  380. return {**data}
  381. except Exception as e:
  382. # Handle connection error here
  383. print(f"Connection error: {e}")
  384. detail = None
  385. if r is not None:
  386. try:
  387. res = r.json()
  388. if "detail" in res:
  389. detail = res["detail"]
  390. except Exception:
  391. pass
  392. raise HTTPException(
  393. status_code=(r.status_code if r is not None else status.HTTP_404_NOT_FOUND),
  394. detail=detail if detail else "Pipeline not found",
  395. )