messages.py 4.0 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141
  1. import json
  2. import time
  3. import uuid
  4. from typing import Optional
  5. from open_webui.internal.db import Base, get_db
  6. from open_webui.models.tags import TagModel, Tag, Tags
  7. from pydantic import BaseModel, ConfigDict
  8. from sqlalchemy import BigInteger, Boolean, Column, String, Text, JSON
  9. from sqlalchemy import or_, func, select, and_, text
  10. from sqlalchemy.sql import exists
  11. ####################
  12. # Message DB Schema
  13. ####################
  14. class Message(Base):
  15. __tablename__ = "message"
  16. id = Column(Text, primary_key=True)
  17. user_id = Column(Text)
  18. channel_id = Column(Text, nullable=True)
  19. content = Column(Text)
  20. data = Column(JSON, nullable=True)
  21. meta = Column(JSON, nullable=True)
  22. created_at = Column(BigInteger) # time_ns
  23. updated_at = Column(BigInteger) # time_ns
  24. class MessageModel(BaseModel):
  25. model_config = ConfigDict(from_attributes=True)
  26. id: str
  27. user_id: str
  28. channel_id: Optional[str] = None
  29. content: str
  30. data: Optional[dict] = None
  31. meta: Optional[dict] = None
  32. created_at: int # timestamp in epoch
  33. updated_at: int # timestamp in epoch
  34. ####################
  35. # Forms
  36. ####################
  37. class MessageForm(BaseModel):
  38. content: str
  39. data: Optional[dict] = None
  40. meta: Optional[dict] = None
  41. class MessageTable:
  42. def insert_new_message(
  43. self, form_data: MessageForm, channel_id: str, user_id: str
  44. ) -> Optional[MessageModel]:
  45. with get_db() as db:
  46. id = str(uuid.uuid4())
  47. ts = int(time.time_ns())
  48. message = MessageModel(
  49. **{
  50. "id": id,
  51. "user_id": user_id,
  52. "channel_id": channel_id,
  53. "content": form_data.content,
  54. "data": form_data.data,
  55. "meta": form_data.meta,
  56. "created_at": ts,
  57. "updated_at": ts,
  58. }
  59. )
  60. result = Message(**message.model_dump())
  61. db.add(result)
  62. db.commit()
  63. db.refresh(result)
  64. return MessageModel.model_validate(result) if result else None
  65. def get_message_by_id(self, id: str) -> Optional[MessageModel]:
  66. with get_db() as db:
  67. message = db.get(Message, id)
  68. return MessageModel.model_validate(message) if message else None
  69. def get_messages_by_channel_id(
  70. self, channel_id: str, skip: int = 0, limit: int = 50
  71. ) -> list[MessageModel]:
  72. with get_db() as db:
  73. all_messages = (
  74. db.query(Message)
  75. .filter_by(channel_id=channel_id)
  76. .order_by(Message.created_at.desc())
  77. .offset(skip)
  78. .limit(limit)
  79. .all()
  80. )
  81. return [MessageModel.model_validate(message) for message in all_messages]
  82. def get_messages_by_user_id(
  83. self, user_id: str, skip: int = 0, limit: int = 50
  84. ) -> list[MessageModel]:
  85. with get_db() as db:
  86. all_messages = (
  87. db.query(Message)
  88. .filter_by(user_id=user_id)
  89. .order_by(Message.created_at.desc())
  90. .offset(skip)
  91. .limit(limit)
  92. .all()
  93. )
  94. return [MessageModel.model_validate(message) for message in all_messages]
  95. def update_message_by_id(
  96. self, id: str, form_data: MessageForm
  97. ) -> Optional[MessageModel]:
  98. with get_db() as db:
  99. message = db.get(Message, id)
  100. message.content = form_data.content
  101. message.data = form_data.data
  102. message.meta = form_data.meta
  103. message.updated_at = int(time.time_ns())
  104. db.commit()
  105. db.refresh(message)
  106. return MessageModel.model_validate(message) if message else None
  107. def delete_message_by_id(self, id: str) -> bool:
  108. with get_db() as db:
  109. db.query(Message).filter_by(id=id).delete()
  110. db.commit()
  111. return True
  112. Messages = MessageTable()