memories.py 3.3 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131
  1. from pydantic import BaseModel, ConfigDict
  2. from typing import List, Union, Optional
  3. from sqlalchemy import Column, String, BigInteger
  4. from sqlalchemy.orm import Session
  5. from apps.webui.internal.db import Base
  6. from apps.webui.models.chats import Chats
  7. import time
  8. import uuid
  9. ####################
  10. # Memory DB Schema
  11. ####################
  12. class Memory(Base):
  13. __tablename__ = "memory"
  14. id = Column(String, primary_key=True)
  15. user_id = Column(String)
  16. content = Column(String)
  17. updated_at = Column(BigInteger)
  18. created_at = Column(BigInteger)
  19. class MemoryModel(BaseModel):
  20. id: str
  21. user_id: str
  22. content: str
  23. updated_at: int # timestamp in epoch
  24. created_at: int # timestamp in epoch
  25. model_config = ConfigDict(from_attributes=True)
  26. ####################
  27. # Forms
  28. ####################
  29. class MemoriesTable:
  30. def insert_new_memory(
  31. self,
  32. db: Session,
  33. user_id: str,
  34. content: str,
  35. ) -> Optional[MemoryModel]:
  36. id = str(uuid.uuid4())
  37. memory = MemoryModel(
  38. **{
  39. "id": id,
  40. "user_id": user_id,
  41. "content": content,
  42. "created_at": int(time.time()),
  43. "updated_at": int(time.time()),
  44. }
  45. )
  46. result = Memory(**memory.dict())
  47. db.add(result)
  48. db.commit()
  49. db.refresh(result)
  50. if result:
  51. return MemoryModel.model_validate(result)
  52. else:
  53. return None
  54. def update_memory_by_id(
  55. self,
  56. db: Session,
  57. id: str,
  58. content: str,
  59. ) -> Optional[MemoryModel]:
  60. try:
  61. db.query(Memory).filter_by(id=id).update(
  62. {"content": content, "updated_at": int(time.time())}
  63. )
  64. return self.get_memory_by_id(db, id)
  65. except:
  66. return None
  67. def get_memories(self, db: Session) -> List[MemoryModel]:
  68. try:
  69. memories = db.query(Memory).all()
  70. return [MemoryModel.model_validate(memory) for memory in memories]
  71. except:
  72. return None
  73. def get_memories_by_user_id(self, db: Session, user_id: str) -> List[MemoryModel]:
  74. try:
  75. memories = db.query(Memory).filter_by(user_id=user_id).all()
  76. return [MemoryModel.model_validate(memory) for memory in memories]
  77. except:
  78. return None
  79. def get_memory_by_id(self, db: Session, id: str) -> Optional[MemoryModel]:
  80. try:
  81. memory = db.get(Memory, id)
  82. return MemoryModel.model_validate(memory)
  83. except:
  84. return None
  85. def delete_memory_by_id(self, db: Session, id: str) -> bool:
  86. try:
  87. db.query(Memory).filter_by(id=id).delete()
  88. return True
  89. except:
  90. return False
  91. def delete_memories_by_user_id(self, db: Session, user_id: str) -> bool:
  92. try:
  93. db.query(Memory).filter_by(user_id=user_id).delete()
  94. return True
  95. except:
  96. return False
  97. def delete_memory_by_id_and_user_id(
  98. self, db: Session, id: str, user_id: str
  99. ) -> bool:
  100. try:
  101. db.query(Memory).filter_by(id=id, user_id=user_id).delete()
  102. return True
  103. except:
  104. return False
  105. Memories = MemoriesTable()