functions.py 12 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405
  1. import os
  2. from pathlib import Path
  3. from typing import Optional
  4. from open_webui.models.functions import (
  5. FunctionForm,
  6. FunctionModel,
  7. FunctionResponse,
  8. Functions,
  9. )
  10. from open_webui.utils.plugin import load_function_module_by_id, replace_imports
  11. from open_webui.config import CACHE_DIR
  12. from open_webui.constants import ERROR_MESSAGES
  13. from fastapi import APIRouter, Depends, HTTPException, Request, status
  14. from open_webui.utils.auth import get_admin_user, get_verified_user
  15. router = APIRouter()
  16. ############################
  17. # GetFunctions
  18. ############################
  19. @router.get("/", response_model=list[FunctionResponse])
  20. async def get_functions(user=Depends(get_verified_user)):
  21. return Functions.get_functions()
  22. ############################
  23. # ExportFunctions
  24. ############################
  25. @router.get("/export", response_model=list[FunctionModel])
  26. async def get_functions(user=Depends(get_admin_user)):
  27. return Functions.get_functions()
  28. ############################
  29. # CreateNewFunction
  30. ############################
  31. @router.post("/create", response_model=Optional[FunctionResponse])
  32. async def create_new_function(
  33. request: Request, form_data: FunctionForm, user=Depends(get_admin_user)
  34. ):
  35. if not form_data.id.isidentifier():
  36. raise HTTPException(
  37. status_code=status.HTTP_400_BAD_REQUEST,
  38. detail="Only alphanumeric characters and underscores are allowed in the id",
  39. )
  40. form_data.id = form_data.id.lower()
  41. function = Functions.get_function_by_id(form_data.id)
  42. if function is None:
  43. try:
  44. form_data.content = replace_imports(form_data.content)
  45. function_module, function_type, frontmatter = load_function_module_by_id(
  46. form_data.id,
  47. content=form_data.content,
  48. )
  49. form_data.meta.manifest = frontmatter
  50. FUNCTIONS = request.app.state.FUNCTIONS
  51. FUNCTIONS[form_data.id] = function_module
  52. function = Functions.insert_new_function(user.id, function_type, form_data)
  53. function_cache_dir = Path(CACHE_DIR) / "functions" / form_data.id
  54. function_cache_dir.mkdir(parents=True, exist_ok=True)
  55. if function:
  56. return function
  57. else:
  58. raise HTTPException(
  59. status_code=status.HTTP_400_BAD_REQUEST,
  60. detail=ERROR_MESSAGES.DEFAULT("Error creating function"),
  61. )
  62. except Exception as e:
  63. print(e)
  64. raise HTTPException(
  65. status_code=status.HTTP_400_BAD_REQUEST,
  66. detail=ERROR_MESSAGES.DEFAULT(e),
  67. )
  68. else:
  69. raise HTTPException(
  70. status_code=status.HTTP_400_BAD_REQUEST,
  71. detail=ERROR_MESSAGES.ID_TAKEN,
  72. )
  73. ############################
  74. # GetFunctionById
  75. ############################
  76. @router.get("/id/{id}", response_model=Optional[FunctionModel])
  77. async def get_function_by_id(id: str, user=Depends(get_admin_user)):
  78. function = Functions.get_function_by_id(id)
  79. if function:
  80. return function
  81. else:
  82. raise HTTPException(
  83. status_code=status.HTTP_401_UNAUTHORIZED,
  84. detail=ERROR_MESSAGES.NOT_FOUND,
  85. )
  86. ############################
  87. # ToggleFunctionById
  88. ############################
  89. @router.post("/id/{id}/toggle", response_model=Optional[FunctionModel])
  90. async def toggle_function_by_id(id: str, user=Depends(get_admin_user)):
  91. function = Functions.get_function_by_id(id)
  92. if function:
  93. function = Functions.update_function_by_id(
  94. id, {"is_active": not function.is_active}
  95. )
  96. if function:
  97. return function
  98. else:
  99. raise HTTPException(
  100. status_code=status.HTTP_400_BAD_REQUEST,
  101. detail=ERROR_MESSAGES.DEFAULT("Error updating function"),
  102. )
  103. else:
  104. raise HTTPException(
  105. status_code=status.HTTP_401_UNAUTHORIZED,
  106. detail=ERROR_MESSAGES.NOT_FOUND,
  107. )
  108. ############################
  109. # ToggleGlobalById
  110. ############################
  111. @router.post("/id/{id}/toggle/global", response_model=Optional[FunctionModel])
  112. async def toggle_global_by_id(id: str, user=Depends(get_admin_user)):
  113. function = Functions.get_function_by_id(id)
  114. if function:
  115. function = Functions.update_function_by_id(
  116. id, {"is_global": not function.is_global}
  117. )
  118. if function:
  119. return function
  120. else:
  121. raise HTTPException(
  122. status_code=status.HTTP_400_BAD_REQUEST,
  123. detail=ERROR_MESSAGES.DEFAULT("Error updating function"),
  124. )
  125. else:
  126. raise HTTPException(
  127. status_code=status.HTTP_401_UNAUTHORIZED,
  128. detail=ERROR_MESSAGES.NOT_FOUND,
  129. )
  130. ############################
  131. # UpdateFunctionById
  132. ############################
  133. @router.post("/id/{id}/update", response_model=Optional[FunctionModel])
  134. async def update_function_by_id(
  135. request: Request, id: str, form_data: FunctionForm, user=Depends(get_admin_user)
  136. ):
  137. try:
  138. form_data.content = replace_imports(form_data.content)
  139. function_module, function_type, frontmatter = load_function_module_by_id(
  140. id, content=form_data.content
  141. )
  142. form_data.meta.manifest = frontmatter
  143. FUNCTIONS = request.app.state.FUNCTIONS
  144. FUNCTIONS[id] = function_module
  145. updated = {**form_data.model_dump(exclude={"id"}), "type": function_type}
  146. print(updated)
  147. function = Functions.update_function_by_id(id, updated)
  148. if function:
  149. return function
  150. else:
  151. raise HTTPException(
  152. status_code=status.HTTP_400_BAD_REQUEST,
  153. detail=ERROR_MESSAGES.DEFAULT("Error updating function"),
  154. )
  155. except Exception as e:
  156. raise HTTPException(
  157. status_code=status.HTTP_400_BAD_REQUEST,
  158. detail=ERROR_MESSAGES.DEFAULT(e),
  159. )
  160. ############################
  161. # DeleteFunctionById
  162. ############################
  163. @router.delete("/id/{id}/delete", response_model=bool)
  164. async def delete_function_by_id(
  165. request: Request, id: str, user=Depends(get_admin_user)
  166. ):
  167. result = Functions.delete_function_by_id(id)
  168. if result:
  169. FUNCTIONS = request.app.state.FUNCTIONS
  170. if id in FUNCTIONS:
  171. del FUNCTIONS[id]
  172. return result
  173. ############################
  174. # GetFunctionValves
  175. ############################
  176. @router.get("/id/{id}/valves", response_model=Optional[dict])
  177. async def get_function_valves_by_id(id: str, user=Depends(get_admin_user)):
  178. function = Functions.get_function_by_id(id)
  179. if function:
  180. try:
  181. valves = Functions.get_function_valves_by_id(id)
  182. return valves
  183. except Exception as e:
  184. raise HTTPException(
  185. status_code=status.HTTP_400_BAD_REQUEST,
  186. detail=ERROR_MESSAGES.DEFAULT(e),
  187. )
  188. else:
  189. raise HTTPException(
  190. status_code=status.HTTP_401_UNAUTHORIZED,
  191. detail=ERROR_MESSAGES.NOT_FOUND,
  192. )
  193. ############################
  194. # GetFunctionValvesSpec
  195. ############################
  196. @router.get("/id/{id}/valves/spec", response_model=Optional[dict])
  197. async def get_function_valves_spec_by_id(
  198. request: Request, id: str, user=Depends(get_admin_user)
  199. ):
  200. function = Functions.get_function_by_id(id)
  201. if function:
  202. if id in request.app.state.FUNCTIONS:
  203. function_module = request.app.state.FUNCTIONS[id]
  204. else:
  205. function_module, function_type, frontmatter = load_function_module_by_id(id)
  206. request.app.state.FUNCTIONS[id] = function_module
  207. if hasattr(function_module, "Valves"):
  208. Valves = function_module.Valves
  209. return Valves.schema()
  210. return None
  211. else:
  212. raise HTTPException(
  213. status_code=status.HTTP_401_UNAUTHORIZED,
  214. detail=ERROR_MESSAGES.NOT_FOUND,
  215. )
  216. ############################
  217. # UpdateFunctionValves
  218. ############################
  219. @router.post("/id/{id}/valves/update", response_model=Optional[dict])
  220. async def update_function_valves_by_id(
  221. request: Request, id: str, form_data: dict, user=Depends(get_admin_user)
  222. ):
  223. function = Functions.get_function_by_id(id)
  224. if function:
  225. if id in request.app.state.FUNCTIONS:
  226. function_module = request.app.state.FUNCTIONS[id]
  227. else:
  228. function_module, function_type, frontmatter = load_function_module_by_id(id)
  229. request.app.state.FUNCTIONS[id] = function_module
  230. if hasattr(function_module, "Valves"):
  231. Valves = function_module.Valves
  232. try:
  233. form_data = {k: v for k, v in form_data.items() if v is not None}
  234. valves = Valves(**form_data)
  235. Functions.update_function_valves_by_id(id, valves.model_dump())
  236. return valves.model_dump()
  237. except Exception as e:
  238. print(e)
  239. raise HTTPException(
  240. status_code=status.HTTP_400_BAD_REQUEST,
  241. detail=ERROR_MESSAGES.DEFAULT(e),
  242. )
  243. else:
  244. raise HTTPException(
  245. status_code=status.HTTP_401_UNAUTHORIZED,
  246. detail=ERROR_MESSAGES.NOT_FOUND,
  247. )
  248. else:
  249. raise HTTPException(
  250. status_code=status.HTTP_401_UNAUTHORIZED,
  251. detail=ERROR_MESSAGES.NOT_FOUND,
  252. )
  253. ############################
  254. # FunctionUserValves
  255. ############################
  256. @router.get("/id/{id}/valves/user", response_model=Optional[dict])
  257. async def get_function_user_valves_by_id(id: str, user=Depends(get_verified_user)):
  258. function = Functions.get_function_by_id(id)
  259. if function:
  260. try:
  261. user_valves = Functions.get_user_valves_by_id_and_user_id(id, user.id)
  262. return user_valves
  263. except Exception as e:
  264. raise HTTPException(
  265. status_code=status.HTTP_400_BAD_REQUEST,
  266. detail=ERROR_MESSAGES.DEFAULT(e),
  267. )
  268. else:
  269. raise HTTPException(
  270. status_code=status.HTTP_401_UNAUTHORIZED,
  271. detail=ERROR_MESSAGES.NOT_FOUND,
  272. )
  273. @router.get("/id/{id}/valves/user/spec", response_model=Optional[dict])
  274. async def get_function_user_valves_spec_by_id(
  275. request: Request, id: str, user=Depends(get_verified_user)
  276. ):
  277. function = Functions.get_function_by_id(id)
  278. if function:
  279. if id in request.app.state.FUNCTIONS:
  280. function_module = request.app.state.FUNCTIONS[id]
  281. else:
  282. function_module, function_type, frontmatter = load_function_module_by_id(id)
  283. request.app.state.FUNCTIONS[id] = function_module
  284. if hasattr(function_module, "UserValves"):
  285. UserValves = function_module.UserValves
  286. return UserValves.schema()
  287. return None
  288. else:
  289. raise HTTPException(
  290. status_code=status.HTTP_401_UNAUTHORIZED,
  291. detail=ERROR_MESSAGES.NOT_FOUND,
  292. )
  293. @router.post("/id/{id}/valves/user/update", response_model=Optional[dict])
  294. async def update_function_user_valves_by_id(
  295. request: Request, id: str, form_data: dict, user=Depends(get_verified_user)
  296. ):
  297. function = Functions.get_function_by_id(id)
  298. if function:
  299. if id in request.app.state.FUNCTIONS:
  300. function_module = request.app.state.FUNCTIONS[id]
  301. else:
  302. function_module, function_type, frontmatter = load_function_module_by_id(id)
  303. request.app.state.FUNCTIONS[id] = function_module
  304. if hasattr(function_module, "UserValves"):
  305. UserValves = function_module.UserValves
  306. try:
  307. form_data = {k: v for k, v in form_data.items() if v is not None}
  308. user_valves = UserValves(**form_data)
  309. Functions.update_user_valves_by_id_and_user_id(
  310. id, user.id, user_valves.model_dump()
  311. )
  312. return user_valves.model_dump()
  313. except Exception as e:
  314. print(e)
  315. raise HTTPException(
  316. status_code=status.HTTP_400_BAD_REQUEST,
  317. detail=ERROR_MESSAGES.DEFAULT(e),
  318. )
  319. else:
  320. raise HTTPException(
  321. status_code=status.HTTP_401_UNAUTHORIZED,
  322. detail=ERROR_MESSAGES.NOT_FOUND,
  323. )
  324. else:
  325. raise HTTPException(
  326. status_code=status.HTTP_401_UNAUTHORIZED,
  327. detail=ERROR_MESSAGES.NOT_FOUND,
  328. )