tools.py 11 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381
  1. from fastapi import Depends, FastAPI, HTTPException, status, Request
  2. from datetime import datetime, timedelta
  3. from typing import List, Union, Optional
  4. from fastapi import APIRouter
  5. from pydantic import BaseModel
  6. import json
  7. from apps.webui.models.users import Users
  8. from apps.webui.models.tools import Tools, ToolForm, ToolModel, ToolResponse
  9. from apps.webui.utils import load_toolkit_module_by_id
  10. from utils.utils import get_admin_user, get_verified_user
  11. from utils.tools import get_tools_specs
  12. from constants import ERROR_MESSAGES
  13. from importlib import util
  14. import os
  15. from pathlib import Path
  16. from config import DATA_DIR, CACHE_DIR
  17. TOOLS_DIR = f"{DATA_DIR}/tools"
  18. os.makedirs(TOOLS_DIR, exist_ok=True)
  19. router = APIRouter()
  20. ############################
  21. # GetToolkits
  22. ############################
  23. @router.get("/", response_model=List[ToolResponse])
  24. async def get_toolkits(user=Depends(get_verified_user)):
  25. toolkits = [toolkit for toolkit in Tools.get_tools()]
  26. return toolkits
  27. ############################
  28. # ExportToolKits
  29. ############################
  30. @router.get("/export", response_model=List[ToolModel])
  31. async def get_toolkits(user=Depends(get_admin_user)):
  32. toolkits = [toolkit for toolkit in Tools.get_tools()]
  33. return toolkits
  34. ############################
  35. # CreateNewToolKit
  36. ############################
  37. @router.post("/create", response_model=Optional[ToolResponse])
  38. async def create_new_toolkit(
  39. request: Request,
  40. form_data: ToolForm,
  41. user=Depends(get_admin_user),
  42. ):
  43. if not form_data.id.isidentifier():
  44. raise HTTPException(
  45. status_code=status.HTTP_400_BAD_REQUEST,
  46. detail="Only alphanumeric characters and underscores are allowed in the id",
  47. )
  48. form_data.id = form_data.id.lower()
  49. toolkit = Tools.get_tool_by_id(form_data.id)
  50. if toolkit == None:
  51. toolkit_path = os.path.join(TOOLS_DIR, f"{form_data.id}.py")
  52. try:
  53. with open(toolkit_path, "w") as tool_file:
  54. tool_file.write(form_data.content)
  55. toolkit_module, frontmatter = load_toolkit_module_by_id(form_data.id)
  56. form_data.meta.manifest = frontmatter
  57. TOOLS = request.app.state.TOOLS
  58. TOOLS[form_data.id] = toolkit_module
  59. specs = get_tools_specs(TOOLS[form_data.id])
  60. toolkit = Tools.insert_new_tool(user.id, form_data, specs)
  61. tool_cache_dir = Path(CACHE_DIR) / "tools" / form_data.id
  62. tool_cache_dir.mkdir(parents=True, exist_ok=True)
  63. if toolkit:
  64. return toolkit
  65. else:
  66. raise HTTPException(
  67. status_code=status.HTTP_400_BAD_REQUEST,
  68. detail=ERROR_MESSAGES.DEFAULT("Error creating toolkit"),
  69. )
  70. except Exception as e:
  71. print(e)
  72. raise HTTPException(
  73. status_code=status.HTTP_400_BAD_REQUEST,
  74. detail=ERROR_MESSAGES.DEFAULT(e),
  75. )
  76. else:
  77. raise HTTPException(
  78. status_code=status.HTTP_400_BAD_REQUEST,
  79. detail=ERROR_MESSAGES.ID_TAKEN,
  80. )
  81. ############################
  82. # GetToolkitById
  83. ############################
  84. @router.get("/id/{id}", response_model=Optional[ToolModel])
  85. async def get_toolkit_by_id(id: str, user=Depends(get_admin_user)):
  86. toolkit = Tools.get_tool_by_id(id)
  87. if toolkit:
  88. return toolkit
  89. else:
  90. raise HTTPException(
  91. status_code=status.HTTP_401_UNAUTHORIZED,
  92. detail=ERROR_MESSAGES.NOT_FOUND,
  93. )
  94. ############################
  95. # UpdateToolkitById
  96. ############################
  97. @router.post("/id/{id}/update", response_model=Optional[ToolModel])
  98. async def update_toolkit_by_id(
  99. request: Request,
  100. id: str,
  101. form_data: ToolForm,
  102. user=Depends(get_admin_user),
  103. ):
  104. toolkit_path = os.path.join(TOOLS_DIR, f"{id}.py")
  105. try:
  106. with open(toolkit_path, "w") as tool_file:
  107. tool_file.write(form_data.content)
  108. toolkit_module, frontmatter = load_toolkit_module_by_id(id)
  109. form_data.meta.manifest = frontmatter
  110. TOOLS = request.app.state.TOOLS
  111. TOOLS[id] = toolkit_module
  112. specs = get_tools_specs(TOOLS[id])
  113. updated = {
  114. **form_data.model_dump(exclude={"id"}),
  115. "specs": specs,
  116. }
  117. print(updated)
  118. toolkit = Tools.update_tool_by_id(id, updated)
  119. if toolkit:
  120. return toolkit
  121. else:
  122. raise HTTPException(
  123. status_code=status.HTTP_400_BAD_REQUEST,
  124. detail=ERROR_MESSAGES.DEFAULT("Error updating toolkit"),
  125. )
  126. except Exception as e:
  127. raise HTTPException(
  128. status_code=status.HTTP_400_BAD_REQUEST,
  129. detail=ERROR_MESSAGES.DEFAULT(e),
  130. )
  131. ############################
  132. # DeleteToolkitById
  133. ############################
  134. @router.delete("/id/{id}/delete", response_model=bool)
  135. async def delete_toolkit_by_id(
  136. request: Request, id: str, user=Depends(get_admin_user)
  137. ):
  138. result = Tools.delete_tool_by_id(id)
  139. if result:
  140. TOOLS = request.app.state.TOOLS
  141. if id in TOOLS:
  142. del TOOLS[id]
  143. # delete the toolkit file
  144. toolkit_path = os.path.join(TOOLS_DIR, f"{id}.py")
  145. os.remove(toolkit_path)
  146. return result
  147. ############################
  148. # GetToolValves
  149. ############################
  150. @router.get("/id/{id}/valves", response_model=Optional[dict])
  151. async def get_toolkit_valves_by_id(id: str, user=Depends(get_admin_user)):
  152. toolkit = Tools.get_tool_by_id(id)
  153. if toolkit:
  154. try:
  155. valves = Tools.get_tool_valves_by_id(id)
  156. return valves
  157. except Exception as e:
  158. raise HTTPException(
  159. status_code=status.HTTP_400_BAD_REQUEST,
  160. detail=ERROR_MESSAGES.DEFAULT(e),
  161. )
  162. else:
  163. raise HTTPException(
  164. status_code=status.HTTP_401_UNAUTHORIZED,
  165. detail=ERROR_MESSAGES.NOT_FOUND,
  166. )
  167. ############################
  168. # GetToolValvesSpec
  169. ############################
  170. @router.get("/id/{id}/valves/spec", response_model=Optional[dict])
  171. async def get_toolkit_valves_spec_by_id(
  172. request: Request, id: str, user=Depends(get_admin_user)
  173. ):
  174. toolkit = Tools.get_tool_by_id(id)
  175. if toolkit:
  176. if id in request.app.state.TOOLS:
  177. toolkit_module = request.app.state.TOOLS[id]
  178. else:
  179. toolkit_module, frontmatter = load_toolkit_module_by_id(id)
  180. request.app.state.TOOLS[id] = toolkit_module
  181. if hasattr(toolkit_module, "Valves"):
  182. Valves = toolkit_module.Valves
  183. return Valves.schema()
  184. return None
  185. else:
  186. raise HTTPException(
  187. status_code=status.HTTP_401_UNAUTHORIZED,
  188. detail=ERROR_MESSAGES.NOT_FOUND,
  189. )
  190. ############################
  191. # UpdateToolValves
  192. ############################
  193. @router.post("/id/{id}/valves/update", response_model=Optional[dict])
  194. async def update_toolkit_valves_by_id(
  195. request: Request, id: str, form_data: dict, user=Depends(get_admin_user)
  196. ):
  197. toolkit = Tools.get_tool_by_id(id)
  198. if toolkit:
  199. if id in request.app.state.TOOLS:
  200. toolkit_module = request.app.state.TOOLS[id]
  201. else:
  202. toolkit_module, frontmatter = load_toolkit_module_by_id(id)
  203. request.app.state.TOOLS[id] = toolkit_module
  204. if hasattr(toolkit_module, "Valves"):
  205. Valves = toolkit_module.Valves
  206. try:
  207. form_data = {k: v for k, v in form_data.items() if v is not None}
  208. valves = Valves(**form_data)
  209. Tools.update_tool_valves_by_id(id, valves.model_dump())
  210. return valves.model_dump()
  211. except Exception as e:
  212. print(e)
  213. raise HTTPException(
  214. status_code=status.HTTP_400_BAD_REQUEST,
  215. detail=ERROR_MESSAGES.DEFAULT(e),
  216. )
  217. else:
  218. raise HTTPException(
  219. status_code=status.HTTP_401_UNAUTHORIZED,
  220. detail=ERROR_MESSAGES.NOT_FOUND,
  221. )
  222. else:
  223. raise HTTPException(
  224. status_code=status.HTTP_401_UNAUTHORIZED,
  225. detail=ERROR_MESSAGES.NOT_FOUND,
  226. )
  227. ############################
  228. # ToolUserValves
  229. ############################
  230. @router.get("/id/{id}/valves/user", response_model=Optional[dict])
  231. async def get_toolkit_user_valves_by_id(id: str, user=Depends(get_verified_user)):
  232. toolkit = Tools.get_tool_by_id(id)
  233. if toolkit:
  234. try:
  235. user_valves = Tools.get_user_valves_by_id_and_user_id(id, user.id)
  236. return user_valves
  237. except Exception as e:
  238. raise HTTPException(
  239. status_code=status.HTTP_400_BAD_REQUEST,
  240. detail=ERROR_MESSAGES.DEFAULT(e),
  241. )
  242. else:
  243. raise HTTPException(
  244. status_code=status.HTTP_401_UNAUTHORIZED,
  245. detail=ERROR_MESSAGES.NOT_FOUND,
  246. )
  247. @router.get("/id/{id}/valves/user/spec", response_model=Optional[dict])
  248. async def get_toolkit_user_valves_spec_by_id(
  249. request: Request, id: str, user=Depends(get_verified_user)
  250. ):
  251. toolkit = Tools.get_tool_by_id(id)
  252. if toolkit:
  253. if id in request.app.state.TOOLS:
  254. toolkit_module = request.app.state.TOOLS[id]
  255. else:
  256. toolkit_module, frontmatter = load_toolkit_module_by_id(id)
  257. request.app.state.TOOLS[id] = toolkit_module
  258. if hasattr(toolkit_module, "UserValves"):
  259. UserValves = toolkit_module.UserValves
  260. return UserValves.schema()
  261. return None
  262. else:
  263. raise HTTPException(
  264. status_code=status.HTTP_401_UNAUTHORIZED,
  265. detail=ERROR_MESSAGES.NOT_FOUND,
  266. )
  267. @router.post("/id/{id}/valves/user/update", response_model=Optional[dict])
  268. async def update_toolkit_user_valves_by_id(
  269. request: Request, id: str, form_data: dict, user=Depends(get_verified_user)
  270. ):
  271. toolkit = Tools.get_tool_by_id(id)
  272. if toolkit:
  273. if id in request.app.state.TOOLS:
  274. toolkit_module = request.app.state.TOOLS[id]
  275. else:
  276. toolkit_module, frontmatter = load_toolkit_module_by_id(id)
  277. request.app.state.TOOLS[id] = toolkit_module
  278. if hasattr(toolkit_module, "UserValves"):
  279. UserValves = toolkit_module.UserValves
  280. try:
  281. form_data = {k: v for k, v in form_data.items() if v is not None}
  282. user_valves = UserValves(**form_data)
  283. Tools.update_user_valves_by_id_and_user_id(
  284. id, user.id, user_valves.model_dump()
  285. )
  286. return user_valves.model_dump()
  287. except Exception as e:
  288. print(e)
  289. raise HTTPException(
  290. status_code=status.HTTP_400_BAD_REQUEST,
  291. detail=ERROR_MESSAGES.DEFAULT(e),
  292. )
  293. else:
  294. raise HTTPException(
  295. status_code=status.HTTP_401_UNAUTHORIZED,
  296. detail=ERROR_MESSAGES.NOT_FOUND,
  297. )
  298. else:
  299. raise HTTPException(
  300. status_code=status.HTTP_401_UNAUTHORIZED,
  301. detail=ERROR_MESSAGES.NOT_FOUND,
  302. )