memories.py 3.3 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125
  1. from fastapi import Response, Request
  2. from fastapi import Depends, FastAPI, HTTPException, status
  3. from datetime import datetime, timedelta
  4. from typing import List, Union, Optional
  5. from fastapi import APIRouter
  6. from pydantic import BaseModel
  7. import logging
  8. from apps.web.models.memories import Memories, MemoryModel
  9. from utils.utils import get_verified_user
  10. from constants import ERROR_MESSAGES
  11. from config import SRC_LOG_LEVELS, CHROMA_CLIENT
  12. log = logging.getLogger(__name__)
  13. log.setLevel(SRC_LOG_LEVELS["MODELS"])
  14. router = APIRouter()
  15. @router.get("/ef")
  16. async def get_embeddings(request: Request):
  17. return {"result": request.app.state.EMBEDDING_FUNCTION("hello world")}
  18. ############################
  19. # GetMemories
  20. ############################
  21. @router.get("/", response_model=List[MemoryModel])
  22. async def get_memories(user=Depends(get_verified_user)):
  23. return Memories.get_memories_by_user_id(user.id)
  24. ############################
  25. # AddMemory
  26. ############################
  27. class AddMemoryForm(BaseModel):
  28. content: str
  29. @router.post("/add", response_model=Optional[MemoryModel])
  30. async def add_memory(
  31. request: Request, form_data: AddMemoryForm, user=Depends(get_verified_user)
  32. ):
  33. memory = Memories.insert_new_memory(user.id, form_data.content)
  34. memory_embedding = request.app.state.EMBEDDING_FUNCTION(memory.content)
  35. collection = CHROMA_CLIENT.get_or_create_collection(name=f"user-memory-{user.id}")
  36. collection.upsert(
  37. documents=[memory.content],
  38. ids=[memory.id],
  39. embeddings=[memory_embedding],
  40. metadatas=[{"created_at": memory.created_at}],
  41. )
  42. return memory
  43. ############################
  44. # QueryMemory
  45. ############################
  46. class QueryMemoryForm(BaseModel):
  47. content: str
  48. @router.post("/query")
  49. async def query_memory(
  50. request: Request, form_data: QueryMemoryForm, user=Depends(get_verified_user)
  51. ):
  52. query_embedding = request.app.state.EMBEDDING_FUNCTION(form_data.content)
  53. collection = CHROMA_CLIENT.get_or_create_collection(name=f"user-memory-{user.id}")
  54. results = collection.query(
  55. query_embeddings=[query_embedding],
  56. n_results=1, # how many results to return
  57. )
  58. return results
  59. ############################
  60. # ResetMemoryFromVectorDB
  61. ############################
  62. @router.get("/reset", response_model=bool)
  63. async def reset_memory_from_vector_db(
  64. request: Request, user=Depends(get_verified_user)
  65. ):
  66. CHROMA_CLIENT.delete_collection(f"user-memory-{user.id}")
  67. collection = CHROMA_CLIENT.get_or_create_collection(name=f"user-memory-{user.id}")
  68. memories = Memories.get_memories_by_user_id(user.id)
  69. for memory in memories:
  70. memory_embedding = request.app.state.EMBEDDING_FUNCTION(memory.content)
  71. collection.upsert(
  72. documents=[memory.content],
  73. ids=[memory.id],
  74. embeddings=[memory_embedding],
  75. )
  76. return True
  77. ############################
  78. # DeleteUserById
  79. ############################
  80. @router.delete("/{memory_id}", response_model=bool)
  81. async def delete_memory_by_id(memory_id: str, user=Depends(get_verified_user)):
  82. result = Memories.delete_memory_by_id_and_user_id(memory_id, user.id)
  83. if result:
  84. collection = CHROMA_CLIENT.get_or_create_collection(
  85. name=f"user-memory-{user.id}"
  86. )
  87. collection.delete_document(memory_id)
  88. return True
  89. return False