tools.py 11 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384
  1. from fastapi import Depends, FastAPI, HTTPException, status, Request
  2. from datetime import datetime, timedelta
  3. from typing import List, Union, Optional
  4. from fastapi import APIRouter
  5. from pydantic import BaseModel
  6. import json
  7. from apps.webui.internal.db import get_db
  8. from apps.webui.models.users import Users
  9. from apps.webui.models.tools import Tools, ToolForm, ToolModel, ToolResponse
  10. from apps.webui.utils import load_toolkit_module_by_id
  11. from utils.utils import get_admin_user, get_verified_user
  12. from utils.tools import get_tools_specs
  13. from constants import ERROR_MESSAGES
  14. from importlib import util
  15. import os
  16. from pathlib import Path
  17. from config import DATA_DIR, CACHE_DIR
  18. TOOLS_DIR = f"{DATA_DIR}/tools"
  19. os.makedirs(TOOLS_DIR, exist_ok=True)
  20. router = APIRouter()
  21. ############################
  22. # GetToolkits
  23. ############################
  24. @router.get("/", response_model=List[ToolResponse])
  25. async def get_toolkits(user=Depends(get_verified_user), db=Depends(get_db)):
  26. toolkits = [toolkit for toolkit in Tools.get_tools()]
  27. return toolkits
  28. ############################
  29. # ExportToolKits
  30. ############################
  31. @router.get("/export", response_model=List[ToolModel])
  32. async def get_toolkits(user=Depends(get_admin_user), db=Depends(get_db)):
  33. toolkits = [toolkit for toolkit in Tools.get_tools(db)]
  34. return toolkits
  35. ############################
  36. # CreateNewToolKit
  37. ############################
  38. @router.post("/create", response_model=Optional[ToolResponse])
  39. async def create_new_toolkit(
  40. request: Request,
  41. form_data: ToolForm,
  42. user=Depends(get_admin_user),
  43. db=Depends(get_db),
  44. ):
  45. if not form_data.id.isidentifier():
  46. raise HTTPException(
  47. status_code=status.HTTP_400_BAD_REQUEST,
  48. detail="Only alphanumeric characters and underscores are allowed in the id",
  49. )
  50. form_data.id = form_data.id.lower()
  51. toolkit = Tools.get_tool_by_id(db, form_data.id)
  52. if toolkit == None:
  53. toolkit_path = os.path.join(TOOLS_DIR, f"{form_data.id}.py")
  54. try:
  55. with open(toolkit_path, "w") as tool_file:
  56. tool_file.write(form_data.content)
  57. toolkit_module, frontmatter = load_toolkit_module_by_id(form_data.id)
  58. form_data.meta.manifest = frontmatter
  59. TOOLS = request.app.state.TOOLS
  60. TOOLS[form_data.id] = toolkit_module
  61. specs = get_tools_specs(TOOLS[form_data.id])
  62. toolkit = Tools.insert_new_tool(db, user.id, form_data, specs)
  63. tool_cache_dir = Path(CACHE_DIR) / "tools" / form_data.id
  64. tool_cache_dir.mkdir(parents=True, exist_ok=True)
  65. if toolkit:
  66. return toolkit
  67. else:
  68. raise HTTPException(
  69. status_code=status.HTTP_400_BAD_REQUEST,
  70. detail=ERROR_MESSAGES.DEFAULT("Error creating toolkit"),
  71. )
  72. except Exception as e:
  73. print(e)
  74. raise HTTPException(
  75. status_code=status.HTTP_400_BAD_REQUEST,
  76. detail=ERROR_MESSAGES.DEFAULT(e),
  77. )
  78. else:
  79. raise HTTPException(
  80. status_code=status.HTTP_400_BAD_REQUEST,
  81. detail=ERROR_MESSAGES.ID_TAKEN,
  82. )
  83. ############################
  84. # GetToolkitById
  85. ############################
  86. @router.get("/id/{id}", response_model=Optional[ToolModel])
  87. async def get_toolkit_by_id(id: str, user=Depends(get_admin_user), db=Depends(get_db)):
  88. toolkit = Tools.get_tool_by_id(db, id)
  89. if toolkit:
  90. return toolkit
  91. else:
  92. raise HTTPException(
  93. status_code=status.HTTP_401_UNAUTHORIZED,
  94. detail=ERROR_MESSAGES.NOT_FOUND,
  95. )
  96. ############################
  97. # UpdateToolkitById
  98. ############################
  99. @router.post("/id/{id}/update", response_model=Optional[ToolModel])
  100. async def update_toolkit_by_id(
  101. request: Request,
  102. id: str,
  103. form_data: ToolForm,
  104. user=Depends(get_admin_user),
  105. db=Depends(get_db),
  106. ):
  107. toolkit_path = os.path.join(TOOLS_DIR, f"{id}.py")
  108. try:
  109. with open(toolkit_path, "w") as tool_file:
  110. tool_file.write(form_data.content)
  111. toolkit_module, frontmatter = load_toolkit_module_by_id(id)
  112. form_data.meta.manifest = frontmatter
  113. TOOLS = request.app.state.TOOLS
  114. TOOLS[id] = toolkit_module
  115. specs = get_tools_specs(TOOLS[id])
  116. updated = {
  117. **form_data.model_dump(exclude={"id"}),
  118. "specs": specs,
  119. }
  120. print(updated)
  121. toolkit = Tools.update_tool_by_id(db, id, updated)
  122. if toolkit:
  123. return toolkit
  124. else:
  125. raise HTTPException(
  126. status_code=status.HTTP_400_BAD_REQUEST,
  127. detail=ERROR_MESSAGES.DEFAULT("Error updating toolkit"),
  128. )
  129. except Exception as e:
  130. raise HTTPException(
  131. status_code=status.HTTP_400_BAD_REQUEST,
  132. detail=ERROR_MESSAGES.DEFAULT(e),
  133. )
  134. ############################
  135. # DeleteToolkitById
  136. ############################
  137. @router.delete("/id/{id}/delete", response_model=bool)
  138. async def delete_toolkit_by_id(
  139. request: Request, id: str, user=Depends(get_admin_user), db=Depends(get_db)
  140. ):
  141. result = Tools.delete_tool_by_id(db, id)
  142. if result:
  143. TOOLS = request.app.state.TOOLS
  144. if id in TOOLS:
  145. del TOOLS[id]
  146. # delete the toolkit file
  147. toolkit_path = os.path.join(TOOLS_DIR, f"{id}.py")
  148. os.remove(toolkit_path)
  149. return result
  150. ############################
  151. # GetToolValves
  152. ############################
  153. @router.get("/id/{id}/valves", response_model=Optional[dict])
  154. async def get_toolkit_valves_by_id(id: str, user=Depends(get_admin_user)):
  155. toolkit = Tools.get_tool_by_id(id)
  156. if toolkit:
  157. try:
  158. valves = Tools.get_tool_valves_by_id(id)
  159. return valves
  160. except Exception as e:
  161. raise HTTPException(
  162. status_code=status.HTTP_400_BAD_REQUEST,
  163. detail=ERROR_MESSAGES.DEFAULT(e),
  164. )
  165. else:
  166. raise HTTPException(
  167. status_code=status.HTTP_401_UNAUTHORIZED,
  168. detail=ERROR_MESSAGES.NOT_FOUND,
  169. )
  170. ############################
  171. # GetToolValvesSpec
  172. ############################
  173. @router.get("/id/{id}/valves/spec", response_model=Optional[dict])
  174. async def get_toolkit_valves_spec_by_id(
  175. request: Request, id: str, user=Depends(get_admin_user)
  176. ):
  177. toolkit = Tools.get_tool_by_id(id)
  178. if toolkit:
  179. if id in request.app.state.TOOLS:
  180. toolkit_module = request.app.state.TOOLS[id]
  181. else:
  182. toolkit_module, frontmatter = load_toolkit_module_by_id(id)
  183. request.app.state.TOOLS[id] = toolkit_module
  184. if hasattr(toolkit_module, "Valves"):
  185. Valves = toolkit_module.Valves
  186. return Valves.schema()
  187. return None
  188. else:
  189. raise HTTPException(
  190. status_code=status.HTTP_401_UNAUTHORIZED,
  191. detail=ERROR_MESSAGES.NOT_FOUND,
  192. )
  193. ############################
  194. # UpdateToolValves
  195. ############################
  196. @router.post("/id/{id}/valves/update", response_model=Optional[dict])
  197. async def update_toolkit_valves_by_id(
  198. request: Request, id: str, form_data: dict, user=Depends(get_admin_user)
  199. ):
  200. toolkit = Tools.get_tool_by_id(id)
  201. if toolkit:
  202. if id in request.app.state.TOOLS:
  203. toolkit_module = request.app.state.TOOLS[id]
  204. else:
  205. toolkit_module, frontmatter = load_toolkit_module_by_id(id)
  206. request.app.state.TOOLS[id] = toolkit_module
  207. if hasattr(toolkit_module, "Valves"):
  208. Valves = toolkit_module.Valves
  209. try:
  210. form_data = {k: v for k, v in form_data.items() if v is not None}
  211. valves = Valves(**form_data)
  212. Tools.update_tool_valves_by_id(id, valves.model_dump())
  213. return valves.model_dump()
  214. except Exception as e:
  215. print(e)
  216. raise HTTPException(
  217. status_code=status.HTTP_400_BAD_REQUEST,
  218. detail=ERROR_MESSAGES.DEFAULT(e),
  219. )
  220. else:
  221. raise HTTPException(
  222. status_code=status.HTTP_401_UNAUTHORIZED,
  223. detail=ERROR_MESSAGES.NOT_FOUND,
  224. )
  225. else:
  226. raise HTTPException(
  227. status_code=status.HTTP_401_UNAUTHORIZED,
  228. detail=ERROR_MESSAGES.NOT_FOUND,
  229. )
  230. ############################
  231. # ToolUserValves
  232. ############################
  233. @router.get("/id/{id}/valves/user", response_model=Optional[dict])
  234. async def get_toolkit_user_valves_by_id(id: str, user=Depends(get_verified_user)):
  235. toolkit = Tools.get_tool_by_id(id)
  236. if toolkit:
  237. try:
  238. user_valves = Tools.get_user_valves_by_id_and_user_id(id, user.id)
  239. return user_valves
  240. except Exception as e:
  241. raise HTTPException(
  242. status_code=status.HTTP_400_BAD_REQUEST,
  243. detail=ERROR_MESSAGES.DEFAULT(e),
  244. )
  245. else:
  246. raise HTTPException(
  247. status_code=status.HTTP_401_UNAUTHORIZED,
  248. detail=ERROR_MESSAGES.NOT_FOUND,
  249. )
  250. @router.get("/id/{id}/valves/user/spec", response_model=Optional[dict])
  251. async def get_toolkit_user_valves_spec_by_id(
  252. request: Request, id: str, user=Depends(get_verified_user)
  253. ):
  254. toolkit = Tools.get_tool_by_id(id)
  255. if toolkit:
  256. if id in request.app.state.TOOLS:
  257. toolkit_module = request.app.state.TOOLS[id]
  258. else:
  259. toolkit_module, frontmatter = load_toolkit_module_by_id(id)
  260. request.app.state.TOOLS[id] = toolkit_module
  261. if hasattr(toolkit_module, "UserValves"):
  262. UserValves = toolkit_module.UserValves
  263. return UserValves.schema()
  264. return None
  265. else:
  266. raise HTTPException(
  267. status_code=status.HTTP_401_UNAUTHORIZED,
  268. detail=ERROR_MESSAGES.NOT_FOUND,
  269. )
  270. @router.post("/id/{id}/valves/user/update", response_model=Optional[dict])
  271. async def update_toolkit_user_valves_by_id(
  272. request: Request, id: str, form_data: dict, user=Depends(get_verified_user)
  273. ):
  274. toolkit = Tools.get_tool_by_id(id)
  275. if toolkit:
  276. if id in request.app.state.TOOLS:
  277. toolkit_module = request.app.state.TOOLS[id]
  278. else:
  279. toolkit_module, frontmatter = load_toolkit_module_by_id(id)
  280. request.app.state.TOOLS[id] = toolkit_module
  281. if hasattr(toolkit_module, "UserValves"):
  282. UserValves = toolkit_module.UserValves
  283. try:
  284. form_data = {k: v for k, v in form_data.items() if v is not None}
  285. user_valves = UserValves(**form_data)
  286. Tools.update_user_valves_by_id_and_user_id(
  287. id, user.id, user_valves.model_dump()
  288. )
  289. return user_valves.model_dump()
  290. except Exception as e:
  291. print(e)
  292. raise HTTPException(
  293. status_code=status.HTTP_400_BAD_REQUEST,
  294. detail=ERROR_MESSAGES.DEFAULT(e),
  295. )
  296. else:
  297. raise HTTPException(
  298. status_code=status.HTTP_401_UNAUTHORIZED,
  299. detail=ERROR_MESSAGES.NOT_FOUND,
  300. )
  301. else:
  302. raise HTTPException(
  303. status_code=status.HTTP_401_UNAUTHORIZED,
  304. detail=ERROR_MESSAGES.NOT_FOUND,
  305. )