瀏覽代碼

feat: memory backend

Timothy J. Baek 11 月之前
父節點
當前提交
288d8a3e32

+ 53 - 0
backend/apps/web/internal/migrations/008_add_memory.py

@@ -0,0 +1,53 @@
+"""Peewee migrations -- 002_add_local_sharing.py.
+
+Some examples (model - class or model name)::
+
+    > Model = migrator.orm['table_name']            # Return model in current state by name
+    > Model = migrator.ModelClass                   # Return model in current state by name
+
+    > migrator.sql(sql)                             # Run custom SQL
+    > migrator.run(func, *args, **kwargs)           # Run python function with the given args
+    > migrator.create_model(Model)                  # Create a model (could be used as decorator)
+    > migrator.remove_model(model, cascade=True)    # Remove a model
+    > migrator.add_fields(model, **fields)          # Add fields to a model
+    > migrator.change_fields(model, **fields)       # Change fields
+    > migrator.remove_fields(model, *field_names, cascade=True)
+    > migrator.rename_field(model, old_field_name, new_field_name)
+    > migrator.rename_table(model, new_table_name)
+    > migrator.add_index(model, *col_names, unique=False)
+    > migrator.add_not_null(model, *field_names)
+    > migrator.add_default(model, field_name, default)
+    > migrator.add_constraint(model, name, sql)
+    > migrator.drop_index(model, *col_names)
+    > migrator.drop_not_null(model, *field_names)
+    > migrator.drop_constraints(model, *constraints)
+
+"""
+
+from contextlib import suppress
+
+import peewee as pw
+from peewee_migrate import Migrator
+
+
+with suppress(ImportError):
+    import playhouse.postgres_ext as pw_pext
+
+
+def migrate(migrator: Migrator, database: pw.Database, *, fake=False):
+    @migrator.create_model
+    class Memory(pw.Model):
+        id = pw.CharField(max_length=255, unique=True)
+        user_id = pw.CharField(max_length=255)
+        content = pw.TextField(null=False)
+        updated_at = pw.BigIntegerField(null=False)
+        created_at = pw.BigIntegerField(null=False)
+
+        class Meta:
+            table_name = "memory"
+
+
+def rollback(migrator: Migrator, database: pw.Database, *, fake=False):
+    """Write your rollback migrations here."""
+
+    migrator.remove_model("memory")

+ 5 - 0
backend/apps/web/main.py

@@ -9,6 +9,7 @@ from apps.web.routers import (
     modelfiles,
     modelfiles,
     prompts,
     prompts,
     configs,
     configs,
+    memories,
     utils,
     utils,
 )
 )
 from config import (
 from config import (
@@ -41,6 +42,7 @@ app.state.config.USER_PERMISSIONS = USER_PERMISSIONS
 app.state.config.WEBHOOK_URL = WEBHOOK_URL
 app.state.config.WEBHOOK_URL = WEBHOOK_URL
 app.state.AUTH_TRUSTED_EMAIL_HEADER = WEBUI_AUTH_TRUSTED_EMAIL_HEADER
 app.state.AUTH_TRUSTED_EMAIL_HEADER = WEBUI_AUTH_TRUSTED_EMAIL_HEADER
 
 
+
 app.add_middleware(
 app.add_middleware(
     CORSMiddleware,
     CORSMiddleware,
     allow_origins=origins,
     allow_origins=origins,
@@ -52,9 +54,12 @@ app.add_middleware(
 app.include_router(auths.router, prefix="/auths", tags=["auths"])
 app.include_router(auths.router, prefix="/auths", tags=["auths"])
 app.include_router(users.router, prefix="/users", tags=["users"])
 app.include_router(users.router, prefix="/users", tags=["users"])
 app.include_router(chats.router, prefix="/chats", tags=["chats"])
 app.include_router(chats.router, prefix="/chats", tags=["chats"])
+
 app.include_router(documents.router, prefix="/documents", tags=["documents"])
 app.include_router(documents.router, prefix="/documents", tags=["documents"])
 app.include_router(modelfiles.router, prefix="/modelfiles", tags=["modelfiles"])
 app.include_router(modelfiles.router, prefix="/modelfiles", tags=["modelfiles"])
 app.include_router(prompts.router, prefix="/prompts", tags=["prompts"])
 app.include_router(prompts.router, prefix="/prompts", tags=["prompts"])
+app.include_router(memories.router, prefix="/memories", tags=["memories"])
+
 
 
 app.include_router(configs.router, prefix="/configs", tags=["configs"])
 app.include_router(configs.router, prefix="/configs", tags=["configs"])
 app.include_router(utils.router, prefix="/utils", tags=["utils"])
 app.include_router(utils.router, prefix="/utils", tags=["utils"])

+ 109 - 0
backend/apps/web/models/memories.py

@@ -0,0 +1,109 @@
+from pydantic import BaseModel
+from peewee import *
+from playhouse.shortcuts import model_to_dict
+from typing import List, Union, Optional
+
+from apps.web.internal.db import DB
+from apps.web.models.chats import Chats
+
+import time
+import uuid
+
+####################
+# Memory DB Schema
+####################
+
+
+class Memory(Model):
+    id = CharField(unique=True)
+    user_id = CharField()
+    content = TextField()
+    updated_at = BigIntegerField()
+    created_at = BigIntegerField()
+
+    class Meta:
+        database = DB
+
+
+class MemoryModel(BaseModel):
+    id: str
+    user_id: str
+    content: str
+    updated_at: int  # timestamp in epoch
+    created_at: int  # timestamp in epoch
+
+
+####################
+# Forms
+####################
+
+
+class MemoriesTable:
+    def __init__(self, db):
+        self.db = db
+        self.db.create_tables([Memory])
+
+    def insert_new_memory(
+        self,
+        user_id: str,
+        content: str,
+    ) -> Optional[MemoryModel]:
+        id = str(uuid.uuid4())
+
+        memory = MemoryModel(
+            **{
+                "id": id,
+                "user_id": user_id,
+                "content": content,
+                "created_at": int(time.time()),
+                "updated_at": int(time.time()),
+            }
+        )
+        result = Memory.create(**memory.model_dump())
+        if result:
+            return memory
+        else:
+            return None
+
+    def get_memories(self) -> List[MemoryModel]:
+        try:
+            memories = Memory.select()
+            return [MemoryModel(**model_to_dict(memory)) for memory in memories]
+        except:
+            return None
+
+    def get_memories_by_user_id(self, user_id: str) -> List[MemoryModel]:
+        try:
+            memories = Memory.select().where(Memory.user_id == user_id)
+            return [MemoryModel(**model_to_dict(memory)) for memory in memories]
+        except:
+            return None
+
+    def get_memory_by_id(self, id) -> Optional[MemoryModel]:
+        try:
+            memory = Memory.get(Memory.id == id)
+            return MemoryModel(**model_to_dict(memory))
+        except:
+            return None
+
+    def delete_memory_by_id(self, id: str) -> bool:
+        try:
+            query = Memory.delete().where(Memory.id == id)
+            query.execute()  # Remove the rows, return number of rows removed.
+
+            return True
+
+        except:
+            return False
+
+    def delete_memory_by_id_and_user_id(self, id: str, user_id: str) -> bool:
+        try:
+            query = Memory.delete().where(Memory.id == id, Memory.user_id == user_id)
+            query.execute()
+
+            return True
+        except:
+            return False
+
+
+Memories = MemoriesTable(DB)

+ 117 - 0
backend/apps/web/routers/memories.py

@@ -0,0 +1,117 @@
+from fastapi import Response, Request
+from fastapi import Depends, FastAPI, HTTPException, status
+from datetime import datetime, timedelta
+from typing import List, Union, Optional
+
+from fastapi import APIRouter
+from pydantic import BaseModel
+import logging
+
+from apps.web.models.memories import Memories, MemoryModel
+
+from utils.utils import get_verified_user
+from constants import ERROR_MESSAGES
+
+from config import SRC_LOG_LEVELS, CHROMA_CLIENT
+
+log = logging.getLogger(__name__)
+log.setLevel(SRC_LOG_LEVELS["MODELS"])
+
+router = APIRouter()
+
+
+@router.get("/ef")
+async def get_embeddings(request: Request):
+    return {"result": request.app.state.EMBEDDING_FUNCTION("hello world")}
+
+
+############################
+# GetMemories
+############################
+
+
+@router.get("/", response_model=List[MemoryModel])
+async def get_memories(user=Depends(get_verified_user)):
+    return Memories.get_memories_by_user_id(user.id)
+
+
+############################
+# AddMemory
+############################
+
+
+class AddMemoryForm(BaseModel):
+    content: str
+
+
+@router.post("/add", response_model=Optional[MemoryModel])
+async def add_memory(
+    request: Request, form_data: AddMemoryForm, user=Depends(get_verified_user)
+):
+    memory = Memories.insert_new_memory(user.id, form_data.content)
+    memory_embedding = request.app.state.EMBEDDING_FUNCTION(memory.content)
+
+    collection = CHROMA_CLIENT.get_or_create_collection(name=f"user-memory-{user.id}")
+    collection.upsert(
+        documents=[memory.content],
+        ids=[memory.id],
+        embeddings=[memory_embedding],
+        metadatas=[{"created_at": memory.created_at}],
+    )
+
+    return memory
+
+
+############################
+# QueryMemory
+############################
+
+
+class QueryMemoryForm(BaseModel):
+    content: str
+
+
+@router.post("/query", response_model=Optional[MemoryModel])
+async def add_memory(
+    request: Request, form_data: QueryMemoryForm, user=Depends(get_verified_user)
+):
+    query_embedding = request.app.state.EMBEDDING_FUNCTION(form_data.content)
+    collection = CHROMA_CLIENT.get_or_create_collection(name=f"user-memory-{user.id}")
+
+    results = collection.query(
+        query_embeddings=[query_embedding],
+        n_results=1,  # how many results to return
+    )
+
+    return results
+
+
+############################
+# ResetMemoryFromVectorDB
+############################
+@router.get("/reset", response_model=bool)
+async def reset_memory_from_vector_db(
+    request: Request, user=Depends(get_verified_user)
+):
+    CHROMA_CLIENT.delete_collection(f"user-memory-{user.id}")
+    collection = CHROMA_CLIENT.get_or_create_collection(name=f"user-memory-{user.id}")
+
+    memories = Memories.get_memories_by_user_id(user.id)
+    for memory in memories:
+        memory_embedding = request.app.state.EMBEDDING_FUNCTION(memory.content)
+        collection.upsert(
+            documents=[memory.content],
+            ids=[memory.id],
+            embeddings=[memory_embedding],
+        )
+    return True
+
+
+############################
+# DeleteUserById
+############################
+
+
+@router.delete("/{memory_id}", response_model=bool)
+async def delete_memory_by_id(memory_id: str, user=Depends(get_verified_user)):
+    return Memories.delete_memory_by_id_and_user_id(memory_id, user.id)

+ 12 - 2
backend/main.py

@@ -238,9 +238,15 @@ async def check_url(request: Request, call_next):
     return response
     return response
 
 
 
 
-app.mount("/api/v1", webui_app)
-app.mount("/litellm/api", litellm_app)
+@app.middleware("http")
+async def update_embedding_function(request: Request, call_next):
+    response = await call_next(request)
+    if "/embedding/update" in request.url.path:
+        webui_app.state.EMBEDDING_FUNCTION = rag_app.state.EMBEDDING_FUNCTION
+    return response
 
 
+
+app.mount("/litellm/api", litellm_app)
 app.mount("/ollama", ollama_app)
 app.mount("/ollama", ollama_app)
 app.mount("/openai/api", openai_app)
 app.mount("/openai/api", openai_app)
 
 
@@ -248,6 +254,10 @@ app.mount("/images/api/v1", images_app)
 app.mount("/audio/api/v1", audio_app)
 app.mount("/audio/api/v1", audio_app)
 app.mount("/rag/api/v1", rag_app)
 app.mount("/rag/api/v1", rag_app)
 
 
+app.mount("/api/v1", webui_app)
+
+webui_app.state.EMBEDDING_FUNCTION = rag_app.state.EMBEDDING_FUNCTION
+
 
 
 @app.get("/api/config")
 @app.get("/api/config")
 async def get_app_config():
 async def get_app_config():