tools.py 12 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419
  1. from pathlib import Path
  2. from typing import Optional
  3. from open_webui.models.tools import (
  4. ToolForm,
  5. ToolModel,
  6. ToolResponse,
  7. ToolUserResponse,
  8. Tools,
  9. )
  10. from open_webui.utils.plugin import load_tools_module_by_id, replace_imports
  11. from open_webui.config import CACHE_DIR
  12. from open_webui.constants import ERROR_MESSAGES
  13. from fastapi import APIRouter, Depends, HTTPException, Request, status
  14. from open_webui.utils.tools import get_tools_specs
  15. from open_webui.utils.auth import get_admin_user, get_verified_user
  16. from open_webui.utils.access_control import has_access, has_permission
  17. router = APIRouter()
  18. ############################
  19. # GetTools
  20. ############################
  21. @router.get("/", response_model=list[ToolUserResponse])
  22. async def get_tools(user=Depends(get_verified_user)):
  23. if user.role == "admin":
  24. tools = Tools.get_tools()
  25. else:
  26. tools = Tools.get_tools_by_user_id(user.id, "read")
  27. return tools
  28. ############################
  29. # GetToolList
  30. ############################
  31. @router.get("/list", response_model=list[ToolUserResponse])
  32. async def get_tool_list(user=Depends(get_verified_user)):
  33. if user.role == "admin":
  34. tools = Tools.get_tools()
  35. else:
  36. tools = Tools.get_tools_by_user_id(user.id, "write")
  37. return tools
  38. ############################
  39. # ExportTools
  40. ############################
  41. @router.get("/export", response_model=list[ToolModel])
  42. async def export_tools(user=Depends(get_admin_user)):
  43. tools = Tools.get_tools()
  44. return tools
  45. ############################
  46. # CreateNewTools
  47. ############################
  48. @router.post("/create", response_model=Optional[ToolResponse])
  49. async def create_new_tools(
  50. request: Request,
  51. form_data: ToolForm,
  52. user=Depends(get_verified_user),
  53. ):
  54. if user.role != "admin" and not has_permission(
  55. user.id, "workspace.knowledge", request.app.state.config.USER_PERMISSIONS
  56. ):
  57. raise HTTPException(
  58. status_code=status.HTTP_401_UNAUTHORIZED,
  59. detail=ERROR_MESSAGES.UNAUTHORIZED,
  60. )
  61. if not form_data.id.isidentifier():
  62. raise HTTPException(
  63. status_code=status.HTTP_400_BAD_REQUEST,
  64. detail="Only alphanumeric characters and underscores are allowed in the id",
  65. )
  66. form_data.id = form_data.id.lower()
  67. tools = Tools.get_tool_by_id(form_data.id)
  68. if tools is None:
  69. try:
  70. form_data.content = replace_imports(form_data.content)
  71. tools_module, frontmatter = load_tools_module_by_id(
  72. form_data.id, content=form_data.content
  73. )
  74. form_data.meta.manifest = frontmatter
  75. TOOLS = request.app.state.TOOLS
  76. TOOLS[form_data.id] = tools_module
  77. specs = get_tools_specs(TOOLS[form_data.id])
  78. tools = Tools.insert_new_tool(user.id, form_data, specs)
  79. tool_cache_dir = Path(CACHE_DIR) / "tools" / form_data.id
  80. tool_cache_dir.mkdir(parents=True, exist_ok=True)
  81. if tools:
  82. return tools
  83. else:
  84. raise HTTPException(
  85. status_code=status.HTTP_400_BAD_REQUEST,
  86. detail=ERROR_MESSAGES.DEFAULT("Error creating tools"),
  87. )
  88. except Exception as e:
  89. print(e)
  90. raise HTTPException(
  91. status_code=status.HTTP_400_BAD_REQUEST,
  92. detail=ERROR_MESSAGES.DEFAULT(str(e)),
  93. )
  94. else:
  95. raise HTTPException(
  96. status_code=status.HTTP_400_BAD_REQUEST,
  97. detail=ERROR_MESSAGES.ID_TAKEN,
  98. )
  99. ############################
  100. # GetToolsById
  101. ############################
  102. @router.get("/id/{id}", response_model=Optional[ToolModel])
  103. async def get_tools_by_id(id: str, user=Depends(get_verified_user)):
  104. tools = Tools.get_tool_by_id(id)
  105. if tools:
  106. if (
  107. user.role == "admin"
  108. or tools.user_id == user.id
  109. or has_access(user.id, "read", tools.access_control)
  110. ):
  111. return tools
  112. else:
  113. raise HTTPException(
  114. status_code=status.HTTP_401_UNAUTHORIZED,
  115. detail=ERROR_MESSAGES.NOT_FOUND,
  116. )
  117. ############################
  118. # UpdateToolsById
  119. ############################
  120. @router.post("/id/{id}/update", response_model=Optional[ToolModel])
  121. async def update_tools_by_id(
  122. request: Request,
  123. id: str,
  124. form_data: ToolForm,
  125. user=Depends(get_verified_user),
  126. ):
  127. tools = Tools.get_tool_by_id(id)
  128. if not tools:
  129. raise HTTPException(
  130. status_code=status.HTTP_401_UNAUTHORIZED,
  131. detail=ERROR_MESSAGES.NOT_FOUND,
  132. )
  133. # Is the user the original creator, in a group with write access, or an admin
  134. if tools.user_id != user.id and not has_access(user.id, "write", tools.access_control) and user.role != "admin":
  135. raise HTTPException(
  136. status_code=status.HTTP_401_UNAUTHORIZED,
  137. detail=ERROR_MESSAGES.UNAUTHORIZED,
  138. )
  139. try:
  140. form_data.content = replace_imports(form_data.content)
  141. tools_module, frontmatter = load_tools_module_by_id(
  142. id, content=form_data.content
  143. )
  144. form_data.meta.manifest = frontmatter
  145. TOOLS = request.app.state.TOOLS
  146. TOOLS[id] = tools_module
  147. specs = get_tools_specs(TOOLS[id])
  148. updated = {
  149. **form_data.model_dump(exclude={"id"}),
  150. "specs": specs,
  151. }
  152. print(updated)
  153. tools = Tools.update_tool_by_id(id, updated)
  154. if tools:
  155. return tools
  156. else:
  157. raise HTTPException(
  158. status_code=status.HTTP_400_BAD_REQUEST,
  159. detail=ERROR_MESSAGES.DEFAULT("Error updating tools"),
  160. )
  161. except Exception as e:
  162. raise HTTPException(
  163. status_code=status.HTTP_400_BAD_REQUEST,
  164. detail=ERROR_MESSAGES.DEFAULT(str(e)),
  165. )
  166. ############################
  167. # DeleteToolsById
  168. ############################
  169. @router.delete("/id/{id}/delete", response_model=bool)
  170. async def delete_tools_by_id(
  171. request: Request, id: str, user=Depends(get_verified_user)
  172. ):
  173. tools = Tools.get_tool_by_id(id)
  174. if not tools:
  175. raise HTTPException(
  176. status_code=status.HTTP_401_UNAUTHORIZED,
  177. detail=ERROR_MESSAGES.NOT_FOUND,
  178. )
  179. if tools.user_id != user.id and user.role != "admin":
  180. raise HTTPException(
  181. status_code=status.HTTP_401_UNAUTHORIZED,
  182. detail=ERROR_MESSAGES.UNAUTHORIZED,
  183. )
  184. result = Tools.delete_tool_by_id(id)
  185. if result:
  186. TOOLS = request.app.state.TOOLS
  187. if id in TOOLS:
  188. del TOOLS[id]
  189. return result
  190. ############################
  191. # GetToolValves
  192. ############################
  193. @router.get("/id/{id}/valves", response_model=Optional[dict])
  194. async def get_tools_valves_by_id(id: str, user=Depends(get_verified_user)):
  195. tools = Tools.get_tool_by_id(id)
  196. if tools:
  197. try:
  198. valves = Tools.get_tool_valves_by_id(id)
  199. return valves
  200. except Exception as e:
  201. raise HTTPException(
  202. status_code=status.HTTP_400_BAD_REQUEST,
  203. detail=ERROR_MESSAGES.DEFAULT(str(e)),
  204. )
  205. else:
  206. raise HTTPException(
  207. status_code=status.HTTP_401_UNAUTHORIZED,
  208. detail=ERROR_MESSAGES.NOT_FOUND,
  209. )
  210. ############################
  211. # GetToolValvesSpec
  212. ############################
  213. @router.get("/id/{id}/valves/spec", response_model=Optional[dict])
  214. async def get_tools_valves_spec_by_id(
  215. request: Request, id: str, user=Depends(get_verified_user)
  216. ):
  217. tools = Tools.get_tool_by_id(id)
  218. if tools:
  219. if id in request.app.state.TOOLS:
  220. tools_module = request.app.state.TOOLS[id]
  221. else:
  222. tools_module, _ = load_tools_module_by_id(id)
  223. request.app.state.TOOLS[id] = tools_module
  224. if hasattr(tools_module, "Valves"):
  225. Valves = tools_module.Valves
  226. return Valves.schema()
  227. return None
  228. else:
  229. raise HTTPException(
  230. status_code=status.HTTP_401_UNAUTHORIZED,
  231. detail=ERROR_MESSAGES.NOT_FOUND,
  232. )
  233. ############################
  234. # UpdateToolValves
  235. ############################
  236. @router.post("/id/{id}/valves/update", response_model=Optional[dict])
  237. async def update_tools_valves_by_id(
  238. request: Request, id: str, form_data: dict, user=Depends(get_verified_user)
  239. ):
  240. tools = Tools.get_tool_by_id(id)
  241. if not tools:
  242. raise HTTPException(
  243. status_code=status.HTTP_401_UNAUTHORIZED,
  244. detail=ERROR_MESSAGES.NOT_FOUND,
  245. )
  246. if id in request.app.state.TOOLS:
  247. tools_module = request.app.state.TOOLS[id]
  248. else:
  249. tools_module, _ = load_tools_module_by_id(id)
  250. request.app.state.TOOLS[id] = tools_module
  251. if not hasattr(tools_module, "Valves"):
  252. raise HTTPException(
  253. status_code=status.HTTP_401_UNAUTHORIZED,
  254. detail=ERROR_MESSAGES.NOT_FOUND,
  255. )
  256. Valves = tools_module.Valves
  257. try:
  258. form_data = {k: v for k, v in form_data.items() if v is not None}
  259. valves = Valves(**form_data)
  260. Tools.update_tool_valves_by_id(id, valves.model_dump())
  261. return valves.model_dump()
  262. except Exception as e:
  263. print(e)
  264. raise HTTPException(
  265. status_code=status.HTTP_400_BAD_REQUEST,
  266. detail=ERROR_MESSAGES.DEFAULT(str(e)),
  267. )
  268. ############################
  269. # ToolUserValves
  270. ############################
  271. @router.get("/id/{id}/valves/user", response_model=Optional[dict])
  272. async def get_tools_user_valves_by_id(id: str, user=Depends(get_verified_user)):
  273. tools = Tools.get_tool_by_id(id)
  274. if tools:
  275. try:
  276. user_valves = Tools.get_user_valves_by_id_and_user_id(id, user.id)
  277. return user_valves
  278. except Exception as e:
  279. raise HTTPException(
  280. status_code=status.HTTP_400_BAD_REQUEST,
  281. detail=ERROR_MESSAGES.DEFAULT(str(e)),
  282. )
  283. else:
  284. raise HTTPException(
  285. status_code=status.HTTP_401_UNAUTHORIZED,
  286. detail=ERROR_MESSAGES.NOT_FOUND,
  287. )
  288. @router.get("/id/{id}/valves/user/spec", response_model=Optional[dict])
  289. async def get_tools_user_valves_spec_by_id(
  290. request: Request, id: str, user=Depends(get_verified_user)
  291. ):
  292. tools = Tools.get_tool_by_id(id)
  293. if tools:
  294. if id in request.app.state.TOOLS:
  295. tools_module = request.app.state.TOOLS[id]
  296. else:
  297. tools_module, _ = load_tools_module_by_id(id)
  298. request.app.state.TOOLS[id] = tools_module
  299. if hasattr(tools_module, "UserValves"):
  300. UserValves = tools_module.UserValves
  301. return UserValves.schema()
  302. return None
  303. else:
  304. raise HTTPException(
  305. status_code=status.HTTP_401_UNAUTHORIZED,
  306. detail=ERROR_MESSAGES.NOT_FOUND,
  307. )
  308. @router.post("/id/{id}/valves/user/update", response_model=Optional[dict])
  309. async def update_tools_user_valves_by_id(
  310. request: Request, id: str, form_data: dict, user=Depends(get_verified_user)
  311. ):
  312. tools = Tools.get_tool_by_id(id)
  313. if tools:
  314. if id in request.app.state.TOOLS:
  315. tools_module = request.app.state.TOOLS[id]
  316. else:
  317. tools_module, _ = load_tools_module_by_id(id)
  318. request.app.state.TOOLS[id] = tools_module
  319. if hasattr(tools_module, "UserValves"):
  320. UserValves = tools_module.UserValves
  321. try:
  322. form_data = {k: v for k, v in form_data.items() if v is not None}
  323. user_valves = UserValves(**form_data)
  324. Tools.update_user_valves_by_id_and_user_id(
  325. id, user.id, user_valves.model_dump()
  326. )
  327. return user_valves.model_dump()
  328. except Exception as e:
  329. print(e)
  330. raise HTTPException(
  331. status_code=status.HTTP_400_BAD_REQUEST,
  332. detail=ERROR_MESSAGES.DEFAULT(str(e)),
  333. )
  334. else:
  335. raise HTTPException(
  336. status_code=status.HTTP_401_UNAUTHORIZED,
  337. detail=ERROR_MESSAGES.NOT_FOUND,
  338. )
  339. else:
  340. raise HTTPException(
  341. status_code=status.HTTP_401_UNAUTHORIZED,
  342. detail=ERROR_MESSAGES.NOT_FOUND,
  343. )