tools.py 12 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423
  1. from pathlib import Path
  2. from typing import Optional
  3. from open_webui.models.tools import (
  4. ToolForm,
  5. ToolModel,
  6. ToolResponse,
  7. ToolUserResponse,
  8. Tools,
  9. )
  10. from open_webui.utils.plugin import load_tools_module_by_id, replace_imports
  11. from open_webui.config import CACHE_DIR
  12. from open_webui.constants import ERROR_MESSAGES
  13. from fastapi import APIRouter, Depends, HTTPException, Request, status
  14. from open_webui.utils.tools import get_tools_specs
  15. from open_webui.utils.auth import get_admin_user, get_verified_user
  16. from open_webui.utils.access_control import has_access, has_permission
  17. router = APIRouter()
  18. ############################
  19. # GetTools
  20. ############################
  21. @router.get("/", response_model=list[ToolUserResponse])
  22. async def get_tools(user=Depends(get_verified_user)):
  23. if user.role == "admin":
  24. tools = Tools.get_tools()
  25. else:
  26. tools = Tools.get_tools_by_user_id(user.id, "read")
  27. return tools
  28. ############################
  29. # GetToolList
  30. ############################
  31. @router.get("/list", response_model=list[ToolUserResponse])
  32. async def get_tool_list(user=Depends(get_verified_user)):
  33. if user.role == "admin":
  34. tools = Tools.get_tools()
  35. else:
  36. tools = Tools.get_tools_by_user_id(user.id, "write")
  37. return tools
  38. ############################
  39. # ExportTools
  40. ############################
  41. @router.get("/export", response_model=list[ToolModel])
  42. async def export_tools(user=Depends(get_admin_user)):
  43. tools = Tools.get_tools()
  44. return tools
  45. ############################
  46. # CreateNewTools
  47. ############################
  48. @router.post("/create", response_model=Optional[ToolResponse])
  49. async def create_new_tools(
  50. request: Request,
  51. form_data: ToolForm,
  52. user=Depends(get_verified_user),
  53. ):
  54. if user.role != "admin" and not has_permission(
  55. user.id, "workspace.knowledge", request.app.state.config.USER_PERMISSIONS
  56. ):
  57. raise HTTPException(
  58. status_code=status.HTTP_401_UNAUTHORIZED,
  59. detail=ERROR_MESSAGES.UNAUTHORIZED,
  60. )
  61. if not form_data.id.isidentifier():
  62. raise HTTPException(
  63. status_code=status.HTTP_400_BAD_REQUEST,
  64. detail="Only alphanumeric characters and underscores are allowed in the id",
  65. )
  66. form_data.id = form_data.id.lower()
  67. tools = Tools.get_tool_by_id(form_data.id)
  68. if tools is None:
  69. try:
  70. form_data.content = replace_imports(form_data.content)
  71. tools_module, frontmatter = load_tools_module_by_id(
  72. form_data.id, content=form_data.content
  73. )
  74. form_data.meta.manifest = frontmatter
  75. TOOLS = request.app.state.TOOLS
  76. TOOLS[form_data.id] = tools_module
  77. specs = get_tools_specs(TOOLS[form_data.id])
  78. tools = Tools.insert_new_tool(user.id, form_data, specs)
  79. tool_cache_dir = Path(CACHE_DIR) / "tools" / form_data.id
  80. tool_cache_dir.mkdir(parents=True, exist_ok=True)
  81. if tools:
  82. return tools
  83. else:
  84. raise HTTPException(
  85. status_code=status.HTTP_400_BAD_REQUEST,
  86. detail=ERROR_MESSAGES.DEFAULT("Error creating tools"),
  87. )
  88. except Exception as e:
  89. print(e)
  90. raise HTTPException(
  91. status_code=status.HTTP_400_BAD_REQUEST,
  92. detail=ERROR_MESSAGES.DEFAULT(str(e)),
  93. )
  94. else:
  95. raise HTTPException(
  96. status_code=status.HTTP_400_BAD_REQUEST,
  97. detail=ERROR_MESSAGES.ID_TAKEN,
  98. )
  99. ############################
  100. # GetToolsById
  101. ############################
  102. @router.get("/id/{id}", response_model=Optional[ToolModel])
  103. async def get_tools_by_id(id: str, user=Depends(get_verified_user)):
  104. tools = Tools.get_tool_by_id(id)
  105. if tools:
  106. if (
  107. user.role == "admin"
  108. or tools.user_id == user.id
  109. or has_access(user.id, "read", tools.access_control)
  110. ):
  111. return tools
  112. else:
  113. raise HTTPException(
  114. status_code=status.HTTP_401_UNAUTHORIZED,
  115. detail=ERROR_MESSAGES.NOT_FOUND,
  116. )
  117. ############################
  118. # UpdateToolsById
  119. ############################
  120. @router.post("/id/{id}/update", response_model=Optional[ToolModel])
  121. async def update_tools_by_id(
  122. request: Request,
  123. id: str,
  124. form_data: ToolForm,
  125. user=Depends(get_verified_user),
  126. ):
  127. tools = Tools.get_tool_by_id(id)
  128. if not tools:
  129. raise HTTPException(
  130. status_code=status.HTTP_401_UNAUTHORIZED,
  131. detail=ERROR_MESSAGES.NOT_FOUND,
  132. )
  133. # Is the user the original creator, in a group with write access, or an admin
  134. if (
  135. tools.user_id != user.id
  136. and not has_access(user.id, "write", tools.access_control)
  137. and user.role != "admin"
  138. ):
  139. raise HTTPException(
  140. status_code=status.HTTP_401_UNAUTHORIZED,
  141. detail=ERROR_MESSAGES.UNAUTHORIZED,
  142. )
  143. try:
  144. form_data.content = replace_imports(form_data.content)
  145. tools_module, frontmatter = load_tools_module_by_id(
  146. id, content=form_data.content
  147. )
  148. form_data.meta.manifest = frontmatter
  149. TOOLS = request.app.state.TOOLS
  150. TOOLS[id] = tools_module
  151. specs = get_tools_specs(TOOLS[id])
  152. updated = {
  153. **form_data.model_dump(exclude={"id"}),
  154. "specs": specs,
  155. }
  156. print(updated)
  157. tools = Tools.update_tool_by_id(id, updated)
  158. if tools:
  159. return tools
  160. else:
  161. raise HTTPException(
  162. status_code=status.HTTP_400_BAD_REQUEST,
  163. detail=ERROR_MESSAGES.DEFAULT("Error updating tools"),
  164. )
  165. except Exception as e:
  166. raise HTTPException(
  167. status_code=status.HTTP_400_BAD_REQUEST,
  168. detail=ERROR_MESSAGES.DEFAULT(str(e)),
  169. )
  170. ############################
  171. # DeleteToolsById
  172. ############################
  173. @router.delete("/id/{id}/delete", response_model=bool)
  174. async def delete_tools_by_id(
  175. request: Request, id: str, user=Depends(get_verified_user)
  176. ):
  177. tools = Tools.get_tool_by_id(id)
  178. if not tools:
  179. raise HTTPException(
  180. status_code=status.HTTP_401_UNAUTHORIZED,
  181. detail=ERROR_MESSAGES.NOT_FOUND,
  182. )
  183. if tools.user_id != user.id and user.role != "admin":
  184. raise HTTPException(
  185. status_code=status.HTTP_401_UNAUTHORIZED,
  186. detail=ERROR_MESSAGES.UNAUTHORIZED,
  187. )
  188. result = Tools.delete_tool_by_id(id)
  189. if result:
  190. TOOLS = request.app.state.TOOLS
  191. if id in TOOLS:
  192. del TOOLS[id]
  193. return result
  194. ############################
  195. # GetToolValves
  196. ############################
  197. @router.get("/id/{id}/valves", response_model=Optional[dict])
  198. async def get_tools_valves_by_id(id: str, user=Depends(get_verified_user)):
  199. tools = Tools.get_tool_by_id(id)
  200. if tools:
  201. try:
  202. valves = Tools.get_tool_valves_by_id(id)
  203. return valves
  204. except Exception as e:
  205. raise HTTPException(
  206. status_code=status.HTTP_400_BAD_REQUEST,
  207. detail=ERROR_MESSAGES.DEFAULT(str(e)),
  208. )
  209. else:
  210. raise HTTPException(
  211. status_code=status.HTTP_401_UNAUTHORIZED,
  212. detail=ERROR_MESSAGES.NOT_FOUND,
  213. )
  214. ############################
  215. # GetToolValvesSpec
  216. ############################
  217. @router.get("/id/{id}/valves/spec", response_model=Optional[dict])
  218. async def get_tools_valves_spec_by_id(
  219. request: Request, id: str, user=Depends(get_verified_user)
  220. ):
  221. tools = Tools.get_tool_by_id(id)
  222. if tools:
  223. if id in request.app.state.TOOLS:
  224. tools_module = request.app.state.TOOLS[id]
  225. else:
  226. tools_module, _ = load_tools_module_by_id(id)
  227. request.app.state.TOOLS[id] = tools_module
  228. if hasattr(tools_module, "Valves"):
  229. Valves = tools_module.Valves
  230. return Valves.schema()
  231. return None
  232. else:
  233. raise HTTPException(
  234. status_code=status.HTTP_401_UNAUTHORIZED,
  235. detail=ERROR_MESSAGES.NOT_FOUND,
  236. )
  237. ############################
  238. # UpdateToolValves
  239. ############################
  240. @router.post("/id/{id}/valves/update", response_model=Optional[dict])
  241. async def update_tools_valves_by_id(
  242. request: Request, id: str, form_data: dict, user=Depends(get_verified_user)
  243. ):
  244. tools = Tools.get_tool_by_id(id)
  245. if not tools:
  246. raise HTTPException(
  247. status_code=status.HTTP_401_UNAUTHORIZED,
  248. detail=ERROR_MESSAGES.NOT_FOUND,
  249. )
  250. if id in request.app.state.TOOLS:
  251. tools_module = request.app.state.TOOLS[id]
  252. else:
  253. tools_module, _ = load_tools_module_by_id(id)
  254. request.app.state.TOOLS[id] = tools_module
  255. if not hasattr(tools_module, "Valves"):
  256. raise HTTPException(
  257. status_code=status.HTTP_401_UNAUTHORIZED,
  258. detail=ERROR_MESSAGES.NOT_FOUND,
  259. )
  260. Valves = tools_module.Valves
  261. try:
  262. form_data = {k: v for k, v in form_data.items() if v is not None}
  263. valves = Valves(**form_data)
  264. Tools.update_tool_valves_by_id(id, valves.model_dump())
  265. return valves.model_dump()
  266. except Exception as e:
  267. print(e)
  268. raise HTTPException(
  269. status_code=status.HTTP_400_BAD_REQUEST,
  270. detail=ERROR_MESSAGES.DEFAULT(str(e)),
  271. )
  272. ############################
  273. # ToolUserValves
  274. ############################
  275. @router.get("/id/{id}/valves/user", response_model=Optional[dict])
  276. async def get_tools_user_valves_by_id(id: str, user=Depends(get_verified_user)):
  277. tools = Tools.get_tool_by_id(id)
  278. if tools:
  279. try:
  280. user_valves = Tools.get_user_valves_by_id_and_user_id(id, user.id)
  281. return user_valves
  282. except Exception as e:
  283. raise HTTPException(
  284. status_code=status.HTTP_400_BAD_REQUEST,
  285. detail=ERROR_MESSAGES.DEFAULT(str(e)),
  286. )
  287. else:
  288. raise HTTPException(
  289. status_code=status.HTTP_401_UNAUTHORIZED,
  290. detail=ERROR_MESSAGES.NOT_FOUND,
  291. )
  292. @router.get("/id/{id}/valves/user/spec", response_model=Optional[dict])
  293. async def get_tools_user_valves_spec_by_id(
  294. request: Request, id: str, user=Depends(get_verified_user)
  295. ):
  296. tools = Tools.get_tool_by_id(id)
  297. if tools:
  298. if id in request.app.state.TOOLS:
  299. tools_module = request.app.state.TOOLS[id]
  300. else:
  301. tools_module, _ = load_tools_module_by_id(id)
  302. request.app.state.TOOLS[id] = tools_module
  303. if hasattr(tools_module, "UserValves"):
  304. UserValves = tools_module.UserValves
  305. return UserValves.schema()
  306. return None
  307. else:
  308. raise HTTPException(
  309. status_code=status.HTTP_401_UNAUTHORIZED,
  310. detail=ERROR_MESSAGES.NOT_FOUND,
  311. )
  312. @router.post("/id/{id}/valves/user/update", response_model=Optional[dict])
  313. async def update_tools_user_valves_by_id(
  314. request: Request, id: str, form_data: dict, user=Depends(get_verified_user)
  315. ):
  316. tools = Tools.get_tool_by_id(id)
  317. if tools:
  318. if id in request.app.state.TOOLS:
  319. tools_module = request.app.state.TOOLS[id]
  320. else:
  321. tools_module, _ = load_tools_module_by_id(id)
  322. request.app.state.TOOLS[id] = tools_module
  323. if hasattr(tools_module, "UserValves"):
  324. UserValves = tools_module.UserValves
  325. try:
  326. form_data = {k: v for k, v in form_data.items() if v is not None}
  327. user_valves = UserValves(**form_data)
  328. Tools.update_user_valves_by_id_and_user_id(
  329. id, user.id, user_valves.model_dump()
  330. )
  331. return user_valves.model_dump()
  332. except Exception as e:
  333. print(e)
  334. raise HTTPException(
  335. status_code=status.HTTP_400_BAD_REQUEST,
  336. detail=ERROR_MESSAGES.DEFAULT(str(e)),
  337. )
  338. else:
  339. raise HTTPException(
  340. status_code=status.HTTP_401_UNAUTHORIZED,
  341. detail=ERROR_MESSAGES.NOT_FOUND,
  342. )
  343. else:
  344. raise HTTPException(
  345. status_code=status.HTTP_401_UNAUTHORIZED,
  346. detail=ERROR_MESSAGES.NOT_FOUND,
  347. )