memories.py 3.6 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138
  1. from pydantic import BaseModel, ConfigDict
  2. from typing import List, Union, Optional
  3. from sqlalchemy import Column, String, BigInteger
  4. from sqlalchemy.orm import Session
  5. from apps.webui.internal.db import Base, get_session
  6. from apps.webui.models.chats import Chats
  7. import time
  8. import uuid
  9. ####################
  10. # Memory DB Schema
  11. ####################
  12. class Memory(Base):
  13. __tablename__ = "memory"
  14. id = Column(String, primary_key=True)
  15. user_id = Column(String)
  16. content = Column(String)
  17. updated_at = Column(BigInteger)
  18. created_at = Column(BigInteger)
  19. class MemoryModel(BaseModel):
  20. id: str
  21. user_id: str
  22. content: str
  23. updated_at: int # timestamp in epoch
  24. created_at: int # timestamp in epoch
  25. model_config = ConfigDict(from_attributes=True)
  26. ####################
  27. # Forms
  28. ####################
  29. class MemoriesTable:
  30. def insert_new_memory(
  31. self,
  32. user_id: str,
  33. content: str,
  34. ) -> Optional[MemoryModel]:
  35. id = str(uuid.uuid4())
  36. memory = MemoryModel(
  37. **{
  38. "id": id,
  39. "user_id": user_id,
  40. "content": content,
  41. "created_at": int(time.time()),
  42. "updated_at": int(time.time()),
  43. }
  44. )
  45. with get_session() as db:
  46. result = Memory(**memory.model_dump())
  47. db.add(result)
  48. db.commit()
  49. db.refresh(result)
  50. if result:
  51. return MemoryModel.model_validate(result)
  52. else:
  53. return None
  54. def update_memory_by_id(
  55. self,
  56. id: str,
  57. content: str,
  58. ) -> Optional[MemoryModel]:
  59. try:
  60. with get_session() as db:
  61. db.query(Memory).filter_by(id=id).update(
  62. {"content": content, "updated_at": int(time.time())}
  63. )
  64. db.commit()
  65. return self.get_memory_by_id(id)
  66. except:
  67. return None
  68. def get_memories(self) -> List[MemoryModel]:
  69. try:
  70. with get_session() as db:
  71. memories = db.query(Memory).all()
  72. return [MemoryModel.model_validate(memory) for memory in memories]
  73. except:
  74. return None
  75. def get_memories_by_user_id(self, user_id: str) -> List[MemoryModel]:
  76. try:
  77. with get_session() as db:
  78. memories = db.query(Memory).filter_by(user_id=user_id).all()
  79. return [MemoryModel.model_validate(memory) for memory in memories]
  80. except:
  81. return None
  82. def get_memory_by_id(self, id: str) -> Optional[MemoryModel]:
  83. try:
  84. with get_session() as db:
  85. memory = db.get(Memory, id)
  86. return MemoryModel.model_validate(memory)
  87. except:
  88. return None
  89. def delete_memory_by_id(self, id: str) -> bool:
  90. try:
  91. with get_session() as db:
  92. db.query(Memory).filter_by(id=id).delete()
  93. return True
  94. except:
  95. return False
  96. def delete_memories_by_user_id(self, db: Session, user_id: str) -> bool:
  97. try:
  98. with get_session() as db:
  99. db.query(Memory).filter_by(user_id=user_id).delete()
  100. return True
  101. except:
  102. return False
  103. def delete_memory_by_id_and_user_id(
  104. self, db: Session, id: str, user_id: str
  105. ) -> bool:
  106. try:
  107. with get_session() as db:
  108. db.query(Memory).filter_by(id=id, user_id=user_id).delete()
  109. return True
  110. except:
  111. return False
  112. Memories = MemoriesTable()