memories.py 3.7 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137
  1. import time
  2. import uuid
  3. from typing import Optional
  4. from open_webui.apps.webui.internal.db import Base, get_db
  5. from pydantic import BaseModel, ConfigDict
  6. from sqlalchemy import BigInteger, Column, String, Text
  7. ####################
  8. # Memory DB Schema
  9. ####################
  10. class Memory(Base):
  11. __tablename__ = "memory"
  12. id = Column(String, primary_key=True)
  13. user_id = Column(String)
  14. content = Column(Text)
  15. updated_at = Column(BigInteger)
  16. created_at = Column(BigInteger)
  17. class MemoryModel(BaseModel):
  18. id: str
  19. user_id: str
  20. content: str
  21. updated_at: int # timestamp in epoch
  22. created_at: int # timestamp in epoch
  23. model_config = ConfigDict(from_attributes=True)
  24. ####################
  25. # Forms
  26. ####################
  27. class MemoriesTable:
  28. def insert_new_memory(
  29. self,
  30. user_id: str,
  31. content: str,
  32. ) -> Optional[MemoryModel]:
  33. with get_db() as db:
  34. id = str(uuid.uuid4())
  35. memory = MemoryModel(
  36. **{
  37. "id": id,
  38. "user_id": user_id,
  39. "content": content,
  40. "created_at": int(time.time()),
  41. "updated_at": int(time.time()),
  42. }
  43. )
  44. result = Memory(**memory.model_dump())
  45. db.add(result)
  46. db.commit()
  47. db.refresh(result)
  48. if result:
  49. return MemoryModel.model_validate(result)
  50. else:
  51. return None
  52. def update_memory_by_id(
  53. self,
  54. id: str,
  55. content: str,
  56. ) -> Optional[MemoryModel]:
  57. with get_db() as db:
  58. try:
  59. db.query(Memory).filter_by(id=id).update(
  60. {"content": content, "updated_at": int(time.time())}
  61. )
  62. db.commit()
  63. return self.get_memory_by_id(id)
  64. except Exception:
  65. return None
  66. def get_memories(self) -> list[MemoryModel]:
  67. with get_db() as db:
  68. try:
  69. memories = db.query(Memory).all()
  70. return [MemoryModel.model_validate(memory) for memory in memories]
  71. except Exception:
  72. return None
  73. def get_memories_by_user_id(self, user_id: str) -> list[MemoryModel]:
  74. with get_db() as db:
  75. try:
  76. memories = db.query(Memory).filter_by(user_id=user_id).all()
  77. return [MemoryModel.model_validate(memory) for memory in memories]
  78. except Exception:
  79. return None
  80. def get_memory_by_id(self, id: str) -> Optional[MemoryModel]:
  81. with get_db() as db:
  82. try:
  83. memory = db.get(Memory, id)
  84. return MemoryModel.model_validate(memory)
  85. except Exception:
  86. return None
  87. def delete_memory_by_id(self, id: str) -> bool:
  88. with get_db() as db:
  89. try:
  90. db.query(Memory).filter_by(id=id).delete()
  91. db.commit()
  92. return True
  93. except Exception:
  94. return False
  95. def delete_memories_by_user_id(self, user_id: str) -> bool:
  96. with get_db() as db:
  97. try:
  98. db.query(Memory).filter_by(user_id=user_id).delete()
  99. db.commit()
  100. return True
  101. except Exception:
  102. return False
  103. def delete_memory_by_id_and_user_id(self, id: str, user_id: str) -> bool:
  104. with get_db() as db:
  105. try:
  106. db.query(Memory).filter_by(id=id, user_id=user_id).delete()
  107. db.commit()
  108. return True
  109. except Exception:
  110. return False
  111. Memories = MemoriesTable()