memories.py 4.8 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180
  1. from fastapi import Response, Request
  2. from fastapi import Depends, FastAPI, HTTPException, status
  3. from datetime import datetime, timedelta
  4. from typing import List, Union, Optional
  5. from fastapi import APIRouter
  6. from pydantic import BaseModel
  7. import logging
  8. from apps.webui.models.memories import Memories, MemoryModel
  9. from utils.utils import get_verified_user
  10. from constants import ERROR_MESSAGES
  11. from config import SRC_LOG_LEVELS, CHROMA_CLIENT
  12. log = logging.getLogger(__name__)
  13. log.setLevel(SRC_LOG_LEVELS["MODELS"])
  14. router = APIRouter()
  15. @router.get("/ef")
  16. async def get_embeddings(request: Request):
  17. return {"result": request.app.state.EMBEDDING_FUNCTION("hello world")}
  18. ############################
  19. # GetMemories
  20. ############################
  21. @router.get("/", response_model=List[MemoryModel])
  22. async def get_memories(user=Depends(get_verified_user)):
  23. return Memories.get_memories_by_user_id(user.id)
  24. ############################
  25. # AddMemory
  26. ############################
  27. class AddMemoryForm(BaseModel):
  28. content: str
  29. class MemoryUpdateModel(BaseModel):
  30. content: Optional[str] = None
  31. @router.post("/add", response_model=Optional[MemoryModel])
  32. async def add_memory(
  33. request: Request,
  34. form_data: AddMemoryForm,
  35. user=Depends(get_verified_user),
  36. ):
  37. memory = Memories.insert_new_memory(user.id, form_data.content)
  38. memory_embedding = request.app.state.EMBEDDING_FUNCTION(memory.content)
  39. collection = CHROMA_CLIENT.get_or_create_collection(name=f"user-memory-{user.id}")
  40. collection.upsert(
  41. documents=[memory.content],
  42. ids=[memory.id],
  43. embeddings=[memory_embedding],
  44. metadatas=[{"created_at": memory.created_at}],
  45. )
  46. return memory
  47. @router.post("/{memory_id}/update", response_model=Optional[MemoryModel])
  48. async def update_memory_by_id(
  49. memory_id: str,
  50. request: Request,
  51. form_data: MemoryUpdateModel,
  52. user=Depends(get_verified_user),
  53. ):
  54. memory = Memories.update_memory_by_id(memory_id, form_data.content)
  55. if memory is None:
  56. raise HTTPException(status_code=404, detail="Memory not found")
  57. if form_data.content is not None:
  58. memory_embedding = request.app.state.EMBEDDING_FUNCTION(form_data.content)
  59. collection = CHROMA_CLIENT.get_or_create_collection(
  60. name=f"user-memory-{user.id}"
  61. )
  62. collection.upsert(
  63. documents=[form_data.content],
  64. ids=[memory.id],
  65. embeddings=[memory_embedding],
  66. metadatas=[
  67. {"created_at": memory.created_at, "updated_at": memory.updated_at}
  68. ],
  69. )
  70. return memory
  71. ############################
  72. # QueryMemory
  73. ############################
  74. class QueryMemoryForm(BaseModel):
  75. content: str
  76. k: Optional[int] = 1
  77. @router.post("/query")
  78. async def query_memory(
  79. request: Request, form_data: QueryMemoryForm, user=Depends(get_verified_user)
  80. ):
  81. query_embedding = request.app.state.EMBEDDING_FUNCTION(form_data.content)
  82. collection = CHROMA_CLIENT.get_or_create_collection(name=f"user-memory-{user.id}")
  83. results = collection.query(
  84. query_embeddings=[query_embedding],
  85. n_results=form_data.k, # how many results to return
  86. )
  87. return results
  88. ############################
  89. # ResetMemoryFromVectorDB
  90. ############################
  91. @router.get("/reset", response_model=bool)
  92. async def reset_memory_from_vector_db(
  93. request: Request, user=Depends(get_verified_user)
  94. ):
  95. CHROMA_CLIENT.delete_collection(f"user-memory-{user.id}")
  96. collection = CHROMA_CLIENT.get_or_create_collection(name=f"user-memory-{user.id}")
  97. memories = Memories.get_memories_by_user_id(user.id)
  98. for memory in memories:
  99. memory_embedding = request.app.state.EMBEDDING_FUNCTION(memory.content)
  100. collection.upsert(
  101. documents=[memory.content],
  102. ids=[memory.id],
  103. embeddings=[memory_embedding],
  104. )
  105. return True
  106. ############################
  107. # DeleteMemoriesByUserId
  108. ############################
  109. @router.delete("/user", response_model=bool)
  110. async def delete_memory_by_user_id(user=Depends(get_verified_user)):
  111. result = Memories.delete_memories_by_user_id(user.id)
  112. if result:
  113. try:
  114. CHROMA_CLIENT.delete_collection(f"user-memory-{user.id}")
  115. except Exception as e:
  116. log.error(e)
  117. return True
  118. return False
  119. ############################
  120. # DeleteMemoryById
  121. ############################
  122. @router.delete("/{memory_id}", response_model=bool)
  123. async def delete_memory_by_id(memory_id: str, user=Depends(get_verified_user)):
  124. result = Memories.delete_memory_by_id_and_user_id(memory_id, user.id)
  125. if result:
  126. collection = CHROMA_CLIENT.get_or_create_collection(
  127. name=f"user-memory-{user.id}"
  128. )
  129. collection.delete(ids=[memory_id])
  130. return True
  131. return False