tools.py 11 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377
  1. import os
  2. from pathlib import Path
  3. from typing import Optional
  4. from open_webui.apps.webui.models.tools import ToolForm, ToolModel, ToolResponse, Tools
  5. from open_webui.apps.webui.utils import load_toolkit_module_by_id, replace_imports
  6. from open_webui.config import CACHE_DIR, DATA_DIR
  7. from open_webui.constants import ERROR_MESSAGES
  8. from fastapi import APIRouter, Depends, HTTPException, Request, status
  9. from open_webui.utils.tools import get_tools_specs
  10. from open_webui.utils.utils import get_admin_user, get_verified_user
  11. router = APIRouter()
  12. ############################
  13. # GetTools
  14. ############################
  15. @router.get("/", response_model=list[ToolResponse])
  16. async def get_tools(user=Depends(get_verified_user)):
  17. if user.role == "admin":
  18. tools = Tools.get_tools()
  19. else:
  20. tools = Tools.get_tools_by_user_id(user.id, "read")
  21. return tools
  22. ############################
  23. # GetToolList
  24. ############################
  25. @router.get("/list", response_model=list[ToolResponse])
  26. async def get_tool_list(user=Depends(get_verified_user)):
  27. if user.role == "admin":
  28. tools = Tools.get_tools()
  29. else:
  30. tools = Tools.get_tools_by_user_id(user.id, "write")
  31. return tools
  32. ############################
  33. # ExportTools
  34. ############################
  35. @router.get("/export", response_model=list[ToolModel])
  36. async def export_tools(user=Depends(get_admin_user)):
  37. tools = Tools.get_tools()
  38. return tools
  39. ############################
  40. # CreateNewTools
  41. ############################
  42. @router.post("/create", response_model=Optional[ToolResponse])
  43. async def create_new_tools(
  44. request: Request,
  45. form_data: ToolForm,
  46. user=Depends(get_verified_user),
  47. ):
  48. if not form_data.id.isidentifier():
  49. raise HTTPException(
  50. status_code=status.HTTP_400_BAD_REQUEST,
  51. detail="Only alphanumeric characters and underscores are allowed in the id",
  52. )
  53. form_data.id = form_data.id.lower()
  54. toolkit = Tools.get_tool_by_id(form_data.id)
  55. if toolkit is None:
  56. try:
  57. form_data.content = replace_imports(form_data.content)
  58. toolkit_module, frontmatter = load_toolkit_module_by_id(
  59. form_data.id, content=form_data.content
  60. )
  61. form_data.meta.manifest = frontmatter
  62. TOOLS = request.app.state.TOOLS
  63. TOOLS[form_data.id] = toolkit_module
  64. specs = get_tools_specs(TOOLS[form_data.id])
  65. toolkit = Tools.insert_new_tool(user.id, form_data, specs)
  66. tool_cache_dir = Path(CACHE_DIR) / "tools" / form_data.id
  67. tool_cache_dir.mkdir(parents=True, exist_ok=True)
  68. if toolkit:
  69. return toolkit
  70. else:
  71. raise HTTPException(
  72. status_code=status.HTTP_400_BAD_REQUEST,
  73. detail=ERROR_MESSAGES.DEFAULT("Error creating toolkit"),
  74. )
  75. except Exception as e:
  76. print(e)
  77. raise HTTPException(
  78. status_code=status.HTTP_400_BAD_REQUEST,
  79. detail=ERROR_MESSAGES.DEFAULT(str(e)),
  80. )
  81. else:
  82. raise HTTPException(
  83. status_code=status.HTTP_400_BAD_REQUEST,
  84. detail=ERROR_MESSAGES.ID_TAKEN,
  85. )
  86. ############################
  87. # GetToolsById
  88. ############################
  89. @router.get("/id/{id}", response_model=Optional[ToolModel])
  90. async def get_tools_by_id(id: str, user=Depends(get_verified_user)):
  91. toolkit = Tools.get_tool_by_id(id)
  92. if toolkit:
  93. return toolkit
  94. else:
  95. raise HTTPException(
  96. status_code=status.HTTP_401_UNAUTHORIZED,
  97. detail=ERROR_MESSAGES.NOT_FOUND,
  98. )
  99. ############################
  100. # UpdateToolsById
  101. ############################
  102. @router.post("/id/{id}/update", response_model=Optional[ToolModel])
  103. async def update_tools_by_id(
  104. request: Request,
  105. id: str,
  106. form_data: ToolForm,
  107. user=Depends(get_verified_user),
  108. ):
  109. try:
  110. form_data.content = replace_imports(form_data.content)
  111. toolkit_module, frontmatter = load_toolkit_module_by_id(
  112. id, content=form_data.content
  113. )
  114. form_data.meta.manifest = frontmatter
  115. TOOLS = request.app.state.TOOLS
  116. TOOLS[id] = toolkit_module
  117. specs = get_tools_specs(TOOLS[id])
  118. updated = {
  119. **form_data.model_dump(exclude={"id"}),
  120. "specs": specs,
  121. }
  122. print(updated)
  123. toolkit = Tools.update_tool_by_id(id, updated)
  124. if toolkit:
  125. return toolkit
  126. else:
  127. raise HTTPException(
  128. status_code=status.HTTP_400_BAD_REQUEST,
  129. detail=ERROR_MESSAGES.DEFAULT("Error updating toolkit"),
  130. )
  131. except Exception as e:
  132. raise HTTPException(
  133. status_code=status.HTTP_400_BAD_REQUEST,
  134. detail=ERROR_MESSAGES.DEFAULT(str(e)),
  135. )
  136. ############################
  137. # DeleteToolsById
  138. ############################
  139. @router.delete("/id/{id}/delete", response_model=bool)
  140. async def delete_tools_by_id(
  141. request: Request, id: str, user=Depends(get_verified_user)
  142. ):
  143. result = Tools.delete_tool_by_id(id)
  144. if result:
  145. TOOLS = request.app.state.TOOLS
  146. if id in TOOLS:
  147. del TOOLS[id]
  148. return result
  149. ############################
  150. # GetToolValves
  151. ############################
  152. @router.get("/id/{id}/valves", response_model=Optional[dict])
  153. async def get_tools_valves_by_id(id: str, user=Depends(get_verified_user)):
  154. toolkit = Tools.get_tool_by_id(id)
  155. if toolkit:
  156. try:
  157. valves = Tools.get_tool_valves_by_id(id)
  158. return valves
  159. except Exception as e:
  160. raise HTTPException(
  161. status_code=status.HTTP_400_BAD_REQUEST,
  162. detail=ERROR_MESSAGES.DEFAULT(str(e)),
  163. )
  164. else:
  165. raise HTTPException(
  166. status_code=status.HTTP_401_UNAUTHORIZED,
  167. detail=ERROR_MESSAGES.NOT_FOUND,
  168. )
  169. ############################
  170. # GetToolValvesSpec
  171. ############################
  172. @router.get("/id/{id}/valves/spec", response_model=Optional[dict])
  173. async def get_tools_valves_spec_by_id(
  174. request: Request, id: str, user=Depends(get_verified_user)
  175. ):
  176. toolkit = Tools.get_tool_by_id(id)
  177. if toolkit:
  178. if id in request.app.state.TOOLS:
  179. toolkit_module = request.app.state.TOOLS[id]
  180. else:
  181. toolkit_module, _ = load_toolkit_module_by_id(id)
  182. request.app.state.TOOLS[id] = toolkit_module
  183. if hasattr(toolkit_module, "Valves"):
  184. Valves = toolkit_module.Valves
  185. return Valves.schema()
  186. return None
  187. else:
  188. raise HTTPException(
  189. status_code=status.HTTP_401_UNAUTHORIZED,
  190. detail=ERROR_MESSAGES.NOT_FOUND,
  191. )
  192. ############################
  193. # UpdateToolValves
  194. ############################
  195. @router.post("/id/{id}/valves/update", response_model=Optional[dict])
  196. async def update_tools_valves_by_id(
  197. request: Request, id: str, form_data: dict, user=Depends(get_verified_user)
  198. ):
  199. toolkit = Tools.get_tool_by_id(id)
  200. if toolkit:
  201. if id in request.app.state.TOOLS:
  202. toolkit_module = request.app.state.TOOLS[id]
  203. else:
  204. toolkit_module, _ = load_toolkit_module_by_id(id)
  205. request.app.state.TOOLS[id] = toolkit_module
  206. if hasattr(toolkit_module, "Valves"):
  207. Valves = toolkit_module.Valves
  208. try:
  209. form_data = {k: v for k, v in form_data.items() if v is not None}
  210. valves = Valves(**form_data)
  211. Tools.update_tool_valves_by_id(id, valves.model_dump())
  212. return valves.model_dump()
  213. except Exception as e:
  214. print(e)
  215. raise HTTPException(
  216. status_code=status.HTTP_400_BAD_REQUEST,
  217. detail=ERROR_MESSAGES.DEFAULT(str(e)),
  218. )
  219. else:
  220. raise HTTPException(
  221. status_code=status.HTTP_401_UNAUTHORIZED,
  222. detail=ERROR_MESSAGES.NOT_FOUND,
  223. )
  224. else:
  225. raise HTTPException(
  226. status_code=status.HTTP_401_UNAUTHORIZED,
  227. detail=ERROR_MESSAGES.NOT_FOUND,
  228. )
  229. ############################
  230. # ToolUserValves
  231. ############################
  232. @router.get("/id/{id}/valves/user", response_model=Optional[dict])
  233. async def get_tools_user_valves_by_id(id: str, user=Depends(get_verified_user)):
  234. toolkit = Tools.get_tool_by_id(id)
  235. if toolkit:
  236. try:
  237. user_valves = Tools.get_user_valves_by_id_and_user_id(id, user.id)
  238. return user_valves
  239. except Exception as e:
  240. raise HTTPException(
  241. status_code=status.HTTP_400_BAD_REQUEST,
  242. detail=ERROR_MESSAGES.DEFAULT(str(e)),
  243. )
  244. else:
  245. raise HTTPException(
  246. status_code=status.HTTP_401_UNAUTHORIZED,
  247. detail=ERROR_MESSAGES.NOT_FOUND,
  248. )
  249. @router.get("/id/{id}/valves/user/spec", response_model=Optional[dict])
  250. async def get_tools_user_valves_spec_by_id(
  251. request: Request, id: str, user=Depends(get_verified_user)
  252. ):
  253. toolkit = Tools.get_tool_by_id(id)
  254. if toolkit:
  255. if id in request.app.state.TOOLS:
  256. toolkit_module = request.app.state.TOOLS[id]
  257. else:
  258. toolkit_module, _ = load_toolkit_module_by_id(id)
  259. request.app.state.TOOLS[id] = toolkit_module
  260. if hasattr(toolkit_module, "UserValves"):
  261. UserValves = toolkit_module.UserValves
  262. return UserValves.schema()
  263. return None
  264. else:
  265. raise HTTPException(
  266. status_code=status.HTTP_401_UNAUTHORIZED,
  267. detail=ERROR_MESSAGES.NOT_FOUND,
  268. )
  269. @router.post("/id/{id}/valves/user/update", response_model=Optional[dict])
  270. async def update_tools_user_valves_by_id(
  271. request: Request, id: str, form_data: dict, user=Depends(get_verified_user)
  272. ):
  273. toolkit = Tools.get_tool_by_id(id)
  274. if toolkit:
  275. if id in request.app.state.TOOLS:
  276. toolkit_module = request.app.state.TOOLS[id]
  277. else:
  278. toolkit_module, _ = load_toolkit_module_by_id(id)
  279. request.app.state.TOOLS[id] = toolkit_module
  280. if hasattr(toolkit_module, "UserValves"):
  281. UserValves = toolkit_module.UserValves
  282. try:
  283. form_data = {k: v for k, v in form_data.items() if v is not None}
  284. user_valves = UserValves(**form_data)
  285. Tools.update_user_valves_by_id_and_user_id(
  286. id, user.id, user_valves.model_dump()
  287. )
  288. return user_valves.model_dump()
  289. except Exception as e:
  290. print(e)
  291. raise HTTPException(
  292. status_code=status.HTTP_400_BAD_REQUEST,
  293. detail=ERROR_MESSAGES.DEFAULT(str(e)),
  294. )
  295. else:
  296. raise HTTPException(
  297. status_code=status.HTTP_401_UNAUTHORIZED,
  298. detail=ERROR_MESSAGES.NOT_FOUND,
  299. )
  300. else:
  301. raise HTTPException(
  302. status_code=status.HTTP_401_UNAUTHORIZED,
  303. detail=ERROR_MESSAGES.NOT_FOUND,
  304. )