functions.py 8.0 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268
  1. from fastapi import Depends, FastAPI, HTTPException, status, Request
  2. from datetime import datetime, timedelta
  3. from typing import List, Union, Optional
  4. from fastapi import APIRouter
  5. from pydantic import BaseModel
  6. import json
  7. from apps.webui.models.functions import (
  8. Functions,
  9. FunctionForm,
  10. FunctionModel,
  11. FunctionResponse,
  12. )
  13. from apps.webui.utils import load_function_module_by_id
  14. from utils.utils import get_verified_user, get_admin_user
  15. from constants import ERROR_MESSAGES
  16. from importlib import util
  17. import os
  18. from pathlib import Path
  19. from config import DATA_DIR, CACHE_DIR, FUNCTIONS_DIR
  20. router = APIRouter()
  21. ############################
  22. # GetFunctions
  23. ############################
  24. @router.get("/", response_model=List[FunctionResponse])
  25. async def get_functions(user=Depends(get_verified_user)):
  26. return Functions.get_functions()
  27. ############################
  28. # ExportFunctions
  29. ############################
  30. @router.get("/export", response_model=List[FunctionModel])
  31. async def get_functions(user=Depends(get_admin_user)):
  32. return Functions.get_functions()
  33. ############################
  34. # CreateNewFunction
  35. ############################
  36. @router.post("/create", response_model=Optional[FunctionResponse])
  37. async def create_new_function(
  38. request: Request, form_data: FunctionForm, user=Depends(get_admin_user)
  39. ):
  40. if not form_data.id.isidentifier():
  41. raise HTTPException(
  42. status_code=status.HTTP_400_BAD_REQUEST,
  43. detail="Only alphanumeric characters and underscores are allowed in the id",
  44. )
  45. form_data.id = form_data.id.lower()
  46. function = Functions.get_function_by_id(form_data.id)
  47. if function == None:
  48. function_path = os.path.join(FUNCTIONS_DIR, f"{form_data.id}.py")
  49. try:
  50. with open(function_path, "w") as function_file:
  51. function_file.write(form_data.content)
  52. function_module, function_type = load_function_module_by_id(form_data.id)
  53. FUNCTIONS = request.app.state.FUNCTIONS
  54. FUNCTIONS[form_data.id] = function_module
  55. function = Functions.insert_new_function(user.id, function_type, form_data)
  56. function_cache_dir = Path(CACHE_DIR) / "functions" / form_data.id
  57. function_cache_dir.mkdir(parents=True, exist_ok=True)
  58. if function:
  59. return function
  60. else:
  61. raise HTTPException(
  62. status_code=status.HTTP_400_BAD_REQUEST,
  63. detail=ERROR_MESSAGES.DEFAULT("Error creating function"),
  64. )
  65. except Exception as e:
  66. print(e)
  67. raise HTTPException(
  68. status_code=status.HTTP_400_BAD_REQUEST,
  69. detail=ERROR_MESSAGES.DEFAULT(e),
  70. )
  71. else:
  72. raise HTTPException(
  73. status_code=status.HTTP_400_BAD_REQUEST,
  74. detail=ERROR_MESSAGES.ID_TAKEN,
  75. )
  76. ############################
  77. # GetFunctionById
  78. ############################
  79. @router.get("/id/{id}", response_model=Optional[FunctionModel])
  80. async def get_function_by_id(id: str, user=Depends(get_admin_user)):
  81. function = Functions.get_function_by_id(id)
  82. if function:
  83. return function
  84. else:
  85. raise HTTPException(
  86. status_code=status.HTTP_401_UNAUTHORIZED,
  87. detail=ERROR_MESSAGES.NOT_FOUND,
  88. )
  89. ############################
  90. # FunctionUserValves
  91. ############################
  92. @router.get("/id/{id}/valves/user", response_model=Optional[dict])
  93. async def get_function_user_valves_by_id(id: str, user=Depends(get_verified_user)):
  94. function = Functions.get_function_by_id(id)
  95. if function:
  96. try:
  97. user_valves = Functions.get_user_valves_by_id_and_user_id(id, user.id)
  98. return user_valves
  99. except Exception as e:
  100. raise HTTPException(
  101. status_code=status.HTTP_400_BAD_REQUEST,
  102. detail=ERROR_MESSAGES.DEFAULT(e),
  103. )
  104. else:
  105. raise HTTPException(
  106. status_code=status.HTTP_401_UNAUTHORIZED,
  107. detail=ERROR_MESSAGES.NOT_FOUND,
  108. )
  109. @router.get("/id/{id}/valves/user/spec", response_model=Optional[dict])
  110. async def get_function_user_valves_spec_by_id(
  111. request: Request, id: str, user=Depends(get_verified_user)
  112. ):
  113. function = Functions.get_tool_by_id(id)
  114. if function:
  115. if id in request.app.state.FUNCTIONS:
  116. function_module = request.app.state.FUNCTIONS[id]
  117. else:
  118. function_module, function_type = load_function_module_by_id(id)
  119. request.app.state.FUNCTIONS[id] = function_module
  120. if hasattr(function_module, "UserValves"):
  121. UserValves = function_module.UserValves
  122. return UserValves.schema()
  123. return None
  124. else:
  125. raise HTTPException(
  126. status_code=status.HTTP_401_UNAUTHORIZED,
  127. detail=ERROR_MESSAGES.NOT_FOUND,
  128. )
  129. @router.post("/id/{id}/valves/user/update", response_model=Optional[dict])
  130. async def update_function_user_valves_by_id(
  131. request: Request, id: str, form_data: dict, user=Depends(get_verified_user)
  132. ):
  133. function = Functions.get_tool_by_id(id)
  134. if function:
  135. if id in request.app.state.FUNCTIONS:
  136. function_module = request.app.state.FUNCTIONS[id]
  137. else:
  138. function_module, function_type = load_function_module_by_id(id)
  139. request.app.state.FUNCTIONS[id] = function_module
  140. if hasattr(function_module, "UserValves"):
  141. UserValves = function_module.UserValves
  142. try:
  143. user_valves = UserValves(**form_data)
  144. Functions.update_user_valves_by_id_and_user_id(
  145. id, user.id, user_valves.model_dump()
  146. )
  147. return user_valves.model_dump()
  148. except Exception as e:
  149. print(e)
  150. raise HTTPException(
  151. status_code=status.HTTP_400_BAD_REQUEST,
  152. detail=ERROR_MESSAGES.DEFAULT(e),
  153. )
  154. else:
  155. raise HTTPException(
  156. status_code=status.HTTP_401_UNAUTHORIZED,
  157. detail=ERROR_MESSAGES.NOT_FOUND,
  158. )
  159. else:
  160. raise HTTPException(
  161. status_code=status.HTTP_401_UNAUTHORIZED,
  162. detail=ERROR_MESSAGES.NOT_FOUND,
  163. )
  164. ############################
  165. # UpdateFunctionById
  166. ############################
  167. @router.post("/id/{id}/update", response_model=Optional[FunctionModel])
  168. async def update_toolkit_by_id(
  169. request: Request, id: str, form_data: FunctionForm, user=Depends(get_admin_user)
  170. ):
  171. function_path = os.path.join(FUNCTIONS_DIR, f"{id}.py")
  172. try:
  173. with open(function_path, "w") as function_file:
  174. function_file.write(form_data.content)
  175. function_module, function_type = load_function_module_by_id(id)
  176. FUNCTIONS = request.app.state.FUNCTIONS
  177. FUNCTIONS[id] = function_module
  178. updated = {**form_data.model_dump(exclude={"id"}), "type": function_type}
  179. print(updated)
  180. function = Functions.update_function_by_id(id, updated)
  181. if function:
  182. return function
  183. else:
  184. raise HTTPException(
  185. status_code=status.HTTP_400_BAD_REQUEST,
  186. detail=ERROR_MESSAGES.DEFAULT("Error updating function"),
  187. )
  188. except Exception as e:
  189. raise HTTPException(
  190. status_code=status.HTTP_400_BAD_REQUEST,
  191. detail=ERROR_MESSAGES.DEFAULT(e),
  192. )
  193. ############################
  194. # DeleteFunctionById
  195. ############################
  196. @router.delete("/id/{id}/delete", response_model=bool)
  197. async def delete_function_by_id(
  198. request: Request, id: str, user=Depends(get_admin_user)
  199. ):
  200. result = Functions.delete_function_by_id(id)
  201. if result:
  202. FUNCTIONS = request.app.state.FUNCTIONS
  203. if id in FUNCTIONS:
  204. del FUNCTIONS[id]
  205. # delete the function file
  206. function_path = os.path.join(FUNCTIONS_DIR, f"{id}.py")
  207. os.remove(function_path)
  208. return result