tools.py 12 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422
  1. import os
  2. from pathlib import Path
  3. from typing import Optional
  4. from open_webui.apps.webui.models.tools import (
  5. ToolForm,
  6. ToolModel,
  7. ToolResponse,
  8. ToolUserResponse,
  9. Tools,
  10. )
  11. from open_webui.apps.webui.utils import load_tools_module_by_id, replace_imports
  12. from open_webui.config import CACHE_DIR, DATA_DIR
  13. from open_webui.constants import ERROR_MESSAGES
  14. from fastapi import APIRouter, Depends, HTTPException, Request, status
  15. from open_webui.utils.tools import get_tools_specs
  16. from open_webui.utils.utils import get_admin_user, get_verified_user
  17. from open_webui.utils.access_control import has_access, has_permission
  18. router = APIRouter()
  19. ############################
  20. # GetTools
  21. ############################
  22. @router.get("/", response_model=list[ToolUserResponse])
  23. async def get_tools(user=Depends(get_verified_user)):
  24. if user.role == "admin":
  25. tools = Tools.get_tools()
  26. else:
  27. tools = Tools.get_tools_by_user_id(user.id, "read")
  28. return tools
  29. ############################
  30. # GetToolList
  31. ############################
  32. @router.get("/list", response_model=list[ToolUserResponse])
  33. async def get_tool_list(user=Depends(get_verified_user)):
  34. if user.role == "admin":
  35. tools = Tools.get_tools()
  36. else:
  37. tools = Tools.get_tools_by_user_id(user.id, "write")
  38. return tools
  39. ############################
  40. # ExportTools
  41. ############################
  42. @router.get("/export", response_model=list[ToolModel])
  43. async def export_tools(user=Depends(get_admin_user)):
  44. tools = Tools.get_tools()
  45. return tools
  46. ############################
  47. # CreateNewTools
  48. ############################
  49. @router.post("/create", response_model=Optional[ToolResponse])
  50. async def create_new_tools(
  51. request: Request,
  52. form_data: ToolForm,
  53. user=Depends(get_verified_user),
  54. ):
  55. if user.role != "admin" and not has_permission(
  56. user.id, "workspace.knowledge", request.app.state.config.USER_PERMISSIONS
  57. ):
  58. raise HTTPException(
  59. status_code=status.HTTP_401_UNAUTHORIZED,
  60. detail=ERROR_MESSAGES.UNAUTHORIZED,
  61. )
  62. if not form_data.id.isidentifier():
  63. raise HTTPException(
  64. status_code=status.HTTP_400_BAD_REQUEST,
  65. detail="Only alphanumeric characters and underscores are allowed in the id",
  66. )
  67. form_data.id = form_data.id.lower()
  68. tools = Tools.get_tool_by_id(form_data.id)
  69. if tools is None:
  70. try:
  71. form_data.content = replace_imports(form_data.content)
  72. tools_module, frontmatter = load_tools_module_by_id(
  73. form_data.id, content=form_data.content
  74. )
  75. form_data.meta.manifest = frontmatter
  76. TOOLS = request.app.state.TOOLS
  77. TOOLS[form_data.id] = tools_module
  78. specs = get_tools_specs(TOOLS[form_data.id])
  79. tools = Tools.insert_new_tool(user.id, form_data, specs)
  80. tool_cache_dir = Path(CACHE_DIR) / "tools" / form_data.id
  81. tool_cache_dir.mkdir(parents=True, exist_ok=True)
  82. if tools:
  83. return tools
  84. else:
  85. raise HTTPException(
  86. status_code=status.HTTP_400_BAD_REQUEST,
  87. detail=ERROR_MESSAGES.DEFAULT("Error creating tools"),
  88. )
  89. except Exception as e:
  90. print(e)
  91. raise HTTPException(
  92. status_code=status.HTTP_400_BAD_REQUEST,
  93. detail=ERROR_MESSAGES.DEFAULT(str(e)),
  94. )
  95. else:
  96. raise HTTPException(
  97. status_code=status.HTTP_400_BAD_REQUEST,
  98. detail=ERROR_MESSAGES.ID_TAKEN,
  99. )
  100. ############################
  101. # GetToolsById
  102. ############################
  103. @router.get("/id/{id}", response_model=Optional[ToolModel])
  104. async def get_tools_by_id(id: str, user=Depends(get_verified_user)):
  105. tools = Tools.get_tool_by_id(id)
  106. if tools:
  107. if (
  108. user.role == "admin"
  109. or tools.user_id == user.id
  110. or has_access(user.id, "read", tools.access_control)
  111. ):
  112. return tools
  113. else:
  114. raise HTTPException(
  115. status_code=status.HTTP_401_UNAUTHORIZED,
  116. detail=ERROR_MESSAGES.NOT_FOUND,
  117. )
  118. ############################
  119. # UpdateToolsById
  120. ############################
  121. @router.post("/id/{id}/update", response_model=Optional[ToolModel])
  122. async def update_tools_by_id(
  123. request: Request,
  124. id: str,
  125. form_data: ToolForm,
  126. user=Depends(get_verified_user),
  127. ):
  128. tools = Tools.get_tool_by_id(id)
  129. if not tools:
  130. raise HTTPException(
  131. status_code=status.HTTP_401_UNAUTHORIZED,
  132. detail=ERROR_MESSAGES.NOT_FOUND,
  133. )
  134. if tools.user_id != user.id and user.role != "admin":
  135. raise HTTPException(
  136. status_code=status.HTTP_401_UNAUTHORIZED,
  137. detail=ERROR_MESSAGES.UNAUTHORIZED,
  138. )
  139. try:
  140. form_data.content = replace_imports(form_data.content)
  141. tools_module, frontmatter = load_tools_module_by_id(
  142. id, content=form_data.content
  143. )
  144. form_data.meta.manifest = frontmatter
  145. TOOLS = request.app.state.TOOLS
  146. TOOLS[id] = tools_module
  147. specs = get_tools_specs(TOOLS[id])
  148. updated = {
  149. **form_data.model_dump(exclude={"id"}),
  150. "specs": specs,
  151. }
  152. print(updated)
  153. tools = Tools.update_tool_by_id(id, updated)
  154. if tools:
  155. return tools
  156. else:
  157. raise HTTPException(
  158. status_code=status.HTTP_400_BAD_REQUEST,
  159. detail=ERROR_MESSAGES.DEFAULT("Error updating tools"),
  160. )
  161. except Exception as e:
  162. raise HTTPException(
  163. status_code=status.HTTP_400_BAD_REQUEST,
  164. detail=ERROR_MESSAGES.DEFAULT(str(e)),
  165. )
  166. ############################
  167. # DeleteToolsById
  168. ############################
  169. @router.delete("/id/{id}/delete", response_model=bool)
  170. async def delete_tools_by_id(
  171. request: Request, id: str, user=Depends(get_verified_user)
  172. ):
  173. tools = Tools.get_tool_by_id(id)
  174. if not tools:
  175. raise HTTPException(
  176. status_code=status.HTTP_401_UNAUTHORIZED,
  177. detail=ERROR_MESSAGES.NOT_FOUND,
  178. )
  179. if tools.user_id != user.id and user.role != "admin":
  180. raise HTTPException(
  181. status_code=status.HTTP_401_UNAUTHORIZED,
  182. detail=ERROR_MESSAGES.UNAUTHORIZED,
  183. )
  184. result = Tools.delete_tool_by_id(id)
  185. if result:
  186. TOOLS = request.app.state.TOOLS
  187. if id in TOOLS:
  188. del TOOLS[id]
  189. return result
  190. ############################
  191. # GetToolValves
  192. ############################
  193. @router.get("/id/{id}/valves", response_model=Optional[dict])
  194. async def get_tools_valves_by_id(id: str, user=Depends(get_verified_user)):
  195. tools = Tools.get_tool_by_id(id)
  196. if tools:
  197. try:
  198. valves = Tools.get_tool_valves_by_id(id)
  199. return valves
  200. except Exception as e:
  201. raise HTTPException(
  202. status_code=status.HTTP_400_BAD_REQUEST,
  203. detail=ERROR_MESSAGES.DEFAULT(str(e)),
  204. )
  205. else:
  206. raise HTTPException(
  207. status_code=status.HTTP_401_UNAUTHORIZED,
  208. detail=ERROR_MESSAGES.NOT_FOUND,
  209. )
  210. ############################
  211. # GetToolValvesSpec
  212. ############################
  213. @router.get("/id/{id}/valves/spec", response_model=Optional[dict])
  214. async def get_tools_valves_spec_by_id(
  215. request: Request, id: str, user=Depends(get_verified_user)
  216. ):
  217. tools = Tools.get_tool_by_id(id)
  218. if tools:
  219. if id in request.app.state.TOOLS:
  220. tools_module = request.app.state.TOOLS[id]
  221. else:
  222. tools_module, _ = load_tools_module_by_id(id)
  223. request.app.state.TOOLS[id] = tools_module
  224. if hasattr(tools_module, "Valves"):
  225. Valves = tools_module.Valves
  226. return Valves.schema()
  227. return None
  228. else:
  229. raise HTTPException(
  230. status_code=status.HTTP_401_UNAUTHORIZED,
  231. detail=ERROR_MESSAGES.NOT_FOUND,
  232. )
  233. ############################
  234. # UpdateToolValves
  235. ############################
  236. @router.post("/id/{id}/valves/update", response_model=Optional[dict])
  237. async def update_tools_valves_by_id(
  238. request: Request, id: str, form_data: dict, user=Depends(get_verified_user)
  239. ):
  240. tools = Tools.get_tool_by_id(id)
  241. if tools:
  242. if id in request.app.state.TOOLS:
  243. tools_module = request.app.state.TOOLS[id]
  244. else:
  245. tools_module, _ = load_tools_module_by_id(id)
  246. request.app.state.TOOLS[id] = tools_module
  247. if hasattr(tools_module, "Valves"):
  248. Valves = tools_module.Valves
  249. try:
  250. form_data = {k: v for k, v in form_data.items() if v is not None}
  251. valves = Valves(**form_data)
  252. Tools.update_tool_valves_by_id(id, valves.model_dump())
  253. return valves.model_dump()
  254. except Exception as e:
  255. print(e)
  256. raise HTTPException(
  257. status_code=status.HTTP_400_BAD_REQUEST,
  258. detail=ERROR_MESSAGES.DEFAULT(str(e)),
  259. )
  260. else:
  261. raise HTTPException(
  262. status_code=status.HTTP_401_UNAUTHORIZED,
  263. detail=ERROR_MESSAGES.NOT_FOUND,
  264. )
  265. else:
  266. raise HTTPException(
  267. status_code=status.HTTP_401_UNAUTHORIZED,
  268. detail=ERROR_MESSAGES.NOT_FOUND,
  269. )
  270. ############################
  271. # ToolUserValves
  272. ############################
  273. @router.get("/id/{id}/valves/user", response_model=Optional[dict])
  274. async def get_tools_user_valves_by_id(id: str, user=Depends(get_verified_user)):
  275. tools = Tools.get_tool_by_id(id)
  276. if tools:
  277. try:
  278. user_valves = Tools.get_user_valves_by_id_and_user_id(id, user.id)
  279. return user_valves
  280. except Exception as e:
  281. raise HTTPException(
  282. status_code=status.HTTP_400_BAD_REQUEST,
  283. detail=ERROR_MESSAGES.DEFAULT(str(e)),
  284. )
  285. else:
  286. raise HTTPException(
  287. status_code=status.HTTP_401_UNAUTHORIZED,
  288. detail=ERROR_MESSAGES.NOT_FOUND,
  289. )
  290. @router.get("/id/{id}/valves/user/spec", response_model=Optional[dict])
  291. async def get_tools_user_valves_spec_by_id(
  292. request: Request, id: str, user=Depends(get_verified_user)
  293. ):
  294. tools = Tools.get_tool_by_id(id)
  295. if tools:
  296. if id in request.app.state.TOOLS:
  297. tools_module = request.app.state.TOOLS[id]
  298. else:
  299. tools_module, _ = load_tools_module_by_id(id)
  300. request.app.state.TOOLS[id] = tools_module
  301. if hasattr(tools_module, "UserValves"):
  302. UserValves = tools_module.UserValves
  303. return UserValves.schema()
  304. return None
  305. else:
  306. raise HTTPException(
  307. status_code=status.HTTP_401_UNAUTHORIZED,
  308. detail=ERROR_MESSAGES.NOT_FOUND,
  309. )
  310. @router.post("/id/{id}/valves/user/update", response_model=Optional[dict])
  311. async def update_tools_user_valves_by_id(
  312. request: Request, id: str, form_data: dict, user=Depends(get_verified_user)
  313. ):
  314. tools = Tools.get_tool_by_id(id)
  315. if tools:
  316. if id in request.app.state.TOOLS:
  317. tools_module = request.app.state.TOOLS[id]
  318. else:
  319. tools_module, _ = load_tools_module_by_id(id)
  320. request.app.state.TOOLS[id] = tools_module
  321. if hasattr(tools_module, "UserValves"):
  322. UserValves = tools_module.UserValves
  323. try:
  324. form_data = {k: v for k, v in form_data.items() if v is not None}
  325. user_valves = UserValves(**form_data)
  326. Tools.update_user_valves_by_id_and_user_id(
  327. id, user.id, user_valves.model_dump()
  328. )
  329. return user_valves.model_dump()
  330. except Exception as e:
  331. print(e)
  332. raise HTTPException(
  333. status_code=status.HTTP_400_BAD_REQUEST,
  334. detail=ERROR_MESSAGES.DEFAULT(str(e)),
  335. )
  336. else:
  337. raise HTTPException(
  338. status_code=status.HTTP_401_UNAUTHORIZED,
  339. detail=ERROR_MESSAGES.NOT_FOUND,
  340. )
  341. else:
  342. raise HTTPException(
  343. status_code=status.HTTP_401_UNAUTHORIZED,
  344. detail=ERROR_MESSAGES.NOT_FOUND,
  345. )