functions.py 12 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397
  1. from fastapi import Depends, FastAPI, HTTPException, status, Request
  2. from datetime import datetime, timedelta
  3. from typing import List, Union, Optional
  4. from fastapi import APIRouter
  5. from pydantic import BaseModel
  6. import json
  7. from apps.webui.internal.db import get_db
  8. from apps.webui.models.functions import (
  9. Functions,
  10. FunctionForm,
  11. FunctionModel,
  12. FunctionResponse,
  13. )
  14. from apps.webui.utils import load_function_module_by_id
  15. from utils.utils import get_verified_user, get_admin_user
  16. from constants import ERROR_MESSAGES
  17. from importlib import util
  18. import os
  19. from pathlib import Path
  20. from config import DATA_DIR, CACHE_DIR, FUNCTIONS_DIR
  21. router = APIRouter()
  22. ############################
  23. # GetFunctions
  24. ############################
  25. @router.get("/", response_model=List[FunctionResponse])
  26. async def get_functions(user=Depends(get_verified_user), db=Depends(get_db)):
  27. return Functions.get_functions(db)
  28. ############################
  29. # ExportFunctions
  30. ############################
  31. @router.get("/export", response_model=List[FunctionModel])
  32. async def get_functions(user=Depends(get_admin_user), db=Depends(get_db)):
  33. return Functions.get_functions(db)
  34. ############################
  35. # CreateNewFunction
  36. ############################
  37. @router.post("/create", response_model=Optional[FunctionResponse])
  38. async def create_new_function(
  39. request: Request, form_data: FunctionForm, user=Depends(get_admin_user), db=Depends(get_db)
  40. ):
  41. if not form_data.id.isidentifier():
  42. raise HTTPException(
  43. status_code=status.HTTP_400_BAD_REQUEST,
  44. detail="Only alphanumeric characters and underscores are allowed in the id",
  45. )
  46. form_data.id = form_data.id.lower()
  47. function = Functions.get_function_by_id(db, form_data.id)
  48. if function == None:
  49. function_path = os.path.join(FUNCTIONS_DIR, f"{form_data.id}.py")
  50. try:
  51. with open(function_path, "w") as function_file:
  52. function_file.write(form_data.content)
  53. function_module, function_type, frontmatter = load_function_module_by_id(
  54. form_data.id
  55. )
  56. form_data.meta.manifest = frontmatter
  57. FUNCTIONS = request.app.state.FUNCTIONS
  58. FUNCTIONS[form_data.id] = function_module
  59. function = Functions.insert_new_function(db, user.id, function_type, form_data)
  60. function_cache_dir = Path(CACHE_DIR) / "functions" / form_data.id
  61. function_cache_dir.mkdir(parents=True, exist_ok=True)
  62. if function:
  63. return function
  64. else:
  65. raise HTTPException(
  66. status_code=status.HTTP_400_BAD_REQUEST,
  67. detail=ERROR_MESSAGES.DEFAULT("Error creating function"),
  68. )
  69. except Exception as e:
  70. print(e)
  71. raise HTTPException(
  72. status_code=status.HTTP_400_BAD_REQUEST,
  73. detail=ERROR_MESSAGES.DEFAULT(e),
  74. )
  75. else:
  76. raise HTTPException(
  77. status_code=status.HTTP_400_BAD_REQUEST,
  78. detail=ERROR_MESSAGES.ID_TAKEN,
  79. )
  80. ############################
  81. # GetFunctionById
  82. ############################
  83. @router.get("/id/{id}", response_model=Optional[FunctionModel])
  84. async def get_function_by_id(id: str, user=Depends(get_admin_user), db=Depends(get_db)):
  85. function = Functions.get_function_by_id(db, id)
  86. if function:
  87. return function
  88. else:
  89. raise HTTPException(
  90. status_code=status.HTTP_401_UNAUTHORIZED,
  91. detail=ERROR_MESSAGES.NOT_FOUND,
  92. )
  93. ############################
  94. # ToggleFunctionById
  95. ############################
  96. @router.post("/id/{id}/toggle", response_model=Optional[FunctionModel])
  97. async def toggle_function_by_id(id: str, user=Depends(get_admin_user)):
  98. function = Functions.get_function_by_id(id)
  99. if function:
  100. function = Functions.update_function_by_id(
  101. id, {"is_active": not function.is_active}
  102. )
  103. if function:
  104. return function
  105. else:
  106. raise HTTPException(
  107. status_code=status.HTTP_400_BAD_REQUEST,
  108. detail=ERROR_MESSAGES.DEFAULT("Error updating function"),
  109. )
  110. else:
  111. raise HTTPException(
  112. status_code=status.HTTP_401_UNAUTHORIZED,
  113. detail=ERROR_MESSAGES.NOT_FOUND,
  114. )
  115. ############################
  116. # UpdateFunctionById
  117. ############################
  118. @router.post("/id/{id}/update", response_model=Optional[FunctionModel])
  119. async def update_function_by_id(
  120. request: Request, id: str, form_data: FunctionForm, user=Depends(get_admin_user), db=Depends(get_db)
  121. ):
  122. function_path = os.path.join(FUNCTIONS_DIR, f"{id}.py")
  123. try:
  124. with open(function_path, "w") as function_file:
  125. function_file.write(form_data.content)
  126. function_module, function_type, frontmatter = load_function_module_by_id(id)
  127. form_data.meta.manifest = frontmatter
  128. FUNCTIONS = request.app.state.FUNCTIONS
  129. FUNCTIONS[id] = function_module
  130. updated = {**form_data.model_dump(exclude={"id"}), "type": function_type}
  131. print(updated)
  132. function = Functions.update_function_by_id(db, id, updated)
  133. if function:
  134. return function
  135. else:
  136. raise HTTPException(
  137. status_code=status.HTTP_400_BAD_REQUEST,
  138. detail=ERROR_MESSAGES.DEFAULT("Error updating function"),
  139. )
  140. except Exception as e:
  141. raise HTTPException(
  142. status_code=status.HTTP_400_BAD_REQUEST,
  143. detail=ERROR_MESSAGES.DEFAULT(e),
  144. )
  145. ############################
  146. # DeleteFunctionById
  147. ############################
  148. @router.delete("/id/{id}/delete", response_model=bool)
  149. async def delete_function_by_id(
  150. request: Request, id: str, user=Depends(get_admin_user), db=Depends(get_db)
  151. ):
  152. result = Functions.delete_function_by_id(db, id)
  153. if result:
  154. FUNCTIONS = request.app.state.FUNCTIONS
  155. if id in FUNCTIONS:
  156. del FUNCTIONS[id]
  157. # delete the function file
  158. function_path = os.path.join(FUNCTIONS_DIR, f"{id}.py")
  159. os.remove(function_path)
  160. return result
  161. ############################
  162. # GetFunctionValves
  163. ############################
  164. @router.get("/id/{id}/valves", response_model=Optional[dict])
  165. async def get_function_valves_by_id(id: str, user=Depends(get_admin_user)):
  166. function = Functions.get_function_by_id(id)
  167. if function:
  168. try:
  169. valves = Functions.get_function_valves_by_id(id)
  170. return valves
  171. except Exception as e:
  172. raise HTTPException(
  173. status_code=status.HTTP_400_BAD_REQUEST,
  174. detail=ERROR_MESSAGES.DEFAULT(e),
  175. )
  176. else:
  177. raise HTTPException(
  178. status_code=status.HTTP_401_UNAUTHORIZED,
  179. detail=ERROR_MESSAGES.NOT_FOUND,
  180. )
  181. ############################
  182. # GetFunctionValvesSpec
  183. ############################
  184. @router.get("/id/{id}/valves/spec", response_model=Optional[dict])
  185. async def get_function_valves_spec_by_id(
  186. request: Request, id: str, user=Depends(get_admin_user)
  187. ):
  188. function = Functions.get_function_by_id(id)
  189. if function:
  190. if id in request.app.state.FUNCTIONS:
  191. function_module = request.app.state.FUNCTIONS[id]
  192. else:
  193. function_module, function_type, frontmatter = load_function_module_by_id(id)
  194. request.app.state.FUNCTIONS[id] = function_module
  195. if hasattr(function_module, "Valves"):
  196. Valves = function_module.Valves
  197. return Valves.schema()
  198. return None
  199. else:
  200. raise HTTPException(
  201. status_code=status.HTTP_401_UNAUTHORIZED,
  202. detail=ERROR_MESSAGES.NOT_FOUND,
  203. )
  204. ############################
  205. # UpdateFunctionValves
  206. ############################
  207. @router.post("/id/{id}/valves/update", response_model=Optional[dict])
  208. async def update_function_valves_by_id(
  209. request: Request, id: str, form_data: dict, user=Depends(get_admin_user)
  210. ):
  211. function = Functions.get_function_by_id(id)
  212. if function:
  213. if id in request.app.state.FUNCTIONS:
  214. function_module = request.app.state.FUNCTIONS[id]
  215. else:
  216. function_module, function_type, frontmatter = load_function_module_by_id(id)
  217. request.app.state.FUNCTIONS[id] = function_module
  218. if hasattr(function_module, "Valves"):
  219. Valves = function_module.Valves
  220. try:
  221. form_data = {k: v for k, v in form_data.items() if v is not None}
  222. valves = Valves(**form_data)
  223. Functions.update_function_valves_by_id(id, valves.model_dump())
  224. return valves.model_dump()
  225. except Exception as e:
  226. print(e)
  227. raise HTTPException(
  228. status_code=status.HTTP_400_BAD_REQUEST,
  229. detail=ERROR_MESSAGES.DEFAULT(e),
  230. )
  231. else:
  232. raise HTTPException(
  233. status_code=status.HTTP_401_UNAUTHORIZED,
  234. detail=ERROR_MESSAGES.NOT_FOUND,
  235. )
  236. else:
  237. raise HTTPException(
  238. status_code=status.HTTP_401_UNAUTHORIZED,
  239. detail=ERROR_MESSAGES.NOT_FOUND,
  240. )
  241. ############################
  242. # FunctionUserValves
  243. ############################
  244. @router.get("/id/{id}/valves/user", response_model=Optional[dict])
  245. async def get_function_user_valves_by_id(id: str, user=Depends(get_verified_user)):
  246. function = Functions.get_function_by_id(id)
  247. if function:
  248. try:
  249. user_valves = Functions.get_user_valves_by_id_and_user_id(id, user.id)
  250. return user_valves
  251. except Exception as e:
  252. raise HTTPException(
  253. status_code=status.HTTP_400_BAD_REQUEST,
  254. detail=ERROR_MESSAGES.DEFAULT(e),
  255. )
  256. else:
  257. raise HTTPException(
  258. status_code=status.HTTP_401_UNAUTHORIZED,
  259. detail=ERROR_MESSAGES.NOT_FOUND,
  260. )
  261. @router.get("/id/{id}/valves/user/spec", response_model=Optional[dict])
  262. async def get_function_user_valves_spec_by_id(
  263. request: Request, id: str, user=Depends(get_verified_user)
  264. ):
  265. function = Functions.get_function_by_id(id)
  266. if function:
  267. if id in request.app.state.FUNCTIONS:
  268. function_module = request.app.state.FUNCTIONS[id]
  269. else:
  270. function_module, function_type, frontmatter = load_function_module_by_id(id)
  271. request.app.state.FUNCTIONS[id] = function_module
  272. if hasattr(function_module, "UserValves"):
  273. UserValves = function_module.UserValves
  274. return UserValves.schema()
  275. return None
  276. else:
  277. raise HTTPException(
  278. status_code=status.HTTP_401_UNAUTHORIZED,
  279. detail=ERROR_MESSAGES.NOT_FOUND,
  280. )
  281. @router.post("/id/{id}/valves/user/update", response_model=Optional[dict])
  282. async def update_function_user_valves_by_id(
  283. request: Request, id: str, form_data: dict, user=Depends(get_verified_user)
  284. ):
  285. function = Functions.get_function_by_id(id)
  286. if function:
  287. if id in request.app.state.FUNCTIONS:
  288. function_module = request.app.state.FUNCTIONS[id]
  289. else:
  290. function_module, function_type, frontmatter = load_function_module_by_id(id)
  291. request.app.state.FUNCTIONS[id] = function_module
  292. if hasattr(function_module, "UserValves"):
  293. UserValves = function_module.UserValves
  294. try:
  295. form_data = {k: v for k, v in form_data.items() if v is not None}
  296. user_valves = UserValves(**form_data)
  297. Functions.update_user_valves_by_id_and_user_id(
  298. id, user.id, user_valves.model_dump()
  299. )
  300. return user_valves.model_dump()
  301. except Exception as e:
  302. print(e)
  303. raise HTTPException(
  304. status_code=status.HTTP_400_BAD_REQUEST,
  305. detail=ERROR_MESSAGES.DEFAULT(e),
  306. )
  307. else:
  308. raise HTTPException(
  309. status_code=status.HTTP_401_UNAUTHORIZED,
  310. detail=ERROR_MESSAGES.NOT_FOUND,
  311. )
  312. else:
  313. raise HTTPException(
  314. status_code=status.HTTP_401_UNAUTHORIZED,
  315. detail=ERROR_MESSAGES.NOT_FOUND,
  316. )