tools.py 13 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438
  1. from pathlib import Path
  2. from typing import Optional
  3. from open_webui.models.tools import (
  4. ToolForm,
  5. ToolModel,
  6. ToolResponse,
  7. ToolUserResponse,
  8. Tools,
  9. )
  10. from open_webui.utils.plugin import load_tools_module_by_id, replace_imports
  11. from open_webui.config import CACHE_DIR
  12. from open_webui.constants import ERROR_MESSAGES
  13. from fastapi import APIRouter, Depends, HTTPException, Request, status
  14. from open_webui.utils.tools import get_tools_specs
  15. from open_webui.utils.auth import get_admin_user, get_verified_user
  16. from open_webui.utils.access_control import has_access, has_permission
  17. router = APIRouter()
  18. ############################
  19. # GetTools
  20. ############################
  21. @router.get("/", response_model=list[ToolUserResponse])
  22. async def get_tools(user=Depends(get_verified_user)):
  23. if user.role == "admin":
  24. tools = Tools.get_tools()
  25. else:
  26. tools = Tools.get_tools_by_user_id(user.id, "read")
  27. return tools
  28. ############################
  29. # GetToolList
  30. ############################
  31. @router.get("/list", response_model=list[ToolUserResponse])
  32. async def get_tool_list(user=Depends(get_verified_user)):
  33. if user.role == "admin":
  34. tools = Tools.get_tools()
  35. else:
  36. tools = Tools.get_tools_by_user_id(user.id, "write")
  37. return tools
  38. ############################
  39. # ExportTools
  40. ############################
  41. @router.get("/export", response_model=list[ToolModel])
  42. async def export_tools(user=Depends(get_admin_user)):
  43. tools = Tools.get_tools()
  44. return tools
  45. ############################
  46. # CreateNewTools
  47. ############################
  48. @router.post("/create", response_model=Optional[ToolResponse])
  49. async def create_new_tools(
  50. request: Request,
  51. form_data: ToolForm,
  52. user=Depends(get_verified_user),
  53. ):
  54. if user.role != "admin" and not has_permission(
  55. user.id, "workspace.tools", request.app.state.config.USER_PERMISSIONS
  56. ):
  57. raise HTTPException(
  58. status_code=status.HTTP_401_UNAUTHORIZED,
  59. detail=ERROR_MESSAGES.UNAUTHORIZED,
  60. )
  61. if not form_data.id.isidentifier():
  62. raise HTTPException(
  63. status_code=status.HTTP_400_BAD_REQUEST,
  64. detail="Only alphanumeric characters and underscores are allowed in the id",
  65. )
  66. form_data.id = form_data.id.lower()
  67. tools = Tools.get_tool_by_id(form_data.id)
  68. if tools is None:
  69. try:
  70. form_data.content = replace_imports(form_data.content)
  71. tools_module, frontmatter = load_tools_module_by_id(
  72. form_data.id, content=form_data.content
  73. )
  74. form_data.meta.manifest = frontmatter
  75. TOOLS = request.app.state.TOOLS
  76. TOOLS[form_data.id] = tools_module
  77. specs = get_tools_specs(TOOLS[form_data.id])
  78. tools = Tools.insert_new_tool(user.id, form_data, specs)
  79. tool_cache_dir = Path(CACHE_DIR) / "tools" / form_data.id
  80. tool_cache_dir.mkdir(parents=True, exist_ok=True)
  81. if tools:
  82. return tools
  83. else:
  84. raise HTTPException(
  85. status_code=status.HTTP_400_BAD_REQUEST,
  86. detail=ERROR_MESSAGES.DEFAULT("Error creating tools"),
  87. )
  88. except Exception as e:
  89. print(e)
  90. raise HTTPException(
  91. status_code=status.HTTP_400_BAD_REQUEST,
  92. detail=ERROR_MESSAGES.DEFAULT(str(e)),
  93. )
  94. else:
  95. raise HTTPException(
  96. status_code=status.HTTP_400_BAD_REQUEST,
  97. detail=ERROR_MESSAGES.ID_TAKEN,
  98. )
  99. ############################
  100. # GetToolsById
  101. ############################
  102. @router.get("/id/{id}", response_model=Optional[ToolModel])
  103. async def get_tools_by_id(id: str, user=Depends(get_verified_user)):
  104. tools = Tools.get_tool_by_id(id)
  105. if tools:
  106. if (
  107. user.role == "admin"
  108. or tools.user_id == user.id
  109. or has_access(user.id, "read", tools.access_control)
  110. ):
  111. return tools
  112. else:
  113. raise HTTPException(
  114. status_code=status.HTTP_401_UNAUTHORIZED,
  115. detail=ERROR_MESSAGES.NOT_FOUND,
  116. )
  117. ############################
  118. # UpdateToolsById
  119. ############################
  120. @router.post("/id/{id}/update", response_model=Optional[ToolModel])
  121. async def update_tools_by_id(
  122. request: Request,
  123. id: str,
  124. form_data: ToolForm,
  125. user=Depends(get_verified_user),
  126. ):
  127. tools = Tools.get_tool_by_id(id)
  128. if not tools:
  129. raise HTTPException(
  130. status_code=status.HTTP_401_UNAUTHORIZED,
  131. detail=ERROR_MESSAGES.NOT_FOUND,
  132. )
  133. # Is the user the original creator, in a group with write access, or an admin
  134. if (
  135. tools.user_id != user.id
  136. and not has_access(user.id, "write", tools.access_control)
  137. and user.role != "admin"
  138. ):
  139. raise HTTPException(
  140. status_code=status.HTTP_401_UNAUTHORIZED,
  141. detail=ERROR_MESSAGES.UNAUTHORIZED,
  142. )
  143. try:
  144. form_data.content = replace_imports(form_data.content)
  145. tools_module, frontmatter = load_tools_module_by_id(
  146. id, content=form_data.content
  147. )
  148. form_data.meta.manifest = frontmatter
  149. TOOLS = request.app.state.TOOLS
  150. TOOLS[id] = tools_module
  151. specs = get_tools_specs(TOOLS[id])
  152. updated = {
  153. **form_data.model_dump(exclude={"id"}),
  154. "specs": specs,
  155. }
  156. print(updated)
  157. tools = Tools.update_tool_by_id(id, updated)
  158. if tools:
  159. return tools
  160. else:
  161. raise HTTPException(
  162. status_code=status.HTTP_400_BAD_REQUEST,
  163. detail=ERROR_MESSAGES.DEFAULT("Error updating tools"),
  164. )
  165. except Exception as e:
  166. raise HTTPException(
  167. status_code=status.HTTP_400_BAD_REQUEST,
  168. detail=ERROR_MESSAGES.DEFAULT(str(e)),
  169. )
  170. ############################
  171. # DeleteToolsById
  172. ############################
  173. @router.delete("/id/{id}/delete", response_model=bool)
  174. async def delete_tools_by_id(
  175. request: Request, id: str, user=Depends(get_verified_user)
  176. ):
  177. tools = Tools.get_tool_by_id(id)
  178. if not tools:
  179. raise HTTPException(
  180. status_code=status.HTTP_401_UNAUTHORIZED,
  181. detail=ERROR_MESSAGES.NOT_FOUND,
  182. )
  183. if (
  184. tools.user_id != user.id
  185. and not has_access(user.id, "write", tools.access_control)
  186. and user.role != "admin"
  187. ):
  188. raise HTTPException(
  189. status_code=status.HTTP_401_UNAUTHORIZED,
  190. detail=ERROR_MESSAGES.UNAUTHORIZED,
  191. )
  192. result = Tools.delete_tool_by_id(id)
  193. if result:
  194. TOOLS = request.app.state.TOOLS
  195. if id in TOOLS:
  196. del TOOLS[id]
  197. return result
  198. ############################
  199. # GetToolValves
  200. ############################
  201. @router.get("/id/{id}/valves", response_model=Optional[dict])
  202. async def get_tools_valves_by_id(id: str, user=Depends(get_verified_user)):
  203. tools = Tools.get_tool_by_id(id)
  204. if tools:
  205. try:
  206. valves = Tools.get_tool_valves_by_id(id)
  207. return valves
  208. except Exception as e:
  209. raise HTTPException(
  210. status_code=status.HTTP_400_BAD_REQUEST,
  211. detail=ERROR_MESSAGES.DEFAULT(str(e)),
  212. )
  213. else:
  214. raise HTTPException(
  215. status_code=status.HTTP_401_UNAUTHORIZED,
  216. detail=ERROR_MESSAGES.NOT_FOUND,
  217. )
  218. ############################
  219. # GetToolValvesSpec
  220. ############################
  221. @router.get("/id/{id}/valves/spec", response_model=Optional[dict])
  222. async def get_tools_valves_spec_by_id(
  223. request: Request, id: str, user=Depends(get_verified_user)
  224. ):
  225. tools = Tools.get_tool_by_id(id)
  226. if tools:
  227. if id in request.app.state.TOOLS:
  228. tools_module = request.app.state.TOOLS[id]
  229. else:
  230. tools_module, _ = load_tools_module_by_id(id)
  231. request.app.state.TOOLS[id] = tools_module
  232. if hasattr(tools_module, "Valves"):
  233. Valves = tools_module.Valves
  234. return Valves.schema()
  235. return None
  236. else:
  237. raise HTTPException(
  238. status_code=status.HTTP_401_UNAUTHORIZED,
  239. detail=ERROR_MESSAGES.NOT_FOUND,
  240. )
  241. ############################
  242. # UpdateToolValves
  243. ############################
  244. @router.post("/id/{id}/valves/update", response_model=Optional[dict])
  245. async def update_tools_valves_by_id(
  246. request: Request, id: str, form_data: dict, user=Depends(get_verified_user)
  247. ):
  248. tools = Tools.get_tool_by_id(id)
  249. if not tools:
  250. raise HTTPException(
  251. status_code=status.HTTP_401_UNAUTHORIZED,
  252. detail=ERROR_MESSAGES.NOT_FOUND,
  253. )
  254. if (
  255. tools.user_id != user.id
  256. and not has_access(user.id, "write", tools.access_control)
  257. and user.role != "admin"
  258. ):
  259. raise HTTPException(
  260. status_code=status.HTTP_400_BAD_REQUEST,
  261. detail=ERROR_MESSAGES.ACCESS_PROHIBITED,
  262. )
  263. if id in request.app.state.TOOLS:
  264. tools_module = request.app.state.TOOLS[id]
  265. else:
  266. tools_module, _ = load_tools_module_by_id(id)
  267. request.app.state.TOOLS[id] = tools_module
  268. if not hasattr(tools_module, "Valves"):
  269. raise HTTPException(
  270. status_code=status.HTTP_401_UNAUTHORIZED,
  271. detail=ERROR_MESSAGES.NOT_FOUND,
  272. )
  273. Valves = tools_module.Valves
  274. try:
  275. form_data = {k: v for k, v in form_data.items() if v is not None}
  276. valves = Valves(**form_data)
  277. Tools.update_tool_valves_by_id(id, valves.model_dump())
  278. return valves.model_dump()
  279. except Exception as e:
  280. print(e)
  281. raise HTTPException(
  282. status_code=status.HTTP_400_BAD_REQUEST,
  283. detail=ERROR_MESSAGES.DEFAULT(str(e)),
  284. )
  285. ############################
  286. # ToolUserValves
  287. ############################
  288. @router.get("/id/{id}/valves/user", response_model=Optional[dict])
  289. async def get_tools_user_valves_by_id(id: str, user=Depends(get_verified_user)):
  290. tools = Tools.get_tool_by_id(id)
  291. if tools:
  292. try:
  293. user_valves = Tools.get_user_valves_by_id_and_user_id(id, user.id)
  294. return user_valves
  295. except Exception as e:
  296. raise HTTPException(
  297. status_code=status.HTTP_400_BAD_REQUEST,
  298. detail=ERROR_MESSAGES.DEFAULT(str(e)),
  299. )
  300. else:
  301. raise HTTPException(
  302. status_code=status.HTTP_401_UNAUTHORIZED,
  303. detail=ERROR_MESSAGES.NOT_FOUND,
  304. )
  305. @router.get("/id/{id}/valves/user/spec", response_model=Optional[dict])
  306. async def get_tools_user_valves_spec_by_id(
  307. request: Request, id: str, user=Depends(get_verified_user)
  308. ):
  309. tools = Tools.get_tool_by_id(id)
  310. if tools:
  311. if id in request.app.state.TOOLS:
  312. tools_module = request.app.state.TOOLS[id]
  313. else:
  314. tools_module, _ = load_tools_module_by_id(id)
  315. request.app.state.TOOLS[id] = tools_module
  316. if hasattr(tools_module, "UserValves"):
  317. UserValves = tools_module.UserValves
  318. return UserValves.schema()
  319. return None
  320. else:
  321. raise HTTPException(
  322. status_code=status.HTTP_401_UNAUTHORIZED,
  323. detail=ERROR_MESSAGES.NOT_FOUND,
  324. )
  325. @router.post("/id/{id}/valves/user/update", response_model=Optional[dict])
  326. async def update_tools_user_valves_by_id(
  327. request: Request, id: str, form_data: dict, user=Depends(get_verified_user)
  328. ):
  329. tools = Tools.get_tool_by_id(id)
  330. if tools:
  331. if id in request.app.state.TOOLS:
  332. tools_module = request.app.state.TOOLS[id]
  333. else:
  334. tools_module, _ = load_tools_module_by_id(id)
  335. request.app.state.TOOLS[id] = tools_module
  336. if hasattr(tools_module, "UserValves"):
  337. UserValves = tools_module.UserValves
  338. try:
  339. form_data = {k: v for k, v in form_data.items() if v is not None}
  340. user_valves = UserValves(**form_data)
  341. Tools.update_user_valves_by_id_and_user_id(
  342. id, user.id, user_valves.model_dump()
  343. )
  344. return user_valves.model_dump()
  345. except Exception as e:
  346. print(e)
  347. raise HTTPException(
  348. status_code=status.HTTP_400_BAD_REQUEST,
  349. detail=ERROR_MESSAGES.DEFAULT(str(e)),
  350. )
  351. else:
  352. raise HTTPException(
  353. status_code=status.HTTP_401_UNAUTHORIZED,
  354. detail=ERROR_MESSAGES.NOT_FOUND,
  355. )
  356. else:
  357. raise HTTPException(
  358. status_code=status.HTTP_401_UNAUTHORIZED,
  359. detail=ERROR_MESSAGES.NOT_FOUND,
  360. )