groups.py 5.0 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192
  1. import json
  2. import logging
  3. import time
  4. from typing import Optional
  5. import uuid
  6. from open_webui.internal.db import Base, get_db
  7. from open_webui.env import SRC_LOG_LEVELS
  8. from open_webui.models.files import FileMetadataResponse
  9. from pydantic import BaseModel, ConfigDict
  10. from sqlalchemy import BigInteger, Column, String, Text, JSON, func
  11. log = logging.getLogger(__name__)
  12. log.setLevel(SRC_LOG_LEVELS["MODELS"])
  13. ####################
  14. # UserGroup DB Schema
  15. ####################
  16. class Group(Base):
  17. __tablename__ = "group"
  18. id = Column(Text, unique=True, primary_key=True)
  19. user_id = Column(Text)
  20. name = Column(Text)
  21. description = Column(Text)
  22. data = Column(JSON, nullable=True)
  23. meta = Column(JSON, nullable=True)
  24. permissions = Column(JSON, nullable=True)
  25. user_ids = Column(JSON, nullable=True)
  26. created_at = Column(BigInteger)
  27. updated_at = Column(BigInteger)
  28. class GroupModel(BaseModel):
  29. model_config = ConfigDict(from_attributes=True)
  30. id: str
  31. user_id: str
  32. name: str
  33. description: str
  34. data: Optional[dict] = None
  35. meta: Optional[dict] = None
  36. permissions: Optional[dict] = None
  37. user_ids: list[str] = []
  38. created_at: int # timestamp in epoch
  39. updated_at: int # timestamp in epoch
  40. ####################
  41. # Forms
  42. ####################
  43. class GroupResponse(BaseModel):
  44. id: str
  45. user_id: str
  46. name: str
  47. description: str
  48. permissions: Optional[dict] = None
  49. data: Optional[dict] = None
  50. meta: Optional[dict] = None
  51. user_ids: list[str] = []
  52. created_at: int # timestamp in epoch
  53. updated_at: int # timestamp in epoch
  54. class GroupForm(BaseModel):
  55. name: str
  56. description: str
  57. permissions: Optional[dict] = None
  58. class GroupUpdateForm(GroupForm):
  59. user_ids: Optional[list[str]] = None
  60. class GroupTable:
  61. def insert_new_group(
  62. self, user_id: str, form_data: GroupForm
  63. ) -> Optional[GroupModel]:
  64. with get_db() as db:
  65. group = GroupModel(
  66. **{
  67. **form_data.model_dump(exclude_none=True),
  68. "id": str(uuid.uuid4()),
  69. "user_id": user_id,
  70. "created_at": int(time.time()),
  71. "updated_at": int(time.time()),
  72. }
  73. )
  74. try:
  75. result = Group(**group.model_dump())
  76. db.add(result)
  77. db.commit()
  78. db.refresh(result)
  79. if result:
  80. return GroupModel.model_validate(result)
  81. else:
  82. return None
  83. except Exception:
  84. return None
  85. def get_groups(self) -> list[GroupModel]:
  86. with get_db() as db:
  87. return [
  88. GroupModel.model_validate(group)
  89. for group in db.query(Group).order_by(Group.updated_at.desc()).all()
  90. ]
  91. def get_groups_by_member_id(self, user_id: str) -> list[GroupModel]:
  92. with get_db() as db:
  93. return [
  94. GroupModel.model_validate(group)
  95. for group in db.query(Group)
  96. .filter(
  97. func.json_array_length(Group.user_ids) > 0
  98. ) # Ensure array exists
  99. .filter(
  100. Group.user_ids.cast(String).like(f'%"{user_id}"%')
  101. ) # String-based check
  102. .order_by(Group.updated_at.desc())
  103. .all()
  104. ]
  105. def get_group_by_id(self, id: str) -> Optional[GroupModel]:
  106. try:
  107. with get_db() as db:
  108. group = db.query(Group).filter_by(id=id).first()
  109. return GroupModel.model_validate(group) if group else None
  110. except Exception:
  111. return None
  112. def get_group_user_ids_by_id(self, id: str) -> Optional[str]:
  113. group = self.get_group_by_id(id)
  114. if group:
  115. return group.user_ids
  116. else:
  117. return None
  118. def update_group_by_id(
  119. self, id: str, form_data: GroupUpdateForm, overwrite: bool = False
  120. ) -> Optional[GroupModel]:
  121. try:
  122. with get_db() as db:
  123. db.query(Group).filter_by(id=id).update(
  124. {
  125. **form_data.model_dump(exclude_none=True),
  126. "updated_at": int(time.time()),
  127. }
  128. )
  129. db.commit()
  130. return self.get_group_by_id(id=id)
  131. except Exception as e:
  132. log.exception(e)
  133. return None
  134. def delete_group_by_id(self, id: str) -> bool:
  135. try:
  136. with get_db() as db:
  137. db.query(Group).filter_by(id=id).delete()
  138. db.commit()
  139. return True
  140. except Exception:
  141. return False
  142. def delete_all_groups(self) -> bool:
  143. with get_db() as db:
  144. try:
  145. db.query(Group).delete()
  146. db.commit()
  147. return True
  148. except Exception:
  149. return False
  150. Groups = GroupTable()