tools.py 11 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374
  1. from fastapi import Depends, HTTPException, status, Request
  2. from typing import Optional
  3. from fastapi import APIRouter
  4. from apps.webui.models.tools import Tools, ToolForm, ToolModel, ToolResponse
  5. from apps.webui.utils import load_toolkit_module_by_id
  6. from utils.utils import get_admin_user, get_verified_user
  7. from utils.tools import get_tools_specs
  8. from constants import ERROR_MESSAGES
  9. import os
  10. from pathlib import Path
  11. from config import DATA_DIR, CACHE_DIR
  12. TOOLS_DIR = f"{DATA_DIR}/tools"
  13. os.makedirs(TOOLS_DIR, exist_ok=True)
  14. router = APIRouter()
  15. ############################
  16. # GetToolkits
  17. ############################
  18. @router.get("/", response_model=list[ToolResponse])
  19. async def get_toolkits(user=Depends(get_verified_user)):
  20. toolkits = [toolkit for toolkit in Tools.get_tools()]
  21. return toolkits
  22. ############################
  23. # ExportToolKits
  24. ############################
  25. @router.get("/export", response_model=list[ToolModel])
  26. async def get_toolkits(user=Depends(get_admin_user)):
  27. toolkits = [toolkit for toolkit in Tools.get_tools()]
  28. return toolkits
  29. ############################
  30. # CreateNewToolKit
  31. ############################
  32. @router.post("/create", response_model=Optional[ToolResponse])
  33. async def create_new_toolkit(
  34. request: Request,
  35. form_data: ToolForm,
  36. user=Depends(get_admin_user),
  37. ):
  38. if not form_data.id.isidentifier():
  39. raise HTTPException(
  40. status_code=status.HTTP_400_BAD_REQUEST,
  41. detail="Only alphanumeric characters and underscores are allowed in the id",
  42. )
  43. form_data.id = form_data.id.lower()
  44. toolkit = Tools.get_tool_by_id(form_data.id)
  45. if toolkit is None:
  46. toolkit_path = os.path.join(TOOLS_DIR, f"{form_data.id}.py")
  47. try:
  48. with open(toolkit_path, "w") as tool_file:
  49. tool_file.write(form_data.content)
  50. toolkit_module, frontmatter = load_toolkit_module_by_id(form_data.id)
  51. form_data.meta.manifest = frontmatter
  52. TOOLS = request.app.state.TOOLS
  53. TOOLS[form_data.id] = toolkit_module
  54. specs = get_tools_specs(TOOLS[form_data.id])
  55. toolkit = Tools.insert_new_tool(user.id, form_data, specs)
  56. tool_cache_dir = Path(CACHE_DIR) / "tools" / form_data.id
  57. tool_cache_dir.mkdir(parents=True, exist_ok=True)
  58. if toolkit:
  59. return toolkit
  60. else:
  61. raise HTTPException(
  62. status_code=status.HTTP_400_BAD_REQUEST,
  63. detail=ERROR_MESSAGES.DEFAULT("Error creating toolkit"),
  64. )
  65. except Exception as e:
  66. print(e)
  67. raise HTTPException(
  68. status_code=status.HTTP_400_BAD_REQUEST,
  69. detail=ERROR_MESSAGES.DEFAULT(str(e)),
  70. )
  71. else:
  72. raise HTTPException(
  73. status_code=status.HTTP_400_BAD_REQUEST,
  74. detail=ERROR_MESSAGES.ID_TAKEN,
  75. )
  76. ############################
  77. # GetToolkitById
  78. ############################
  79. @router.get("/id/{id}", response_model=Optional[ToolModel])
  80. async def get_toolkit_by_id(id: str, user=Depends(get_admin_user)):
  81. toolkit = Tools.get_tool_by_id(id)
  82. if toolkit:
  83. return toolkit
  84. else:
  85. raise HTTPException(
  86. status_code=status.HTTP_401_UNAUTHORIZED,
  87. detail=ERROR_MESSAGES.NOT_FOUND,
  88. )
  89. ############################
  90. # UpdateToolkitById
  91. ############################
  92. @router.post("/id/{id}/update", response_model=Optional[ToolModel])
  93. async def update_toolkit_by_id(
  94. request: Request,
  95. id: str,
  96. form_data: ToolForm,
  97. user=Depends(get_admin_user),
  98. ):
  99. toolkit_path = os.path.join(TOOLS_DIR, f"{id}.py")
  100. try:
  101. with open(toolkit_path, "w") as tool_file:
  102. tool_file.write(form_data.content)
  103. toolkit_module, frontmatter = load_toolkit_module_by_id(id)
  104. form_data.meta.manifest = frontmatter
  105. TOOLS = request.app.state.TOOLS
  106. TOOLS[id] = toolkit_module
  107. specs = get_tools_specs(TOOLS[id])
  108. updated = {
  109. **form_data.model_dump(exclude={"id"}),
  110. "specs": specs,
  111. }
  112. print(updated)
  113. toolkit = Tools.update_tool_by_id(id, updated)
  114. if toolkit:
  115. return toolkit
  116. else:
  117. raise HTTPException(
  118. status_code=status.HTTP_400_BAD_REQUEST,
  119. detail=ERROR_MESSAGES.DEFAULT("Error updating toolkit"),
  120. )
  121. except Exception as e:
  122. raise HTTPException(
  123. status_code=status.HTTP_400_BAD_REQUEST,
  124. detail=ERROR_MESSAGES.DEFAULT(str(e)),
  125. )
  126. ############################
  127. # DeleteToolkitById
  128. ############################
  129. @router.delete("/id/{id}/delete", response_model=bool)
  130. async def delete_toolkit_by_id(request: Request, id: str, user=Depends(get_admin_user)):
  131. result = Tools.delete_tool_by_id(id)
  132. if result:
  133. TOOLS = request.app.state.TOOLS
  134. if id in TOOLS:
  135. del TOOLS[id]
  136. # delete the toolkit file
  137. toolkit_path = os.path.join(TOOLS_DIR, f"{id}.py")
  138. os.remove(toolkit_path)
  139. return result
  140. ############################
  141. # GetToolValves
  142. ############################
  143. @router.get("/id/{id}/valves", response_model=Optional[dict])
  144. async def get_toolkit_valves_by_id(id: str, user=Depends(get_admin_user)):
  145. toolkit = Tools.get_tool_by_id(id)
  146. if toolkit:
  147. try:
  148. valves = Tools.get_tool_valves_by_id(id)
  149. return valves
  150. except Exception as e:
  151. raise HTTPException(
  152. status_code=status.HTTP_400_BAD_REQUEST,
  153. detail=ERROR_MESSAGES.DEFAULT(str(e)),
  154. )
  155. else:
  156. raise HTTPException(
  157. status_code=status.HTTP_401_UNAUTHORIZED,
  158. detail=ERROR_MESSAGES.NOT_FOUND,
  159. )
  160. ############################
  161. # GetToolValvesSpec
  162. ############################
  163. @router.get("/id/{id}/valves/spec", response_model=Optional[dict])
  164. async def get_toolkit_valves_spec_by_id(
  165. request: Request, id: str, user=Depends(get_admin_user)
  166. ):
  167. toolkit = Tools.get_tool_by_id(id)
  168. if toolkit:
  169. if id in request.app.state.TOOLS:
  170. toolkit_module = request.app.state.TOOLS[id]
  171. else:
  172. toolkit_module, _ = load_toolkit_module_by_id(id)
  173. request.app.state.TOOLS[id] = toolkit_module
  174. if hasattr(toolkit_module, "Valves"):
  175. Valves = toolkit_module.Valves
  176. return Valves.schema()
  177. return None
  178. else:
  179. raise HTTPException(
  180. status_code=status.HTTP_401_UNAUTHORIZED,
  181. detail=ERROR_MESSAGES.NOT_FOUND,
  182. )
  183. ############################
  184. # UpdateToolValves
  185. ############################
  186. @router.post("/id/{id}/valves/update", response_model=Optional[dict])
  187. async def update_toolkit_valves_by_id(
  188. request: Request, id: str, form_data: dict, user=Depends(get_admin_user)
  189. ):
  190. toolkit = Tools.get_tool_by_id(id)
  191. if toolkit:
  192. if id in request.app.state.TOOLS:
  193. toolkit_module = request.app.state.TOOLS[id]
  194. else:
  195. toolkit_module, _ = load_toolkit_module_by_id(id)
  196. request.app.state.TOOLS[id] = toolkit_module
  197. if hasattr(toolkit_module, "Valves"):
  198. Valves = toolkit_module.Valves
  199. try:
  200. form_data = {k: v for k, v in form_data.items() if v is not None}
  201. valves = Valves(**form_data)
  202. Tools.update_tool_valves_by_id(id, valves.model_dump())
  203. return valves.model_dump()
  204. except Exception as e:
  205. print(e)
  206. raise HTTPException(
  207. status_code=status.HTTP_400_BAD_REQUEST,
  208. detail=ERROR_MESSAGES.DEFAULT(str(e)),
  209. )
  210. else:
  211. raise HTTPException(
  212. status_code=status.HTTP_401_UNAUTHORIZED,
  213. detail=ERROR_MESSAGES.NOT_FOUND,
  214. )
  215. else:
  216. raise HTTPException(
  217. status_code=status.HTTP_401_UNAUTHORIZED,
  218. detail=ERROR_MESSAGES.NOT_FOUND,
  219. )
  220. ############################
  221. # ToolUserValves
  222. ############################
  223. @router.get("/id/{id}/valves/user", response_model=Optional[dict])
  224. async def get_toolkit_user_valves_by_id(id: str, user=Depends(get_verified_user)):
  225. toolkit = Tools.get_tool_by_id(id)
  226. if toolkit:
  227. try:
  228. user_valves = Tools.get_user_valves_by_id_and_user_id(id, user.id)
  229. return user_valves
  230. except Exception as e:
  231. raise HTTPException(
  232. status_code=status.HTTP_400_BAD_REQUEST,
  233. detail=ERROR_MESSAGES.DEFAULT(str(e)),
  234. )
  235. else:
  236. raise HTTPException(
  237. status_code=status.HTTP_401_UNAUTHORIZED,
  238. detail=ERROR_MESSAGES.NOT_FOUND,
  239. )
  240. @router.get("/id/{id}/valves/user/spec", response_model=Optional[dict])
  241. async def get_toolkit_user_valves_spec_by_id(
  242. request: Request, id: str, user=Depends(get_verified_user)
  243. ):
  244. toolkit = Tools.get_tool_by_id(id)
  245. if toolkit:
  246. if id in request.app.state.TOOLS:
  247. toolkit_module = request.app.state.TOOLS[id]
  248. else:
  249. toolkit_module, _ = load_toolkit_module_by_id(id)
  250. request.app.state.TOOLS[id] = toolkit_module
  251. if hasattr(toolkit_module, "UserValves"):
  252. UserValves = toolkit_module.UserValves
  253. return UserValves.schema()
  254. return None
  255. else:
  256. raise HTTPException(
  257. status_code=status.HTTP_401_UNAUTHORIZED,
  258. detail=ERROR_MESSAGES.NOT_FOUND,
  259. )
  260. @router.post("/id/{id}/valves/user/update", response_model=Optional[dict])
  261. async def update_toolkit_user_valves_by_id(
  262. request: Request, id: str, form_data: dict, user=Depends(get_verified_user)
  263. ):
  264. toolkit = Tools.get_tool_by_id(id)
  265. if toolkit:
  266. if id in request.app.state.TOOLS:
  267. toolkit_module = request.app.state.TOOLS[id]
  268. else:
  269. toolkit_module, _ = load_toolkit_module_by_id(id)
  270. request.app.state.TOOLS[id] = toolkit_module
  271. if hasattr(toolkit_module, "UserValves"):
  272. UserValves = toolkit_module.UserValves
  273. try:
  274. form_data = {k: v for k, v in form_data.items() if v is not None}
  275. user_valves = UserValves(**form_data)
  276. Tools.update_user_valves_by_id_and_user_id(
  277. id, user.id, user_valves.model_dump()
  278. )
  279. return user_valves.model_dump()
  280. except Exception as e:
  281. print(e)
  282. raise HTTPException(
  283. status_code=status.HTTP_400_BAD_REQUEST,
  284. detail=ERROR_MESSAGES.DEFAULT(str(e)),
  285. )
  286. else:
  287. raise HTTPException(
  288. status_code=status.HTTP_401_UNAUTHORIZED,
  289. detail=ERROR_MESSAGES.NOT_FOUND,
  290. )
  291. else:
  292. raise HTTPException(
  293. status_code=status.HTTP_401_UNAUTHORIZED,
  294. detail=ERROR_MESSAGES.NOT_FOUND,
  295. )