functions.py 8.1 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288
  1. from pydantic import BaseModel, ConfigDict
  2. from typing import Union, Optional
  3. import time
  4. import logging
  5. from sqlalchemy import Column, String, Text, BigInteger, Boolean
  6. from apps.webui.internal.db import JSONField, Base, get_db
  7. from apps.webui.models.users import Users
  8. import json
  9. import copy
  10. from env import SRC_LOG_LEVELS
  11. log = logging.getLogger(__name__)
  12. log.setLevel(SRC_LOG_LEVELS["MODELS"])
  13. ####################
  14. # Functions DB Schema
  15. ####################
  16. class Function(Base):
  17. __tablename__ = "function"
  18. id = Column(String, primary_key=True)
  19. user_id = Column(String)
  20. name = Column(Text)
  21. type = Column(Text)
  22. content = Column(Text)
  23. meta = Column(JSONField)
  24. valves = Column(JSONField)
  25. is_active = Column(Boolean)
  26. is_global = Column(Boolean)
  27. updated_at = Column(BigInteger)
  28. created_at = Column(BigInteger)
  29. class FunctionMeta(BaseModel):
  30. description: Optional[str] = None
  31. manifest: Optional[dict] = {}
  32. class FunctionModel(BaseModel):
  33. id: str
  34. user_id: str
  35. name: str
  36. type: str
  37. content: str
  38. meta: FunctionMeta
  39. is_active: bool = False
  40. is_global: bool = False
  41. updated_at: int # timestamp in epoch
  42. created_at: int # timestamp in epoch
  43. model_config = ConfigDict(from_attributes=True)
  44. ####################
  45. # Forms
  46. ####################
  47. class FunctionResponse(BaseModel):
  48. id: str
  49. user_id: str
  50. type: str
  51. name: str
  52. meta: FunctionMeta
  53. is_active: bool
  54. is_global: bool
  55. updated_at: int # timestamp in epoch
  56. created_at: int # timestamp in epoch
  57. class FunctionForm(BaseModel):
  58. id: str
  59. name: str
  60. content: str
  61. meta: FunctionMeta
  62. class FunctionValves(BaseModel):
  63. valves: Optional[dict] = None
  64. class FunctionsTable:
  65. def insert_new_function(
  66. self, user_id: str, type: str, form_data: FunctionForm
  67. ) -> Optional[FunctionModel]:
  68. function = FunctionModel(
  69. **{
  70. **form_data.model_dump(),
  71. "user_id": user_id,
  72. "type": type,
  73. "updated_at": int(time.time()),
  74. "created_at": int(time.time()),
  75. }
  76. )
  77. try:
  78. with get_db() as db:
  79. result = Function(**function.model_dump())
  80. db.add(result)
  81. db.commit()
  82. db.refresh(result)
  83. if result:
  84. return FunctionModel.model_validate(result)
  85. else:
  86. return None
  87. except Exception as e:
  88. print(f"Error creating tool: {e}")
  89. return None
  90. def get_function_by_id(self, id: str) -> Optional[FunctionModel]:
  91. try:
  92. with get_db() as db:
  93. function = db.get(Function, id)
  94. return FunctionModel.model_validate(function)
  95. except Exception:
  96. return None
  97. def get_functions(self, active_only=False) -> list[FunctionModel]:
  98. with get_db() as db:
  99. if active_only:
  100. return [
  101. FunctionModel.model_validate(function)
  102. for function in db.query(Function).filter_by(is_active=True).all()
  103. ]
  104. else:
  105. return [
  106. FunctionModel.model_validate(function)
  107. for function in db.query(Function).all()
  108. ]
  109. def get_functions_by_type(
  110. self, type: str, active_only=False
  111. ) -> list[FunctionModel]:
  112. with get_db() as db:
  113. if active_only:
  114. return [
  115. FunctionModel.model_validate(function)
  116. for function in db.query(Function)
  117. .filter_by(type=type, is_active=True)
  118. .all()
  119. ]
  120. else:
  121. return [
  122. FunctionModel.model_validate(function)
  123. for function in db.query(Function).filter_by(type=type).all()
  124. ]
  125. def get_global_filter_functions(self) -> list[FunctionModel]:
  126. with get_db() as db:
  127. return [
  128. FunctionModel.model_validate(function)
  129. for function in db.query(Function)
  130. .filter_by(type="filter", is_active=True, is_global=True)
  131. .all()
  132. ]
  133. def get_global_action_functions(self) -> list[FunctionModel]:
  134. with get_db() as db:
  135. return [
  136. FunctionModel.model_validate(function)
  137. for function in db.query(Function)
  138. .filter_by(type="action", is_active=True, is_global=True)
  139. .all()
  140. ]
  141. def get_function_valves_by_id(self, id: str) -> Optional[dict]:
  142. with get_db() as db:
  143. try:
  144. function = db.get(Function, id)
  145. return function.valves if function.valves else {}
  146. except Exception as e:
  147. print(f"An error occurred: {e}")
  148. return None
  149. def update_function_valves_by_id(
  150. self, id: str, valves: dict
  151. ) -> Optional[FunctionValves]:
  152. with get_db() as db:
  153. try:
  154. function = db.get(Function, id)
  155. function.valves = valves
  156. function.updated_at = int(time.time())
  157. db.commit()
  158. db.refresh(function)
  159. return self.get_function_by_id(id)
  160. except Exception:
  161. return None
  162. def get_user_valves_by_id_and_user_id(
  163. self, id: str, user_id: str
  164. ) -> Optional[dict]:
  165. try:
  166. user = Users.get_user_by_id(user_id)
  167. user_settings = user.settings.model_dump() if user.settings else {}
  168. # Check if user has "functions" and "valves" settings
  169. if "functions" not in user_settings:
  170. user_settings["functions"] = {}
  171. if "valves" not in user_settings["functions"]:
  172. user_settings["functions"]["valves"] = {}
  173. return user_settings["functions"]["valves"].get(id, {})
  174. except Exception as e:
  175. print(f"An error occurred: {e}")
  176. return None
  177. def update_user_valves_by_id_and_user_id(
  178. self, id: str, user_id: str, valves: dict
  179. ) -> Optional[dict]:
  180. try:
  181. user = Users.get_user_by_id(user_id)
  182. user_settings = user.settings.model_dump() if user.settings else {}
  183. # Check if user has "functions" and "valves" settings
  184. if "functions" not in user_settings:
  185. user_settings["functions"] = {}
  186. if "valves" not in user_settings["functions"]:
  187. user_settings["functions"]["valves"] = {}
  188. user_settings["functions"]["valves"][id] = valves
  189. # Update the user settings in the database
  190. Users.update_user_by_id(user_id, {"settings": user_settings})
  191. return user_settings["functions"]["valves"][id]
  192. except Exception as e:
  193. print(f"An error occurred: {e}")
  194. return None
  195. def update_function_by_id(self, id: str, updated: dict) -> Optional[FunctionModel]:
  196. with get_db() as db:
  197. try:
  198. db.query(Function).filter_by(id=id).update(
  199. {
  200. **updated,
  201. "updated_at": int(time.time()),
  202. }
  203. )
  204. db.commit()
  205. return self.get_function_by_id(id)
  206. except Exception:
  207. return None
  208. def deactivate_all_functions(self) -> Optional[bool]:
  209. with get_db() as db:
  210. try:
  211. db.query(Function).update(
  212. {
  213. "is_active": False,
  214. "updated_at": int(time.time()),
  215. }
  216. )
  217. db.commit()
  218. return True
  219. except Exception:
  220. return None
  221. def delete_function_by_id(self, id: str) -> bool:
  222. with get_db() as db:
  223. try:
  224. db.query(Function).filter_by(id=id).delete()
  225. db.commit()
  226. return True
  227. except Exception:
  228. return False
  229. Functions = FunctionsTable()