tools.py 12 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418
  1. from pathlib import Path
  2. from typing import Optional
  3. from open_webui.models.tools import (
  4. ToolForm,
  5. ToolModel,
  6. ToolResponse,
  7. ToolUserResponse,
  8. Tools,
  9. )
  10. from open_webui.utils.plugin import load_tools_module_by_id, replace_imports
  11. from open_webui.config import CACHE_DIR
  12. from open_webui.constants import ERROR_MESSAGES
  13. from fastapi import APIRouter, Depends, HTTPException, Request, status
  14. from open_webui.utils.tools import get_tools_specs
  15. from open_webui.utils.auth import get_admin_user, get_verified_user
  16. from open_webui.utils.access_control import has_access, has_permission
  17. router = APIRouter()
  18. ############################
  19. # GetTools
  20. ############################
  21. @router.get("/", response_model=list[ToolUserResponse])
  22. async def get_tools(user=Depends(get_verified_user)):
  23. if user.role == "admin":
  24. tools = Tools.get_tools()
  25. else:
  26. tools = Tools.get_tools_by_user_id(user.id, "read")
  27. return tools
  28. ############################
  29. # GetToolList
  30. ############################
  31. @router.get("/list", response_model=list[ToolUserResponse])
  32. async def get_tool_list(user=Depends(get_verified_user)):
  33. if user.role == "admin":
  34. tools = Tools.get_tools()
  35. else:
  36. tools = Tools.get_tools_by_user_id(user.id, "write")
  37. return tools
  38. ############################
  39. # ExportTools
  40. ############################
  41. @router.get("/export", response_model=list[ToolModel])
  42. async def export_tools(user=Depends(get_admin_user)):
  43. tools = Tools.get_tools()
  44. return tools
  45. ############################
  46. # CreateNewTools
  47. ############################
  48. @router.post("/create", response_model=Optional[ToolResponse])
  49. async def create_new_tools(
  50. request: Request,
  51. form_data: ToolForm,
  52. user=Depends(get_verified_user),
  53. ):
  54. if user.role != "admin" and not has_permission(
  55. user.id, "workspace.knowledge", request.app.state.config.USER_PERMISSIONS
  56. ):
  57. raise HTTPException(
  58. status_code=status.HTTP_401_UNAUTHORIZED,
  59. detail=ERROR_MESSAGES.UNAUTHORIZED,
  60. )
  61. if not form_data.id.isidentifier():
  62. raise HTTPException(
  63. status_code=status.HTTP_400_BAD_REQUEST,
  64. detail="Only alphanumeric characters and underscores are allowed in the id",
  65. )
  66. form_data.id = form_data.id.lower()
  67. tools = Tools.get_tool_by_id(form_data.id)
  68. if tools is None:
  69. try:
  70. form_data.content = replace_imports(form_data.content)
  71. tools_module, frontmatter = load_tools_module_by_id(
  72. form_data.id, content=form_data.content
  73. )
  74. form_data.meta.manifest = frontmatter
  75. TOOLS = request.app.state.TOOLS
  76. TOOLS[form_data.id] = tools_module
  77. specs = get_tools_specs(TOOLS[form_data.id])
  78. tools = Tools.insert_new_tool(user.id, form_data, specs)
  79. tool_cache_dir = Path(CACHE_DIR) / "tools" / form_data.id
  80. tool_cache_dir.mkdir(parents=True, exist_ok=True)
  81. if tools:
  82. return tools
  83. else:
  84. raise HTTPException(
  85. status_code=status.HTTP_400_BAD_REQUEST,
  86. detail=ERROR_MESSAGES.DEFAULT("Error creating tools"),
  87. )
  88. except Exception as e:
  89. print(e)
  90. raise HTTPException(
  91. status_code=status.HTTP_400_BAD_REQUEST,
  92. detail=ERROR_MESSAGES.DEFAULT(str(e)),
  93. )
  94. else:
  95. raise HTTPException(
  96. status_code=status.HTTP_400_BAD_REQUEST,
  97. detail=ERROR_MESSAGES.ID_TAKEN,
  98. )
  99. ############################
  100. # GetToolsById
  101. ############################
  102. @router.get("/id/{id}", response_model=Optional[ToolModel])
  103. async def get_tools_by_id(id: str, user=Depends(get_verified_user)):
  104. tools = Tools.get_tool_by_id(id)
  105. if tools:
  106. if (
  107. user.role == "admin"
  108. or tools.user_id == user.id
  109. or has_access(user.id, "read", tools.access_control)
  110. ):
  111. return tools
  112. else:
  113. raise HTTPException(
  114. status_code=status.HTTP_401_UNAUTHORIZED,
  115. detail=ERROR_MESSAGES.NOT_FOUND,
  116. )
  117. ############################
  118. # UpdateToolsById
  119. ############################
  120. @router.post("/id/{id}/update", response_model=Optional[ToolModel])
  121. async def update_tools_by_id(
  122. request: Request,
  123. id: str,
  124. form_data: ToolForm,
  125. user=Depends(get_verified_user),
  126. ):
  127. tools = Tools.get_tool_by_id(id)
  128. if not tools:
  129. raise HTTPException(
  130. status_code=status.HTTP_401_UNAUTHORIZED,
  131. detail=ERROR_MESSAGES.NOT_FOUND,
  132. )
  133. if tools.user_id != user.id and user.role != "admin":
  134. raise HTTPException(
  135. status_code=status.HTTP_401_UNAUTHORIZED,
  136. detail=ERROR_MESSAGES.UNAUTHORIZED,
  137. )
  138. try:
  139. form_data.content = replace_imports(form_data.content)
  140. tools_module, frontmatter = load_tools_module_by_id(
  141. id, content=form_data.content
  142. )
  143. form_data.meta.manifest = frontmatter
  144. TOOLS = request.app.state.TOOLS
  145. TOOLS[id] = tools_module
  146. specs = get_tools_specs(TOOLS[id])
  147. updated = {
  148. **form_data.model_dump(exclude={"id"}),
  149. "specs": specs,
  150. }
  151. print(updated)
  152. tools = Tools.update_tool_by_id(id, updated)
  153. if tools:
  154. return tools
  155. else:
  156. raise HTTPException(
  157. status_code=status.HTTP_400_BAD_REQUEST,
  158. detail=ERROR_MESSAGES.DEFAULT("Error updating tools"),
  159. )
  160. except Exception as e:
  161. raise HTTPException(
  162. status_code=status.HTTP_400_BAD_REQUEST,
  163. detail=ERROR_MESSAGES.DEFAULT(str(e)),
  164. )
  165. ############################
  166. # DeleteToolsById
  167. ############################
  168. @router.delete("/id/{id}/delete", response_model=bool)
  169. async def delete_tools_by_id(
  170. request: Request, id: str, user=Depends(get_verified_user)
  171. ):
  172. tools = Tools.get_tool_by_id(id)
  173. if not tools:
  174. raise HTTPException(
  175. status_code=status.HTTP_401_UNAUTHORIZED,
  176. detail=ERROR_MESSAGES.NOT_FOUND,
  177. )
  178. if tools.user_id != user.id and user.role != "admin":
  179. raise HTTPException(
  180. status_code=status.HTTP_401_UNAUTHORIZED,
  181. detail=ERROR_MESSAGES.UNAUTHORIZED,
  182. )
  183. result = Tools.delete_tool_by_id(id)
  184. if result:
  185. TOOLS = request.app.state.TOOLS
  186. if id in TOOLS:
  187. del TOOLS[id]
  188. return result
  189. ############################
  190. # GetToolValves
  191. ############################
  192. @router.get("/id/{id}/valves", response_model=Optional[dict])
  193. async def get_tools_valves_by_id(id: str, user=Depends(get_verified_user)):
  194. tools = Tools.get_tool_by_id(id)
  195. if tools:
  196. try:
  197. valves = Tools.get_tool_valves_by_id(id)
  198. return valves
  199. except Exception as e:
  200. raise HTTPException(
  201. status_code=status.HTTP_400_BAD_REQUEST,
  202. detail=ERROR_MESSAGES.DEFAULT(str(e)),
  203. )
  204. else:
  205. raise HTTPException(
  206. status_code=status.HTTP_401_UNAUTHORIZED,
  207. detail=ERROR_MESSAGES.NOT_FOUND,
  208. )
  209. ############################
  210. # GetToolValvesSpec
  211. ############################
  212. @router.get("/id/{id}/valves/spec", response_model=Optional[dict])
  213. async def get_tools_valves_spec_by_id(
  214. request: Request, id: str, user=Depends(get_verified_user)
  215. ):
  216. tools = Tools.get_tool_by_id(id)
  217. if tools:
  218. if id in request.app.state.TOOLS:
  219. tools_module = request.app.state.TOOLS[id]
  220. else:
  221. tools_module, _ = load_tools_module_by_id(id)
  222. request.app.state.TOOLS[id] = tools_module
  223. if hasattr(tools_module, "Valves"):
  224. Valves = tools_module.Valves
  225. return Valves.schema()
  226. return None
  227. else:
  228. raise HTTPException(
  229. status_code=status.HTTP_401_UNAUTHORIZED,
  230. detail=ERROR_MESSAGES.NOT_FOUND,
  231. )
  232. ############################
  233. # UpdateToolValves
  234. ############################
  235. @router.post("/id/{id}/valves/update", response_model=Optional[dict])
  236. async def update_tools_valves_by_id(
  237. request: Request, id: str, form_data: dict, user=Depends(get_verified_user)
  238. ):
  239. tools = Tools.get_tool_by_id(id)
  240. if not tools:
  241. raise HTTPException(
  242. status_code=status.HTTP_401_UNAUTHORIZED,
  243. detail=ERROR_MESSAGES.NOT_FOUND,
  244. )
  245. if id in request.app.state.TOOLS:
  246. tools_module = request.app.state.TOOLS[id]
  247. else:
  248. tools_module, _ = load_tools_module_by_id(id)
  249. request.app.state.TOOLS[id] = tools_module
  250. if not hasattr(tools_module, "Valves"):
  251. raise HTTPException(
  252. status_code=status.HTTP_401_UNAUTHORIZED,
  253. detail=ERROR_MESSAGES.NOT_FOUND,
  254. )
  255. Valves = tools_module.Valves
  256. try:
  257. form_data = {k: v for k, v in form_data.items() if v is not None}
  258. valves = Valves(**form_data)
  259. Tools.update_tool_valves_by_id(id, valves.model_dump())
  260. return valves.model_dump()
  261. except Exception as e:
  262. print(e)
  263. raise HTTPException(
  264. status_code=status.HTTP_400_BAD_REQUEST,
  265. detail=ERROR_MESSAGES.DEFAULT(str(e)),
  266. )
  267. ############################
  268. # ToolUserValves
  269. ############################
  270. @router.get("/id/{id}/valves/user", response_model=Optional[dict])
  271. async def get_tools_user_valves_by_id(id: str, user=Depends(get_verified_user)):
  272. tools = Tools.get_tool_by_id(id)
  273. if tools:
  274. try:
  275. user_valves = Tools.get_user_valves_by_id_and_user_id(id, user.id)
  276. return user_valves
  277. except Exception as e:
  278. raise HTTPException(
  279. status_code=status.HTTP_400_BAD_REQUEST,
  280. detail=ERROR_MESSAGES.DEFAULT(str(e)),
  281. )
  282. else:
  283. raise HTTPException(
  284. status_code=status.HTTP_401_UNAUTHORIZED,
  285. detail=ERROR_MESSAGES.NOT_FOUND,
  286. )
  287. @router.get("/id/{id}/valves/user/spec", response_model=Optional[dict])
  288. async def get_tools_user_valves_spec_by_id(
  289. request: Request, id: str, user=Depends(get_verified_user)
  290. ):
  291. tools = Tools.get_tool_by_id(id)
  292. if tools:
  293. if id in request.app.state.TOOLS:
  294. tools_module = request.app.state.TOOLS[id]
  295. else:
  296. tools_module, _ = load_tools_module_by_id(id)
  297. request.app.state.TOOLS[id] = tools_module
  298. if hasattr(tools_module, "UserValves"):
  299. UserValves = tools_module.UserValves
  300. return UserValves.schema()
  301. return None
  302. else:
  303. raise HTTPException(
  304. status_code=status.HTTP_401_UNAUTHORIZED,
  305. detail=ERROR_MESSAGES.NOT_FOUND,
  306. )
  307. @router.post("/id/{id}/valves/user/update", response_model=Optional[dict])
  308. async def update_tools_user_valves_by_id(
  309. request: Request, id: str, form_data: dict, user=Depends(get_verified_user)
  310. ):
  311. tools = Tools.get_tool_by_id(id)
  312. if tools:
  313. if id in request.app.state.TOOLS:
  314. tools_module = request.app.state.TOOLS[id]
  315. else:
  316. tools_module, _ = load_tools_module_by_id(id)
  317. request.app.state.TOOLS[id] = tools_module
  318. if hasattr(tools_module, "UserValves"):
  319. UserValves = tools_module.UserValves
  320. try:
  321. form_data = {k: v for k, v in form_data.items() if v is not None}
  322. user_valves = UserValves(**form_data)
  323. Tools.update_user_valves_by_id_and_user_id(
  324. id, user.id, user_valves.model_dump()
  325. )
  326. return user_valves.model_dump()
  327. except Exception as e:
  328. print(e)
  329. raise HTTPException(
  330. status_code=status.HTTP_400_BAD_REQUEST,
  331. detail=ERROR_MESSAGES.DEFAULT(str(e)),
  332. )
  333. else:
  334. raise HTTPException(
  335. status_code=status.HTTP_401_UNAUTHORIZED,
  336. detail=ERROR_MESSAGES.NOT_FOUND,
  337. )
  338. else:
  339. raise HTTPException(
  340. status_code=status.HTTP_401_UNAUTHORIZED,
  341. detail=ERROR_MESSAGES.NOT_FOUND,
  342. )