tools.py 5.4 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203
  1. from pydantic import BaseModel
  2. from peewee import *
  3. from playhouse.shortcuts import model_to_dict
  4. from typing import List, Union, Optional
  5. import time
  6. import logging
  7. from apps.webui.internal.db import DB, JSONField
  8. from apps.webui.models.users import Users
  9. import json
  10. import copy
  11. from config import SRC_LOG_LEVELS
  12. log = logging.getLogger(__name__)
  13. log.setLevel(SRC_LOG_LEVELS["MODELS"])
  14. ####################
  15. # Tools DB Schema
  16. ####################
  17. class Tool(Model):
  18. id = CharField(unique=True)
  19. user_id = CharField()
  20. name = TextField()
  21. content = TextField()
  22. specs = JSONField()
  23. meta = JSONField()
  24. valves = JSONField()
  25. updated_at = BigIntegerField()
  26. created_at = BigIntegerField()
  27. class Meta:
  28. database = DB
  29. class ToolMeta(BaseModel):
  30. description: Optional[str] = None
  31. manifest: Optional[dict] = {}
  32. class ToolModel(BaseModel):
  33. id: str
  34. user_id: str
  35. name: str
  36. content: str
  37. specs: List[dict]
  38. meta: ToolMeta
  39. updated_at: int # timestamp in epoch
  40. created_at: int # timestamp in epoch
  41. ####################
  42. # Forms
  43. ####################
  44. class ToolResponse(BaseModel):
  45. id: str
  46. user_id: str
  47. name: str
  48. meta: ToolMeta
  49. updated_at: int # timestamp in epoch
  50. created_at: int # timestamp in epoch
  51. class ToolForm(BaseModel):
  52. id: str
  53. name: str
  54. content: str
  55. meta: ToolMeta
  56. class ToolValves(BaseModel):
  57. valves: Optional[dict] = None
  58. class ToolsTable:
  59. def __init__(self, db):
  60. self.db = db
  61. self.db.create_tables([Tool])
  62. def insert_new_tool(
  63. self, user_id: str, form_data: ToolForm, specs: List[dict]
  64. ) -> Optional[ToolModel]:
  65. tool = ToolModel(
  66. **{
  67. **form_data.model_dump(),
  68. "specs": specs,
  69. "user_id": user_id,
  70. "updated_at": int(time.time()),
  71. "created_at": int(time.time()),
  72. }
  73. )
  74. try:
  75. result = Tool.create(**tool.model_dump())
  76. if result:
  77. return tool
  78. else:
  79. return None
  80. except Exception as e:
  81. print(f"Error creating tool: {e}")
  82. return None
  83. def get_tool_by_id(self, id: str) -> Optional[ToolModel]:
  84. try:
  85. tool = Tool.get(Tool.id == id)
  86. return ToolModel(**model_to_dict(tool))
  87. except:
  88. return None
  89. def get_tools(self) -> List[ToolModel]:
  90. return [ToolModel(**model_to_dict(tool)) for tool in Tool.select()]
  91. def get_tool_valves_by_id(self, id: str) -> Optional[dict]:
  92. try:
  93. tool = Tool.get(Tool.id == id)
  94. return tool.valves if tool.valves else {}
  95. except Exception as e:
  96. print(f"An error occurred: {e}")
  97. return None
  98. def update_tool_valves_by_id(self, id: str, valves: dict) -> Optional[ToolValves]:
  99. try:
  100. query = Tool.update(
  101. **{"valves": valves},
  102. updated_at=int(time.time()),
  103. ).where(Tool.id == id)
  104. query.execute()
  105. tool = Tool.get(Tool.id == id)
  106. return ToolValves(**model_to_dict(tool))
  107. except:
  108. return None
  109. def get_user_valves_by_id_and_user_id(
  110. self, id: str, user_id: str
  111. ) -> Optional[dict]:
  112. try:
  113. user = Users.get_user_by_id(user_id)
  114. user_settings = user.settings.model_dump() if user.settings else {}
  115. # Check if user has "tools" and "valves" settings
  116. if "tools" not in user_settings:
  117. user_settings["tools"] = {}
  118. if "valves" not in user_settings["tools"]:
  119. user_settings["tools"]["valves"] = {}
  120. return user_settings["tools"]["valves"].get(id, {})
  121. except Exception as e:
  122. print(f"An error occurred: {e}")
  123. return None
  124. def update_user_valves_by_id_and_user_id(
  125. self, id: str, user_id: str, valves: dict
  126. ) -> Optional[dict]:
  127. try:
  128. user = Users.get_user_by_id(user_id)
  129. user_settings = user.settings.model_dump() if user.settings else {}
  130. # Check if user has "tools" and "valves" settings
  131. if "tools" not in user_settings:
  132. user_settings["tools"] = {}
  133. if "valves" not in user_settings["tools"]:
  134. user_settings["tools"]["valves"] = {}
  135. user_settings["tools"]["valves"][id] = valves
  136. # Update the user settings in the database
  137. Users.update_user_by_id(user_id, {"settings": user_settings})
  138. return user_settings["tools"]["valves"][id]
  139. except Exception as e:
  140. print(f"An error occurred: {e}")
  141. return None
  142. def update_tool_by_id(self, id: str, updated: dict) -> Optional[ToolModel]:
  143. try:
  144. query = Tool.update(
  145. **updated,
  146. updated_at=int(time.time()),
  147. ).where(Tool.id == id)
  148. query.execute()
  149. tool = Tool.get(Tool.id == id)
  150. return ToolModel(**model_to_dict(tool))
  151. except:
  152. return None
  153. def delete_tool_by_id(self, id: str) -> bool:
  154. try:
  155. query = Tool.delete().where((Tool.id == id))
  156. query.execute() # Remove the rows, return number of rows removed.
  157. return True
  158. except:
  159. return False
  160. Tools = ToolsTable(DB)