functions.py 13 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416
  1. import os
  2. from pathlib import Path
  3. from typing import Optional
  4. from open_webui.apps.webui.models.functions import (
  5. FunctionForm,
  6. FunctionModel,
  7. FunctionResponse,
  8. Functions,
  9. )
  10. from open_webui.apps.webui.utils import load_function_module_by_id
  11. from open_webui.config import CACHE_DIR, FUNCTIONS_DIR
  12. from open_webui.constants import ERROR_MESSAGES
  13. from fastapi import APIRouter, Depends, HTTPException, Request, status
  14. from open_webui.utils.utils import get_admin_user, get_verified_user
  15. router = APIRouter()
  16. ############################
  17. # GetFunctions
  18. ############################
  19. @router.get("/", response_model=list[FunctionResponse])
  20. async def get_functions(user=Depends(get_verified_user)):
  21. return Functions.get_functions()
  22. ############################
  23. # ExportFunctions
  24. ############################
  25. @router.get("/export", response_model=list[FunctionModel])
  26. async def get_functions(user=Depends(get_admin_user)):
  27. return Functions.get_functions()
  28. ############################
  29. # CreateNewFunction
  30. ############################
  31. @router.post("/create", response_model=Optional[FunctionResponse])
  32. async def create_new_function(
  33. request: Request, form_data: FunctionForm, user=Depends(get_admin_user)
  34. ):
  35. if not form_data.id.isidentifier():
  36. raise HTTPException(
  37. status_code=status.HTTP_400_BAD_REQUEST,
  38. detail="Only alphanumeric characters and underscores are allowed in the id",
  39. )
  40. form_data.id = form_data.id.lower()
  41. function = Functions.get_function_by_id(form_data.id)
  42. if function is None:
  43. function_path = os.path.join(FUNCTIONS_DIR, f"{form_data.id}.py")
  44. try:
  45. with open(function_path, "w") as function_file:
  46. function_file.write(form_data.content)
  47. function_module, function_type, frontmatter = load_function_module_by_id(
  48. form_data.id
  49. )
  50. form_data.meta.manifest = frontmatter
  51. FUNCTIONS = request.app.state.FUNCTIONS
  52. FUNCTIONS[form_data.id] = function_module
  53. function = Functions.insert_new_function(user.id, function_type, form_data)
  54. function_cache_dir = Path(CACHE_DIR) / "functions" / form_data.id
  55. function_cache_dir.mkdir(parents=True, exist_ok=True)
  56. if function:
  57. return function
  58. else:
  59. raise HTTPException(
  60. status_code=status.HTTP_400_BAD_REQUEST,
  61. detail=ERROR_MESSAGES.DEFAULT("Error creating function"),
  62. )
  63. except Exception as e:
  64. print(e)
  65. raise HTTPException(
  66. status_code=status.HTTP_400_BAD_REQUEST,
  67. detail=ERROR_MESSAGES.DEFAULT(e),
  68. )
  69. else:
  70. raise HTTPException(
  71. status_code=status.HTTP_400_BAD_REQUEST,
  72. detail=ERROR_MESSAGES.ID_TAKEN,
  73. )
  74. ############################
  75. # GetFunctionById
  76. ############################
  77. @router.get("/id/{id}", response_model=Optional[FunctionModel])
  78. async def get_function_by_id(id: str, user=Depends(get_admin_user)):
  79. function = Functions.get_function_by_id(id)
  80. if function:
  81. return function
  82. else:
  83. raise HTTPException(
  84. status_code=status.HTTP_401_UNAUTHORIZED,
  85. detail=ERROR_MESSAGES.NOT_FOUND,
  86. )
  87. ############################
  88. # ToggleFunctionById
  89. ############################
  90. @router.post("/id/{id}/toggle", response_model=Optional[FunctionModel])
  91. async def toggle_function_by_id(id: str, user=Depends(get_admin_user)):
  92. function = Functions.get_function_by_id(id)
  93. if function:
  94. function = Functions.update_function_by_id(
  95. id, {"is_active": not function.is_active}
  96. )
  97. if function:
  98. return function
  99. else:
  100. raise HTTPException(
  101. status_code=status.HTTP_400_BAD_REQUEST,
  102. detail=ERROR_MESSAGES.DEFAULT("Error updating function"),
  103. )
  104. else:
  105. raise HTTPException(
  106. status_code=status.HTTP_401_UNAUTHORIZED,
  107. detail=ERROR_MESSAGES.NOT_FOUND,
  108. )
  109. ############################
  110. # ToggleGlobalById
  111. ############################
  112. @router.post("/id/{id}/toggle/global", response_model=Optional[FunctionModel])
  113. async def toggle_global_by_id(id: str, user=Depends(get_admin_user)):
  114. function = Functions.get_function_by_id(id)
  115. if function:
  116. function = Functions.update_function_by_id(
  117. id, {"is_global": not function.is_global}
  118. )
  119. if function:
  120. return function
  121. else:
  122. raise HTTPException(
  123. status_code=status.HTTP_400_BAD_REQUEST,
  124. detail=ERROR_MESSAGES.DEFAULT("Error updating function"),
  125. )
  126. else:
  127. raise HTTPException(
  128. status_code=status.HTTP_401_UNAUTHORIZED,
  129. detail=ERROR_MESSAGES.NOT_FOUND,
  130. )
  131. ############################
  132. # UpdateFunctionById
  133. ############################
  134. @router.post("/id/{id}/update", response_model=Optional[FunctionModel])
  135. async def update_function_by_id(
  136. request: Request, id: str, form_data: FunctionForm, user=Depends(get_admin_user)
  137. ):
  138. function_path = os.path.join(FUNCTIONS_DIR, f"{id}.py")
  139. try:
  140. with open(function_path, "w") as function_file:
  141. function_file.write(form_data.content)
  142. function_module, function_type, frontmatter = load_function_module_by_id(id)
  143. form_data.meta.manifest = frontmatter
  144. FUNCTIONS = request.app.state.FUNCTIONS
  145. FUNCTIONS[id] = function_module
  146. updated = {**form_data.model_dump(exclude={"id"}), "type": function_type}
  147. print(updated)
  148. function = Functions.update_function_by_id(id, updated)
  149. if function:
  150. return function
  151. else:
  152. raise HTTPException(
  153. status_code=status.HTTP_400_BAD_REQUEST,
  154. detail=ERROR_MESSAGES.DEFAULT("Error updating function"),
  155. )
  156. except Exception as e:
  157. raise HTTPException(
  158. status_code=status.HTTP_400_BAD_REQUEST,
  159. detail=ERROR_MESSAGES.DEFAULT(e),
  160. )
  161. ############################
  162. # DeleteFunctionById
  163. ############################
  164. @router.delete("/id/{id}/delete", response_model=bool)
  165. async def delete_function_by_id(
  166. request: Request, id: str, user=Depends(get_admin_user)
  167. ):
  168. result = Functions.delete_function_by_id(id)
  169. if result:
  170. FUNCTIONS = request.app.state.FUNCTIONS
  171. if id in FUNCTIONS:
  172. del FUNCTIONS[id]
  173. # delete the function file
  174. function_path = os.path.join(FUNCTIONS_DIR, f"{id}.py")
  175. try:
  176. os.remove(function_path)
  177. except Exception:
  178. pass
  179. return result
  180. ############################
  181. # GetFunctionValves
  182. ############################
  183. @router.get("/id/{id}/valves", response_model=Optional[dict])
  184. async def get_function_valves_by_id(id: str, user=Depends(get_admin_user)):
  185. function = Functions.get_function_by_id(id)
  186. if function:
  187. try:
  188. valves = Functions.get_function_valves_by_id(id)
  189. return valves
  190. except Exception as e:
  191. raise HTTPException(
  192. status_code=status.HTTP_400_BAD_REQUEST,
  193. detail=ERROR_MESSAGES.DEFAULT(e),
  194. )
  195. else:
  196. raise HTTPException(
  197. status_code=status.HTTP_401_UNAUTHORIZED,
  198. detail=ERROR_MESSAGES.NOT_FOUND,
  199. )
  200. ############################
  201. # GetFunctionValvesSpec
  202. ############################
  203. @router.get("/id/{id}/valves/spec", response_model=Optional[dict])
  204. async def get_function_valves_spec_by_id(
  205. request: Request, id: str, user=Depends(get_admin_user)
  206. ):
  207. function = Functions.get_function_by_id(id)
  208. if function:
  209. if id in request.app.state.FUNCTIONS:
  210. function_module = request.app.state.FUNCTIONS[id]
  211. else:
  212. function_module, function_type, frontmatter = load_function_module_by_id(id)
  213. request.app.state.FUNCTIONS[id] = function_module
  214. if hasattr(function_module, "Valves"):
  215. Valves = function_module.Valves
  216. return Valves.schema()
  217. return None
  218. else:
  219. raise HTTPException(
  220. status_code=status.HTTP_401_UNAUTHORIZED,
  221. detail=ERROR_MESSAGES.NOT_FOUND,
  222. )
  223. ############################
  224. # UpdateFunctionValves
  225. ############################
  226. @router.post("/id/{id}/valves/update", response_model=Optional[dict])
  227. async def update_function_valves_by_id(
  228. request: Request, id: str, form_data: dict, user=Depends(get_admin_user)
  229. ):
  230. function = Functions.get_function_by_id(id)
  231. if function:
  232. if id in request.app.state.FUNCTIONS:
  233. function_module = request.app.state.FUNCTIONS[id]
  234. else:
  235. function_module, function_type, frontmatter = load_function_module_by_id(id)
  236. request.app.state.FUNCTIONS[id] = function_module
  237. if hasattr(function_module, "Valves"):
  238. Valves = function_module.Valves
  239. try:
  240. form_data = {k: v for k, v in form_data.items() if v is not None}
  241. valves = Valves(**form_data)
  242. Functions.update_function_valves_by_id(id, valves.model_dump())
  243. return valves.model_dump()
  244. except Exception as e:
  245. print(e)
  246. raise HTTPException(
  247. status_code=status.HTTP_400_BAD_REQUEST,
  248. detail=ERROR_MESSAGES.DEFAULT(e),
  249. )
  250. else:
  251. raise HTTPException(
  252. status_code=status.HTTP_401_UNAUTHORIZED,
  253. detail=ERROR_MESSAGES.NOT_FOUND,
  254. )
  255. else:
  256. raise HTTPException(
  257. status_code=status.HTTP_401_UNAUTHORIZED,
  258. detail=ERROR_MESSAGES.NOT_FOUND,
  259. )
  260. ############################
  261. # FunctionUserValves
  262. ############################
  263. @router.get("/id/{id}/valves/user", response_model=Optional[dict])
  264. async def get_function_user_valves_by_id(id: str, user=Depends(get_verified_user)):
  265. function = Functions.get_function_by_id(id)
  266. if function:
  267. try:
  268. user_valves = Functions.get_user_valves_by_id_and_user_id(id, user.id)
  269. return user_valves
  270. except Exception as e:
  271. raise HTTPException(
  272. status_code=status.HTTP_400_BAD_REQUEST,
  273. detail=ERROR_MESSAGES.DEFAULT(e),
  274. )
  275. else:
  276. raise HTTPException(
  277. status_code=status.HTTP_401_UNAUTHORIZED,
  278. detail=ERROR_MESSAGES.NOT_FOUND,
  279. )
  280. @router.get("/id/{id}/valves/user/spec", response_model=Optional[dict])
  281. async def get_function_user_valves_spec_by_id(
  282. request: Request, id: str, user=Depends(get_verified_user)
  283. ):
  284. function = Functions.get_function_by_id(id)
  285. if function:
  286. if id in request.app.state.FUNCTIONS:
  287. function_module = request.app.state.FUNCTIONS[id]
  288. else:
  289. function_module, function_type, frontmatter = load_function_module_by_id(id)
  290. request.app.state.FUNCTIONS[id] = function_module
  291. if hasattr(function_module, "UserValves"):
  292. UserValves = function_module.UserValves
  293. return UserValves.schema()
  294. return None
  295. else:
  296. raise HTTPException(
  297. status_code=status.HTTP_401_UNAUTHORIZED,
  298. detail=ERROR_MESSAGES.NOT_FOUND,
  299. )
  300. @router.post("/id/{id}/valves/user/update", response_model=Optional[dict])
  301. async def update_function_user_valves_by_id(
  302. request: Request, id: str, form_data: dict, user=Depends(get_verified_user)
  303. ):
  304. function = Functions.get_function_by_id(id)
  305. if function:
  306. if id in request.app.state.FUNCTIONS:
  307. function_module = request.app.state.FUNCTIONS[id]
  308. else:
  309. function_module, function_type, frontmatter = load_function_module_by_id(id)
  310. request.app.state.FUNCTIONS[id] = function_module
  311. if hasattr(function_module, "UserValves"):
  312. UserValves = function_module.UserValves
  313. try:
  314. form_data = {k: v for k, v in form_data.items() if v is not None}
  315. user_valves = UserValves(**form_data)
  316. Functions.update_user_valves_by_id_and_user_id(
  317. id, user.id, user_valves.model_dump()
  318. )
  319. return user_valves.model_dump()
  320. except Exception as e:
  321. print(e)
  322. raise HTTPException(
  323. status_code=status.HTTP_400_BAD_REQUEST,
  324. detail=ERROR_MESSAGES.DEFAULT(e),
  325. )
  326. else:
  327. raise HTTPException(
  328. status_code=status.HTTP_401_UNAUTHORIZED,
  329. detail=ERROR_MESSAGES.NOT_FOUND,
  330. )
  331. else:
  332. raise HTTPException(
  333. status_code=status.HTTP_401_UNAUTHORIZED,
  334. detail=ERROR_MESSAGES.NOT_FOUND,
  335. )