prompts.py 5.6 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188
  1. import time
  2. from typing import Optional
  3. from open_webui.apps.webui.internal.db import Base, get_db
  4. from open_webui.apps.webui.models.groups import Groups
  5. from pydantic import BaseModel, ConfigDict
  6. from sqlalchemy import BigInteger, Column, String, Text, JSON
  7. ####################
  8. # Prompts DB Schema
  9. ####################
  10. class Prompt(Base):
  11. __tablename__ = "prompt"
  12. command = Column(String, primary_key=True)
  13. user_id = Column(String)
  14. title = Column(Text)
  15. content = Column(Text)
  16. timestamp = Column(BigInteger)
  17. access_control = Column(JSON, nullable=True) # Controls data access levels.
  18. # Defines access control rules for this entry.
  19. # - `None`: Public access, available to all users with the "user" role.
  20. # - `{}`: Private access, restricted exclusively to the owner.
  21. # - Custom permissions: Specific access control for reading and writing;
  22. # Can specify group or user-level restrictions:
  23. # {
  24. # "read": {
  25. # "group_ids": ["group_id1", "group_id2"],
  26. # "user_ids": ["user_id1", "user_id2"]
  27. # },
  28. # "write": {
  29. # "group_ids": ["group_id1", "group_id2"],
  30. # "user_ids": ["user_id1", "user_id2"]
  31. # }
  32. # }
  33. class PromptModel(BaseModel):
  34. command: str
  35. user_id: str
  36. title: str
  37. content: str
  38. timestamp: int # timestamp in epoch
  39. access_control: Optional[dict] = None
  40. model_config = ConfigDict(from_attributes=True)
  41. ####################
  42. # Forms
  43. ####################
  44. class PromptForm(BaseModel):
  45. command: str
  46. title: str
  47. content: str
  48. class PromptsTable:
  49. def insert_new_prompt(
  50. self, user_id: str, form_data: PromptForm
  51. ) -> Optional[PromptModel]:
  52. prompt = PromptModel(
  53. **{
  54. "user_id": user_id,
  55. "command": form_data.command,
  56. "title": form_data.title,
  57. "content": form_data.content,
  58. "timestamp": int(time.time()),
  59. }
  60. )
  61. try:
  62. with get_db() as db:
  63. result = Prompt(**prompt.dict())
  64. db.add(result)
  65. db.commit()
  66. db.refresh(result)
  67. if result:
  68. return PromptModel.model_validate(result)
  69. else:
  70. return None
  71. except Exception:
  72. return None
  73. def get_prompt_by_command(self, command: str) -> Optional[PromptModel]:
  74. try:
  75. with get_db() as db:
  76. prompt = db.query(Prompt).filter_by(command=command).first()
  77. return PromptModel.model_validate(prompt)
  78. except Exception:
  79. return None
  80. def get_prompts(self) -> list[PromptModel]:
  81. with get_db() as db:
  82. return [
  83. PromptModel.model_validate(prompt) for prompt in db.query(Prompt).all()
  84. ]
  85. def get_prompts_by_user_id(
  86. self, user_id: str, permission: str = "write"
  87. ) -> list[PromptModel]:
  88. prompts = self.get_prompts()
  89. groups = Groups.get_groups_by_member_id(user_id)
  90. group_ids = [group.id for group in groups]
  91. if permission == "write":
  92. return [
  93. prompt
  94. for prompt in prompts
  95. if prompt.user_id == user_id
  96. or (
  97. prompt.access_control
  98. and (
  99. any(
  100. group_id
  101. in prompt.access_control.get(permission, {}).get(
  102. "group_ids", []
  103. )
  104. for group_id in group_ids
  105. )
  106. or (
  107. user_id
  108. in prompt.access_control.get(permission, {}).get(
  109. "user_ids", []
  110. )
  111. )
  112. )
  113. )
  114. ]
  115. elif permission == "read":
  116. return [
  117. prompt
  118. for prompt in prompts
  119. if prompt.user_id == user_id
  120. or prompt.access_control is None
  121. or (
  122. prompt.access_control
  123. and (
  124. any(
  125. prompt.access_control.get(permission, {}).get(
  126. "group_ids", []
  127. )
  128. in group_id
  129. for group_id in group_ids
  130. )
  131. or (
  132. user_id
  133. in prompt.access_control.get(permission, {}).get(
  134. "user_ids", []
  135. )
  136. )
  137. )
  138. )
  139. ]
  140. def update_prompt_by_command(
  141. self, command: str, form_data: PromptForm
  142. ) -> Optional[PromptModel]:
  143. try:
  144. with get_db() as db:
  145. prompt = db.query(Prompt).filter_by(command=command).first()
  146. prompt.title = form_data.title
  147. prompt.content = form_data.content
  148. prompt.timestamp = int(time.time())
  149. db.commit()
  150. return PromptModel.model_validate(prompt)
  151. except Exception:
  152. return None
  153. def delete_prompt_by_command(self, command: str) -> bool:
  154. try:
  155. with get_db() as db:
  156. db.query(Prompt).filter_by(command=command).delete()
  157. db.commit()
  158. return True
  159. except Exception:
  160. return False
  161. Prompts = PromptsTable()