functions.py 13 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426
  1. from fastapi import Depends, FastAPI, HTTPException, status, Request
  2. from datetime import datetime, timedelta
  3. from typing import Union, Optional
  4. from fastapi import APIRouter
  5. from pydantic import BaseModel
  6. import json
  7. from apps.webui.models.functions import (
  8. Functions,
  9. FunctionForm,
  10. FunctionModel,
  11. FunctionResponse,
  12. )
  13. from apps.webui.utils import load_function_module_by_id
  14. from utils.utils import get_verified_user, get_admin_user
  15. from constants import ERROR_MESSAGES
  16. from importlib import util
  17. import os
  18. from pathlib import Path
  19. from config import DATA_DIR, CACHE_DIR, FUNCTIONS_DIR
  20. router = APIRouter()
  21. ############################
  22. # GetFunctions
  23. ############################
  24. @router.get("/", response_model=list[FunctionResponse])
  25. async def get_functions(user=Depends(get_verified_user)):
  26. return Functions.get_functions()
  27. ############################
  28. # ExportFunctions
  29. ############################
  30. @router.get("/export", response_model=list[FunctionModel])
  31. async def get_functions(user=Depends(get_admin_user)):
  32. return Functions.get_functions()
  33. ############################
  34. # CreateNewFunction
  35. ############################
  36. @router.post("/create", response_model=Optional[FunctionResponse])
  37. async def create_new_function(
  38. request: Request, form_data: FunctionForm, user=Depends(get_admin_user)
  39. ):
  40. if not form_data.id.isidentifier():
  41. raise HTTPException(
  42. status_code=status.HTTP_400_BAD_REQUEST,
  43. detail="Only alphanumeric characters and underscores are allowed in the id",
  44. )
  45. form_data.id = form_data.id.lower()
  46. function = Functions.get_function_by_id(form_data.id)
  47. if function is None:
  48. function_path = os.path.join(FUNCTIONS_DIR, f"{form_data.id}.py")
  49. try:
  50. with open(function_path, "w") as function_file:
  51. function_file.write(form_data.content)
  52. function_module, function_type, frontmatter = load_function_module_by_id(
  53. form_data.id
  54. )
  55. form_data.meta.manifest = frontmatter
  56. FUNCTIONS = request.app.state.FUNCTIONS
  57. FUNCTIONS[form_data.id] = function_module
  58. function = Functions.insert_new_function(user.id, function_type, form_data)
  59. function_cache_dir = Path(CACHE_DIR) / "functions" / form_data.id
  60. function_cache_dir.mkdir(parents=True, exist_ok=True)
  61. if function:
  62. return function
  63. else:
  64. raise HTTPException(
  65. status_code=status.HTTP_400_BAD_REQUEST,
  66. detail=ERROR_MESSAGES.DEFAULT("Error creating function"),
  67. )
  68. except Exception as e:
  69. print(e)
  70. raise HTTPException(
  71. status_code=status.HTTP_400_BAD_REQUEST,
  72. detail=ERROR_MESSAGES.DEFAULT(e),
  73. )
  74. else:
  75. raise HTTPException(
  76. status_code=status.HTTP_400_BAD_REQUEST,
  77. detail=ERROR_MESSAGES.ID_TAKEN,
  78. )
  79. ############################
  80. # GetFunctionById
  81. ############################
  82. @router.get("/id/{id}", response_model=Optional[FunctionModel])
  83. async def get_function_by_id(id: str, user=Depends(get_admin_user)):
  84. function = Functions.get_function_by_id(id)
  85. if function:
  86. return function
  87. else:
  88. raise HTTPException(
  89. status_code=status.HTTP_401_UNAUTHORIZED,
  90. detail=ERROR_MESSAGES.NOT_FOUND,
  91. )
  92. ############################
  93. # ToggleFunctionById
  94. ############################
  95. @router.post("/id/{id}/toggle", response_model=Optional[FunctionModel])
  96. async def toggle_function_by_id(id: str, user=Depends(get_admin_user)):
  97. function = Functions.get_function_by_id(id)
  98. if function:
  99. function = Functions.update_function_by_id(
  100. id, {"is_active": not function.is_active}
  101. )
  102. if function:
  103. return function
  104. else:
  105. raise HTTPException(
  106. status_code=status.HTTP_400_BAD_REQUEST,
  107. detail=ERROR_MESSAGES.DEFAULT("Error updating function"),
  108. )
  109. else:
  110. raise HTTPException(
  111. status_code=status.HTTP_401_UNAUTHORIZED,
  112. detail=ERROR_MESSAGES.NOT_FOUND,
  113. )
  114. ############################
  115. # ToggleGlobalById
  116. ############################
  117. @router.post("/id/{id}/toggle/global", response_model=Optional[FunctionModel])
  118. async def toggle_global_by_id(id: str, user=Depends(get_admin_user)):
  119. function = Functions.get_function_by_id(id)
  120. if function:
  121. function = Functions.update_function_by_id(
  122. id, {"is_global": not function.is_global}
  123. )
  124. if function:
  125. return function
  126. else:
  127. raise HTTPException(
  128. status_code=status.HTTP_400_BAD_REQUEST,
  129. detail=ERROR_MESSAGES.DEFAULT("Error updating function"),
  130. )
  131. else:
  132. raise HTTPException(
  133. status_code=status.HTTP_401_UNAUTHORIZED,
  134. detail=ERROR_MESSAGES.NOT_FOUND,
  135. )
  136. ############################
  137. # UpdateFunctionById
  138. ############################
  139. @router.post("/id/{id}/update", response_model=Optional[FunctionModel])
  140. async def update_function_by_id(
  141. request: Request, id: str, form_data: FunctionForm, user=Depends(get_admin_user)
  142. ):
  143. function_path = os.path.join(FUNCTIONS_DIR, f"{id}.py")
  144. try:
  145. with open(function_path, "w") as function_file:
  146. function_file.write(form_data.content)
  147. function_module, function_type, frontmatter = load_function_module_by_id(id)
  148. form_data.meta.manifest = frontmatter
  149. FUNCTIONS = request.app.state.FUNCTIONS
  150. FUNCTIONS[id] = function_module
  151. updated = {**form_data.model_dump(exclude={"id"}), "type": function_type}
  152. print(updated)
  153. function = Functions.update_function_by_id(id, updated)
  154. if function:
  155. return function
  156. else:
  157. raise HTTPException(
  158. status_code=status.HTTP_400_BAD_REQUEST,
  159. detail=ERROR_MESSAGES.DEFAULT("Error updating function"),
  160. )
  161. except Exception as e:
  162. raise HTTPException(
  163. status_code=status.HTTP_400_BAD_REQUEST,
  164. detail=ERROR_MESSAGES.DEFAULT(e),
  165. )
  166. ############################
  167. # DeleteFunctionById
  168. ############################
  169. @router.delete("/id/{id}/delete", response_model=bool)
  170. async def delete_function_by_id(
  171. request: Request, id: str, user=Depends(get_admin_user)
  172. ):
  173. result = Functions.delete_function_by_id(id)
  174. if result:
  175. FUNCTIONS = request.app.state.FUNCTIONS
  176. if id in FUNCTIONS:
  177. del FUNCTIONS[id]
  178. # delete the function file
  179. function_path = os.path.join(FUNCTIONS_DIR, f"{id}.py")
  180. try:
  181. os.remove(function_path)
  182. except Exception:
  183. pass
  184. return result
  185. ############################
  186. # GetFunctionValves
  187. ############################
  188. @router.get("/id/{id}/valves", response_model=Optional[dict])
  189. async def get_function_valves_by_id(id: str, user=Depends(get_admin_user)):
  190. function = Functions.get_function_by_id(id)
  191. if function:
  192. try:
  193. valves = Functions.get_function_valves_by_id(id)
  194. return valves
  195. except Exception as e:
  196. raise HTTPException(
  197. status_code=status.HTTP_400_BAD_REQUEST,
  198. detail=ERROR_MESSAGES.DEFAULT(e),
  199. )
  200. else:
  201. raise HTTPException(
  202. status_code=status.HTTP_401_UNAUTHORIZED,
  203. detail=ERROR_MESSAGES.NOT_FOUND,
  204. )
  205. ############################
  206. # GetFunctionValvesSpec
  207. ############################
  208. @router.get("/id/{id}/valves/spec", response_model=Optional[dict])
  209. async def get_function_valves_spec_by_id(
  210. request: Request, id: str, user=Depends(get_admin_user)
  211. ):
  212. function = Functions.get_function_by_id(id)
  213. if function:
  214. if id in request.app.state.FUNCTIONS:
  215. function_module = request.app.state.FUNCTIONS[id]
  216. else:
  217. function_module, function_type, frontmatter = load_function_module_by_id(id)
  218. request.app.state.FUNCTIONS[id] = function_module
  219. if hasattr(function_module, "Valves"):
  220. Valves = function_module.Valves
  221. return Valves.schema()
  222. return None
  223. else:
  224. raise HTTPException(
  225. status_code=status.HTTP_401_UNAUTHORIZED,
  226. detail=ERROR_MESSAGES.NOT_FOUND,
  227. )
  228. ############################
  229. # UpdateFunctionValves
  230. ############################
  231. @router.post("/id/{id}/valves/update", response_model=Optional[dict])
  232. async def update_function_valves_by_id(
  233. request: Request, id: str, form_data: dict, user=Depends(get_admin_user)
  234. ):
  235. function = Functions.get_function_by_id(id)
  236. if function:
  237. if id in request.app.state.FUNCTIONS:
  238. function_module = request.app.state.FUNCTIONS[id]
  239. else:
  240. function_module, function_type, frontmatter = load_function_module_by_id(id)
  241. request.app.state.FUNCTIONS[id] = function_module
  242. if hasattr(function_module, "Valves"):
  243. Valves = function_module.Valves
  244. try:
  245. form_data = {k: v for k, v in form_data.items() if v is not None}
  246. valves = Valves(**form_data)
  247. Functions.update_function_valves_by_id(id, valves.model_dump())
  248. return valves.model_dump()
  249. except Exception as e:
  250. print(e)
  251. raise HTTPException(
  252. status_code=status.HTTP_400_BAD_REQUEST,
  253. detail=ERROR_MESSAGES.DEFAULT(e),
  254. )
  255. else:
  256. raise HTTPException(
  257. status_code=status.HTTP_401_UNAUTHORIZED,
  258. detail=ERROR_MESSAGES.NOT_FOUND,
  259. )
  260. else:
  261. raise HTTPException(
  262. status_code=status.HTTP_401_UNAUTHORIZED,
  263. detail=ERROR_MESSAGES.NOT_FOUND,
  264. )
  265. ############################
  266. # FunctionUserValves
  267. ############################
  268. @router.get("/id/{id}/valves/user", response_model=Optional[dict])
  269. async def get_function_user_valves_by_id(id: str, user=Depends(get_verified_user)):
  270. function = Functions.get_function_by_id(id)
  271. if function:
  272. try:
  273. user_valves = Functions.get_user_valves_by_id_and_user_id(id, user.id)
  274. return user_valves
  275. except Exception as e:
  276. raise HTTPException(
  277. status_code=status.HTTP_400_BAD_REQUEST,
  278. detail=ERROR_MESSAGES.DEFAULT(e),
  279. )
  280. else:
  281. raise HTTPException(
  282. status_code=status.HTTP_401_UNAUTHORIZED,
  283. detail=ERROR_MESSAGES.NOT_FOUND,
  284. )
  285. @router.get("/id/{id}/valves/user/spec", response_model=Optional[dict])
  286. async def get_function_user_valves_spec_by_id(
  287. request: Request, id: str, user=Depends(get_verified_user)
  288. ):
  289. function = Functions.get_function_by_id(id)
  290. if function:
  291. if id in request.app.state.FUNCTIONS:
  292. function_module = request.app.state.FUNCTIONS[id]
  293. else:
  294. function_module, function_type, frontmatter = load_function_module_by_id(id)
  295. request.app.state.FUNCTIONS[id] = function_module
  296. if hasattr(function_module, "UserValves"):
  297. UserValves = function_module.UserValves
  298. return UserValves.schema()
  299. return None
  300. else:
  301. raise HTTPException(
  302. status_code=status.HTTP_401_UNAUTHORIZED,
  303. detail=ERROR_MESSAGES.NOT_FOUND,
  304. )
  305. @router.post("/id/{id}/valves/user/update", response_model=Optional[dict])
  306. async def update_function_user_valves_by_id(
  307. request: Request, id: str, form_data: dict, user=Depends(get_verified_user)
  308. ):
  309. function = Functions.get_function_by_id(id)
  310. if function:
  311. if id in request.app.state.FUNCTIONS:
  312. function_module = request.app.state.FUNCTIONS[id]
  313. else:
  314. function_module, function_type, frontmatter = load_function_module_by_id(id)
  315. request.app.state.FUNCTIONS[id] = function_module
  316. if hasattr(function_module, "UserValves"):
  317. UserValves = function_module.UserValves
  318. try:
  319. form_data = {k: v for k, v in form_data.items() if v is not None}
  320. user_valves = UserValves(**form_data)
  321. Functions.update_user_valves_by_id_and_user_id(
  322. id, user.id, user_valves.model_dump()
  323. )
  324. return user_valves.model_dump()
  325. except Exception as e:
  326. print(e)
  327. raise HTTPException(
  328. status_code=status.HTTP_400_BAD_REQUEST,
  329. detail=ERROR_MESSAGES.DEFAULT(e),
  330. )
  331. else:
  332. raise HTTPException(
  333. status_code=status.HTTP_401_UNAUTHORIZED,
  334. detail=ERROR_MESSAGES.NOT_FOUND,
  335. )
  336. else:
  337. raise HTTPException(
  338. status_code=status.HTTP_401_UNAUTHORIZED,
  339. detail=ERROR_MESSAGES.NOT_FOUND,
  340. )