memories.py 4.9 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192
  1. from fastapi import APIRouter, Depends, HTTPException, Request
  2. from pydantic import BaseModel
  3. import logging
  4. from typing import Optional
  5. from open_webui.models.memories import Memories, MemoryModel
  6. from open_webui.retrieval.vector.connector import VECTOR_DB_CLIENT
  7. from open_webui.utils.auth import get_verified_user
  8. from open_webui.env import SRC_LOG_LEVELS
  9. log = logging.getLogger(__name__)
  10. log.setLevel(SRC_LOG_LEVELS["MODELS"])
  11. router = APIRouter()
  12. @router.get("/ef")
  13. async def get_embeddings(request: Request):
  14. return {"result": request.app.state.EMBEDDING_FUNCTION("hello world")}
  15. ############################
  16. # GetMemories
  17. ############################
  18. @router.get("/", response_model=list[MemoryModel])
  19. async def get_memories(user=Depends(get_verified_user)):
  20. return Memories.get_memories_by_user_id(user.id)
  21. ############################
  22. # AddMemory
  23. ############################
  24. class AddMemoryForm(BaseModel):
  25. content: str
  26. class MemoryUpdateModel(BaseModel):
  27. content: Optional[str] = None
  28. @router.post("/add", response_model=Optional[MemoryModel])
  29. async def add_memory(
  30. request: Request,
  31. form_data: AddMemoryForm,
  32. user=Depends(get_verified_user),
  33. ):
  34. memory = Memories.insert_new_memory(user.id, form_data.content)
  35. VECTOR_DB_CLIENT.upsert(
  36. collection_name=f"user-memory-{user.id}",
  37. items=[
  38. {
  39. "id": memory.id,
  40. "text": memory.content,
  41. "vector": request.app.state.EMBEDDING_FUNCTION(memory.content, user),
  42. "metadata": {"created_at": memory.created_at},
  43. }
  44. ],
  45. )
  46. return memory
  47. ############################
  48. # QueryMemory
  49. ############################
  50. class QueryMemoryForm(BaseModel):
  51. content: str
  52. k: Optional[int] = 1
  53. @router.post("/query")
  54. async def query_memory(
  55. request: Request, form_data: QueryMemoryForm, user=Depends(get_verified_user)
  56. ):
  57. results = VECTOR_DB_CLIENT.search(
  58. collection_name=f"user-memory-{user.id}",
  59. vectors=[request.app.state.EMBEDDING_FUNCTION(form_data.content, user)],
  60. limit=form_data.k,
  61. )
  62. return results
  63. ############################
  64. # ResetMemoryFromVectorDB
  65. ############################
  66. @router.post("/reset", response_model=bool)
  67. async def reset_memory_from_vector_db(
  68. request: Request, user=Depends(get_verified_user)
  69. ):
  70. VECTOR_DB_CLIENT.delete_collection(f"user-memory-{user.id}")
  71. memories = Memories.get_memories_by_user_id(user.id)
  72. VECTOR_DB_CLIENT.upsert(
  73. collection_name=f"user-memory-{user.id}",
  74. items=[
  75. {
  76. "id": memory.id,
  77. "text": memory.content,
  78. "vector": request.app.state.EMBEDDING_FUNCTION(memory.content, user),
  79. "metadata": {
  80. "created_at": memory.created_at,
  81. "updated_at": memory.updated_at,
  82. },
  83. }
  84. for memory in memories
  85. ],
  86. )
  87. return True
  88. ############################
  89. # DeleteMemoriesByUserId
  90. ############################
  91. @router.delete("/delete/user", response_model=bool)
  92. async def delete_memory_by_user_id(user=Depends(get_verified_user)):
  93. result = Memories.delete_memories_by_user_id(user.id)
  94. if result:
  95. try:
  96. VECTOR_DB_CLIENT.delete_collection(f"user-memory-{user.id}")
  97. except Exception as e:
  98. log.error(e)
  99. return True
  100. return False
  101. ############################
  102. # UpdateMemoryById
  103. ############################
  104. @router.post("/{memory_id}/update", response_model=Optional[MemoryModel])
  105. async def update_memory_by_id(
  106. memory_id: str,
  107. request: Request,
  108. form_data: MemoryUpdateModel,
  109. user=Depends(get_verified_user),
  110. ):
  111. memory = Memories.update_memory_by_id(memory_id, form_data.content)
  112. if memory is None:
  113. raise HTTPException(status_code=404, detail="Memory not found")
  114. if form_data.content is not None:
  115. VECTOR_DB_CLIENT.upsert(
  116. collection_name=f"user-memory-{user.id}",
  117. items=[
  118. {
  119. "id": memory.id,
  120. "text": memory.content,
  121. "vector": request.app.state.EMBEDDING_FUNCTION(
  122. memory.content, user
  123. ),
  124. "metadata": {
  125. "created_at": memory.created_at,
  126. "updated_at": memory.updated_at,
  127. },
  128. }
  129. ],
  130. )
  131. return memory
  132. ############################
  133. # DeleteMemoryById
  134. ############################
  135. @router.delete("/{memory_id}", response_model=bool)
  136. async def delete_memory_by_id(memory_id: str, user=Depends(get_verified_user)):
  137. result = Memories.delete_memory_by_id_and_user_id(memory_id, user.id)
  138. if result:
  139. VECTOR_DB_CLIENT.delete(
  140. collection_name=f"user-memory-{user.id}", ids=[memory_id]
  141. )
  142. return True
  143. return False