Explorar o código

feat(sqlalchemy): Replace peewee with sqlalchemy

Jonathan Rohde hai 10 meses
pai
achega
df09d0830a
Modificáronse 47 ficheiros con 2577 adicións e 1000 borrados
  1. 2 2
      .github/workflows/integration-test.yml
  2. 114 0
      backend/alembic.ini
  3. 5 2
      backend/apps/ollama/main.py
  4. 3 1
      backend/apps/openai/main.py
  5. 3 1
      backend/apps/socket/main.py
  6. 40 26
      backend/apps/webui/internal/db.py
  7. 0 72
      backend/apps/webui/internal/wrappers.py
  8. 3 3
      backend/apps/webui/main.py
  9. 40 37
      backend/apps/webui/models/auths.py
  10. 139 162
      backend/apps/webui/models/chats.py
  11. 48 56
      backend/apps/webui/models/documents.py
  12. 30 32
      backend/apps/webui/models/files.py
  13. 36 38
      backend/apps/webui/models/functions.py
  14. 43 44
      backend/apps/webui/models/memories.py
  15. 38 41
      backend/apps/webui/models/models.py
  16. 38 48
      backend/apps/webui/models/prompts.py
  17. 109 89
      backend/apps/webui/models/tags.py
  18. 37 41
      backend/apps/webui/models/tools.py
  19. 92 83
      backend/apps/webui/models/users.py
  20. 40 26
      backend/apps/webui/routers/auths.py
  21. 88 58
      backend/apps/webui/routers/chats.py
  22. 27 13
      backend/apps/webui/routers/documents.py
  23. 14 11
      backend/apps/webui/routers/files.py
  24. 14 13
      backend/apps/webui/routers/functions.py
  25. 18 11
      backend/apps/webui/routers/memories.py
  26. 22 13
      backend/apps/webui/routers/models.py
  27. 21 11
      backend/apps/webui/routers/prompts.py
  28. 22 13
      backend/apps/webui/routers/tools.py
  29. 44 25
      backend/apps/webui/routers/users.py
  30. 4 4
      backend/apps/webui/routers/utils.py
  31. 45 11
      backend/main.py
  32. 4 0
      backend/migrations/README
  33. 93 0
      backend/migrations/env.py
  34. 27 0
      backend/migrations/script.py.mako
  35. 188 0
      backend/migrations/versions/22b5ab2667b8_init.py
  36. 10 3
      backend/requirements.txt
  37. 0 0
      backend/test/__init__.py
  38. 209 0
      backend/test/apps/webui/routers/test_auths.py
  39. 239 0
      backend/test/apps/webui/routers/test_chats.py
  40. 106 0
      backend/test/apps/webui/routers/test_documents.py
  41. 60 0
      backend/test/apps/webui/routers/test_models.py
  42. 82 0
      backend/test/apps/webui/routers/test_prompts.py
  43. 170 0
      backend/test/apps/webui/routers/test_users.py
  44. 155 0
      backend/test/util/abstract_integration_test.py
  45. 45 0
      backend/test/util/mock_user.py
  46. 9 6
      backend/utils/utils.py
  47. 1 4
      src/lib/apis/models/index.ts

+ 2 - 2
.github/workflows/integration-test.yml

@@ -171,7 +171,7 @@ jobs:
           fi
 
           # Check that service will reconnect to postgres when connection will be closed
-          status_code=$(curl --write-out %{http_code} -s --output /dev/null http://localhost:8081/health)
+          status_code=$(curl --write-out %{http_code} -s --output /dev/null http://localhost:8081/health/db)
           if [[ "$status_code" -ne 200 ]] ; then
             echo "Server has failed before postgres reconnect check"
             exit 1
@@ -183,7 +183,7 @@ jobs:
             cur = conn.cursor(); \
             cur.execute('SELECT pg_terminate_backend(psa.pid) FROM pg_stat_activity psa WHERE datname = current_database() AND pid <> pg_backend_pid();')"
 
-          status_code=$(curl --write-out %{http_code} -s --output /dev/null http://localhost:8081/health)
+          status_code=$(curl --write-out %{http_code} -s --output /dev/null http://localhost:8081/health/db)
           if [[ "$status_code" -ne 200 ]] ; then
             echo "Server has not reconnected to postgres after connection was closed: returned status $status_code"
             exit 1

+ 114 - 0
backend/alembic.ini

@@ -0,0 +1,114 @@
+# A generic, single database configuration.
+
+[alembic]
+# path to migration scripts
+script_location = migrations
+
+# template used to generate migration file names; The default value is %%(rev)s_%%(slug)s
+# Uncomment the line below if you want the files to be prepended with date and time
+# file_template = %%(year)d_%%(month).2d_%%(day).2d_%%(hour).2d%%(minute).2d-%%(rev)s_%%(slug)s
+
+# sys.path path, will be prepended to sys.path if present.
+# defaults to the current working directory.
+prepend_sys_path = .
+
+# timezone to use when rendering the date within the migration file
+# as well as the filename.
+# If specified, requires the python>=3.9 or backports.zoneinfo library.
+# Any required deps can installed by adding `alembic[tz]` to the pip requirements
+# string value is passed to ZoneInfo()
+# leave blank for localtime
+# timezone =
+
+# max length of characters to apply to the
+# "slug" field
+# truncate_slug_length = 40
+
+# set to 'true' to run the environment during
+# the 'revision' command, regardless of autogenerate
+# revision_environment = false
+
+# set to 'true' to allow .pyc and .pyo files without
+# a source .py file to be detected as revisions in the
+# versions/ directory
+# sourceless = false
+
+# version location specification; This defaults
+# to migrations/versions.  When using multiple version
+# directories, initial revisions must be specified with --version-path.
+# The path separator used here should be the separator specified by "version_path_separator" below.
+# version_locations = %(here)s/bar:%(here)s/bat:migrations/versions
+
+# version path separator; As mentioned above, this is the character used to split
+# version_locations. The default within new alembic.ini files is "os", which uses os.pathsep.
+# If this key is omitted entirely, it falls back to the legacy behavior of splitting on spaces and/or commas.
+# Valid values for version_path_separator are:
+#
+# version_path_separator = :
+# version_path_separator = ;
+# version_path_separator = space
+version_path_separator = os  # Use os.pathsep. Default configuration used for new projects.
+
+# set to 'true' to search source files recursively
+# in each "version_locations" directory
+# new in Alembic version 1.10
+# recursive_version_locations = false
+
+# the output encoding used when revision files
+# are written from script.py.mako
+# output_encoding = utf-8
+
+sqlalchemy.url = REPLACE_WITH_DATABASE_URL
+
+
+[post_write_hooks]
+# post_write_hooks defines scripts or Python functions that are run
+# on newly generated revision scripts.  See the documentation for further
+# detail and examples
+
+# format using "black" - use the console_scripts runner, against the "black" entrypoint
+# hooks = black
+# black.type = console_scripts
+# black.entrypoint = black
+# black.options = -l 79 REVISION_SCRIPT_FILENAME
+
+# lint with attempts to fix using "ruff" - use the exec runner, execute a binary
+# hooks = ruff
+# ruff.type = exec
+# ruff.executable = %(here)s/.venv/bin/ruff
+# ruff.options = --fix REVISION_SCRIPT_FILENAME
+
+# Logging configuration
+[loggers]
+keys = root,sqlalchemy,alembic
+
+[handlers]
+keys = console
+
+[formatters]
+keys = generic
+
+[logger_root]
+level = WARN
+handlers = console
+qualname =
+
+[logger_sqlalchemy]
+level = WARN
+handlers =
+qualname = sqlalchemy.engine
+
+[logger_alembic]
+level = INFO
+handlers =
+qualname = alembic
+
+[handler_console]
+class = StreamHandler
+args = (sys.stderr,)
+level = NOTSET
+formatter = generic
+
+[formatter_generic]
+format = %(levelname)-5.5s [%(name)s] %(message)s
+datefmt = %H:%M:%S

+ 5 - 2
backend/apps/ollama/main.py

@@ -31,6 +31,7 @@ from typing import Optional, List, Union
 
 from starlette.background import BackgroundTask
 
+from apps.webui.internal.db import get_db
 from apps.webui.models.models import Models
 from apps.webui.models.users import Users
 from constants import ERROR_MESSAGES
@@ -711,6 +712,7 @@ async def generate_chat_completion(
     form_data: GenerateChatCompletionForm,
     url_idx: Optional[int] = None,
     user=Depends(get_verified_user),
+    db=Depends(get_db),
 ):
 
     log.debug(
@@ -724,7 +726,7 @@ async def generate_chat_completion(
     }
 
     model_id = form_data.model
-    model_info = Models.get_model_by_id(model_id)
+    model_info = Models.get_model_by_id(db, model_id)
 
     if model_info:
         if model_info.base_model_id:
@@ -883,6 +885,7 @@ async def generate_openai_chat_completion(
     form_data: dict,
     url_idx: Optional[int] = None,
     user=Depends(get_verified_user),
+    db=Depends(get_db),
 ):
     form_data = OpenAIChatCompletionForm(**form_data)
 
@@ -891,7 +894,7 @@ async def generate_openai_chat_completion(
     }
 
     model_id = form_data.model
-    model_info = Models.get_model_by_id(model_id)
+    model_info = Models.get_model_by_id(db, model_id)
 
     if model_info:
         if model_info.base_model_id:

+ 3 - 1
backend/apps/openai/main.py

@@ -11,6 +11,7 @@ import logging
 from pydantic import BaseModel
 from starlette.background import BackgroundTask
 
+from apps.webui.internal.db import get_db
 from apps.webui.models.models import Models
 from apps.webui.models.users import Users
 from constants import ERROR_MESSAGES
@@ -353,12 +354,13 @@ async def generate_chat_completion(
     form_data: dict,
     url_idx: Optional[int] = None,
     user=Depends(get_verified_user),
+    db=Depends(get_db),
 ):
     idx = 0
     payload = {**form_data}
 
     model_id = form_data.get("model")
-    model_info = Models.get_model_by_id(model_id)
+    model_info = Models.get_model_by_id(db, model_id)
 
     if model_info:
         if model_info.base_model_id:

+ 3 - 1
backend/apps/socket/main.py

@@ -24,7 +24,9 @@ async def connect(sid, environ, auth):
         data = decode_token(auth["token"])
 
         if data is not None and "id" in data:
-            user = Users.get_user_by_id(data["id"])
+            from apps.webui.internal.db import SessionLocal
+
+            user = Users.get_user_by_id(SessionLocal(), data["id"])
 
         if user:
             SESSION_POOL[sid] = user.id

+ 40 - 26
backend/apps/webui/internal/db.py

@@ -1,18 +1,34 @@
 import os
 import logging
 import json
+from typing import Optional, Any
+from typing_extensions import Self
 
-from peewee import *
-from peewee_migrate import Router
+from sqlalchemy import create_engine, types, Dialect
+from sqlalchemy.ext.declarative import declarative_base
+from sqlalchemy.orm import sessionmaker
+from sqlalchemy.sql.type_api import _T
 
-from apps.webui.internal.wrappers import register_connection
 from config import SRC_LOG_LEVELS, DATA_DIR, DATABASE_URL, BACKEND_DIR
 
 log = logging.getLogger(__name__)
 log.setLevel(SRC_LOG_LEVELS["DB"])
 
 
-class JSONField(TextField):
+class JSONField(types.TypeDecorator):
+    impl = types.Text
+    cache_ok = True
+
+    def process_bind_param(self, value: Optional[_T], dialect: Dialect) -> Any:
+        return json.dumps(value)
+
+    def process_result_value(self, value: Optional[_T], dialect: Dialect) -> Any:
+        if value is not None:
+            return json.loads(value)
+
+    def copy(self, **kw: Any) -> Self:
+        return JSONField(self.impl.length)
+
     def db_value(self, value):
         return json.dumps(value)
 
@@ -29,26 +45,24 @@ if os.path.exists(f"{DATA_DIR}/ollama.db"):
 else:
     pass
 
+SQLALCHEMY_DATABASE_URL = DATABASE_URL
+if "sqlite" in SQLALCHEMY_DATABASE_URL:
+    engine = create_engine(
+        SQLALCHEMY_DATABASE_URL, connect_args={"check_same_thread": False}
+    )
+else:
+    engine = create_engine(SQLALCHEMY_DATABASE_URL, pool_pre_ping=True)
+SessionLocal = sessionmaker(autocommit=False, autoflush=False, bind=engine)
+Base = declarative_base()
 
-# The `register_connection` function encapsulates the logic for setting up
-# the database connection based on the connection string, while `connect`
-# is a Peewee-specific method to manage the connection state and avoid errors
-# when a connection is already open.
-try:
-    DB = register_connection(DATABASE_URL)
-    log.info(f"Connected to a {DB.__class__.__name__} database.")
-except Exception as e:
-    log.error(f"Failed to initialize the database connection: {e}")
-    raise
-
-router = Router(
-    DB,
-    migrate_dir=BACKEND_DIR / "apps" / "webui" / "internal" / "migrations",
-    logger=log,
-)
-router.run()
-try:
-    DB.connect(reuse_if_open=True)
-except OperationalError as e:
-    log.info(f"Failed to connect to database again due to: {e}")
-    pass
+
+def get_db():
+    db = SessionLocal()
+    try:
+        yield db
+        db.commit()
+    except Exception as e:
+        db.rollback()
+        raise e
+    finally:
+        db.close()

+ 0 - 72
backend/apps/webui/internal/wrappers.py

@@ -1,72 +0,0 @@
-from contextvars import ContextVar
-from peewee import *
-from peewee import PostgresqlDatabase, InterfaceError as PeeWeeInterfaceError
-
-import logging
-from playhouse.db_url import connect, parse
-from playhouse.shortcuts import ReconnectMixin
-
-from config import SRC_LOG_LEVELS
-
-log = logging.getLogger(__name__)
-log.setLevel(SRC_LOG_LEVELS["DB"])
-
-db_state_default = {"closed": None, "conn": None, "ctx": None, "transactions": None}
-db_state = ContextVar("db_state", default=db_state_default.copy())
-
-
-class PeeweeConnectionState(object):
-    def __init__(self, **kwargs):
-        super().__setattr__("_state", db_state)
-        super().__init__(**kwargs)
-
-    def __setattr__(self, name, value):
-        self._state.get()[name] = value
-
-    def __getattr__(self, name):
-        value = self._state.get()[name]
-        return value
-
-
-class CustomReconnectMixin(ReconnectMixin):
-    reconnect_errors = (
-        # psycopg2
-        (OperationalError, "termin"),
-        (InterfaceError, "closed"),
-        # peewee
-        (PeeWeeInterfaceError, "closed"),
-    )
-
-
-class ReconnectingPostgresqlDatabase(CustomReconnectMixin, PostgresqlDatabase):
-    pass
-
-
-def register_connection(db_url):
-    db = connect(db_url)
-    if isinstance(db, PostgresqlDatabase):
-        # Enable autoconnect for SQLite databases, managed by Peewee
-        db.autoconnect = True
-        db.reuse_if_open = True
-        log.info("Connected to PostgreSQL database")
-
-        # Get the connection details
-        connection = parse(db_url)
-
-        # Use our custom database class that supports reconnection
-        db = ReconnectingPostgresqlDatabase(
-            connection["database"],
-            user=connection["user"],
-            password=connection["password"],
-            host=connection["host"],
-            port=connection["port"],
-        )
-        db.connect(reuse_if_open=True)
-    elif isinstance(db, SqliteDatabase):
-        # Enable autoconnect for SQLite databases, managed by Peewee
-        db.autoconnect = True
-        db.reuse_if_open = True
-        log.info("Connected to SQLite database")
-    else:
-        raise ValueError("Unsupported database connection")
-    return db

+ 3 - 3
backend/apps/webui/main.py

@@ -3,7 +3,7 @@ from fastapi.routing import APIRoute
 from fastapi.responses import StreamingResponse
 from fastapi.middleware.cors import CORSMiddleware
 from starlette.middleware.sessions import SessionMiddleware
-
+from sqlalchemy.orm import Session
 from apps.webui.routers import (
     auths,
     users,
@@ -114,8 +114,8 @@ async def get_status():
     }
 
 
-async def get_pipe_models():
-    pipes = Functions.get_functions_by_type("pipe", active_only=True)
+async def get_pipe_models(db: Session):
+    pipes = Functions.get_functions_by_type(db, "pipe", active_only=True)
     pipe_models = []
 
     for pipe in pipes:

+ 40 - 37
backend/apps/webui/models/auths.py

@@ -1,14 +1,14 @@
 from pydantic import BaseModel
-from typing import List, Union, Optional
-import time
+from typing import Optional
 import uuid
 import logging
-from peewee import *
+from sqlalchemy import String, Column, Boolean
+from sqlalchemy.orm import Session
 
 from apps.webui.models.users import UserModel, Users
 from utils.utils import verify_password
 
-from apps.webui.internal.db import DB
+from apps.webui.internal.db import Base
 
 from config import SRC_LOG_LEVELS
 
@@ -20,14 +20,13 @@ log.setLevel(SRC_LOG_LEVELS["MODELS"])
 ####################
 
 
-class Auth(Model):
-    id = CharField(unique=True)
-    email = CharField()
-    password = TextField()
-    active = BooleanField()
+class Auth(Base):
+    __tablename__ = "auth"
 
-    class Meta:
-        database = DB
+    id = Column(String, primary_key=True)
+    email = Column(String)
+    password = Column(String)
+    active = Column(Boolean)
 
 
 class AuthModel(BaseModel):
@@ -94,12 +93,10 @@ class AddUserForm(SignupForm):
 
 
 class AuthsTable:
-    def __init__(self, db):
-        self.db = db
-        self.db.create_tables([Auth])
 
     def insert_new_auth(
         self,
+        db: Session,
         email: str,
         password: str,
         name: str,
@@ -114,24 +111,30 @@ class AuthsTable:
         auth = AuthModel(
             **{"id": id, "email": email, "password": password, "active": True}
         )
-        result = Auth.create(**auth.model_dump())
+        result = Auth(**auth.model_dump())
+        db.add(result)
 
         user = Users.insert_new_user(
-            id, name, email, profile_image_url, role, oauth_sub
+            db, id, name, email, profile_image_url, role, oauth_sub
         )
 
+        db.commit()
+        db.refresh(result)
+
         if result and user:
             return user
         else:
             return None
 
-    def authenticate_user(self, email: str, password: str) -> Optional[UserModel]:
+    def authenticate_user(
+        self, db: Session, email: str, password: str
+    ) -> Optional[UserModel]:
         log.info(f"authenticate_user: {email}")
         try:
-            auth = Auth.get(Auth.email == email, Auth.active == True)
+            auth = db.query(Auth).filter_by(email=email, active=True).first()
             if auth:
                 if verify_password(password, auth.password):
-                    user = Users.get_user_by_id(auth.id)
+                    user = Users.get_user_by_id(db, auth.id)
                     return user
                 else:
                     return None
@@ -140,55 +143,55 @@ class AuthsTable:
         except:
             return None
 
-    def authenticate_user_by_api_key(self, api_key: str) -> Optional[UserModel]:
+    def authenticate_user_by_api_key(
+        self, db: Session, api_key: str
+    ) -> Optional[UserModel]:
         log.info(f"authenticate_user_by_api_key: {api_key}")
         # if no api_key, return None
         if not api_key:
             return None
 
         try:
-            user = Users.get_user_by_api_key(api_key)
+            user = Users.get_user_by_api_key(db, api_key)
             return user if user else None
         except:
             return False
 
-    def authenticate_user_by_trusted_header(self, email: str) -> Optional[UserModel]:
+    def authenticate_user_by_trusted_header(
+        self, db: Session, email: str
+    ) -> Optional[UserModel]:
         log.info(f"authenticate_user_by_trusted_header: {email}")
         try:
-            auth = Auth.get(Auth.email == email, Auth.active == True)
+            auth = db.query(Auth).filter(email=email, active=True).first()
             if auth:
                 user = Users.get_user_by_id(auth.id)
                 return user
         except:
             return None
 
-    def update_user_password_by_id(self, id: str, new_password: str) -> bool:
+    def update_user_password_by_id(
+        self, db: Session, id: str, new_password: str
+    ) -> bool:
         try:
-            query = Auth.update(password=new_password).where(Auth.id == id)
-            result = query.execute()
-
+            result = db.query(Auth).filter_by(id=id).update({"password": new_password})
             return True if result == 1 else False
         except:
             return False
 
-    def update_email_by_id(self, id: str, email: str) -> bool:
+    def update_email_by_id(self, db: Session, id: str, email: str) -> bool:
         try:
-            query = Auth.update(email=email).where(Auth.id == id)
-            result = query.execute()
-
+            result = db.query(Auth).filter_by(id=id).update({"email": email})
             return True if result == 1 else False
         except:
             return False
 
-    def delete_auth_by_id(self, id: str) -> bool:
+    def delete_auth_by_id(self, db: Session, id: str) -> bool:
         try:
             # Delete User
-            result = Users.delete_user_by_id(id)
+            result = Users.delete_user_by_id(db, id)
 
             if result:
-                # Delete Auth
-                query = Auth.delete().where(Auth.id == id)
-                query.execute()  # Remove the rows, return number of rows removed.
+                db.query(Auth).filter_by(id=id).delete()
 
                 return True
             else:
@@ -197,4 +200,4 @@ class AuthsTable:
             return False
 
 
-Auths = AuthsTable(DB)
+Auths = AuthsTable()

+ 139 - 162
backend/apps/webui/models/chats.py

@@ -1,36 +1,39 @@
-from pydantic import BaseModel
+from pydantic import BaseModel, ConfigDict
 from typing import List, Union, Optional
-from peewee import *
-from playhouse.shortcuts import model_to_dict
 
 import json
 import uuid
 import time
 
-from apps.webui.internal.db import DB
+from sqlalchemy import Column, String, BigInteger, Boolean
+from sqlalchemy.orm import Session
+
+from apps.webui.internal.db import Base
+
 
 ####################
 # Chat DB Schema
 ####################
 
 
-class Chat(Model):
-    id = CharField(unique=True)
-    user_id = CharField()
-    title = TextField()
-    chat = TextField()  # Save Chat JSON as Text
+class Chat(Base):
+    __tablename__ = "chat"
 
-    created_at = BigIntegerField()
-    updated_at = BigIntegerField()
+    id = Column(String, primary_key=True)
+    user_id = Column(String)
+    title = Column(String)
+    chat = Column(String)  # Save Chat JSON as Text
 
-    share_id = CharField(null=True, unique=True)
-    archived = BooleanField(default=False)
+    created_at = Column(BigInteger)
+    updated_at = Column(BigInteger)
 
-    class Meta:
-        database = DB
+    share_id = Column(String, unique=True, nullable=True)
+    archived = Column(Boolean, default=False)
 
 
 class ChatModel(BaseModel):
+    model_config = ConfigDict(from_attributes=True)
+
     id: str
     user_id: str
     title: str
@@ -75,11 +78,10 @@ class ChatTitleIdResponse(BaseModel):
 
 
 class ChatTable:
-    def __init__(self, db):
-        self.db = db
-        db.create_tables([Chat])
 
-    def insert_new_chat(self, user_id: str, form_data: ChatForm) -> Optional[ChatModel]:
+    def insert_new_chat(
+        self, db: Session, user_id: str, form_data: ChatForm
+    ) -> Optional[ChatModel]:
         id = str(uuid.uuid4())
         chat = ChatModel(
             **{
@@ -94,29 +96,36 @@ class ChatTable:
             }
         )
 
-        result = Chat.create(**chat.model_dump())
-        return chat if result else None
+        result = Chat(**chat.model_dump())
+        db.add(result)
+        db.commit()
+        db.refresh(result)
+        return ChatModel.model_validate(result) if result else None
 
-    def update_chat_by_id(self, id: str, chat: dict) -> Optional[ChatModel]:
+    def update_chat_by_id(
+        self, db: Session, id: str, chat: dict
+    ) -> Optional[ChatModel]:
         try:
-            query = Chat.update(
-                chat=json.dumps(chat),
-                title=chat["title"] if "title" in chat else "New Chat",
-                updated_at=int(time.time()),
-            ).where(Chat.id == id)
-            query.execute()
-
-            chat = Chat.get(Chat.id == id)
-            return ChatModel(**model_to_dict(chat))
+            db.query(Chat).filter_by(id=id).update(
+                {
+                    "chat": json.dumps(chat),
+                    "title": chat["title"] if "title" in chat else "New Chat",
+                    "updated_at": int(time.time()),
+                }
+            )
+
+            return self.get_chat_by_id(db, id)
         except:
             return None
 
-    def insert_shared_chat_by_chat_id(self, chat_id: str) -> Optional[ChatModel]:
+    def insert_shared_chat_by_chat_id(
+        self, db: Session, chat_id: str
+    ) -> Optional[ChatModel]:
         # Get the existing chat to share
-        chat = Chat.get(Chat.id == chat_id)
+        chat = db.get(Chat, chat_id)
         # Check if the chat is already shared
         if chat.share_id:
-            return self.get_chat_by_id_and_user_id(chat.share_id, "shared")
+            return self.get_chat_by_id_and_user_id(db, chat.share_id, "shared")
         # Create a new chat with the same data, but with a new ID
         shared_chat = ChatModel(
             **{
@@ -128,228 +137,196 @@ class ChatTable:
                 "updated_at": int(time.time()),
             }
         )
-        shared_result = Chat.create(**shared_chat.model_dump())
+        shared_result = Chat(**shared_chat.model_dump())
+        db.add(shared_result)
+        db.commit()
+        db.refresh(shared_result)
         # Update the original chat with the share_id
         result = (
-            Chat.update(share_id=shared_chat.id).where(Chat.id == chat_id).execute()
+            db.query(Chat).filter_by(id=chat_id).update({"share_id": shared_chat.id})
         )
 
         return shared_chat if (shared_result and result) else None
 
-    def update_shared_chat_by_chat_id(self, chat_id: str) -> Optional[ChatModel]:
+    def update_shared_chat_by_chat_id(
+        self, db: Session, chat_id: str
+    ) -> Optional[ChatModel]:
         try:
             print("update_shared_chat_by_id")
-            chat = Chat.get(Chat.id == chat_id)
+            chat = db.get(Chat, chat_id)
             print(chat)
 
-            query = Chat.update(
-                title=chat.title,
-                chat=chat.chat,
-            ).where(Chat.id == chat.share_id)
+            db.query(Chat).filter_by(id=chat.share_id).update(
+                {"title": chat.title, "chat": chat.chat}
+            )
 
-            query.execute()
-
-            chat = Chat.get(Chat.id == chat.share_id)
-            return ChatModel(**model_to_dict(chat))
+            return self.get_chat_by_id(db, chat.share_id)
         except:
             return None
 
-    def delete_shared_chat_by_chat_id(self, chat_id: str) -> bool:
+    def delete_shared_chat_by_chat_id(self, db: Session, chat_id: str) -> bool:
         try:
-            query = Chat.delete().where(Chat.user_id == f"shared-{chat_id}")
-            query.execute()  # Remove the rows, return number of rows removed.
-
+            db.query(Chat).filter_by(user_id=f"shared-{chat_id}").delete()
             return True
         except:
             return False
 
     def update_chat_share_id_by_id(
-        self, id: str, share_id: Optional[str]
+        self, db: Session, id: str, share_id: Optional[str]
     ) -> Optional[ChatModel]:
         try:
-            query = Chat.update(
-                share_id=share_id,
-            ).where(Chat.id == id)
-            query.execute()
+            db.query(Chat).filter_by(id=id).update({"share_id": share_id})
 
-            chat = Chat.get(Chat.id == id)
-            return ChatModel(**model_to_dict(chat))
+            return self.get_chat_by_id(db, id)
         except:
             return None
 
-    def toggle_chat_archive_by_id(self, id: str) -> Optional[ChatModel]:
+    def toggle_chat_archive_by_id(self, db: Session, id: str) -> Optional[ChatModel]:
         try:
-            chat = self.get_chat_by_id(id)
-            query = Chat.update(
-                archived=(not chat.archived),
-            ).where(Chat.id == id)
+            chat = self.get_chat_by_id(db, id)
+            db.query(Chat).filter_by(id=id).update({"archived": not chat.archived})
 
-            query.execute()
-
-            chat = Chat.get(Chat.id == id)
-            return ChatModel(**model_to_dict(chat))
+            return self.get_chat_by_id(db, id)
         except:
             return None
 
-    def archive_all_chats_by_user_id(self, user_id: str) -> bool:
+    def archive_all_chats_by_user_id(self, db: Session, user_id: str) -> bool:
         try:
-            chats = self.get_chats_by_user_id(user_id)
-            for chat in chats:
-                query = Chat.update(
-                    archived=True,
-                ).where(Chat.id == chat.id)
-
-                query.execute()
+            db.query(Chat).filter_by(user_id=user_id).update({"archived": True})
 
             return True
         except:
             return False
 
     def get_archived_chat_list_by_user_id(
-        self, user_id: str, skip: int = 0, limit: int = 50
+        self, db: Session, user_id: str, skip: int = 0, limit: int = 50
     ) -> List[ChatModel]:
-        return [
-            ChatModel(**model_to_dict(chat))
-            for chat in Chat.select()
-            .where(Chat.archived == True)
-            .where(Chat.user_id == user_id)
+        all_chats = (
+            db.query(Chat)
+            .filter_by(user_id=user_id, archived=True)
             .order_by(Chat.updated_at.desc())
-            # .limit(limit)
-            # .offset(skip)
-        ]
+            # .limit(limit).offset(skip)
+            .all()
+        )
+        return [ChatModel.model_validate(chat) for chat in all_chats]
 
     def get_chat_list_by_user_id(
         self,
+        db: Session,
         user_id: str,
         include_archived: bool = False,
         skip: int = 0,
         limit: int = 50,
     ) -> List[ChatModel]:
-        if include_archived:
-            return [
-                ChatModel(**model_to_dict(chat))
-                for chat in Chat.select()
-                .where(Chat.user_id == user_id)
-                .order_by(Chat.updated_at.desc())
-                # .limit(limit)
-                # .offset(skip)
-            ]
-        else:
-            return [
-                ChatModel(**model_to_dict(chat))
-                for chat in Chat.select()
-                .where(Chat.archived == False)
-                .where(Chat.user_id == user_id)
-                .order_by(Chat.updated_at.desc())
-                # .limit(limit)
-                # .offset(skip)
-            ]
+        query = db.query(Chat).filter_by(user_id=user_id)
+        if not include_archived:
+            query = query.filter_by(archived=False)
+        all_chats = (
+            query.order_by(Chat.updated_at.desc())
+            # .limit(limit).offset(skip)
+            .all()
+        )
+        return [ChatModel.model_validate(chat) for chat in all_chats]
 
     def get_chat_list_by_chat_ids(
-        self, chat_ids: List[str], skip: int = 0, limit: int = 50
+        self, db: Session, chat_ids: List[str], skip: int = 0, limit: int = 50
     ) -> List[ChatModel]:
-        return [
-            ChatModel(**model_to_dict(chat))
-            for chat in Chat.select()
-            .where(Chat.archived == False)
-            .where(Chat.id.in_(chat_ids))
+        all_chats = (
+            db.query(Chat)
+            .filter(Chat.id.in_(chat_ids))
+            .filter_by(archived=False)
             .order_by(Chat.updated_at.desc())
-        ]
+            .all()
+        )
+        return [ChatModel.model_validate(chat) for chat in all_chats]
 
-    def get_chat_by_id(self, id: str) -> Optional[ChatModel]:
+    def get_chat_by_id(self, db: Session, id: str) -> Optional[ChatModel]:
         try:
-            chat = Chat.get(Chat.id == id)
-            return ChatModel(**model_to_dict(chat))
+            chat = db.get(Chat, id)
+            return ChatModel.model_validate(chat)
         except:
             return None
 
-    def get_chat_by_share_id(self, id: str) -> Optional[ChatModel]:
+    def get_chat_by_share_id(self, db: Session, id: str) -> Optional[ChatModel]:
         try:
-            chat = Chat.get(Chat.share_id == id)
+            chat = db.query(Chat).filter_by(share_id=id).first()
 
             if chat:
-                chat = Chat.get(Chat.id == id)
-                return ChatModel(**model_to_dict(chat))
+                return self.get_chat_by_id(db, id)
             else:
                 return None
-        except:
+        except Exception as e:
             return None
 
-    def get_chat_by_id_and_user_id(self, id: str, user_id: str) -> Optional[ChatModel]:
+    def get_chat_by_id_and_user_id(
+        self, db: Session, id: str, user_id: str
+    ) -> Optional[ChatModel]:
         try:
-            chat = Chat.get(Chat.id == id, Chat.user_id == user_id)
-            return ChatModel(**model_to_dict(chat))
+            chat = db.query(Chat).filter_by(id=id, user_id=user_id).first()
+            return ChatModel.model_validate(chat)
         except:
             return None
 
-    def get_chats(self, skip: int = 0, limit: int = 50) -> List[ChatModel]:
-        return [
-            ChatModel(**model_to_dict(chat))
-            for chat in Chat.select().order_by(Chat.updated_at.desc())
+    def get_chats(self, db: Session, skip: int = 0, limit: int = 50) -> List[ChatModel]:
+        all_chats = (
+            db.query(Chat)
             # .limit(limit).offset(skip)
-        ]
-
-    def get_chats_by_user_id(self, user_id: str) -> List[ChatModel]:
-        return [
-            ChatModel(**model_to_dict(chat))
-            for chat in Chat.select()
-            .where(Chat.user_id == user_id)
             .order_by(Chat.updated_at.desc())
-            # .limit(limit).offset(skip)
-        ]
-
-    def get_archived_chats_by_user_id(self, user_id: str) -> List[ChatModel]:
-        return [
-            ChatModel(**model_to_dict(chat))
-            for chat in Chat.select()
-            .where(Chat.archived == True)
-            .where(Chat.user_id == user_id)
+        )
+        return [ChatModel.model_validate(chat) for chat in all_chats]
+
+    def get_chats_by_user_id(self, db: Session, user_id: str) -> List[ChatModel]:
+        all_chats = (
+            db.query(Chat).filter_by(user_id=user_id).order_by(Chat.updated_at.desc())
+        )
+        return [ChatModel.model_validate(chat) for chat in all_chats]
+
+    def get_archived_chats_by_user_id(
+        self, db: Session, user_id: str
+    ) -> List[ChatModel]:
+        all_chats = (
+            db.query(Chat)
+            .filter_by(user_id=user_id, archived=True)
             .order_by(Chat.updated_at.desc())
-        ]
+        )
+        return [ChatModel.model_validate(chat) for chat in all_chats]
 
-    def delete_chat_by_id(self, id: str) -> bool:
+    def delete_chat_by_id(self, db: Session, id: str) -> bool:
         try:
-            query = Chat.delete().where((Chat.id == id))
-            query.execute()  # Remove the rows, return number of rows removed.
+            db.query(Chat).filter_by(id=id).delete()
 
-            return True and self.delete_shared_chat_by_chat_id(id)
+            return True and self.delete_shared_chat_by_chat_id(db, id)
         except:
             return False
 
-    def delete_chat_by_id_and_user_id(self, id: str, user_id: str) -> bool:
+    def delete_chat_by_id_and_user_id(self, db: Session, id: str, user_id: str) -> bool:
         try:
-            query = Chat.delete().where((Chat.id == id) & (Chat.user_id == user_id))
-            query.execute()  # Remove the rows, return number of rows removed.
+            db.query(Chat).filter_by(id=id, user_id=user_id).delete()
 
-            return True and self.delete_shared_chat_by_chat_id(id)
+            return True and self.delete_shared_chat_by_chat_id(db, id)
         except:
             return False
 
-    def delete_chats_by_user_id(self, user_id: str) -> bool:
+    def delete_chats_by_user_id(self, db: Session, user_id: str) -> bool:
         try:
 
-            self.delete_shared_chats_by_user_id(user_id)
-
-            query = Chat.delete().where(Chat.user_id == user_id)
-            query.execute()  # Remove the rows, return number of rows removed.
+            self.delete_shared_chats_by_user_id(db, user_id)
 
+            db.query(Chat).filter_by(user_id=user_id).delete()
             return True
         except:
             return False
 
-    def delete_shared_chats_by_user_id(self, user_id: str) -> bool:
+    def delete_shared_chats_by_user_id(self, db: Session, user_id: str) -> bool:
         try:
-            shared_chat_ids = [
-                f"shared-{chat.id}"
-                for chat in Chat.select().where(Chat.user_id == user_id)
-            ]
+            chats_by_user = db.query(Chat).filter_by(user_id=user_id).all()
+            shared_chat_ids = [f"shared-{chat.id}" for chat in chats_by_user]
 
-            query = Chat.delete().where(Chat.user_id << shared_chat_ids)
-            query.execute()  # Remove the rows, return number of rows removed.
+            db.query(Chat).filter(Chat.user_id.in_(shared_chat_ids)).delete()
 
             return True
         except:
             return False
 
 
-Chats = ChatTable(DB)
+Chats = ChatTable()

+ 48 - 56
backend/apps/webui/models/documents.py

@@ -1,14 +1,12 @@
-from pydantic import BaseModel
-from peewee import *
-from playhouse.shortcuts import model_to_dict
-from typing import List, Union, Optional
+from pydantic import BaseModel, ConfigDict
+from typing import List, Optional
 import time
 import logging
 
-from utils.utils import decode_token
-from utils.misc import get_gravatar_url
+from sqlalchemy import String, Column, BigInteger
+from sqlalchemy.orm import Session
 
-from apps.webui.internal.db import DB
+from apps.webui.internal.db import Base
 
 import json
 
@@ -22,20 +20,21 @@ log.setLevel(SRC_LOG_LEVELS["MODELS"])
 ####################
 
 
-class Document(Model):
-    collection_name = CharField(unique=True)
-    name = CharField(unique=True)
-    title = TextField()
-    filename = TextField()
-    content = TextField(null=True)
-    user_id = CharField()
-    timestamp = BigIntegerField()
+class Document(Base):
+    __tablename__ = "document"
 
-    class Meta:
-        database = DB
+    collection_name = Column(String, primary_key=True)
+    name = Column(String, unique=True)
+    title = Column(String)
+    filename = Column(String)
+    content = Column(String, nullable=True)
+    user_id = Column(String)
+    timestamp = Column(BigInteger)
 
 
 class DocumentModel(BaseModel):
+    model_config = ConfigDict(from_attributes=True)
+
     collection_name: str
     name: str
     title: str
@@ -72,12 +71,9 @@ class DocumentForm(DocumentUpdateForm):
 
 
 class DocumentsTable:
-    def __init__(self, db):
-        self.db = db
-        self.db.create_tables([Document])
 
     def insert_new_doc(
-        self, user_id: str, form_data: DocumentForm
+        self, db: Session, user_id: str, form_data: DocumentForm
     ) -> Optional[DocumentModel]:
         document = DocumentModel(
             **{
@@ -88,73 +84,69 @@ class DocumentsTable:
         )
 
         try:
-            result = Document.create(**document.model_dump())
+            result = Document(**document.model_dump())
+            db.add(result)
+            db.commit()
+            db.refresh(result)
             if result:
-                return document
+                return DocumentModel.model_validate(result)
             else:
                 return None
         except:
             return None
 
-    def get_doc_by_name(self, name: str) -> Optional[DocumentModel]:
+    def get_doc_by_name(self, db: Session, name: str) -> Optional[DocumentModel]:
         try:
-            document = Document.get(Document.name == name)
-            return DocumentModel(**model_to_dict(document))
+            document = db.query(Document).filter_by(name=name).first()
+            return DocumentModel.model_validate(document) if document else None
         except:
             return None
 
-    def get_docs(self) -> List[DocumentModel]:
-        return [
-            DocumentModel(**model_to_dict(doc))
-            for doc in Document.select()
-            # .limit(limit).offset(skip)
-        ]
+    def get_docs(self, db: Session) -> List[DocumentModel]:
+        return [DocumentModel.model_validate(doc) for doc in db.query(Document).all()]
 
     def update_doc_by_name(
-        self, name: str, form_data: DocumentUpdateForm
+        self, db: Session, name: str, form_data: DocumentUpdateForm
     ) -> Optional[DocumentModel]:
         try:
-            query = Document.update(
-                title=form_data.title,
-                name=form_data.name,
-                timestamp=int(time.time()),
-            ).where(Document.name == name)
-            query.execute()
-
-            doc = Document.get(Document.name == form_data.name)
-            return DocumentModel(**model_to_dict(doc))
+            db.query(Document).filter_by(name=name).update(
+                {
+                    "title": form_data.title,
+                    "name": form_data.name,
+                    "timestamp": int(time.time()),
+                }
+            )
+            return self.get_doc_by_name(db, form_data.name)
         except Exception as e:
             log.exception(e)
             return None
 
     def update_doc_content_by_name(
-        self, name: str, updated: dict
+        self, db: Session, name: str, updated: dict
     ) -> Optional[DocumentModel]:
         try:
-            doc = self.get_doc_by_name(name)
+            doc = self.get_doc_by_name(db, name)
             doc_content = json.loads(doc.content if doc.content else "{}")
             doc_content = {**doc_content, **updated}
 
-            query = Document.update(
-                content=json.dumps(doc_content),
-                timestamp=int(time.time()),
-            ).where(Document.name == name)
-            query.execute()
+            db.query(Document).filter_by(name=name).update(
+                {
+                    "content": json.dumps(doc_content),
+                    "timestamp": int(time.time()),
+                }
+            )
 
-            doc = Document.get(Document.name == name)
-            return DocumentModel(**model_to_dict(doc))
+            return self.get_doc_by_name(db, name)
         except Exception as e:
             log.exception(e)
             return None
 
-    def delete_doc_by_name(self, name: str) -> bool:
+    def delete_doc_by_name(self, db: Session, name: str) -> bool:
         try:
-            query = Document.delete().where((Document.name == name))
-            query.execute()  # Remove the rows, return number of rows removed.
-
+            db.query(Document).filter_by(name=name).delete()
             return True
         except:
             return False
 
 
-Documents = DocumentsTable(DB)
+Documents = DocumentsTable()

+ 30 - 32
backend/apps/webui/models/files.py

@@ -1,10 +1,12 @@
-from pydantic import BaseModel
-from peewee import *
-from playhouse.shortcuts import model_to_dict
+from pydantic import BaseModel, ConfigDict
 from typing import List, Union, Optional
 import time
 import logging
-from apps.webui.internal.db import DB, JSONField
+
+from sqlalchemy import Column, String, BigInteger
+from sqlalchemy.orm import Session
+
+from apps.webui.internal.db import JSONField, Base
 
 import json
 
@@ -18,15 +20,14 @@ log.setLevel(SRC_LOG_LEVELS["MODELS"])
 ####################
 
 
-class File(Model):
-    id = CharField(unique=True)
-    user_id = CharField()
-    filename = TextField()
-    meta = JSONField()
-    created_at = BigIntegerField()
+class File(Base):
+    __tablename__ = "file"
 
-    class Meta:
-        database = DB
+    id = Column(String, primary_key=True)
+    user_id = Column(String)
+    filename = Column(String)
+    meta = Column(JSONField)
+    created_at = Column(BigInteger)
 
 
 class FileModel(BaseModel):
@@ -36,6 +37,7 @@ class FileModel(BaseModel):
     meta: dict
     created_at: int  # timestamp in epoch
 
+    model_config = ConfigDict(from_attributes=True)
 
 ####################
 # Forms
@@ -57,11 +59,8 @@ class FileForm(BaseModel):
 
 
 class FilesTable:
-    def __init__(self, db):
-        self.db = db
-        self.db.create_tables([File])
 
-    def insert_new_file(self, user_id: str, form_data: FileForm) -> Optional[FileModel]:
+    def insert_new_file(self, db: Session, user_id: str, form_data: FileForm) -> Optional[FileModel]:
         file = FileModel(
             **{
                 **form_data.model_dump(),
@@ -71,42 +70,41 @@ class FilesTable:
         )
 
         try:
-            result = File.create(**file.model_dump())
+            result = File(**file.model_dump())
+            db.add(result)
+            db.commit()
+            db.refresh(result)
             if result:
-                return file
+                return FileModel.model_validate(result)
             else:
                 return None
         except Exception as e:
             print(f"Error creating tool: {e}")
             return None
 
-    def get_file_by_id(self, id: str) -> Optional[FileModel]:
+    def get_file_by_id(self, db: Session, id: str) -> Optional[FileModel]:
         try:
-            file = File.get(File.id == id)
-            return FileModel(**model_to_dict(file))
+            file = db.get(File, id)
+            return FileModel.model_validate(file)
         except:
             return None
 
-    def get_files(self) -> List[FileModel]:
-        return [FileModel(**model_to_dict(file)) for file in File.select()]
+    def get_files(self, db: Session) -> List[FileModel]:
+        return [FileModel.model_validate(file) for file in db.query(File).all()]
 
-    def delete_file_by_id(self, id: str) -> bool:
+    def delete_file_by_id(self, db: Session, id: str) -> bool:
         try:
-            query = File.delete().where((File.id == id))
-            query.execute()  # Remove the rows, return number of rows removed.
-
+            db.query(File).filter_by(id=id).delete()
             return True
         except:
             return False
 
-    def delete_all_files(self) -> bool:
+    def delete_all_files(self, db: Session) -> bool:
         try:
-            query = File.delete()
-            query.execute()  # Remove the rows, return number of rows removed.
-
+            db.query(File).delete()
             return True
         except:
             return False
 
 
-Files = FilesTable(DB)
+Files = FilesTable()

+ 36 - 38
backend/apps/webui/models/functions.py

@@ -1,10 +1,12 @@
-from pydantic import BaseModel
-from peewee import *
-from playhouse.shortcuts import model_to_dict
+from pydantic import BaseModel, ConfigDict
 from typing import List, Union, Optional
 import time
 import logging
-from apps.webui.internal.db import DB, JSONField
+
+from sqlalchemy import Column, String, Text, BigInteger, Boolean
+from sqlalchemy.orm import Session
+
+from apps.webui.internal.db import JSONField, Base
 from apps.webui.models.users import Users
 
 import json
@@ -21,20 +23,19 @@ log.setLevel(SRC_LOG_LEVELS["MODELS"])
 ####################
 
 
-class Function(Model):
-    id = CharField(unique=True)
-    user_id = CharField()
-    name = TextField()
-    type = TextField()
-    content = TextField()
-    meta = JSONField()
-    valves = JSONField()
-    is_active = BooleanField(default=False)
-    updated_at = BigIntegerField()
-    created_at = BigIntegerField()
+class Function(Base):
+    __tablename__ = "function"
 
-    class Meta:
-        database = DB
+    id = Column(String, primary_key=True)
+    user_id = Column(String)
+    name = Column(Text)
+    type = Column(Text)
+    content = Column(Text)
+    meta = Column(JSONField)
+    valves = Column(JSONField)
+    is_active = Column(Boolean)
+    updated_at = Column(BigInteger)
+    created_at = Column(BigInteger)
 
 
 class FunctionMeta(BaseModel):
@@ -53,6 +54,8 @@ class FunctionModel(BaseModel):
     updated_at: int  # timestamp in epoch
     created_at: int  # timestamp in epoch
 
+    model_config = ConfigDict(from_attributes=True)
+
 
 ####################
 # Forms
@@ -82,12 +85,9 @@ class FunctionValves(BaseModel):
 
 
 class FunctionsTable:
-    def __init__(self, db):
-        self.db = db
-        self.db.create_tables([Function])
 
     def insert_new_function(
-        self, user_id: str, type: str, form_data: FunctionForm
+        self, db: Session, user_id: str, type: str, form_data: FunctionForm
     ) -> Optional[FunctionModel]:
         function = FunctionModel(
             **{
@@ -100,19 +100,22 @@ class FunctionsTable:
         )
 
         try:
-            result = Function.create(**function.model_dump())
+            result = Function(**function.model_dump())
+            db.add(result)
+            db.commit()
+            db.refresh(result)
             if result:
-                return function
+                return FunctionModel.model_validate(result)
             else:
                 return None
         except Exception as e:
             print(f"Error creating tool: {e}")
             return None
 
-    def get_function_by_id(self, id: str) -> Optional[FunctionModel]:
+    def get_function_by_id(self, db: Session, id: str) -> Optional[FunctionModel]:
         try:
-            function = Function.get(Function.id == id)
-            return FunctionModel(**model_to_dict(function))
+            function = db.get(Function, id)
+            return FunctionModel.model_validate(function)
         except:
             return None
 
@@ -211,14 +214,11 @@ class FunctionsTable:
 
     def update_function_by_id(self, id: str, updated: dict) -> Optional[FunctionModel]:
         try:
-            query = Function.update(
+            db.query(Function).filter_by(id=id).update({
                 **updated,
-                updated_at=int(time.time()),
-            ).where(Function.id == id)
-            query.execute()
-
-            function = Function.get(Function.id == id)
-            return FunctionModel(**model_to_dict(function))
+                "updated_at": int(time.time()),
+            })
+            return self.get_function_by_id(db, id)
         except:
             return None
 
@@ -235,14 +235,12 @@ class FunctionsTable:
         except:
             return None
 
-    def delete_function_by_id(self, id: str) -> bool:
+    def delete_function_by_id(self, db: Session, id: str) -> bool:
         try:
-            query = Function.delete().where((Function.id == id))
-            query.execute()  # Remove the rows, return number of rows removed.
-
+            db.query(Function).filter_by(id=id).delete()
             return True
         except:
             return False
 
 
-Functions = FunctionsTable(DB)
+Functions = FunctionsTable()

+ 43 - 44
backend/apps/webui/models/memories.py

@@ -1,9 +1,10 @@
-from pydantic import BaseModel
-from peewee import *
-from playhouse.shortcuts import model_to_dict
+from pydantic import BaseModel, ConfigDict
 from typing import List, Union, Optional
 
-from apps.webui.internal.db import DB
+from sqlalchemy import Column, String, BigInteger
+from sqlalchemy.orm import Session
+
+from apps.webui.internal.db import Base
 from apps.webui.models.chats import Chats
 
 import time
@@ -14,15 +15,14 @@ import uuid
 ####################
 
 
-class Memory(Model):
-    id = CharField(unique=True)
-    user_id = CharField()
-    content = TextField()
-    updated_at = BigIntegerField()
-    created_at = BigIntegerField()
+class Memory(Base):
+    __tablename__ = "memory"
 
-    class Meta:
-        database = DB
+    id = Column(String, primary_key=True)
+    user_id = Column(String)
+    content = Column(String)
+    updated_at = Column(BigInteger)
+    created_at = Column(BigInteger)
 
 
 class MemoryModel(BaseModel):
@@ -32,6 +32,8 @@ class MemoryModel(BaseModel):
     updated_at: int  # timestamp in epoch
     created_at: int  # timestamp in epoch
 
+    model_config = ConfigDict(from_attributes=True)
+
 
 ####################
 # Forms
@@ -39,12 +41,10 @@ class MemoryModel(BaseModel):
 
 
 class MemoriesTable:
-    def __init__(self, db):
-        self.db = db
-        self.db.create_tables([Memory])
 
     def insert_new_memory(
         self,
+        db: Session,
         user_id: str,
         content: str,
     ) -> Optional[MemoryModel]:
@@ -59,74 +59,73 @@ class MemoriesTable:
                 "updated_at": int(time.time()),
             }
         )
-        result = Memory.create(**memory.model_dump())
+        result = Memory(**memory.dict())
+        db.add(result)
+        db.commit()
+        db.refresh(result)
         if result:
-            return memory
+            return MemoryModel.model_validate(result)
         else:
             return None
 
     def update_memory_by_id(
         self,
+        db: Session,
         id: str,
         content: str,
     ) -> Optional[MemoryModel]:
         try:
-            memory = Memory.get(Memory.id == id)
-            memory.content = content
-            memory.updated_at = int(time.time())
-            memory.save()
-            return MemoryModel(**model_to_dict(memory))
+            db.query(Memory).filter_by(id=id).update(
+                {"content": content, "updated_at": int(time.time())}
+            )
+            return self.get_memory_by_id(db, id)
         except:
             return None
 
-    def get_memories(self) -> List[MemoryModel]:
+    def get_memories(self, db: Session) -> List[MemoryModel]:
         try:
-            memories = Memory.select()
-            return [MemoryModel(**model_to_dict(memory)) for memory in memories]
+            memories = db.query(Memory).all()
+            return [MemoryModel.model_validate(memory) for memory in memories]
         except:
             return None
 
-    def get_memories_by_user_id(self, user_id: str) -> List[MemoryModel]:
+    def get_memories_by_user_id(self, db: Session, user_id: str) -> List[MemoryModel]:
         try:
-            memories = Memory.select().where(Memory.user_id == user_id)
-            return [MemoryModel(**model_to_dict(memory)) for memory in memories]
+            memories = db.query(Memory).filter_by(user_id=user_id).all()
+            return [MemoryModel.model_validate(memory) for memory in memories]
         except:
             return None
 
-    def get_memory_by_id(self, id) -> Optional[MemoryModel]:
+    def get_memory_by_id(self, db: Session, id: str) -> Optional[MemoryModel]:
         try:
-            memory = Memory.get(Memory.id == id)
-            return MemoryModel(**model_to_dict(memory))
+            memory = db.get(Memory, id)
+            return MemoryModel.model_validate(memory)
         except:
             return None
 
-    def delete_memory_by_id(self, id: str) -> bool:
+    def delete_memory_by_id(self, db: Session, id: str) -> bool:
         try:
-            query = Memory.delete().where(Memory.id == id)
-            query.execute()  # Remove the rows, return number of rows removed.
-
+            db.query(Memory).filter_by(id=id).delete()
             return True
 
         except:
             return False
 
-    def delete_memories_by_user_id(self, user_id: str) -> bool:
+    def delete_memories_by_user_id(self, db: Session, user_id: str) -> bool:
         try:
-            query = Memory.delete().where(Memory.user_id == user_id)
-            query.execute()
-
+            db.query(Memory).filter_by(user_id=user_id).delete()
             return True
         except:
             return False
 
-    def delete_memory_by_id_and_user_id(self, id: str, user_id: str) -> bool:
+    def delete_memory_by_id_and_user_id(
+        self, db: Session, id: str, user_id: str
+    ) -> bool:
         try:
-            query = Memory.delete().where(Memory.id == id, Memory.user_id == user_id)
-            query.execute()
-
+            db.query(Memory).filter_by(id=id, user_id=user_id).delete()
             return True
         except:
             return False
 
 
-Memories = MemoriesTable(DB)
+Memories = MemoriesTable()

+ 38 - 41
backend/apps/webui/models/models.py

@@ -2,13 +2,11 @@ import json
 import logging
 from typing import Optional
 
-import peewee as pw
-from peewee import *
-
-from playhouse.shortcuts import model_to_dict
 from pydantic import BaseModel, ConfigDict
+from sqlalchemy import String, Column, BigInteger
+from sqlalchemy.orm import Session
 
-from apps.webui.internal.db import DB, JSONField
+from apps.webui.internal.db import Base, JSONField
 
 from typing import List, Union, Optional
 from config import SRC_LOG_LEVELS
@@ -46,41 +44,42 @@ class ModelMeta(BaseModel):
     pass
 
 
-class Model(pw.Model):
-    id = pw.TextField(unique=True)
+class Model(Base):
+    __tablename__ = "model"
+
+    id = Column(String, primary_key=True)
     """
         The model's id as used in the API. If set to an existing model, it will override the model.
     """
-    user_id = pw.TextField()
+    user_id = Column(String)
 
-    base_model_id = pw.TextField(null=True)
+    base_model_id = Column(String, nullable=True)
     """
         An optional pointer to the actual model that should be used when proxying requests.
     """
 
-    name = pw.TextField()
+    name = Column(String)
     """
         The human-readable display name of the model.
     """
 
-    params = JSONField()
+    params = Column(JSONField)
     """
         Holds a JSON encoded blob of parameters, see `ModelParams`.
     """
 
-    meta = JSONField()
+    meta = Column(JSONField)
     """
         Holds a JSON encoded blob of metadata, see `ModelMeta`.
     """
 
-    updated_at = BigIntegerField()
-    created_at = BigIntegerField()
-
-    class Meta:
-        database = DB
+    updated_at = Column(BigInteger)
+    created_at = Column(BigInteger)
 
 
 class ModelModel(BaseModel):
+    model_config = ConfigDict(from_attributes=True)
+
     id: str
     user_id: str
     base_model_id: Optional[str] = None
@@ -115,15 +114,9 @@ class ModelForm(BaseModel):
 
 
 class ModelsTable:
-    def __init__(
-        self,
-        db: pw.SqliteDatabase | pw.PostgresqlDatabase,
-    ):
-        self.db = db
-        self.db.create_tables([Model])
 
     def insert_new_model(
-        self, form_data: ModelForm, user_id: str
+        self, db: Session, form_data: ModelForm, user_id: str
     ) -> Optional[ModelModel]:
         model = ModelModel(
             **{
@@ -134,46 +127,50 @@ class ModelsTable:
             }
         )
         try:
-            result = Model.create(**model.model_dump())
+            result = Model(**model.dict())
+            db.add(result)
+            db.commit()
+            db.refresh(result)
 
             if result:
-                return model
+                return ModelModel.model_validate(result)
             else:
                 return None
         except Exception as e:
             print(e)
             return None
 
-    def get_all_models(self) -> List[ModelModel]:
-        return [ModelModel(**model_to_dict(model)) for model in Model.select()]
+    def get_all_models(self, db: Session) -> List[ModelModel]:
+        return [ModelModel.model_validate(model) for model in db.query(Model).all()]
 
-    def get_model_by_id(self, id: str) -> Optional[ModelModel]:
+    def get_model_by_id(self, db: Session, id: str) -> Optional[ModelModel]:
         try:
-            model = Model.get(Model.id == id)
-            return ModelModel(**model_to_dict(model))
+            model = db.get(Model, id)
+            return ModelModel.model_validate(model)
         except:
             return None
 
-    def update_model_by_id(self, id: str, model: ModelForm) -> Optional[ModelModel]:
+    def update_model_by_id(
+        self, db: Session, id: str, model: ModelForm
+    ) -> Optional[ModelModel]:
         try:
             # update only the fields that are present in the model
-            query = Model.update(**model.model_dump()).where(Model.id == id)
-            query.execute()
-
-            model = Model.get(Model.id == id)
-            return ModelModel(**model_to_dict(model))
+            model = db.query(Model).get(id)
+            model.update(**model.model_dump())
+            db.commit()
+            db.refresh(model)
+            return ModelModel.model_validate(model)
         except Exception as e:
             print(e)
 
             return None
 
-    def delete_model_by_id(self, id: str) -> bool:
+    def delete_model_by_id(self, db: Session, id: str) -> bool:
         try:
-            query = Model.delete().where(Model.id == id)
-            query.execute()
+            db.query(Model).filter_by(id=id).delete()
             return True
         except:
             return False
 
 
-Models = ModelsTable(DB)
+Models = ModelsTable()

+ 38 - 48
backend/apps/webui/models/prompts.py

@@ -1,13 +1,11 @@
-from pydantic import BaseModel
-from peewee import *
-from playhouse.shortcuts import model_to_dict
-from typing import List, Union, Optional
+from pydantic import BaseModel, ConfigDict
+from typing import List, Optional
 import time
 
-from utils.utils import decode_token
-from utils.misc import get_gravatar_url
+from sqlalchemy import String, Column, BigInteger
+from sqlalchemy.orm import Session
 
-from apps.webui.internal.db import DB
+from apps.webui.internal.db import Base
 
 import json
 
@@ -16,15 +14,14 @@ import json
 ####################
 
 
-class Prompt(Model):
-    command = CharField(unique=True)
-    user_id = CharField()
-    title = TextField()
-    content = TextField()
-    timestamp = BigIntegerField()
+class Prompt(Base):
+    __tablename__ = "prompt"
 
-    class Meta:
-        database = DB
+    command = Column(String, primary_key=True)
+    user_id = Column(String)
+    title = Column(String)
+    content = Column(String)
+    timestamp = Column(BigInteger)
 
 
 class PromptModel(BaseModel):
@@ -34,6 +31,8 @@ class PromptModel(BaseModel):
     content: str
     timestamp: int  # timestamp in epoch
 
+    model_config = ConfigDict(from_attributes=True)
+
 
 ####################
 # Forms
@@ -48,12 +47,8 @@ class PromptForm(BaseModel):
 
 class PromptsTable:
 
-    def __init__(self, db):
-        self.db = db
-        self.db.create_tables([Prompt])
-
     def insert_new_prompt(
-        self, user_id: str, form_data: PromptForm
+        self, db: Session, user_id: str, form_data: PromptForm
     ) -> Optional[PromptModel]:
         prompt = PromptModel(
             **{
@@ -66,53 +61,48 @@ class PromptsTable:
         )
 
         try:
-            result = Prompt.create(**prompt.model_dump())
+            result = Prompt(**prompt.dict())
+            db.add(result)
+            db.commit()
+            db.refresh(result)
             if result:
-                return prompt
+                return PromptModel.model_validate(result)
             else:
                 return None
-        except:
+        except Exception as e:
             return None
 
-    def get_prompt_by_command(self, command: str) -> Optional[PromptModel]:
+    def get_prompt_by_command(self, db: Session, command: str) -> Optional[PromptModel]:
         try:
-            prompt = Prompt.get(Prompt.command == command)
-            return PromptModel(**model_to_dict(prompt))
+            prompt = db.query(Prompt).filter_by(command=command).first()
+            return PromptModel.model_validate(prompt)
         except:
             return None
 
-    def get_prompts(self) -> List[PromptModel]:
-        return [
-            PromptModel(**model_to_dict(prompt))
-            for prompt in Prompt.select()
-            # .limit(limit).offset(skip)
-        ]
+    def get_prompts(self, db: Session) -> List[PromptModel]:
+        return [PromptModel.model_validate(prompt) for prompt in db.query(Prompt).all()]
 
     def update_prompt_by_command(
-        self, command: str, form_data: PromptForm
+        self, db: Session, command: str, form_data: PromptForm
     ) -> Optional[PromptModel]:
         try:
-            query = Prompt.update(
-                title=form_data.title,
-                content=form_data.content,
-                timestamp=int(time.time()),
-            ).where(Prompt.command == command)
-
-            query.execute()
-
-            prompt = Prompt.get(Prompt.command == command)
-            return PromptModel(**model_to_dict(prompt))
+            db.query(Prompt).filter_by(command=command).update(
+                {
+                    "title": form_data.title,
+                    "content": form_data.content,
+                    "timestamp": int(time.time()),
+                }
+            )
+            return self.get_prompt_by_command(db, command)
         except:
             return None
 
-    def delete_prompt_by_command(self, command: str) -> bool:
+    def delete_prompt_by_command(self, db: Session, command: str) -> bool:
         try:
-            query = Prompt.delete().where((Prompt.command == command))
-            query.execute()  # Remove the rows, return number of rows removed.
-
+            db.query(Prompt).filter_by(command=command).delete()
             return True
         except:
             return False
 
 
-Prompts = PromptsTable(DB)
+Prompts = PromptsTable()

+ 109 - 89
backend/apps/webui/models/tags.py

@@ -1,14 +1,15 @@
-from pydantic import BaseModel
-from typing import List, Union, Optional
-from peewee import *
-from playhouse.shortcuts import model_to_dict
+from pydantic import BaseModel, ConfigDict
+from typing import List, Optional
 
 import json
 import uuid
 import time
 import logging
 
-from apps.webui.internal.db import DB
+from sqlalchemy import String, Column, BigInteger
+from sqlalchemy.orm import Session
+
+from apps.webui.internal.db import Base
 
 from config import SRC_LOG_LEVELS
 
@@ -20,25 +21,23 @@ log.setLevel(SRC_LOG_LEVELS["MODELS"])
 ####################
 
 
-class Tag(Model):
-    id = CharField(unique=True)
-    name = CharField()
-    user_id = CharField()
-    data = TextField(null=True)
+class Tag(Base):
+    __tablename__ = "tag"
 
-    class Meta:
-        database = DB
+    id = Column(String, primary_key=True)
+    name = Column(String)
+    user_id = Column(String)
+    data = Column(String, nullable=True)
 
 
-class ChatIdTag(Model):
-    id = CharField(unique=True)
-    tag_name = CharField()
-    chat_id = CharField()
-    user_id = CharField()
-    timestamp = BigIntegerField()
+class ChatIdTag(Base):
+    __tablename__ = "chatidtag"
 
-    class Meta:
-        database = DB
+    id = Column(String, primary_key=True)
+    tag_name = Column(String)
+    chat_id = Column(String)
+    user_id = Column(String)
+    timestamp = Column(BigInteger)
 
 
 class TagModel(BaseModel):
@@ -47,6 +46,8 @@ class TagModel(BaseModel):
     user_id: str
     data: Optional[str] = None
 
+    model_config = ConfigDict(from_attributes=True)
+
 
 class ChatIdTagModel(BaseModel):
     id: str
@@ -55,6 +56,8 @@ class ChatIdTagModel(BaseModel):
     user_id: str
     timestamp: int
 
+    model_config = ConfigDict(from_attributes=True)
+
 
 ####################
 # Forms
@@ -75,37 +78,39 @@ class ChatTagsResponse(BaseModel):
 
 
 class TagTable:
-    def __init__(self, db):
-        self.db = db
-        db.create_tables([Tag, ChatIdTag])
 
-    def insert_new_tag(self, name: str, user_id: str) -> Optional[TagModel]:
+    def insert_new_tag(
+        self, db: Session, name: str, user_id: str
+    ) -> Optional[TagModel]:
         id = str(uuid.uuid4())
         tag = TagModel(**{"id": id, "user_id": user_id, "name": name})
         try:
-            result = Tag.create(**tag.model_dump())
+            result = Tag(**tag.dict())
+            db.add(result)
+            db.commit()
+            db.refresh(result)
             if result:
-                return tag
+                return TagModel.model_validate(result)
             else:
                 return None
         except Exception as e:
             return None
 
     def get_tag_by_name_and_user_id(
-        self, name: str, user_id: str
+        self, db: Session, name: str, user_id: str
     ) -> Optional[TagModel]:
         try:
-            tag = Tag.get(Tag.name == name, Tag.user_id == user_id)
-            return TagModel(**model_to_dict(tag))
+            tag = db.query(Tag).filter(name=name, user_id=user_id).first()
+            return TagModel.model_validate(tag)
         except Exception as e:
             return None
 
     def add_tag_to_chat(
-        self, user_id: str, form_data: ChatIdTagForm
+        self, db: Session, user_id: str, form_data: ChatIdTagForm
     ) -> Optional[ChatIdTagModel]:
-        tag = self.get_tag_by_name_and_user_id(form_data.tag_name, user_id)
+        tag = self.get_tag_by_name_and_user_id(db, form_data.tag_name, user_id)
         if tag == None:
-            tag = self.insert_new_tag(form_data.tag_name, user_id)
+            tag = self.insert_new_tag(db, form_data.tag_name, user_id)
 
         id = str(uuid.uuid4())
         chatIdTag = ChatIdTagModel(
@@ -118,120 +123,135 @@ class TagTable:
             }
         )
         try:
-            result = ChatIdTag.create(**chatIdTag.model_dump())
+            result = ChatIdTag(**chatIdTag.dict())
+            db.add(result)
+            db.commit()
+            db.refresh(result)
             if result:
-                return chatIdTag
+                return ChatIdTagModel.model_validate(result)
             else:
                 return None
         except:
             return None
 
-    def get_tags_by_user_id(self, user_id: str) -> List[TagModel]:
+    def get_tags_by_user_id(self, db: Session, user_id: str) -> List[TagModel]:
         tag_names = [
-            ChatIdTagModel(**model_to_dict(chat_id_tag)).tag_name
-            for chat_id_tag in ChatIdTag.select()
-            .where(ChatIdTag.user_id == user_id)
-            .order_by(ChatIdTag.timestamp.desc())
+            chat_id_tag.tag_name
+            for chat_id_tag in (
+                db.query(ChatIdTag)
+                .filter_by(user_id=user_id)
+                .order_by(ChatIdTag.timestamp.desc())
+                .all()
+            )
         ]
 
         return [
-            TagModel(**model_to_dict(tag))
-            for tag in Tag.select()
-            .where(Tag.user_id == user_id)
-            .where(Tag.name.in_(tag_names))
+            TagModel.model_validate(tag)
+            for tag in (
+                db.query(Tag)
+                .filter_by(user_id=user_id)
+                .filter(Tag.name.in_(tag_names))
+                .all()
+            )
         ]
 
     def get_tags_by_chat_id_and_user_id(
-        self, chat_id: str, user_id: str
+        self, db: Session, chat_id: str, user_id: str
     ) -> List[TagModel]:
         tag_names = [
-            ChatIdTagModel(**model_to_dict(chat_id_tag)).tag_name
-            for chat_id_tag in ChatIdTag.select()
-            .where((ChatIdTag.user_id == user_id) & (ChatIdTag.chat_id == chat_id))
-            .order_by(ChatIdTag.timestamp.desc())
+            chat_id_tag.tag_name
+            for chat_id_tag in (
+                db.query(ChatIdTag)
+                .filter_by(user_id=user_id, chat_id=chat_id)
+                .order_by(ChatIdTag.timestamp.desc())
+                .all()
+            )
         ]
 
         return [
-            TagModel(**model_to_dict(tag))
-            for tag in Tag.select()
-            .where(Tag.user_id == user_id)
-            .where(Tag.name.in_(tag_names))
+            TagModel.model_validate(tag)
+            for tag in (
+                db.query(Tag)
+                .filter_by(user_id=user_id)
+                .filter(Tag.name.in_(tag_names))
+                .all()
+            )
         ]
 
     def get_chat_ids_by_tag_name_and_user_id(
-        self, tag_name: str, user_id: str
-    ) -> Optional[ChatIdTagModel]:
+        self, db: Session, tag_name: str, user_id: str
+    ) -> List[ChatIdTagModel]:
         return [
-            ChatIdTagModel(**model_to_dict(chat_id_tag))
-            for chat_id_tag in ChatIdTag.select()
-            .where((ChatIdTag.user_id == user_id) & (ChatIdTag.tag_name == tag_name))
-            .order_by(ChatIdTag.timestamp.desc())
+            ChatIdTagModel.model_validate(chat_id_tag)
+            for chat_id_tag in (
+                db.query(ChatIdTag)
+                .filter_by(user_id=user_id, tag_name=tag_name)
+                .order_by(ChatIdTag.timestamp.desc())
+                .all()
+            )
         ]
 
     def count_chat_ids_by_tag_name_and_user_id(
-        self, tag_name: str, user_id: str
+        self, db: Session, tag_name: str, user_id: str
     ) -> int:
-        return (
-            ChatIdTag.select()
-            .where((ChatIdTag.tag_name == tag_name) & (ChatIdTag.user_id == user_id))
-            .count()
-        )
+        return db.query(ChatIdTag).filter_by(tag_name=tag_name, user_id=user_id).count()
 
-    def delete_tag_by_tag_name_and_user_id(self, tag_name: str, user_id: str) -> bool:
+    def delete_tag_by_tag_name_and_user_id(
+        self, db: Session, tag_name: str, user_id: str
+    ) -> bool:
         try:
-            query = ChatIdTag.delete().where(
-                (ChatIdTag.tag_name == tag_name) & (ChatIdTag.user_id == user_id)
+            res = (
+                db.query(ChatIdTag)
+                .filter_by(tag_name=tag_name, user_id=user_id)
+                .delete()
             )
-            res = query.execute()  # Remove the rows, return number of rows removed.
             log.debug(f"res: {res}")
 
-            tag_count = self.count_chat_ids_by_tag_name_and_user_id(tag_name, user_id)
+            tag_count = self.count_chat_ids_by_tag_name_and_user_id(
+                db, tag_name, user_id
+            )
             if tag_count == 0:
                 # Remove tag item from Tag col as well
-                query = Tag.delete().where(
-                    (Tag.name == tag_name) & (Tag.user_id == user_id)
-                )
-                query.execute()  # Remove the rows, return number of rows removed.
-
+                db.query(Tag).filter_by(name=tag_name, user_id=user_id).delete()
             return True
         except Exception as e:
             log.error(f"delete_tag: {e}")
             return False
 
     def delete_tag_by_tag_name_and_chat_id_and_user_id(
-        self, tag_name: str, chat_id: str, user_id: str
+        self, db: Session, tag_name: str, chat_id: str, user_id: str
     ) -> bool:
         try:
-            query = ChatIdTag.delete().where(
-                (ChatIdTag.tag_name == tag_name)
-                & (ChatIdTag.chat_id == chat_id)
-                & (ChatIdTag.user_id == user_id)
+            res = (
+                db.query(ChatIdTag)
+                .filter_by(tag_name=tag_name, chat_id=chat_id, user_id=user_id)
+                .delete()
             )
-            res = query.execute()  # Remove the rows, return number of rows removed.
             log.debug(f"res: {res}")
 
-            tag_count = self.count_chat_ids_by_tag_name_and_user_id(tag_name, user_id)
+            tag_count = self.count_chat_ids_by_tag_name_and_user_id(
+                db, tag_name, user_id
+            )
             if tag_count == 0:
                 # Remove tag item from Tag col as well
-                query = Tag.delete().where(
-                    (Tag.name == tag_name) & (Tag.user_id == user_id)
-                )
-                query.execute()  # Remove the rows, return number of rows removed.
+                db.query(Tag).filter_by(name=tag_name, user_id=user_id).delete()
 
             return True
         except Exception as e:
             log.error(f"delete_tag: {e}")
             return False
 
-    def delete_tags_by_chat_id_and_user_id(self, chat_id: str, user_id: str) -> bool:
-        tags = self.get_tags_by_chat_id_and_user_id(chat_id, user_id)
+    def delete_tags_by_chat_id_and_user_id(
+        self, db: Session, chat_id: str, user_id: str
+    ) -> bool:
+        tags = self.get_tags_by_chat_id_and_user_id(db, chat_id, user_id)
 
         for tag in tags:
             self.delete_tag_by_tag_name_and_chat_id_and_user_id(
-                tag.tag_name, chat_id, user_id
+                db, tag.tag_name, chat_id, user_id
             )
 
         return True
 
 
-Tags = TagTable(DB)
+Tags = TagTable()

+ 37 - 41
backend/apps/webui/models/tools.py

@@ -1,10 +1,11 @@
-from pydantic import BaseModel
-from peewee import *
-from playhouse.shortcuts import model_to_dict
-from typing import List, Union, Optional
+from pydantic import BaseModel, ConfigDict
+from typing import List, Optional
 import time
 import logging
-from apps.webui.internal.db import DB, JSONField
+from sqlalchemy import String, Column, BigInteger
+from sqlalchemy.orm import Session
+
+from apps.webui.internal.db import Base, JSONField
 from apps.webui.models.users import Users
 
 import json
@@ -21,19 +22,18 @@ log.setLevel(SRC_LOG_LEVELS["MODELS"])
 ####################
 
 
-class Tool(Model):
-    id = CharField(unique=True)
-    user_id = CharField()
-    name = TextField()
-    content = TextField()
-    specs = JSONField()
-    meta = JSONField()
-    valves = JSONField()
-    updated_at = BigIntegerField()
-    created_at = BigIntegerField()
+class Tool(Base):
+    __tablename__ = "tool"
 
-    class Meta:
-        database = DB
+    id = Column(String, primary_key=True)
+    user_id = Column(String)
+    name = Column(String)
+    content = Column(String)
+    specs = Column(JSONField)
+    meta = Column(JSONField)
+    valves = Column(JSONField)
+    updated_at = Column(BigInteger)
+    created_at = Column(BigInteger)
 
 
 class ToolMeta(BaseModel):
@@ -51,6 +51,8 @@ class ToolModel(BaseModel):
     updated_at: int  # timestamp in epoch
     created_at: int  # timestamp in epoch
 
+    model_config = ConfigDict(from_attributes=True)
+
 
 ####################
 # Forms
@@ -78,12 +80,9 @@ class ToolValves(BaseModel):
 
 
 class ToolsTable:
-    def __init__(self, db):
-        self.db = db
-        self.db.create_tables([Tool])
 
     def insert_new_tool(
-        self, user_id: str, form_data: ToolForm, specs: List[dict]
+        self, db: Session, user_id: str, form_data: ToolForm, specs: List[dict]
     ) -> Optional[ToolModel]:
         tool = ToolModel(
             **{
@@ -96,24 +95,27 @@ class ToolsTable:
         )
 
         try:
-            result = Tool.create(**tool.model_dump())
+            result = Tool(**tool.dict())
+            db.add(result)
+            db.commit()
+            db.refresh(result)
             if result:
-                return tool
+                return ToolModel.model_validate(result)
             else:
                 return None
         except Exception as e:
             print(f"Error creating tool: {e}")
             return None
 
-    def get_tool_by_id(self, id: str) -> Optional[ToolModel]:
+    def get_tool_by_id(self, db: Session, id: str) -> Optional[ToolModel]:
         try:
-            tool = Tool.get(Tool.id == id)
-            return ToolModel(**model_to_dict(tool))
+            tool = db.get(Tool, id)
+            return ToolModel.model_validate(tool)
         except:
             return None
 
-    def get_tools(self) -> List[ToolModel]:
-        return [ToolModel(**model_to_dict(tool)) for tool in Tool.select()]
+    def get_tools(self, db: Session) -> List[ToolModel]:
+        return [ToolModel.model_validate(tool) for tool in db.query(Tool).all()]
 
     def get_tool_valves_by_id(self, id: str) -> Optional[dict]:
         try:
@@ -180,25 +182,19 @@ class ToolsTable:
 
     def update_tool_by_id(self, id: str, updated: dict) -> Optional[ToolModel]:
         try:
-            query = Tool.update(
-                **updated,
-                updated_at=int(time.time()),
-            ).where(Tool.id == id)
-            query.execute()
-
-            tool = Tool.get(Tool.id == id)
-            return ToolModel(**model_to_dict(tool))
+            db.query(Tool).filter_by(id=id).update(
+                {**updated, "updated_at": int(time.time())}
+            )
+            return self.get_tool_by_id(db, id)
         except:
             return None
 
-    def delete_tool_by_id(self, id: str) -> bool:
+    def delete_tool_by_id(self, db: Session, id: str) -> bool:
         try:
-            query = Tool.delete().where((Tool.id == id))
-            query.execute()  # Remove the rows, return number of rows removed.
-
+            db.query(Tool).filter_by(id=id).delete()
             return True
         except:
             return False
 
 
-Tools = ToolsTable(DB)
+Tools = ToolsTable()

+ 92 - 83
backend/apps/webui/models/users.py

@@ -1,11 +1,13 @@
-from pydantic import BaseModel, ConfigDict
-from peewee import *
-from playhouse.shortcuts import model_to_dict
+from pydantic import BaseModel, ConfigDict, parse_obj_as
 from typing import List, Union, Optional
 import time
+
+from sqlalchemy import String, Column, BigInteger, Text
+from sqlalchemy.orm import Session
+
 from utils.misc import get_gravatar_url
 
-from apps.webui.internal.db import DB, JSONField
+from apps.webui.internal.db import Base, JSONField
 from apps.webui.models.chats import Chats
 
 ####################
@@ -13,25 +15,24 @@ from apps.webui.models.chats import Chats
 ####################
 
 
-class User(Model):
-    id = CharField(unique=True)
-    name = CharField()
-    email = CharField()
-    role = CharField()
-    profile_image_url = TextField()
+class User(Base):
+    __tablename__ = "user"
 
-    last_active_at = BigIntegerField()
-    updated_at = BigIntegerField()
-    created_at = BigIntegerField()
+    id = Column(String, primary_key=True)
+    name = Column(String)
+    email = Column(String)
+    role = Column(String)
+    profile_image_url = Column(String)
 
-    api_key = CharField(null=True, unique=True)
-    settings = JSONField(null=True)
-    info = JSONField(null=True)
+    last_active_at = Column(BigInteger)
+    updated_at = Column(BigInteger)
+    created_at = Column(BigInteger)
 
-    oauth_sub = TextField(null=True, unique=True)
+    api_key = Column(String, nullable=True, unique=True)
+    settings = Column(JSONField, nullable=True)
+    info = Column(JSONField, nullable=True)
 
-    class Meta:
-        database = DB
+    oauth_sub = Column(Text, unique=True)
 
 
 class UserSettings(BaseModel):
@@ -41,6 +42,8 @@ class UserSettings(BaseModel):
 
 
 class UserModel(BaseModel):
+    model_config = ConfigDict(from_attributes=True)
+
     id: str
     name: str
     email: str
@@ -76,12 +79,10 @@ class UserUpdateForm(BaseModel):
 
 
 class UsersTable:
-    def __init__(self, db):
-        self.db = db
-        self.db.create_tables([User])
 
     def insert_new_user(
         self,
+        db: Session,
         id: str,
         name: str,
         email: str,
@@ -102,30 +103,33 @@ class UsersTable:
                 "oauth_sub": oauth_sub,
             }
         )
-        result = User.create(**user.model_dump())
+        result = User(**user.model_dump())
+        db.add(result)
+        db.commit()
+        db.refresh(result)
         if result:
             return user
         else:
             return None
 
-    def get_user_by_id(self, id: str) -> Optional[UserModel]:
+    def get_user_by_id(self, db: Session, id: str) -> Optional[UserModel]:
         try:
-            user = User.get(User.id == id)
-            return UserModel(**model_to_dict(user))
-        except:
+            user = db.query(User).filter_by(id=id).first()
+            return UserModel.model_validate(user)
+        except Exception as e:
             return None
 
-    def get_user_by_api_key(self, api_key: str) -> Optional[UserModel]:
+    def get_user_by_api_key(self, db: Session, api_key: str) -> Optional[UserModel]:
         try:
-            user = User.get(User.api_key == api_key)
-            return UserModel(**model_to_dict(user))
+            user = db.query(User).filter_by(api_key=api_key).first()
+            return UserModel.model_validate(user)
         except:
             return None
 
-    def get_user_by_email(self, email: str) -> Optional[UserModel]:
+    def get_user_by_email(self, db: Session, email: str) -> Optional[UserModel]:
         try:
-            user = User.get(User.email == email)
-            return UserModel(**model_to_dict(user))
+            user = db.query(User).filter_by(email=email).first()
+            return UserModel.model_validate(user)
         except:
             return None
 
@@ -136,88 +140,94 @@ class UsersTable:
         except:
             return None
 
-    def get_users(self, skip: int = 0, limit: int = 50) -> List[UserModel]:
-        return [
-            UserModel(**model_to_dict(user))
-            for user in User.select()
-            # .limit(limit).offset(skip)
-        ]
+    def get_users(self, db: Session, skip: int = 0, limit: int = 50) -> List[UserModel]:
+        users = (
+            db.query(User)
+            # .offset(skip).limit(limit)
+            .all()
+        )
+        return [UserModel.model_validate(user) for user in users]
 
-    def get_num_users(self) -> Optional[int]:
-        return User.select().count()
+    def get_num_users(self, db: Session) -> Optional[int]:
+        return db.query(User).count()
 
-    def get_first_user(self) -> UserModel:
+    def get_first_user(self, db: Session) -> UserModel:
         try:
-            user = User.select().order_by(User.created_at).first()
-            return UserModel(**model_to_dict(user))
+            user = db.query(User).order_by(User.created_at).first()
+            return UserModel.model_validate(user)
         except:
             return None
 
-    def update_user_role_by_id(self, id: str, role: str) -> Optional[UserModel]:
+    def update_user_role_by_id(
+        self, db: Session, id: str, role: str
+    ) -> Optional[UserModel]:
         try:
-            query = User.update(role=role).where(User.id == id)
-            query.execute()
+            db.query(User).filter_by(id=id).update({"role": role})
+            db.commit()
 
-            user = User.get(User.id == id)
-            return UserModel(**model_to_dict(user))
+            user = db.query(User).filter_by(id=id).first()
+            return UserModel.model_validate(user)
         except:
             return None
 
     def update_user_profile_image_url_by_id(
-        self, id: str, profile_image_url: str
+        self, db: Session, id: str, profile_image_url: str
     ) -> Optional[UserModel]:
         try:
-            query = User.update(profile_image_url=profile_image_url).where(
-                User.id == id
+            db.query(User).filter_by(id=id).update(
+                {"profile_image_url": profile_image_url}
             )
-            query.execute()
+            db.commit()
 
-            user = User.get(User.id == id)
-            return UserModel(**model_to_dict(user))
+            user = db.query(User).filter_by(id=id).first()
+            return UserModel.model_validate(user)
         except:
             return None
 
-    def update_user_last_active_by_id(self, id: str) -> Optional[UserModel]:
+    def update_user_last_active_by_id(
+        self, db: Session, id: str
+    ) -> Optional[UserModel]:
         try:
-            query = User.update(last_active_at=int(time.time())).where(User.id == id)
-            query.execute()
+            db.query(User).filter_by(id=id).update({"last_active_at": int(time.time())})
 
-            user = User.get(User.id == id)
-            return UserModel(**model_to_dict(user))
+            user = db.query(User).filter_by(id=id).first()
+            return UserModel.model_validate(user)
         except:
             return None
 
     def update_user_oauth_sub_by_id(
-        self, id: str, oauth_sub: str
+        self, db: Session, id: str, oauth_sub: str
     ) -> Optional[UserModel]:
         try:
-            query = User.update(oauth_sub=oauth_sub).where(User.id == id)
-            query.execute()
+            db.query(User).filter_by(id=id).update({"oauth_sub": oauth_sub})
 
-            user = User.get(User.id == id)
-            return UserModel(**model_to_dict(user))
+            user = db.query(User).filter_by(id=id).first()
+            return UserModel.model_validate(user)
         except:
             return None
 
-    def update_user_by_id(self, id: str, updated: dict) -> Optional[UserModel]:
+    def update_user_by_id(
+        self, db: Session, id: str, updated: dict
+    ) -> Optional[UserModel]:
         try:
-            query = User.update(**updated).where(User.id == id)
-            query.execute()
+            db.query(User).filter_by(id=id).update(updated)
+            db.commit()
 
-            user = User.get(User.id == id)
-            return UserModel(**model_to_dict(user))
-        except:
+            user = db.query(User).filter_by(id=id).first()
+            return UserModel.model_validate(user)
+            # return UserModel(**user.dict())
+        except Exception as e:
             return None
 
-    def delete_user_by_id(self, id: str) -> bool:
+    def delete_user_by_id(self, db: Session, id: str) -> bool:
         try:
             # Delete User Chats
-            result = Chats.delete_chats_by_user_id(id)
+            result = Chats.delete_chats_by_user_id(db, id)
 
             if result:
                 # Delete User
-                query = User.delete().where(User.id == id)
-                query.execute()  # Remove the rows, return number of rows removed.
+                db.query(User).filter_by(id=id).delete()
+                db.commit()
 
                 return True
             else:
@@ -225,21 +235,20 @@ class UsersTable:
         except:
             return False
 
-    def update_user_api_key_by_id(self, id: str, api_key: str) -> str:
+    def update_user_api_key_by_id(self, db: Session, id: str, api_key: str) -> str:
         try:
-            query = User.update(api_key=api_key).where(User.id == id)
-            result = query.execute()
-
+            result = db.query(User).filter_by(id=id).update({"api_key": api_key})
+            db.commit()
             return True if result == 1 else False
         except:
             return False
 
-    def get_user_api_key_by_id(self, id: str) -> Optional[str]:
+    def get_user_api_key_by_id(self, db: Session, id: str) -> Optional[str]:
         try:
-            user = User.get(User.id == id)
+            user = db.query(User).filter_by(id=id).first()
             return user.api_key
-        except:
+        except Exception as e:
             return None
 
 
-Users = UsersTable(DB)
+Users = UsersTable()

+ 40 - 26
backend/apps/webui/routers/auths.py

@@ -10,6 +10,7 @@ import re
 import uuid
 import csv
 
+from apps.webui.internal.db import get_db
 from apps.webui.models.auths import (
     SigninForm,
     SignupForm,
@@ -78,10 +79,13 @@ async def get_session_user(
 
 @router.post("/update/profile", response_model=UserResponse)
 async def update_profile(
-    form_data: UpdateProfileForm, session_user=Depends(get_current_user)
+    form_data: UpdateProfileForm,
+    session_user=Depends(get_current_user),
+    db=Depends(get_db),
 ):
     if session_user:
         user = Users.update_user_by_id(
+            db,
             session_user.id,
             {"profile_image_url": form_data.profile_image_url, "name": form_data.name},
         )
@@ -100,16 +104,18 @@ async def update_profile(
 
 @router.post("/update/password", response_model=bool)
 async def update_password(
-    form_data: UpdatePasswordForm, session_user=Depends(get_current_user)
+    form_data: UpdatePasswordForm,
+    session_user=Depends(get_current_user),
+    db=Depends(get_db),
 ):
     if WEBUI_AUTH_TRUSTED_EMAIL_HEADER:
         raise HTTPException(400, detail=ERROR_MESSAGES.ACTION_PROHIBITED)
     if session_user:
-        user = Auths.authenticate_user(session_user.email, form_data.password)
+        user = Auths.authenticate_user(db, session_user.email, form_data.password)
 
         if user:
             hashed = get_password_hash(form_data.new_password)
-            return Auths.update_user_password_by_id(user.id, hashed)
+            return Auths.update_user_password_by_id(db, user.id, hashed)
         else:
             raise HTTPException(400, detail=ERROR_MESSAGES.INVALID_PASSWORD)
     else:
@@ -122,7 +128,7 @@ async def update_password(
 
 
 @router.post("/signin", response_model=SigninResponse)
-async def signin(request: Request, response: Response, form_data: SigninForm):
+async def signin(request: Request, response: Response, form_data: SigninForm, db=Depends(get_db)):
     if WEBUI_AUTH_TRUSTED_EMAIL_HEADER:
         if WEBUI_AUTH_TRUSTED_EMAIL_HEADER not in request.headers:
             raise HTTPException(400, detail=ERROR_MESSAGES.INVALID_TRUSTED_HEADER)
@@ -133,32 +139,34 @@ async def signin(request: Request, response: Response, form_data: SigninForm):
             trusted_name = request.headers.get(
                 WEBUI_AUTH_TRUSTED_NAME_HEADER, trusted_email
             )
-        if not Users.get_user_by_email(trusted_email.lower()):
+        if not Users.get_user_by_email(db, trusted_email.lower()):
             await signup(
                 request,
                 SignupForm(
                     email=trusted_email, password=str(uuid.uuid4()), name=trusted_name
                 ),
+                db,
             )
-        user = Auths.authenticate_user_by_trusted_header(trusted_email)
+        user = Auths.authenticate_user_by_trusted_header(db, trusted_email)
     elif WEBUI_AUTH == False:
         admin_email = "admin@localhost"
         admin_password = "admin"
 
-        if Users.get_user_by_email(admin_email.lower()):
-            user = Auths.authenticate_user(admin_email.lower(), admin_password)
+        if Users.get_user_by_email(db, admin_email.lower()):
+            user = Auths.authenticate_user(db, admin_email.lower(), admin_password)
         else:
-            if Users.get_num_users() != 0:
+            if Users.get_num_users(db) != 0:
                 raise HTTPException(400, detail=ERROR_MESSAGES.EXISTING_USERS)
 
             await signup(
                 request,
                 SignupForm(email=admin_email, password=admin_password, name="User"),
+                db,
             )
 
-            user = Auths.authenticate_user(admin_email.lower(), admin_password)
+            user = Auths.authenticate_user(db, admin_email.lower(), admin_password)
     else:
-        user = Auths.authenticate_user(form_data.email.lower(), form_data.password)
+        user = Auths.authenticate_user(db, form_data.email.lower(), form_data.password)
 
     if user:
         token = create_token(
@@ -192,7 +200,7 @@ async def signin(request: Request, response: Response, form_data: SigninForm):
 
 
 @router.post("/signup", response_model=SigninResponse)
-async def signup(request: Request, response: Response, form_data: SignupForm):
+async def signup(request: Request, response: Response, form_data: SignupForm, db=Depends(get_db)):
     if not request.app.state.config.ENABLE_SIGNUP and WEBUI_AUTH:
         raise HTTPException(
             status.HTTP_403_FORBIDDEN, detail=ERROR_MESSAGES.ACCESS_PROHIBITED
@@ -203,17 +211,18 @@ async def signup(request: Request, response: Response, form_data: SignupForm):
             status.HTTP_400_BAD_REQUEST, detail=ERROR_MESSAGES.INVALID_EMAIL_FORMAT
         )
 
-    if Users.get_user_by_email(form_data.email.lower()):
+    if Users.get_user_by_email(db, form_data.email.lower()):
         raise HTTPException(400, detail=ERROR_MESSAGES.EMAIL_TAKEN)
 
     try:
         role = (
             "admin"
-            if Users.get_num_users() == 0
+            if Users.get_num_users(db) == 0
             else request.app.state.config.DEFAULT_USER_ROLE
         )
         hashed = get_password_hash(form_data.password)
         user = Auths.insert_new_auth(
+            db,
             form_data.email.lower(),
             hashed,
             form_data.name,
@@ -267,14 +276,16 @@ async def signup(request: Request, response: Response, form_data: SignupForm):
 
 
 @router.post("/add", response_model=SigninResponse)
-async def add_user(form_data: AddUserForm, user=Depends(get_admin_user)):
+async def add_user(
+    form_data: AddUserForm, user=Depends(get_admin_user), db=Depends(get_db)
+):
 
     if not validate_email_format(form_data.email.lower()):
         raise HTTPException(
             status.HTTP_400_BAD_REQUEST, detail=ERROR_MESSAGES.INVALID_EMAIL_FORMAT
         )
 
-    if Users.get_user_by_email(form_data.email.lower()):
+    if Users.get_user_by_email(db, form_data.email.lower()):
         raise HTTPException(400, detail=ERROR_MESSAGES.EMAIL_TAKEN)
 
     try:
@@ -282,6 +293,7 @@ async def add_user(form_data: AddUserForm, user=Depends(get_admin_user)):
         print(form_data)
         hashed = get_password_hash(form_data.password)
         user = Auths.insert_new_auth(
+            db,
             form_data.email.lower(),
             hashed,
             form_data.name,
@@ -312,7 +324,9 @@ async def add_user(form_data: AddUserForm, user=Depends(get_admin_user)):
 
 
 @router.get("/admin/details")
-async def get_admin_details(request: Request, user=Depends(get_current_user)):
+async def get_admin_details(
+    request: Request, user=Depends(get_current_user), db=Depends(get_db)
+):
     if request.app.state.config.SHOW_ADMIN_DETAILS:
         admin_email = request.app.state.config.ADMIN_EMAIL
         admin_name = None
@@ -320,11 +334,11 @@ async def get_admin_details(request: Request, user=Depends(get_current_user)):
         print(admin_email, admin_name)
 
         if admin_email:
-            admin = Users.get_user_by_email(admin_email)
+            admin = Users.get_user_by_email(db, admin_email)
             if admin:
                 admin_name = admin.name
         else:
-            admin = Users.get_first_user()
+            admin = Users.get_first_user(db)
             if admin:
                 admin_email = admin.email
                 admin_name = admin.name
@@ -397,9 +411,9 @@ async def update_admin_config(
 
 # create api key
 @router.post("/api_key", response_model=ApiKey)
-async def create_api_key_(user=Depends(get_current_user)):
+async def create_api_key_(user=Depends(get_current_user), db=Depends(get_db)):
     api_key = create_api_key()
-    success = Users.update_user_api_key_by_id(user.id, api_key)
+    success = Users.update_user_api_key_by_id(db, user.id, api_key)
     if success:
         return {
             "api_key": api_key,
@@ -410,15 +424,15 @@ async def create_api_key_(user=Depends(get_current_user)):
 
 # delete api key
 @router.delete("/api_key", response_model=bool)
-async def delete_api_key(user=Depends(get_current_user)):
-    success = Users.update_user_api_key_by_id(user.id, None)
+async def delete_api_key(user=Depends(get_current_user), db=Depends(get_db)):
+    success = Users.update_user_api_key_by_id(db, user.id, None)
     return success
 
 
 # get api key
 @router.get("/api_key", response_model=ApiKey)
-async def get_api_key(user=Depends(get_current_user)):
-    api_key = Users.get_user_api_key_by_id(user.id)
+async def get_api_key(user=Depends(get_current_user), db=Depends(get_db)):
+    api_key = Users.get_user_api_key_by_id(db, user.id)
     if api_key:
         return {
             "api_key": api_key,

+ 88 - 58
backend/apps/webui/routers/chats.py

@@ -1,6 +1,8 @@
 from fastapi import Depends, Request, HTTPException, status
 from datetime import datetime, timedelta
 from typing import List, Union, Optional
+
+from apps.webui.internal.db import get_db
 from utils.utils import get_current_user, get_admin_user
 from fastapi import APIRouter
 from pydantic import BaseModel
@@ -43,9 +45,9 @@ router = APIRouter()
 @router.get("/", response_model=List[ChatTitleIdResponse])
 @router.get("/list", response_model=List[ChatTitleIdResponse])
 async def get_session_user_chat_list(
-    user=Depends(get_current_user), skip: int = 0, limit: int = 50
+    user=Depends(get_current_user), skip: int = 0, limit: int = 50, db=Depends(get_db)
 ):
-    return Chats.get_chat_list_by_user_id(user.id, skip, limit)
+    return Chats.get_chat_list_by_user_id(db, user.id, skip, limit)
 
 
 ############################
@@ -54,7 +56,9 @@ async def get_session_user_chat_list(
 
 
 @router.delete("/", response_model=bool)
-async def delete_all_user_chats(request: Request, user=Depends(get_current_user)):
+async def delete_all_user_chats(
+    request: Request, user=Depends(get_current_user), db=Depends(get_db)
+):
 
     if (
         user.role == "user"
@@ -65,7 +69,7 @@ async def delete_all_user_chats(request: Request, user=Depends(get_current_user)
             detail=ERROR_MESSAGES.ACCESS_PROHIBITED,
         )
 
-    result = Chats.delete_chats_by_user_id(user.id)
+    result = Chats.delete_chats_by_user_id(db, user.id)
     return result
 
 
@@ -76,10 +80,14 @@ async def delete_all_user_chats(request: Request, user=Depends(get_current_user)
 
 @router.get("/list/user/{user_id}", response_model=List[ChatTitleIdResponse])
 async def get_user_chat_list_by_user_id(
-    user_id: str, user=Depends(get_admin_user), skip: int = 0, limit: int = 50
+    user_id: str,
+    user=Depends(get_admin_user),
+    skip: int = 0,
+    limit: int = 50,
+    db=Depends(get_db),
 ):
     return Chats.get_chat_list_by_user_id(
-        user_id, include_archived=True, skip=skip, limit=limit
+        db, user_id, include_archived=True, skip=skip, limit=limit
     )
 
 
@@ -89,9 +97,11 @@ async def get_user_chat_list_by_user_id(
 
 
 @router.post("/new", response_model=Optional[ChatResponse])
-async def create_new_chat(form_data: ChatForm, user=Depends(get_current_user)):
+async def create_new_chat(
+    form_data: ChatForm, user=Depends(get_current_user), db=Depends(get_db)
+):
     try:
-        chat = Chats.insert_new_chat(user.id, form_data)
+        chat = Chats.insert_new_chat(db, user.id, form_data)
         return ChatResponse(**{**chat.model_dump(), "chat": json.loads(chat.chat)})
     except Exception as e:
         log.exception(e)
@@ -106,10 +116,10 @@ async def create_new_chat(form_data: ChatForm, user=Depends(get_current_user)):
 
 
 @router.get("/all", response_model=List[ChatResponse])
-async def get_user_chats(user=Depends(get_current_user)):
+async def get_user_chats(user=Depends(get_current_user), db=Depends(get_db)):
     return [
         ChatResponse(**{**chat.model_dump(), "chat": json.loads(chat.chat)})
-        for chat in Chats.get_chats_by_user_id(user.id)
+        for chat in Chats.get_chats_by_user_id(db, user.id)
     ]
 
 
@@ -119,10 +129,10 @@ async def get_user_chats(user=Depends(get_current_user)):
 
 
 @router.get("/all/archived", response_model=List[ChatResponse])
-async def get_user_chats(user=Depends(get_current_user)):
+async def get_user_archived_chats(user=Depends(get_current_user), db=Depends(get_db)):
     return [
         ChatResponse(**{**chat.model_dump(), "chat": json.loads(chat.chat)})
-        for chat in Chats.get_archived_chats_by_user_id(user.id)
+        for chat in Chats.get_archived_chats_by_user_id(db, user.id)
     ]
 
 
@@ -132,7 +142,7 @@ async def get_user_chats(user=Depends(get_current_user)):
 
 
 @router.get("/all/db", response_model=List[ChatResponse])
-async def get_all_user_chats_in_db(user=Depends(get_admin_user)):
+async def get_all_user_chats_in_db(user=Depends(get_admin_user), db=Depends(get_db)):
     if not ENABLE_ADMIN_EXPORT:
         raise HTTPException(
             status_code=status.HTTP_401_UNAUTHORIZED,
@@ -140,7 +150,7 @@ async def get_all_user_chats_in_db(user=Depends(get_admin_user)):
         )
     return [
         ChatResponse(**{**chat.model_dump(), "chat": json.loads(chat.chat)})
-        for chat in Chats.get_chats()
+        for chat in Chats.get_chats(db)
     ]
 
 
@@ -151,9 +161,9 @@ async def get_all_user_chats_in_db(user=Depends(get_admin_user)):
 
 @router.get("/archived", response_model=List[ChatTitleIdResponse])
 async def get_archived_session_user_chat_list(
-    user=Depends(get_current_user), skip: int = 0, limit: int = 50
+    user=Depends(get_current_user), skip: int = 0, limit: int = 50, db=Depends(get_db)
 ):
-    return Chats.get_archived_chat_list_by_user_id(user.id, skip, limit)
+    return Chats.get_archived_chat_list_by_user_id(db, user.id, skip, limit)
 
 
 ############################
@@ -162,8 +172,8 @@ async def get_archived_session_user_chat_list(
 
 
 @router.post("/archive/all", response_model=bool)
-async def archive_all_chats(user=Depends(get_current_user)):
-    return Chats.archive_all_chats_by_user_id(user.id)
+async def archive_all_chats(user=Depends(get_current_user), db=Depends(get_db)):
+    return Chats.archive_all_chats_by_user_id(db, user.id)
 
 
 ############################
@@ -172,16 +182,18 @@ async def archive_all_chats(user=Depends(get_current_user)):
 
 
 @router.get("/share/{share_id}", response_model=Optional[ChatResponse])
-async def get_shared_chat_by_id(share_id: str, user=Depends(get_current_user)):
+async def get_shared_chat_by_id(
+    share_id: str, user=Depends(get_current_user), db=Depends(get_db)
+):
     if user.role == "pending":
         raise HTTPException(
             status_code=status.HTTP_401_UNAUTHORIZED, detail=ERROR_MESSAGES.NOT_FOUND
         )
 
     if user.role == "user":
-        chat = Chats.get_chat_by_share_id(share_id)
+        chat = Chats.get_chat_by_share_id(db, share_id)
     elif user.role == "admin":
-        chat = Chats.get_chat_by_id(share_id)
+        chat = Chats.get_chat_by_id(db, share_id)
 
     if chat:
         return ChatResponse(**{**chat.model_dump(), "chat": json.loads(chat.chat)})
@@ -204,21 +216,23 @@ class TagNameForm(BaseModel):
 
 @router.post("/tags", response_model=List[ChatTitleIdResponse])
 async def get_user_chat_list_by_tag_name(
-    form_data: TagNameForm, user=Depends(get_current_user)
+    form_data: TagNameForm, user=Depends(get_current_user), db=Depends(get_db)
 ):
 
     print(form_data)
     chat_ids = [
         chat_id_tag.chat_id
         for chat_id_tag in Tags.get_chat_ids_by_tag_name_and_user_id(
-            form_data.name, user.id
+            db, form_data.name, user.id
         )
     ]
 
-    chats = Chats.get_chat_list_by_chat_ids(chat_ids, form_data.skip, form_data.limit)
+    chats = Chats.get_chat_list_by_chat_ids(
+        db, chat_ids, form_data.skip, form_data.limit
+    )
 
     if len(chats) == 0:
-        Tags.delete_tag_by_tag_name_and_user_id(form_data.name, user.id)
+        Tags.delete_tag_by_tag_name_and_user_id(db, form_data.name, user.id)
 
     return chats
 
@@ -229,9 +243,9 @@ async def get_user_chat_list_by_tag_name(
 
 
 @router.get("/tags/all", response_model=List[TagModel])
-async def get_all_tags(user=Depends(get_current_user)):
+async def get_all_tags(user=Depends(get_current_user), db=Depends(get_db)):
     try:
-        tags = Tags.get_tags_by_user_id(user.id)
+        tags = Tags.get_tags_by_user_id(db, user.id)
         return tags
     except Exception as e:
         log.exception(e)
@@ -246,8 +260,8 @@ async def get_all_tags(user=Depends(get_current_user)):
 
 
 @router.get("/{id}", response_model=Optional[ChatResponse])
-async def get_chat_by_id(id: str, user=Depends(get_current_user)):
-    chat = Chats.get_chat_by_id_and_user_id(id, user.id)
+async def get_chat_by_id(id: str, user=Depends(get_current_user), db=Depends(get_db)):
+    chat = Chats.get_chat_by_id_and_user_id(db, id, user.id)
 
     if chat:
         return ChatResponse(**{**chat.model_dump(), "chat": json.loads(chat.chat)})
@@ -264,13 +278,13 @@ async def get_chat_by_id(id: str, user=Depends(get_current_user)):
 
 @router.post("/{id}", response_model=Optional[ChatResponse])
 async def update_chat_by_id(
-    id: str, form_data: ChatForm, user=Depends(get_current_user)
+    id: str, form_data: ChatForm, user=Depends(get_current_user), db=Depends(get_db)
 ):
-    chat = Chats.get_chat_by_id_and_user_id(id, user.id)
+    chat = Chats.get_chat_by_id_and_user_id(db, id, user.id)
     if chat:
         updated_chat = {**json.loads(chat.chat), **form_data.chat}
 
-        chat = Chats.update_chat_by_id(id, updated_chat)
+        chat = Chats.update_chat_by_id(db, id, updated_chat)
         return ChatResponse(**{**chat.model_dump(), "chat": json.loads(chat.chat)})
     else:
         raise HTTPException(
@@ -285,10 +299,12 @@ async def update_chat_by_id(
 
 
 @router.delete("/{id}", response_model=bool)
-async def delete_chat_by_id(request: Request, id: str, user=Depends(get_current_user)):
+async def delete_chat_by_id(
+    request: Request, id: str, user=Depends(get_current_user), db=Depends(get_db)
+):
 
     if user.role == "admin":
-        result = Chats.delete_chat_by_id(id)
+        result = Chats.delete_chat_by_id(db, id)
         return result
     else:
         if not request.app.state.config.USER_PERMISSIONS["chat"]["deletion"]:
@@ -297,7 +313,7 @@ async def delete_chat_by_id(request: Request, id: str, user=Depends(get_current_
                 detail=ERROR_MESSAGES.ACCESS_PROHIBITED,
             )
 
-        result = Chats.delete_chat_by_id_and_user_id(id, user.id)
+        result = Chats.delete_chat_by_id_and_user_id(db, id, user.id)
         return result
 
 
@@ -307,8 +323,8 @@ async def delete_chat_by_id(request: Request, id: str, user=Depends(get_current_
 
 
 @router.get("/{id}/clone", response_model=Optional[ChatResponse])
-async def clone_chat_by_id(id: str, user=Depends(get_current_user)):
-    chat = Chats.get_chat_by_id_and_user_id(id, user.id)
+async def clone_chat_by_id(id: str, user=Depends(get_current_user), db=Depends(get_db)):
+    chat = Chats.get_chat_by_id_and_user_id(db, id, user.id)
     if chat:
 
         chat_body = json.loads(chat.chat)
@@ -319,7 +335,7 @@ async def clone_chat_by_id(id: str, user=Depends(get_current_user)):
             "title": f"Clone of {chat.title}",
         }
 
-        chat = Chats.insert_new_chat(user.id, ChatForm(**{"chat": updated_chat}))
+        chat = Chats.insert_new_chat(db, user.id, ChatForm(**{"chat": updated_chat}))
         return ChatResponse(**{**chat.model_dump(), "chat": json.loads(chat.chat)})
     else:
         raise HTTPException(
@@ -333,10 +349,12 @@ async def clone_chat_by_id(id: str, user=Depends(get_current_user)):
 
 
 @router.get("/{id}/archive", response_model=Optional[ChatResponse])
-async def archive_chat_by_id(id: str, user=Depends(get_current_user)):
-    chat = Chats.get_chat_by_id_and_user_id(id, user.id)
+async def archive_chat_by_id(
+    id: str, user=Depends(get_current_user), db=Depends(get_db)
+):
+    chat = Chats.get_chat_by_id_and_user_id(db, id, user.id)
     if chat:
-        chat = Chats.toggle_chat_archive_by_id(id)
+        chat = Chats.toggle_chat_archive_by_id(db, id)
         return ChatResponse(**{**chat.model_dump(), "chat": json.loads(chat.chat)})
     else:
         raise HTTPException(
@@ -350,16 +368,16 @@ async def archive_chat_by_id(id: str, user=Depends(get_current_user)):
 
 
 @router.post("/{id}/share", response_model=Optional[ChatResponse])
-async def share_chat_by_id(id: str, user=Depends(get_current_user)):
-    chat = Chats.get_chat_by_id_and_user_id(id, user.id)
+async def share_chat_by_id(id: str, user=Depends(get_current_user), db=Depends(get_db)):
+    chat = Chats.get_chat_by_id_and_user_id(db, id, user.id)
     if chat:
         if chat.share_id:
-            shared_chat = Chats.update_shared_chat_by_chat_id(chat.id)
+            shared_chat = Chats.update_shared_chat_by_chat_id(db, chat.id)
             return ChatResponse(
                 **{**shared_chat.model_dump(), "chat": json.loads(shared_chat.chat)}
             )
 
-        shared_chat = Chats.insert_shared_chat_by_chat_id(chat.id)
+        shared_chat = Chats.insert_shared_chat_by_chat_id(db, chat.id)
         if not shared_chat:
             raise HTTPException(
                 status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
@@ -382,14 +400,16 @@ async def share_chat_by_id(id: str, user=Depends(get_current_user)):
 
 
 @router.delete("/{id}/share", response_model=Optional[bool])
-async def delete_shared_chat_by_id(id: str, user=Depends(get_current_user)):
-    chat = Chats.get_chat_by_id_and_user_id(id, user.id)
+async def delete_shared_chat_by_id(
+    id: str, user=Depends(get_current_user), db=Depends(get_db)
+):
+    chat = Chats.get_chat_by_id_and_user_id(db, id, user.id)
     if chat:
         if not chat.share_id:
             return False
 
-        result = Chats.delete_shared_chat_by_chat_id(id)
-        update_result = Chats.update_chat_share_id_by_id(id, None)
+        result = Chats.delete_shared_chat_by_chat_id(db, id)
+        update_result = Chats.update_chat_share_id_by_id(db, id, None)
 
         return result and update_result != None
     else:
@@ -405,8 +425,10 @@ async def delete_shared_chat_by_id(id: str, user=Depends(get_current_user)):
 
 
 @router.get("/{id}/tags", response_model=List[TagModel])
-async def get_chat_tags_by_id(id: str, user=Depends(get_current_user)):
-    tags = Tags.get_tags_by_chat_id_and_user_id(id, user.id)
+async def get_chat_tags_by_id(
+    id: str, user=Depends(get_current_user), db=Depends(get_db)
+):
+    tags = Tags.get_tags_by_chat_id_and_user_id(db, id, user.id)
 
     if tags != None:
         return tags
@@ -423,12 +445,15 @@ async def get_chat_tags_by_id(id: str, user=Depends(get_current_user)):
 
 @router.post("/{id}/tags", response_model=Optional[ChatIdTagModel])
 async def add_chat_tag_by_id(
-    id: str, form_data: ChatIdTagForm, user=Depends(get_current_user)
+    id: str,
+    form_data: ChatIdTagForm,
+    user=Depends(get_current_user),
+    db=Depends(get_db),
 ):
-    tags = Tags.get_tags_by_chat_id_and_user_id(id, user.id)
+    tags = Tags.get_tags_by_chat_id_and_user_id(db, id, user.id)
 
     if form_data.tag_name not in tags:
-        tag = Tags.add_tag_to_chat(user.id, form_data)
+        tag = Tags.add_tag_to_chat(db, user.id, form_data)
 
         if tag:
             return tag
@@ -450,10 +475,13 @@ async def add_chat_tag_by_id(
 
 @router.delete("/{id}/tags", response_model=Optional[bool])
 async def delete_chat_tag_by_id(
-    id: str, form_data: ChatIdTagForm, user=Depends(get_current_user)
+    id: str,
+    form_data: ChatIdTagForm,
+    user=Depends(get_current_user),
+    db=Depends(get_db),
 ):
     result = Tags.delete_tag_by_tag_name_and_chat_id_and_user_id(
-        form_data.tag_name, id, user.id
+        db, form_data.tag_name, id, user.id
     )
 
     if result:
@@ -470,8 +498,10 @@ async def delete_chat_tag_by_id(
 
 
 @router.delete("/{id}/tags/all", response_model=Optional[bool])
-async def delete_all_chat_tags_by_id(id: str, user=Depends(get_current_user)):
-    result = Tags.delete_tags_by_chat_id_and_user_id(id, user.id)
+async def delete_all_chat_tags_by_id(
+    id: str, user=Depends(get_current_user), db=Depends(get_db)
+):
+    result = Tags.delete_tags_by_chat_id_and_user_id(db, id, user.id)
 
     if result:
         return result

+ 27 - 13
backend/apps/webui/routers/documents.py

@@ -6,6 +6,7 @@ from fastapi import APIRouter
 from pydantic import BaseModel
 import json
 
+from apps.webui.internal.db import get_db
 from apps.webui.models.documents import (
     Documents,
     DocumentForm,
@@ -25,7 +26,7 @@ router = APIRouter()
 
 
 @router.get("/", response_model=List[DocumentResponse])
-async def get_documents(user=Depends(get_current_user)):
+async def get_documents(user=Depends(get_current_user), db=Depends(get_db)):
     docs = [
         DocumentResponse(
             **{
@@ -33,7 +34,7 @@ async def get_documents(user=Depends(get_current_user)):
                 "content": json.loads(doc.content if doc.content else "{}"),
             }
         )
-        for doc in Documents.get_docs()
+        for doc in Documents.get_docs(db)
     ]
     return docs
 
@@ -44,10 +45,12 @@ async def get_documents(user=Depends(get_current_user)):
 
 
 @router.post("/create", response_model=Optional[DocumentResponse])
-async def create_new_doc(form_data: DocumentForm, user=Depends(get_admin_user)):
-    doc = Documents.get_doc_by_name(form_data.name)
+async def create_new_doc(
+    form_data: DocumentForm, user=Depends(get_admin_user), db=Depends(get_db)
+):
+    doc = Documents.get_doc_by_name(db, form_data.name)
     if doc == None:
-        doc = Documents.insert_new_doc(user.id, form_data)
+        doc = Documents.insert_new_doc(db, user.id, form_data)
 
         if doc:
             return DocumentResponse(
@@ -74,8 +77,10 @@ async def create_new_doc(form_data: DocumentForm, user=Depends(get_admin_user)):
 
 
 @router.get("/doc", response_model=Optional[DocumentResponse])
-async def get_doc_by_name(name: str, user=Depends(get_current_user)):
-    doc = Documents.get_doc_by_name(name)
+async def get_doc_by_name(
+    name: str, user=Depends(get_current_user), db=Depends(get_db)
+):
+    doc = Documents.get_doc_by_name(db, name)
 
     if doc:
         return DocumentResponse(
@@ -106,8 +111,12 @@ class TagDocumentForm(BaseModel):
 
 
 @router.post("/doc/tags", response_model=Optional[DocumentResponse])
-async def tag_doc_by_name(form_data: TagDocumentForm, user=Depends(get_current_user)):
-    doc = Documents.update_doc_content_by_name(form_data.name, {"tags": form_data.tags})
+async def tag_doc_by_name(
+    form_data: TagDocumentForm, user=Depends(get_current_user), db=Depends(get_db)
+):
+    doc = Documents.update_doc_content_by_name(
+        db, form_data.name, {"tags": form_data.tags}
+    )
 
     if doc:
         return DocumentResponse(
@@ -130,9 +139,12 @@ async def tag_doc_by_name(form_data: TagDocumentForm, user=Depends(get_current_u
 
 @router.post("/doc/update", response_model=Optional[DocumentResponse])
 async def update_doc_by_name(
-    name: str, form_data: DocumentUpdateForm, user=Depends(get_admin_user)
+    name: str,
+    form_data: DocumentUpdateForm,
+    user=Depends(get_admin_user),
+    db=Depends(get_db),
 ):
-    doc = Documents.update_doc_by_name(name, form_data)
+    doc = Documents.update_doc_by_name(db, name, form_data)
     if doc:
         return DocumentResponse(
             **{
@@ -153,6 +165,8 @@ async def update_doc_by_name(
 
 
 @router.delete("/doc/delete", response_model=bool)
-async def delete_doc_by_name(name: str, user=Depends(get_admin_user)):
-    result = Documents.delete_doc_by_name(name)
+async def delete_doc_by_name(
+    name: str, user=Depends(get_admin_user), db=Depends(get_db)
+):
+    result = Documents.delete_doc_by_name(db, name)
     return result

+ 14 - 11
backend/apps/webui/routers/files.py

@@ -20,6 +20,7 @@ from fastapi.responses import StreamingResponse, JSONResponse, FileResponse
 from pydantic import BaseModel
 import json
 
+from apps.webui.internal.db import get_db
 from apps.webui.models.files import (
     Files,
     FileForm,
@@ -53,6 +54,7 @@ router = APIRouter()
 def upload_file(
     file: UploadFile = File(...),
     user=Depends(get_verified_user),
+    db=Depends(get_db)
 ):
     log.info(f"file.content_type: {file.content_type}")
     try:
@@ -70,6 +72,7 @@ def upload_file(
             f.close()
 
         file = Files.insert_new_file(
+            db,
             user.id,
             FileForm(
                 **{
@@ -106,8 +109,8 @@ def upload_file(
 
 
 @router.get("/", response_model=List[FileModel])
-async def list_files(user=Depends(get_verified_user)):
-    files = Files.get_files()
+async def list_files(user=Depends(get_verified_user), db=Depends(get_db)):
+    files = Files.get_files(db)
     return files
 
 
@@ -117,8 +120,8 @@ async def list_files(user=Depends(get_verified_user)):
 
 
 @router.delete("/all")
-async def delete_all_files(user=Depends(get_admin_user)):
-    result = Files.delete_all_files()
+async def delete_all_files(user=Depends(get_admin_user), db=Depends(get_db)):
+    result = Files.delete_all_files(db)
 
     if result:
         folder = f"{UPLOAD_DIR}"
@@ -154,8 +157,8 @@ async def delete_all_files(user=Depends(get_admin_user)):
 
 
 @router.get("/{id}", response_model=Optional[FileModel])
-async def get_file_by_id(id: str, user=Depends(get_verified_user)):
-    file = Files.get_file_by_id(id)
+async def get_file_by_id(id: str, user=Depends(get_verified_user), db=Depends(get_db)):
+    file = Files.get_file_by_id(db, id)
 
     if file:
         return file
@@ -172,8 +175,8 @@ async def get_file_by_id(id: str, user=Depends(get_verified_user)):
 
 
 @router.get("/{id}/content", response_model=Optional[FileModel])
-async def get_file_content_by_id(id: str, user=Depends(get_verified_user)):
-    file = Files.get_file_by_id(id)
+async def get_file_content_by_id(id: str, user=Depends(get_verified_user), db=Depends(get_db)):
+    file = Files.get_file_by_id(db, id)
 
     if file:
         file_path = Path(file.meta["path"])
@@ -223,11 +226,11 @@ async def get_file_content_by_id(id: str, user=Depends(get_verified_user)):
 
 
 @router.delete("/{id}")
-async def delete_file_by_id(id: str, user=Depends(get_verified_user)):
-    file = Files.get_file_by_id(id)
+async def delete_file_by_id(id: str, user=Depends(get_verified_user), db=Depends(get_db)):
+    file = Files.get_file_by_id(db, id)
 
     if file:
-        result = Files.delete_file_by_id(id)
+        result = Files.delete_file_by_id(db, id)
         if result:
             return {"message": "File deleted successfully"}
         else:

+ 14 - 13
backend/apps/webui/routers/functions.py

@@ -6,6 +6,7 @@ from fastapi import APIRouter
 from pydantic import BaseModel
 import json
 
+from apps.webui.internal.db import get_db
 from apps.webui.models.functions import (
     Functions,
     FunctionForm,
@@ -31,8 +32,8 @@ router = APIRouter()
 
 
 @router.get("/", response_model=List[FunctionResponse])
-async def get_functions(user=Depends(get_verified_user)):
-    return Functions.get_functions()
+async def get_functions(user=Depends(get_verified_user), db=Depends(get_db)):
+    return Functions.get_functions(db)
 
 
 ############################
@@ -41,8 +42,8 @@ async def get_functions(user=Depends(get_verified_user)):
 
 
 @router.get("/export", response_model=List[FunctionModel])
-async def get_functions(user=Depends(get_admin_user)):
-    return Functions.get_functions()
+async def get_functions(user=Depends(get_admin_user), db=Depends(get_db)):
+    return Functions.get_functions(db)
 
 
 ############################
@@ -52,7 +53,7 @@ async def get_functions(user=Depends(get_admin_user)):
 
 @router.post("/create", response_model=Optional[FunctionResponse])
 async def create_new_function(
-    request: Request, form_data: FunctionForm, user=Depends(get_admin_user)
+    request: Request, form_data: FunctionForm, user=Depends(get_admin_user), db=Depends(get_db)
 ):
     if not form_data.id.isidentifier():
         raise HTTPException(
@@ -62,7 +63,7 @@ async def create_new_function(
 
     form_data.id = form_data.id.lower()
 
-    function = Functions.get_function_by_id(form_data.id)
+    function = Functions.get_function_by_id(db, form_data.id)
     if function == None:
         function_path = os.path.join(FUNCTIONS_DIR, f"{form_data.id}.py")
         try:
@@ -77,7 +78,7 @@ async def create_new_function(
             FUNCTIONS = request.app.state.FUNCTIONS
             FUNCTIONS[form_data.id] = function_module
 
-            function = Functions.insert_new_function(user.id, function_type, form_data)
+            function = Functions.insert_new_function(db, user.id, function_type, form_data)
 
             function_cache_dir = Path(CACHE_DIR) / "functions" / form_data.id
             function_cache_dir.mkdir(parents=True, exist_ok=True)
@@ -108,8 +109,8 @@ async def create_new_function(
 
 
 @router.get("/id/{id}", response_model=Optional[FunctionModel])
-async def get_function_by_id(id: str, user=Depends(get_admin_user)):
-    function = Functions.get_function_by_id(id)
+async def get_function_by_id(id: str, user=Depends(get_admin_user), db=Depends(get_db)):
+    function = Functions.get_function_by_id(db, id)
 
     if function:
         return function
@@ -154,7 +155,7 @@ async def toggle_function_by_id(id: str, user=Depends(get_admin_user)):
 
 @router.post("/id/{id}/update", response_model=Optional[FunctionModel])
 async def update_function_by_id(
-    request: Request, id: str, form_data: FunctionForm, user=Depends(get_admin_user)
+    request: Request, id: str, form_data: FunctionForm, user=Depends(get_admin_user), db=Depends(get_db)
 ):
     function_path = os.path.join(FUNCTIONS_DIR, f"{id}.py")
 
@@ -171,7 +172,7 @@ async def update_function_by_id(
         updated = {**form_data.model_dump(exclude={"id"}), "type": function_type}
         print(updated)
 
-        function = Functions.update_function_by_id(id, updated)
+        function = Functions.update_function_by_id(db, id, updated)
 
         if function:
             return function
@@ -195,9 +196,9 @@ async def update_function_by_id(
 
 @router.delete("/id/{id}/delete", response_model=bool)
 async def delete_function_by_id(
-    request: Request, id: str, user=Depends(get_admin_user)
+    request: Request, id: str, user=Depends(get_admin_user), db=Depends(get_db)
 ):
-    result = Functions.delete_function_by_id(id)
+    result = Functions.delete_function_by_id(db, id)
 
     if result:
         FUNCTIONS = request.app.state.FUNCTIONS

+ 18 - 11
backend/apps/webui/routers/memories.py

@@ -7,6 +7,7 @@ from fastapi import APIRouter
 from pydantic import BaseModel
 import logging
 
+from apps.webui.internal.db import get_db
 from apps.webui.models.memories import Memories, MemoryModel
 
 from utils.utils import get_verified_user
@@ -31,8 +32,8 @@ async def get_embeddings(request: Request):
 
 
 @router.get("/", response_model=List[MemoryModel])
-async def get_memories(user=Depends(get_verified_user)):
-    return Memories.get_memories_by_user_id(user.id)
+async def get_memories(user=Depends(get_verified_user), db=Depends(get_db)):
+    return Memories.get_memories_by_user_id(db, user.id)
 
 
 ############################
@@ -50,9 +51,12 @@ class MemoryUpdateModel(BaseModel):
 
 @router.post("/add", response_model=Optional[MemoryModel])
 async def add_memory(
-    request: Request, form_data: AddMemoryForm, user=Depends(get_verified_user)
+    request: Request,
+    form_data: AddMemoryForm,
+    user=Depends(get_verified_user),
+    db=Depends(get_db),
 ):
-    memory = Memories.insert_new_memory(user.id, form_data.content)
+    memory = Memories.insert_new_memory(db, user.id, form_data.content)
     memory_embedding = request.app.state.EMBEDDING_FUNCTION(memory.content)
 
     collection = CHROMA_CLIENT.get_or_create_collection(name=f"user-memory-{user.id}")
@@ -72,8 +76,9 @@ async def update_memory_by_id(
     request: Request,
     form_data: MemoryUpdateModel,
     user=Depends(get_verified_user),
+    db=Depends(get_db),
 ):
-    memory = Memories.update_memory_by_id(memory_id, form_data.content)
+    memory = Memories.update_memory_by_id(db, memory_id, form_data.content)
     if memory is None:
         raise HTTPException(status_code=404, detail="Memory not found")
 
@@ -124,12 +129,12 @@ async def query_memory(
 ############################
 @router.get("/reset", response_model=bool)
 async def reset_memory_from_vector_db(
-    request: Request, user=Depends(get_verified_user)
+    request: Request, user=Depends(get_verified_user), db=Depends(get_db)
 ):
     CHROMA_CLIENT.delete_collection(f"user-memory-{user.id}")
     collection = CHROMA_CLIENT.get_or_create_collection(name=f"user-memory-{user.id}")
 
-    memories = Memories.get_memories_by_user_id(user.id)
+    memories = Memories.get_memories_by_user_id(db, user.id)
     for memory in memories:
         memory_embedding = request.app.state.EMBEDDING_FUNCTION(memory.content)
         collection.upsert(
@@ -146,8 +151,8 @@ async def reset_memory_from_vector_db(
 
 
 @router.delete("/user", response_model=bool)
-async def delete_memory_by_user_id(user=Depends(get_verified_user)):
-    result = Memories.delete_memories_by_user_id(user.id)
+async def delete_memory_by_user_id(user=Depends(get_verified_user), db=Depends(get_db)):
+    result = Memories.delete_memories_by_user_id(db, user.id)
 
     if result:
         try:
@@ -165,8 +170,10 @@ async def delete_memory_by_user_id(user=Depends(get_verified_user)):
 
 
 @router.delete("/{memory_id}", response_model=bool)
-async def delete_memory_by_id(memory_id: str, user=Depends(get_verified_user)):
-    result = Memories.delete_memory_by_id_and_user_id(memory_id, user.id)
+async def delete_memory_by_id(
+    memory_id: str, user=Depends(get_verified_user), db=Depends(get_db)
+):
+    result = Memories.delete_memory_by_id_and_user_id(db, memory_id, user.id)
 
     if result:
         collection = CHROMA_CLIENT.get_or_create_collection(

+ 22 - 13
backend/apps/webui/routers/models.py

@@ -5,6 +5,8 @@ from typing import List, Union, Optional
 from fastapi import APIRouter
 from pydantic import BaseModel
 import json
+
+from apps.webui.internal.db import get_db
 from apps.webui.models.models import Models, ModelModel, ModelForm, ModelResponse
 
 from utils.utils import get_verified_user, get_admin_user
@@ -18,8 +20,8 @@ router = APIRouter()
 
 
 @router.get("/", response_model=List[ModelResponse])
-async def get_models(user=Depends(get_verified_user)):
-    return Models.get_all_models()
+async def get_models(user=Depends(get_verified_user), db=Depends(get_db)):
+    return Models.get_all_models(db)
 
 
 ############################
@@ -29,7 +31,10 @@ async def get_models(user=Depends(get_verified_user)):
 
 @router.post("/add", response_model=Optional[ModelModel])
 async def add_new_model(
-    request: Request, form_data: ModelForm, user=Depends(get_admin_user)
+    request: Request,
+    form_data: ModelForm,
+    user=Depends(get_admin_user),
+    db=Depends(get_db),
 ):
     if form_data.id in request.app.state.MODELS:
         raise HTTPException(
@@ -37,7 +42,7 @@ async def add_new_model(
             detail=ERROR_MESSAGES.MODEL_ID_TAKEN,
         )
     else:
-        model = Models.insert_new_model(form_data, user.id)
+        model = Models.insert_new_model(db, form_data, user.id)
 
         if model:
             return model
@@ -53,9 +58,9 @@ async def add_new_model(
 ############################
 
 
-@router.get("/", response_model=Optional[ModelModel])
-async def get_model_by_id(id: str, user=Depends(get_verified_user)):
-    model = Models.get_model_by_id(id)
+@router.get("/{id}", response_model=Optional[ModelModel])
+async def get_model_by_id(id: str, user=Depends(get_verified_user), db=Depends(get_db)):
+    model = Models.get_model_by_id(db, id)
 
     if model:
         return model
@@ -73,15 +78,19 @@ async def get_model_by_id(id: str, user=Depends(get_verified_user)):
 
 @router.post("/update", response_model=Optional[ModelModel])
 async def update_model_by_id(
-    request: Request, id: str, form_data: ModelForm, user=Depends(get_admin_user)
+    request: Request,
+    id: str,
+    form_data: ModelForm,
+    user=Depends(get_admin_user),
+    db=Depends(get_db),
 ):
-    model = Models.get_model_by_id(id)
+    model = Models.get_model_by_id(db, id)
     if model:
-        model = Models.update_model_by_id(id, form_data)
+        model = Models.update_model_by_id(db, id, form_data)
         return model
     else:
         if form_data.id in request.app.state.MODELS:
-            model = Models.insert_new_model(form_data, user.id)
+            model = Models.insert_new_model(db, form_data, user.id)
             if model:
                 return model
             else:
@@ -102,6 +111,6 @@ async def update_model_by_id(
 
 
 @router.delete("/delete", response_model=bool)
-async def delete_model_by_id(id: str, user=Depends(get_admin_user)):
-    result = Models.delete_model_by_id(id)
+async def delete_model_by_id(id: str, user=Depends(get_admin_user), db=Depends(get_db)):
+    result = Models.delete_model_by_id(db, id)
     return result

+ 21 - 11
backend/apps/webui/routers/prompts.py

@@ -6,6 +6,7 @@ from fastapi import APIRouter
 from pydantic import BaseModel
 import json
 
+from apps.webui.internal.db import get_db
 from apps.webui.models.prompts import Prompts, PromptForm, PromptModel
 
 from utils.utils import get_current_user, get_admin_user
@@ -19,8 +20,8 @@ router = APIRouter()
 
 
 @router.get("/", response_model=List[PromptModel])
-async def get_prompts(user=Depends(get_current_user)):
-    return Prompts.get_prompts()
+async def get_prompts(user=Depends(get_current_user), db=Depends(get_db)):
+    return Prompts.get_prompts(db)
 
 
 ############################
@@ -29,10 +30,12 @@ async def get_prompts(user=Depends(get_current_user)):
 
 
 @router.post("/create", response_model=Optional[PromptModel])
-async def create_new_prompt(form_data: PromptForm, user=Depends(get_admin_user)):
-    prompt = Prompts.get_prompt_by_command(form_data.command)
+async def create_new_prompt(
+    form_data: PromptForm, user=Depends(get_admin_user), db=Depends(get_db)
+):
+    prompt = Prompts.get_prompt_by_command(db, form_data.command)
     if prompt == None:
-        prompt = Prompts.insert_new_prompt(user.id, form_data)
+        prompt = Prompts.insert_new_prompt(db, user.id, form_data)
 
         if prompt:
             return prompt
@@ -52,8 +55,10 @@ async def create_new_prompt(form_data: PromptForm, user=Depends(get_admin_user))
 
 
 @router.get("/command/{command}", response_model=Optional[PromptModel])
-async def get_prompt_by_command(command: str, user=Depends(get_current_user)):
-    prompt = Prompts.get_prompt_by_command(f"/{command}")
+async def get_prompt_by_command(
+    command: str, user=Depends(get_current_user), db=Depends(get_db)
+):
+    prompt = Prompts.get_prompt_by_command(db, f"/{command}")
 
     if prompt:
         return prompt
@@ -71,9 +76,12 @@ async def get_prompt_by_command(command: str, user=Depends(get_current_user)):
 
 @router.post("/command/{command}/update", response_model=Optional[PromptModel])
 async def update_prompt_by_command(
-    command: str, form_data: PromptForm, user=Depends(get_admin_user)
+    command: str,
+    form_data: PromptForm,
+    user=Depends(get_admin_user),
+    db=Depends(get_db),
 ):
-    prompt = Prompts.update_prompt_by_command(f"/{command}", form_data)
+    prompt = Prompts.update_prompt_by_command(db, f"/{command}", form_data)
     if prompt:
         return prompt
     else:
@@ -89,6 +97,8 @@ async def update_prompt_by_command(
 
 
 @router.delete("/command/{command}/delete", response_model=bool)
-async def delete_prompt_by_command(command: str, user=Depends(get_admin_user)):
-    result = Prompts.delete_prompt_by_command(f"/{command}")
+async def delete_prompt_by_command(
+    command: str, user=Depends(get_admin_user), db=Depends(get_db)
+):
+    result = Prompts.delete_prompt_by_command(db, f"/{command}")
     return result

+ 22 - 13
backend/apps/webui/routers/tools.py

@@ -6,7 +6,7 @@ from fastapi import APIRouter
 from pydantic import BaseModel
 import json
 
-
+from apps.webui.internal.db import get_db
 from apps.webui.models.users import Users
 from apps.webui.models.tools import Tools, ToolForm, ToolModel, ToolResponse
 from apps.webui.utils import load_toolkit_module_by_id
@@ -34,7 +34,7 @@ router = APIRouter()
 
 
 @router.get("/", response_model=List[ToolResponse])
-async def get_toolkits(user=Depends(get_verified_user)):
+async def get_toolkits(user=Depends(get_verified_user), db=Depends(get_db)):
     toolkits = [toolkit for toolkit in Tools.get_tools()]
     return toolkits
 
@@ -45,8 +45,8 @@ async def get_toolkits(user=Depends(get_verified_user)):
 
 
 @router.get("/export", response_model=List[ToolModel])
-async def get_toolkits(user=Depends(get_admin_user)):
-    toolkits = [toolkit for toolkit in Tools.get_tools()]
+async def get_toolkits(user=Depends(get_admin_user), db=Depends(get_db)):
+    toolkits = [toolkit for toolkit in Tools.get_tools(db)]
     return toolkits
 
 
@@ -57,7 +57,10 @@ async def get_toolkits(user=Depends(get_admin_user)):
 
 @router.post("/create", response_model=Optional[ToolResponse])
 async def create_new_toolkit(
-    request: Request, form_data: ToolForm, user=Depends(get_admin_user)
+    request: Request,
+    form_data: ToolForm,
+    user=Depends(get_admin_user),
+    db=Depends(get_db),
 ):
     if not form_data.id.isidentifier():
         raise HTTPException(
@@ -67,7 +70,7 @@ async def create_new_toolkit(
 
     form_data.id = form_data.id.lower()
 
-    toolkit = Tools.get_tool_by_id(form_data.id)
+    toolkit = Tools.get_tool_by_id(db, form_data.id)
     if toolkit == None:
         toolkit_path = os.path.join(TOOLS_DIR, f"{form_data.id}.py")
         try:
@@ -81,7 +84,7 @@ async def create_new_toolkit(
             TOOLS[form_data.id] = toolkit_module
 
             specs = get_tools_specs(TOOLS[form_data.id])
-            toolkit = Tools.insert_new_tool(user.id, form_data, specs)
+            toolkit = Tools.insert_new_tool(db, user.id, form_data, specs)
 
             tool_cache_dir = Path(CACHE_DIR) / "tools" / form_data.id
             tool_cache_dir.mkdir(parents=True, exist_ok=True)
@@ -112,8 +115,8 @@ async def create_new_toolkit(
 
 
 @router.get("/id/{id}", response_model=Optional[ToolModel])
-async def get_toolkit_by_id(id: str, user=Depends(get_admin_user)):
-    toolkit = Tools.get_tool_by_id(id)
+async def get_toolkit_by_id(id: str, user=Depends(get_admin_user), db=Depends(get_db)):
+    toolkit = Tools.get_tool_by_id(db, id)
 
     if toolkit:
         return toolkit
@@ -131,7 +134,11 @@ async def get_toolkit_by_id(id: str, user=Depends(get_admin_user)):
 
 @router.post("/id/{id}/update", response_model=Optional[ToolModel])
 async def update_toolkit_by_id(
-    request: Request, id: str, form_data: ToolForm, user=Depends(get_admin_user)
+    request: Request,
+    id: str,
+    form_data: ToolForm,
+    user=Depends(get_admin_user),
+    db=Depends(get_db),
 ):
     toolkit_path = os.path.join(TOOLS_DIR, f"{id}.py")
 
@@ -153,7 +160,7 @@ async def update_toolkit_by_id(
         }
 
         print(updated)
-        toolkit = Tools.update_tool_by_id(id, updated)
+        toolkit = Tools.update_tool_by_id(db, id, updated)
 
         if toolkit:
             return toolkit
@@ -176,8 +183,10 @@ async def update_toolkit_by_id(
 
 
 @router.delete("/id/{id}/delete", response_model=bool)
-async def delete_toolkit_by_id(request: Request, id: str, user=Depends(get_admin_user)):
-    result = Tools.delete_tool_by_id(id)
+async def delete_toolkit_by_id(
+    request: Request, id: str, user=Depends(get_admin_user), db=Depends(get_db)
+):
+    result = Tools.delete_tool_by_id(db, id)
 
     if result:
         TOOLS = request.app.state.TOOLS

+ 44 - 25
backend/apps/webui/routers/users.py

@@ -9,6 +9,7 @@ import time
 import uuid
 import logging
 
+from apps.webui.internal.db import get_db
 from apps.webui.models.users import (
     UserModel,
     UserUpdateForm,
@@ -40,8 +41,10 @@ router = APIRouter()
 
 
 @router.get("/", response_model=List[UserModel])
-async def get_users(skip: int = 0, limit: int = 50, user=Depends(get_admin_user)):
-    return Users.get_users(skip, limit)
+async def get_users(
+    skip: int = 0, limit: int = 50, user=Depends(get_admin_user), db=Depends(get_db)
+):
+    return Users.get_users(db, skip, limit)
 
 
 ############################
@@ -68,10 +71,12 @@ async def update_user_permissions(
 
 
 @router.post("/update/role", response_model=Optional[UserModel])
-async def update_user_role(form_data: UserRoleUpdateForm, user=Depends(get_admin_user)):
+async def update_user_role(
+    form_data: UserRoleUpdateForm, user=Depends(get_admin_user), db=Depends(get_db)
+):
 
-    if user.id != form_data.id and form_data.id != Users.get_first_user().id:
-        return Users.update_user_role_by_id(form_data.id, form_data.role)
+    if user.id != form_data.id and form_data.id != Users.get_first_user(db).id:
+        return Users.update_user_role_by_id(db, form_data.id, form_data.role)
 
     raise HTTPException(
         status_code=status.HTTP_403_FORBIDDEN,
@@ -85,8 +90,10 @@ async def update_user_role(form_data: UserRoleUpdateForm, user=Depends(get_admin
 
 
 @router.get("/user/settings", response_model=Optional[UserSettings])
-async def get_user_settings_by_session_user(user=Depends(get_verified_user)):
-    user = Users.get_user_by_id(user.id)
+async def get_user_settings_by_session_user(
+    user=Depends(get_verified_user), db=Depends(get_db)
+):
+    user = Users.get_user_by_id(db, user.id)
     if user:
         return user.settings
     else:
@@ -103,9 +110,9 @@ async def get_user_settings_by_session_user(user=Depends(get_verified_user)):
 
 @router.post("/user/settings/update", response_model=UserSettings)
 async def update_user_settings_by_session_user(
-    form_data: UserSettings, user=Depends(get_verified_user)
+    form_data: UserSettings, user=Depends(get_verified_user), db=Depends(get_db)
 ):
-    user = Users.update_user_by_id(user.id, {"settings": form_data.model_dump()})
+    user = Users.update_user_by_id(db, user.id, {"settings": form_data.model_dump()})
     if user:
         return user.settings
     else:
@@ -121,8 +128,10 @@ async def update_user_settings_by_session_user(
 
 
 @router.get("/user/info", response_model=Optional[dict])
-async def get_user_info_by_session_user(user=Depends(get_verified_user)):
-    user = Users.get_user_by_id(user.id)
+async def get_user_info_by_session_user(
+    user=Depends(get_verified_user), db=Depends(get_db)
+):
+    user = Users.get_user_by_id(db, user.id)
     if user:
         return user.info
     else:
@@ -138,15 +147,17 @@ async def get_user_info_by_session_user(user=Depends(get_verified_user)):
 
 
 @router.post("/user/info/update", response_model=Optional[dict])
-async def update_user_settings_by_session_user(
-    form_data: dict, user=Depends(get_verified_user)
+async def update_user_info_by_session_user(
+    form_data: dict, user=Depends(get_verified_user), db=Depends(get_db)
 ):
-    user = Users.get_user_by_id(user.id)
+    user = Users.get_user_by_id(db, user.id)
     if user:
         if user.info is None:
             user.info = {}
 
-        user = Users.update_user_by_id(user.id, {"info": {**user.info, **form_data}})
+        user = Users.update_user_by_id(
+            db, user.id, {"info": {**user.info, **form_data}}
+        )
         if user:
             return user.info
         else:
@@ -172,13 +183,15 @@ class UserResponse(BaseModel):
 
 
 @router.get("/{user_id}", response_model=UserResponse)
-async def get_user_by_id(user_id: str, user=Depends(get_verified_user)):
+async def get_user_by_id(
+    user_id: str, user=Depends(get_verified_user), db=Depends(get_db)
+):
 
     # Check if user_id is a shared chat
     # If it is, get the user_id from the chat
     if user_id.startswith("shared-"):
         chat_id = user_id.replace("shared-", "")
-        chat = Chats.get_chat_by_id(chat_id)
+        chat = Chats.get_chat_by_id(db, chat_id)
         if chat:
             user_id = chat.user_id
         else:
@@ -187,7 +200,7 @@ async def get_user_by_id(user_id: str, user=Depends(get_verified_user)):
                 detail=ERROR_MESSAGES.USER_NOT_FOUND,
             )
 
-    user = Users.get_user_by_id(user_id)
+    user = Users.get_user_by_id(db, user_id)
 
     if user:
         return UserResponse(name=user.name, profile_image_url=user.profile_image_url)
@@ -205,13 +218,16 @@ async def get_user_by_id(user_id: str, user=Depends(get_verified_user)):
 
 @router.post("/{user_id}/update", response_model=Optional[UserModel])
 async def update_user_by_id(
-    user_id: str, form_data: UserUpdateForm, session_user=Depends(get_admin_user)
+    user_id: str,
+    form_data: UserUpdateForm,
+    session_user=Depends(get_admin_user),
+    db=Depends(get_db),
 ):
-    user = Users.get_user_by_id(user_id)
+    user = Users.get_user_by_id(db, user_id)
 
     if user:
         if form_data.email.lower() != user.email:
-            email_user = Users.get_user_by_email(form_data.email.lower())
+            email_user = Users.get_user_by_email(db, form_data.email.lower())
             if email_user:
                 raise HTTPException(
                     status_code=status.HTTP_400_BAD_REQUEST,
@@ -221,10 +237,11 @@ async def update_user_by_id(
         if form_data.password:
             hashed = get_password_hash(form_data.password)
             log.debug(f"hashed: {hashed}")
-            Auths.update_user_password_by_id(user_id, hashed)
+            Auths.update_user_password_by_id(db, user_id, hashed)
 
-        Auths.update_email_by_id(user_id, form_data.email.lower())
+        Auths.update_email_by_id(db, user_id, form_data.email.lower())
         updated_user = Users.update_user_by_id(
+            db,
             user_id,
             {
                 "name": form_data.name,
@@ -253,9 +270,11 @@ async def update_user_by_id(
 
 
 @router.delete("/{user_id}", response_model=bool)
-async def delete_user_by_id(user_id: str, user=Depends(get_admin_user)):
+async def delete_user_by_id(
+    user_id: str, user=Depends(get_admin_user), db=Depends(get_db)
+):
     if user.id != user_id:
-        result = Auths.delete_auth_by_id(user_id)
+        result = Auths.delete_auth_by_id(db, user_id)
 
         if result:
             return True

+ 4 - 4
backend/apps/webui/routers/utils.py

@@ -1,6 +1,5 @@
 from fastapi import APIRouter, UploadFile, File, Response
 from fastapi import Depends, HTTPException, status
-from peewee import SqliteDatabase
 from starlette.responses import StreamingResponse, FileResponse
 from pydantic import BaseModel
 
@@ -10,7 +9,6 @@ import markdown
 import black
 
 
-from apps.webui.internal.db import DB
 from utils.utils import get_admin_user
 from utils.misc import calculate_sha256, get_gravatar_url
 
@@ -114,13 +112,15 @@ async def download_db(user=Depends(get_admin_user)):
             status_code=status.HTTP_401_UNAUTHORIZED,
             detail=ERROR_MESSAGES.ACCESS_PROHIBITED,
         )
-    if not isinstance(DB, SqliteDatabase):
+    from apps.webui.internal.db import engine
+
+    if engine.name != "sqlite":
         raise HTTPException(
             status_code=status.HTTP_400_BAD_REQUEST,
             detail=ERROR_MESSAGES.DB_NOT_SQLITE,
         )
     return FileResponse(
-        DB.database,
+        engine.url.database,
         media_type="application/octet-stream",
         filename="webui.db",
     )

+ 45 - 11
backend/main.py

@@ -1,5 +1,6 @@
 import base64
 import uuid
+import subprocess
 from contextlib import asynccontextmanager
 
 from authlib.integrations.starlette_client import OAuth
@@ -27,6 +28,8 @@ from fastapi.responses import JSONResponse
 from fastapi import HTTPException
 from fastapi.middleware.wsgi import WSGIMiddleware
 from fastapi.middleware.cors import CORSMiddleware
+from sqlalchemy import text
+from sqlalchemy.orm import Session
 from starlette.exceptions import HTTPException as StarletteHTTPException
 from starlette.middleware.base import BaseHTTPMiddleware
 from starlette.middleware.sessions import SessionMiddleware
@@ -54,6 +57,7 @@ from apps.webui.main import (
     get_pipe_models,
     generate_function_chat_completion,
 )
+from apps.webui.internal.db import get_db, SessionLocal
 
 
 from pydantic import BaseModel
@@ -124,6 +128,8 @@ from config import (
     WEBUI_SESSION_COOKIE_SAME_SITE,
     WEBUI_SESSION_COOKIE_SECURE,
     AppConfig,
+    BACKEND_DIR,
+    DATABASE_URL,
 )
 from constants import ERROR_MESSAGES, WEBHOOK_MESSAGES
 from utils.webhook import post_webhook
@@ -166,8 +172,19 @@ https://github.com/open-webui/open-webui
 )
 
 
+def run_migrations():
+    from alembic.config import Config
+    from alembic import command
+
+    alembic_cfg = Config(f"{BACKEND_DIR}/alembic.ini")
+    alembic_cfg.set_main_option("sqlalchemy.url", DATABASE_URL)
+    alembic_cfg.set_main_option("script_location", f"{BACKEND_DIR}/migrations")
+    command.upgrade(alembic_cfg, "head")
+
+
 @asynccontextmanager
 async def lifespan(app: FastAPI):
+    run_migrations()
     yield
 
 
@@ -393,6 +410,7 @@ class ChatCompletionMiddleware(BaseHTTPMiddleware):
             user = get_current_user(
                 request,
                 get_http_authorization_cred(request.headers.get("Authorization")),
+                SessionLocal(),
             )
             # Flag to skip RAG completions if file_handler is present in tools/functions
             skip_files = False
@@ -736,6 +754,7 @@ class PipelineMiddleware(BaseHTTPMiddleware):
             user = get_current_user(
                 request,
                 get_http_authorization_cred(request.headers.get("Authorization")),
+                SessionLocal(),
             )
 
             try:
@@ -781,7 +800,9 @@ app.add_middleware(
 @app.middleware("http")
 async def check_url(request: Request, call_next):
     if len(app.state.MODELS) == 0:
-        await get_all_models()
+        db = SessionLocal()
+        await get_all_models(db)
+        db.commit()
     else:
         pass
 
@@ -815,12 +836,12 @@ app.mount("/api/v1", webui_app)
 webui_app.state.EMBEDDING_FUNCTION = rag_app.state.EMBEDDING_FUNCTION
 
 
-async def get_all_models():
+async def get_all_models(db: Session):
     pipe_models = []
     openai_models = []
     ollama_models = []
 
-    pipe_models = await get_pipe_models()
+    pipe_models = await get_pipe_models(db)
 
     if app.state.config.ENABLE_OPENAI_API:
         openai_models = await get_openai_models()
@@ -842,7 +863,7 @@ async def get_all_models():
 
     models = pipe_models + openai_models + ollama_models
 
-    custom_models = Models.get_all_models()
+    custom_models = Models.get_all_models(db)
     for custom_model in custom_models:
         if custom_model.base_model_id == None:
             for model in models:
@@ -882,8 +903,8 @@ async def get_all_models():
 
 
 @app.get("/api/models")
-async def get_models(user=Depends(get_verified_user)):
-    models = await get_all_models()
+async def get_models(user=Depends(get_verified_user), db=Depends(get_db)):
+    models = await get_all_models(db)
 
     # Filter out filter pipelines
     models = [
@@ -1584,9 +1605,12 @@ async def get_pipelines(urlIdx: Optional[int] = None, user=Depends(get_admin_use
 
 @app.get("/api/pipelines/{pipeline_id}/valves")
 async def get_pipeline_valves(
-    urlIdx: Optional[int], pipeline_id: str, user=Depends(get_admin_user)
+    urlIdx: Optional[int],
+    pipeline_id: str,
+    user=Depends(get_admin_user),
+    db=Depends(get_db),
 ):
-    models = await get_all_models()
+    models = await get_all_models(db)
     r = None
     try:
 
@@ -1622,9 +1646,12 @@ async def get_pipeline_valves(
 
 @app.get("/api/pipelines/{pipeline_id}/valves/spec")
 async def get_pipeline_valves_spec(
-    urlIdx: Optional[int], pipeline_id: str, user=Depends(get_admin_user)
+    urlIdx: Optional[int],
+    pipeline_id: str,
+    user=Depends(get_admin_user),
+    db=Depends(get_db),
 ):
-    models = await get_all_models()
+    models = await get_all_models(db)
 
     r = None
     try:
@@ -1663,8 +1690,9 @@ async def update_pipeline_valves(
     pipeline_id: str,
     form_data: dict,
     user=Depends(get_admin_user),
+    db=Depends(get_db),
 ):
-    models = await get_all_models()
+    models = await get_all_models(db)
 
     r = None
     try:
@@ -2011,6 +2039,12 @@ async def healthcheck():
     return {"status": True}
 
 
+@app.get("/health/db")
+async def healthcheck_with_db(db: Session = Depends(get_db)):
+    result = db.execute(text("SELECT 1;")).all()
+    return {"status": True}
+
+
 app.mount("/static", StaticFiles(directory=STATIC_DIR), name="static")
 app.mount("/cache", StaticFiles(directory=CACHE_DIR), name="cache")
 

+ 4 - 0
backend/migrations/README

@@ -0,0 +1,4 @@
+Generic single-database configuration.
+
+Create new migrations with
+DATABASE_URL=<replace with actual url> alembic revision --autogenerate -m "a description"

+ 93 - 0
backend/migrations/env.py

@@ -0,0 +1,93 @@
+import os
+from logging.config import fileConfig
+
+from sqlalchemy import engine_from_config
+from sqlalchemy import pool
+
+from alembic import context
+
+from apps.webui.models.auths import Auth
+from apps.webui.models.chats import Chat
+from apps.webui.models.documents import Document
+from apps.webui.models.memories import Memory
+from apps.webui.models.models import Model
+from apps.webui.models.prompts import Prompt
+from apps.webui.models.tags import Tag, ChatIdTag
+from apps.webui.models.tools import Tool
+from apps.webui.models.users import User
+from apps.webui.models.files import File
+from apps.webui.models.functions import Function
+
+# this is the Alembic Config object, which provides
+# access to the values within the .ini file in use.
+config = context.config
+
+# Interpret the config file for Python logging.
+# This line sets up loggers basically.
+if config.config_file_name is not None:
+    fileConfig(config.config_file_name)
+
+# add your model's MetaData object here
+# for 'autogenerate' support
+# from myapp import mymodel
+# target_metadata = mymodel.Base.metadata
+target_metadata = Auth.metadata
+
+# other values from the config, defined by the needs of env.py,
+# can be acquired:
+# my_important_option = config.get_main_option("my_important_option")
+# ... etc.
+
+database_url = os.getenv("DATABASE_URL", None)
+if database_url:
+    config.set_main_option("sqlalchemy.url", database_url)
+
+
+def run_migrations_offline() -> None:
+    """Run migrations in 'offline' mode.
+
+    This configures the context with just a URL
+    and not an Engine, though an Engine is acceptable
+    here as well.  By skipping the Engine creation
+    we don't even need a DBAPI to be available.
+
+    Calls to context.execute() here emit the given string to the
+    script output.
+
+    """
+    url = config.get_main_option("sqlalchemy.url")
+    context.configure(
+        url=url,
+        target_metadata=target_metadata,
+        literal_binds=True,
+        dialect_opts={"paramstyle": "named"},
+    )
+
+    with context.begin_transaction():
+        context.run_migrations()
+
+
+def run_migrations_online() -> None:
+    """Run migrations in 'online' mode.
+
+    In this scenario we need to create an Engine
+    and associate a connection with the context.
+
+    """
+    connectable = engine_from_config(
+        config.get_section(config.config_ini_section, {}),
+        prefix="sqlalchemy.",
+        poolclass=pool.NullPool,
+    )
+
+    with connectable.connect() as connection:
+        context.configure(connection=connection, target_metadata=target_metadata)
+
+        with context.begin_transaction():
+            context.run_migrations()
+
+
+if context.is_offline_mode():
+    run_migrations_offline()
+else:
+    run_migrations_online()

+ 27 - 0
backend/migrations/script.py.mako

@@ -0,0 +1,27 @@
+"""${message}
+
+Revision ID: ${up_revision}
+Revises: ${down_revision | comma,n}
+Create Date: ${create_date}
+
+"""
+from typing import Sequence, Union
+
+from alembic import op
+import sqlalchemy as sa
+import apps.webui.internal.db
+${imports if imports else ""}
+
+# revision identifiers, used by Alembic.
+revision: str = ${repr(up_revision)}
+down_revision: Union[str, None] = ${repr(down_revision)}
+branch_labels: Union[str, Sequence[str], None] = ${repr(branch_labels)}
+depends_on: Union[str, Sequence[str], None] = ${repr(depends_on)}
+
+
+def upgrade() -> None:
+    ${upgrades if upgrades else "pass"}
+
+
+def downgrade() -> None:
+    ${downgrades if downgrades else "pass"}

+ 188 - 0
backend/migrations/versions/22b5ab2667b8_init.py

@@ -0,0 +1,188 @@
+"""init
+
+Revision ID: 22b5ab2667b8
+Revises: 
+Create Date: 2024-06-20 13:22:40.397002
+
+"""
+
+from typing import Sequence, Union
+
+from alembic import op
+import sqlalchemy as sa
+from sqlalchemy.engine.reflection import Inspector
+
+import apps.webui.internal.db
+
+
+# revision identifiers, used by Alembic.
+revision: str = "22b5ab2667b8"
+down_revision: Union[str, None] = None
+branch_labels: Union[str, Sequence[str], None] = None
+depends_on: Union[str, Sequence[str], None] = None
+
+
+def upgrade() -> None:
+    con = op.get_bind()
+    inspector = Inspector.from_engine(con)
+    tables = set(inspector.get_table_names())
+
+    # ### commands auto generated by Alembic - please adjust! ###
+    if not "auth" in tables:
+        op.create_table(
+            "auth",
+            sa.Column("id", sa.String(), nullable=False),
+            sa.Column("email", sa.String(), nullable=True),
+            sa.Column("password", sa.String(), nullable=True),
+            sa.Column("active", sa.Boolean(), nullable=True),
+            sa.PrimaryKeyConstraint("id"),
+        )
+
+    if not "chat" in tables:
+        op.create_table(
+            "chat",
+            sa.Column("id", sa.String(), nullable=False),
+            sa.Column("user_id", sa.String(), nullable=True),
+            sa.Column("title", sa.String(), nullable=True),
+            sa.Column("chat", sa.String(), nullable=True),
+            sa.Column("created_at", sa.BigInteger(), nullable=True),
+            sa.Column("updated_at", sa.BigInteger(), nullable=True),
+            sa.Column("share_id", sa.String(), nullable=True),
+            sa.Column("archived", sa.Boolean(), nullable=True),
+            sa.PrimaryKeyConstraint("id"),
+            sa.UniqueConstraint("share_id"),
+        )
+
+    if not "chatidtag" in tables:
+        op.create_table(
+            "chatidtag",
+            sa.Column("id", sa.String(), nullable=False),
+            sa.Column("tag_name", sa.String(), nullable=True),
+            sa.Column("chat_id", sa.String(), nullable=True),
+            sa.Column("user_id", sa.String(), nullable=True),
+            sa.Column("timestamp", sa.BigInteger(), nullable=True),
+            sa.PrimaryKeyConstraint("id"),
+        )
+
+    if not "document" in tables:
+        op.create_table(
+            "document",
+            sa.Column("collection_name", sa.String(), nullable=False),
+            sa.Column("name", sa.String(), nullable=True),
+            sa.Column("title", sa.String(), nullable=True),
+            sa.Column("filename", sa.String(), nullable=True),
+            sa.Column("content", sa.String(), nullable=True),
+            sa.Column("user_id", sa.String(), nullable=True),
+            sa.Column("timestamp", sa.BigInteger(), nullable=True),
+            sa.PrimaryKeyConstraint("collection_name"),
+            sa.UniqueConstraint("name"),
+        )
+
+    if not "memory" in tables:
+        op.create_table(
+            "memory",
+            sa.Column("id", sa.String(), nullable=False),
+            sa.Column("user_id", sa.String(), nullable=True),
+            sa.Column("content", sa.String(), nullable=True),
+            sa.Column("updated_at", sa.BigInteger(), nullable=True),
+            sa.Column("created_at", sa.BigInteger(), nullable=True),
+            sa.PrimaryKeyConstraint("id"),
+        )
+
+    if not "model" in tables:
+        op.create_table(
+            "model",
+            sa.Column("id", sa.String(), nullable=False),
+            sa.Column("user_id", sa.String(), nullable=True),
+            sa.Column("base_model_id", sa.String(), nullable=True),
+            sa.Column("name", sa.String(), nullable=True),
+            sa.Column("params", apps.webui.internal.db.JSONField(), nullable=True),
+            sa.Column("meta", apps.webui.internal.db.JSONField(), nullable=True),
+            sa.Column("updated_at", sa.BigInteger(), nullable=True),
+            sa.Column("created_at", sa.BigInteger(), nullable=True),
+            sa.PrimaryKeyConstraint("id"),
+        )
+
+    if not "prompt" in tables:
+        op.create_table(
+            "prompt",
+            sa.Column("command", sa.String(), nullable=False),
+            sa.Column("user_id", sa.String(), nullable=True),
+            sa.Column("title", sa.String(), nullable=True),
+            sa.Column("content", sa.String(), nullable=True),
+            sa.Column("timestamp", sa.BigInteger(), nullable=True),
+            sa.PrimaryKeyConstraint("command"),
+        )
+
+    if not "tag" in tables:
+        op.create_table(
+            "tag",
+            sa.Column("id", sa.String(), nullable=False),
+            sa.Column("name", sa.String(), nullable=True),
+            sa.Column("user_id", sa.String(), nullable=True),
+            sa.Column("data", sa.String(), nullable=True),
+            sa.PrimaryKeyConstraint("id"),
+        )
+
+    if not "tool" in tables:
+        op.create_table(
+            "tool",
+            sa.Column("id", sa.String(), nullable=False),
+            sa.Column("user_id", sa.String(), nullable=True),
+            sa.Column("name", sa.String(), nullable=True),
+            sa.Column("content", sa.String(), nullable=True),
+            sa.Column("specs", apps.webui.internal.db.JSONField(), nullable=True),
+            sa.Column("meta", apps.webui.internal.db.JSONField(), nullable=True),
+            sa.Column("updated_at", sa.BigInteger(), nullable=True),
+            sa.Column("created_at", sa.BigInteger(), nullable=True),
+            sa.PrimaryKeyConstraint("id"),
+        )
+
+    if not "user" in tables:
+        op.create_table(
+            "user",
+            sa.Column("id", sa.String(), nullable=False),
+            sa.Column("name", sa.String(), nullable=True),
+            sa.Column("email", sa.String(), nullable=True),
+            sa.Column("role", sa.String(), nullable=True),
+            sa.Column("profile_image_url", sa.String(), nullable=True),
+            sa.Column("last_active_at", sa.BigInteger(), nullable=True),
+            sa.Column("updated_at", sa.BigInteger(), nullable=True),
+            sa.Column("created_at", sa.BigInteger(), nullable=True),
+            sa.Column("api_key", sa.String(), nullable=True),
+            sa.Column("settings", apps.webui.internal.db.JSONField(), nullable=True),
+            sa.Column("info", apps.webui.internal.db.JSONField(), nullable=True),
+            sa.PrimaryKeyConstraint("id"),
+            sa.UniqueConstraint("api_key"),
+        )
+
+    if not "file" in tables:
+        op.create_table('file',
+                        sa.Column('id', sa.String(), nullable=False),
+                        sa.Column('user_id', sa.String(), nullable=True),
+                        sa.Column('filename', sa.String(), nullable=True),
+                        sa.Column('meta', apps.webui.internal.db.JSONField(), nullable=True),
+                        sa.Column('created_at', sa.BigInteger(), nullable=True),
+                        sa.PrimaryKeyConstraint('id')
+                        )
+
+    if not "function" in tables:
+        op.create_table('function',
+                        sa.Column('id', sa.String(), nullable=False),
+                        sa.Column('user_id', sa.String(), nullable=True),
+                        sa.Column('name', sa.Text(), nullable=True),
+                        sa.Column('type', sa.Text(), nullable=True),
+                        sa.Column('content', sa.Text(), nullable=True),
+                        sa.Column('meta', apps.webui.internal.db.JSONField(), nullable=True),
+                        sa.Column('updated_at', sa.BigInteger(), nullable=True),
+                        sa.Column('created_at', sa.BigInteger(), nullable=True),
+                        sa.PrimaryKeyConstraint('id')
+                        )
+    # ### end Alembic commands ###
+
+
+def downgrade() -> None:
+    # ### commands auto generated by Alembic - please adjust! ###
+    # do nothing as we assume we had previous migrations from peewee-migrate
+    pass
+    # ### end Alembic commands ###

+ 10 - 3
backend/requirements.txt

@@ -12,8 +12,10 @@ passlib[bcrypt]==1.7.4
 
 requests==2.32.2
 aiohttp==3.9.5
-peewee==3.17.5
-peewee-migrate==1.12.2
+sqlalchemy==2.0.30
+alembic==1.13.1
+# peewee==3.17.5
+# peewee-migrate==1.12.2
 psycopg2-binary==2.9.9
 PyMySQL==1.1.1
 bcrypt==4.1.3
@@ -67,4 +69,9 @@ pytube==15.0.0
 
 extract_msg
 pydub
-duckduckgo-search~=6.1.5
+duckduckgo-search~=6.1.5
+
+## Tests
+docker~=7.1.0
+pytest~=8.2.1
+pytest-docker~=3.1.1

+ 0 - 0
backend/test/__init__.py


+ 209 - 0
backend/test/apps/webui/routers/test_auths.py

@@ -0,0 +1,209 @@
+import pytest
+
+from test.util.abstract_integration_test import AbstractPostgresTest
+from test.util.mock_user import mock_webui_user
+
+
+class TestAuths(AbstractPostgresTest):
+    BASE_PATH = "/api/v1/auths"
+
+    def setup_class(cls):
+        super().setup_class()
+        from apps.webui.models.users import Users
+        from apps.webui.models.auths import Auths
+
+        cls.users = Users
+        cls.auths = Auths
+
+    def test_get_session_user(self):
+        with mock_webui_user():
+            response = self.fast_api_client.get(self.create_url(""))
+        assert response.status_code == 200
+        assert response.json() == {
+            "id": "1",
+            "name": "John Doe",
+            "email": "john.doe@openwebui.com",
+            "role": "user",
+            "profile_image_url": "/user.png",
+        }
+
+    def test_update_profile(self):
+        from utils.utils import get_password_hash
+
+        user = self.auths.insert_new_auth(
+            self.db_session,
+            email="john.doe@openwebui.com",
+            password=get_password_hash("old_password"),
+            name="John Doe",
+            profile_image_url="/user.png",
+            role="user",
+        )
+
+        with mock_webui_user(id=user.id):
+            response = self.fast_api_client.post(
+                self.create_url("/update/profile"),
+                json={"name": "John Doe 2", "profile_image_url": "/user2.png"},
+            )
+        assert response.status_code == 200
+        db_user = self.users.get_user_by_id(self.db_session, user.id)
+        assert db_user.name == "John Doe 2"
+        assert db_user.profile_image_url == "/user2.png"
+
+    def test_update_password(self):
+        from utils.utils import get_password_hash
+
+        user = self.auths.insert_new_auth(
+            self.db_session,
+            email="john.doe@openwebui.com",
+            password=get_password_hash("old_password"),
+            name="John Doe",
+            profile_image_url="/user.png",
+            role="user",
+        )
+
+        with mock_webui_user(id=user.id):
+            response = self.fast_api_client.post(
+                self.create_url("/update/password"),
+                json={"password": "old_password", "new_password": "new_password"},
+            )
+        assert response.status_code == 200
+
+        old_auth = self.auths.authenticate_user(
+            self.db_session, "john.doe@openwebui.com", "old_password"
+        )
+        assert old_auth is None
+        new_auth = self.auths.authenticate_user(
+            self.db_session, "john.doe@openwebui.com", "new_password"
+        )
+        assert new_auth is not None
+
+    def test_signin(self):
+        from utils.utils import get_password_hash
+
+        user = self.auths.insert_new_auth(
+            self.db_session,
+            email="john.doe@openwebui.com",
+            password=get_password_hash("password"),
+            name="John Doe",
+            profile_image_url="/user.png",
+            role="user",
+        )
+        response = self.fast_api_client.post(
+            self.create_url("/signin"),
+            json={"email": "john.doe@openwebui.com", "password": "password"},
+        )
+        assert response.status_code == 200
+        data = response.json()
+        assert data["id"] == user.id
+        assert data["name"] == "John Doe"
+        assert data["email"] == "john.doe@openwebui.com"
+        assert data["role"] == "user"
+        assert data["profile_image_url"] == "/user.png"
+        assert data["token"] is not None and len(data["token"]) > 0
+        assert data["token_type"] == "Bearer"
+
+    def test_signup(self):
+        response = self.fast_api_client.post(
+            self.create_url("/signup"),
+            json={
+                "name": "John Doe",
+                "email": "john.doe@openwebui.com",
+                "password": "password",
+            },
+        )
+        assert response.status_code == 200
+        data = response.json()
+        assert data["id"] is not None and len(data["id"]) > 0
+        assert data["name"] == "John Doe"
+        assert data["email"] == "john.doe@openwebui.com"
+        assert data["role"] in ["admin", "user", "pending"]
+        assert data["profile_image_url"] == "/user.png"
+        assert data["token"] is not None and len(data["token"]) > 0
+        assert data["token_type"] == "Bearer"
+
+    def test_add_user(self):
+        with mock_webui_user():
+            response = self.fast_api_client.post(
+                self.create_url("/add"),
+                json={
+                    "name": "John Doe 2",
+                    "email": "john.doe2@openwebui.com",
+                    "password": "password2",
+                    "role": "admin",
+                },
+            )
+        assert response.status_code == 200
+        data = response.json()
+        assert data["id"] is not None and len(data["id"]) > 0
+        assert data["name"] == "John Doe 2"
+        assert data["email"] == "john.doe2@openwebui.com"
+        assert data["role"] == "admin"
+        assert data["profile_image_url"] == "/user.png"
+        assert data["token"] is not None and len(data["token"]) > 0
+        assert data["token_type"] == "Bearer"
+
+    def test_get_admin_details(self):
+        self.auths.insert_new_auth(
+            self.db_session,
+            email="john.doe@openwebui.com",
+            password="password",
+            name="John Doe",
+            profile_image_url="/user.png",
+            role="admin",
+        )
+        with mock_webui_user():
+            response = self.fast_api_client.get(self.create_url("/admin/details"))
+
+        assert response.status_code == 200
+        assert response.json() == {
+            "name": "John Doe",
+            "email": "john.doe@openwebui.com",
+        }
+
+    def test_create_api_key_(self):
+        user = self.auths.insert_new_auth(
+            self.db_session,
+            email="john.doe@openwebui.com",
+            password="password",
+            name="John Doe",
+            profile_image_url="/user.png",
+            role="admin",
+        )
+        with mock_webui_user(id=user.id):
+            response = self.fast_api_client.post(self.create_url("/api_key"))
+        assert response.status_code == 200
+        data = response.json()
+        assert data["api_key"] is not None
+        assert len(data["api_key"]) > 0
+
+    def test_delete_api_key(self):
+        user = self.auths.insert_new_auth(
+            self.db_session,
+            email="john.doe@openwebui.com",
+            password="password",
+            name="John Doe",
+            profile_image_url="/user.png",
+            role="admin",
+        )
+        self.users.update_user_api_key_by_id(self.db_session, user.id, "abc")
+        with mock_webui_user(id=user.id):
+            response = self.fast_api_client.delete(self.create_url("/api_key"))
+        assert response.status_code == 200
+        assert response.json() == True
+        db_user = self.users.get_user_by_id(self.db_session, user.id)
+        assert db_user.api_key is None
+
+    def test_get_api_key(self):
+        user = self.auths.insert_new_auth(
+            self.db_session,
+            email="john.doe@openwebui.com",
+            password="password",
+            name="John Doe",
+            profile_image_url="/user.png",
+            role="admin",
+        )
+        self.users.update_user_api_key_by_id(self.db_session, user.id, "abc")
+        with mock_webui_user(id=user.id):
+            response = self.fast_api_client.get(self.create_url("/api_key"))
+        assert response.status_code == 200
+        assert response.json() == {"api_key": "abc"}

+ 239 - 0
backend/test/apps/webui/routers/test_chats.py

@@ -0,0 +1,239 @@
+import uuid
+
+from test.util.abstract_integration_test import AbstractPostgresTest
+from test.util.mock_user import mock_webui_user
+
+
+class TestChats(AbstractPostgresTest):
+
+    BASE_PATH = "/api/v1/chats"
+
+    def setup_class(cls):
+        super().setup_class()
+
+    def setup_method(self):
+        super().setup_method()
+        from apps.webui.models.chats import ChatForm
+        from apps.webui.models.chats import Chats
+
+        self.chats = Chats
+        self.chats.insert_new_chat(
+            self.db_session,
+            "2",
+            ChatForm(
+                **{
+                    "chat": {
+                        "name": "chat1",
+                        "description": "chat1 description",
+                        "tags": ["tag1", "tag2"],
+                        "history": {"currentId": "1", "messages": []},
+                    }
+                }
+            ),
+        )
+
+    def test_get_session_user_chat_list(self):
+        with mock_webui_user(id="2"):
+            response = self.fast_api_client.get(self.create_url("/"))
+        assert response.status_code == 200
+        first_chat = response.json()[0]
+        assert first_chat["id"] is not None
+        assert first_chat["title"] == "New Chat"
+        assert first_chat["created_at"] is not None
+        assert first_chat["updated_at"] is not None
+
+    def test_delete_all_user_chats(self):
+        with mock_webui_user(id="2"):
+            response = self.fast_api_client.delete(self.create_url("/"))
+        assert response.status_code == 200
+        assert len(self.chats.get_chats(self.db_session)) == 0
+
+    def test_get_user_chat_list_by_user_id(self):
+        with mock_webui_user(id="3"):
+            response = self.fast_api_client.get(self.create_url("/list/user/2"))
+        assert response.status_code == 200
+        first_chat = response.json()[0]
+        assert first_chat["id"] is not None
+        assert first_chat["title"] == "New Chat"
+        assert first_chat["created_at"] is not None
+        assert first_chat["updated_at"] is not None
+
+    def test_create_new_chat(self):
+        with mock_webui_user(id="2"):
+            response = self.fast_api_client.post(
+                self.create_url("/new"),
+                json={
+                    "chat": {
+                        "name": "chat2",
+                        "description": "chat2 description",
+                        "tags": ["tag1", "tag2"],
+                    }
+                },
+            )
+        assert response.status_code == 200
+        data = response.json()
+        assert data["archived"] is False
+        assert data["chat"] == {
+            "name": "chat2",
+            "description": "chat2 description",
+            "tags": ["tag1", "tag2"],
+        }
+        assert data["user_id"] == "2"
+        assert data["id"] is not None
+        assert data["share_id"] is None
+        assert data["title"] == "New Chat"
+        assert data["updated_at"] is not None
+        assert data["created_at"] is not None
+        assert len(self.chats.get_chats(self.db_session)) == 2
+
+    def test_get_user_chats(self):
+        self.test_get_session_user_chat_list()
+
+    def test_get_user_archived_chats(self):
+        self.chats.archive_all_chats_by_user_id(self.db_session, "2")
+        self.db_session.commit()
+        with mock_webui_user(id="2"):
+            response = self.fast_api_client.get(self.create_url("/all/archived"))
+        assert response.status_code == 200
+        first_chat = response.json()[0]
+        assert first_chat["id"] is not None
+        assert first_chat["title"] == "New Chat"
+        assert first_chat["created_at"] is not None
+        assert first_chat["updated_at"] is not None
+
+    def test_get_all_user_chats_in_db(self):
+        with mock_webui_user(id="4"):
+            response = self.fast_api_client.get(self.create_url("/all/db"))
+        assert response.status_code == 200
+        assert len(response.json()) == 1
+
+    def test_get_archived_session_user_chat_list(self):
+        self.test_get_user_archived_chats()
+
+    def test_archive_all_chats(self):
+        with mock_webui_user(id="2"):
+            response = self.fast_api_client.post(self.create_url("/archive/all"))
+        assert response.status_code == 200
+        assert len(self.chats.get_archived_chats_by_user_id(self.db_session, "2")) == 1
+
+    def test_get_shared_chat_by_id(self):
+        chat_id = self.chats.get_chats(self.db_session)[0].id
+        self.chats.update_chat_share_id_by_id(self.db_session, chat_id, chat_id)
+        self.db_session.commit()
+        with mock_webui_user(id="2"):
+            response = self.fast_api_client.get(self.create_url(f"/share/{chat_id}"))
+        assert response.status_code == 200
+        data = response.json()
+        assert data["id"] == chat_id
+        assert data["chat"] == {
+            "name": "chat1",
+            "description": "chat1 description",
+            "tags": ["tag1", "tag2"],
+            "history": {"currentId": "1", "messages": []},
+        }
+        assert data["id"] == chat_id
+        assert data["share_id"] == chat_id
+        assert data["title"] == "New Chat"
+
+    def test_get_chat_by_id(self):
+        chat_id = self.chats.get_chats(self.db_session)[0].id
+        with mock_webui_user(id="2"):
+            response = self.fast_api_client.get(self.create_url(f"/{chat_id}"))
+        assert response.status_code == 200
+        data = response.json()
+        assert data["id"] == chat_id
+        assert data["chat"] == {
+            "name": "chat1",
+            "description": "chat1 description",
+            "tags": ["tag1", "tag2"],
+            "history": {"currentId": "1", "messages": []},
+        }
+        assert data["share_id"] is None
+        assert data["title"] == "New Chat"
+        assert data["user_id"] == "2"
+
+    def test_update_chat_by_id(self):
+        chat_id = self.chats.get_chats(self.db_session)[0].id
+        with mock_webui_user(id="2"):
+            response = self.fast_api_client.post(
+                self.create_url(f"/{chat_id}"),
+                json={
+                    "chat": {
+                        "name": "chat2",
+                        "description": "chat2 description",
+                        "tags": ["tag2", "tag4"],
+                        "title": "Just another title",
+                    }
+                },
+            )
+        assert response.status_code == 200
+        data = response.json()
+        assert data["id"] == chat_id
+        assert data["chat"] == {
+            "name": "chat2",
+            "title": "Just another title",
+            "description": "chat2 description",
+            "tags": ["tag2", "tag4"],
+            "history": {"currentId": "1", "messages": []},
+        }
+        assert data["share_id"] is None
+        assert data["title"] == "Just another title"
+        assert data["user_id"] == "2"
+
+    def test_delete_chat_by_id(self):
+        chat_id = self.chats.get_chats(self.db_session)[0].id
+        with mock_webui_user(id="2"):
+            response = self.fast_api_client.delete(self.create_url(f"/{chat_id}"))
+        assert response.status_code == 200
+        assert response.json() is True
+
+    def test_clone_chat_by_id(self):
+        chat_id = self.chats.get_chats(self.db_session)[0].id
+        with mock_webui_user(id="2"):
+            response = self.fast_api_client.get(self.create_url(f"/{chat_id}/clone"))
+
+        assert response.status_code == 200
+        data = response.json()
+        assert data["id"] != chat_id
+        assert data["chat"] == {
+            "branchPointMessageId": "1",
+            "description": "chat1 description",
+            "history": {"currentId": "1", "messages": []},
+            "name": "chat1",
+            "originalChatId": chat_id,
+            "tags": ["tag1", "tag2"],
+            "title": "Clone of New Chat",
+        }
+        assert data["share_id"] is None
+        assert data["title"] == "Clone of New Chat"
+        assert data["user_id"] == "2"
+
+    def test_archive_chat_by_id(self):
+        chat_id = self.chats.get_chats(self.db_session)[0].id
+        with mock_webui_user(id="2"):
+            response = self.fast_api_client.get(self.create_url(f"/{chat_id}/archive"))
+        assert response.status_code == 200
+
+        chat = self.chats.get_chat_by_id(self.db_session, chat_id)
+        assert chat.archived is True
+
+    def test_share_chat_by_id(self):
+        chat_id = self.chats.get_chats(self.db_session)[0].id
+        with mock_webui_user(id="2"):
+            response = self.fast_api_client.post(self.create_url(f"/{chat_id}/share"))
+        assert response.status_code == 200
+
+        chat = self.chats.get_chat_by_id(self.db_session, chat_id)
+        assert chat.share_id is not None
+
+    def test_delete_shared_chat_by_id(self):
+        chat_id = self.chats.get_chats(self.db_session)[0].id
+        share_id = str(uuid.uuid4())
+        self.chats.update_chat_share_id_by_id(self.db_session, chat_id, share_id)
+        self.db_session.commit()
+        with mock_webui_user(id="2"):
+            response = self.fast_api_client.delete(self.create_url(f"/{chat_id}/share"))
+        assert response.status_code
+
+        chat = self.chats.get_chat_by_id(self.db_session, chat_id)
+        assert chat.share_id is None

+ 106 - 0
backend/test/apps/webui/routers/test_documents.py

@@ -0,0 +1,106 @@
+from test.util.abstract_integration_test import AbstractPostgresTest
+from test.util.mock_user import mock_webui_user
+
+
+class TestDocuments(AbstractPostgresTest):
+
+    BASE_PATH = "/api/v1/documents"
+
+    def setup_class(cls):
+        super().setup_class()
+        from apps.webui.models.documents import Documents
+
+        cls.documents = Documents
+
+    def test_documents(self):
+        # Empty database
+        assert len(self.documents.get_docs(self.db_session)) == 0
+        with mock_webui_user(id="2"):
+            response = self.fast_api_client.get(self.create_url("/"))
+        assert response.status_code == 200
+        assert len(response.json()) == 0
+
+        # Create a new document
+        with mock_webui_user(id="2"):
+            response = self.fast_api_client.post(
+                self.create_url("/create"),
+                json={
+                    "name": "doc_name",
+                    "title": "doc title",
+                    "collection_name": "custom collection",
+                    "filename": "doc_name.pdf",
+                    "content": "",
+                },
+            )
+        assert response.status_code == 200
+        assert response.json()["name"] == "doc_name"
+        assert len(self.documents.get_docs(self.db_session)) == 1
+
+        # Get the document
+        with mock_webui_user(id="2"):
+            response = self.fast_api_client.get(self.create_url("/doc?name=doc_name"))
+        assert response.status_code == 200
+        data = response.json()
+        assert data["collection_name"] == "custom collection"
+        assert data["name"] == "doc_name"
+        assert data["title"] == "doc title"
+        assert data["filename"] == "doc_name.pdf"
+        assert data["content"] == {}
+
+        # Create another document
+        with mock_webui_user(id="2"):
+            response = self.fast_api_client.post(
+                self.create_url("/create"),
+                json={
+                    "name": "doc_name 2",
+                    "title": "doc title 2",
+                    "collection_name": "custom collection 2",
+                    "filename": "doc_name2.pdf",
+                    "content": "",
+                },
+            )
+        assert response.status_code == 200
+        assert response.json()["name"] == "doc_name 2"
+        assert len(self.documents.get_docs(self.db_session)) == 2
+
+        # Get all documents
+        with mock_webui_user(id="2"):
+            response = self.fast_api_client.get(self.create_url("/"))
+        assert response.status_code == 200
+        assert len(response.json()) == 2
+
+        # Update the first document
+        with mock_webui_user(id="2"):
+            response = self.fast_api_client.post(
+                self.create_url("/doc/update?name=doc_name"),
+                json={"name": "doc_name rework", "title": "updated title"},
+            )
+        assert response.status_code == 200
+        data = response.json()
+        assert data["name"] == "doc_name rework"
+        assert data["title"] == "updated title"
+
+        # Tag the first document
+        with mock_webui_user(id="2"):
+            response = self.fast_api_client.post(
+                self.create_url("/doc/tags"),
+                json={
+                    "name": "doc_name rework",
+                    "tags": [{"name": "testing-tag"}, {"name": "another-tag"}],
+                },
+            )
+        assert response.status_code == 200
+        data = response.json()
+        assert data["name"] == "doc_name rework"
+        assert data["content"] == {
+            "tags": [{"name": "testing-tag"}, {"name": "another-tag"}]
+        }
+        assert len(self.documents.get_docs(self.db_session)) == 2
+
+        # Delete the first document
+        with mock_webui_user(id="2"):
+            response = self.fast_api_client.delete(
+                self.create_url("/doc/delete?name=doc_name rework")
+            )
+        assert response.status_code == 200
+        assert len(self.documents.get_docs(self.db_session)) == 1

+ 60 - 0
backend/test/apps/webui/routers/test_models.py

@@ -0,0 +1,60 @@
+from test.util.abstract_integration_test import AbstractPostgresTest
+from test.util.mock_user import mock_webui_user
+
+
+class TestModels(AbstractPostgresTest):
+
+    BASE_PATH = "/api/v1/models"
+
+    def setup_class(cls):
+        super().setup_class()
+        from apps.webui.models.models import Model
+
+        cls.models = Model
+
+    def test_models(self):
+        with mock_webui_user(id="2"):
+            response = self.fast_api_client.get(self.create_url("/"))
+        assert response.status_code == 200
+        assert len(response.json()) == 0
+
+        with mock_webui_user(id="2"):
+            response = self.fast_api_client.post(
+                self.create_url("/add"),
+                json={
+                    "id": "my-model",
+                    "base_model_id": "base-model-id",
+                    "name": "Hello World",
+                    "meta": {
+                        "profile_image_url": "/favicon.png",
+                        "description": "description",
+                        "capabilities": None,
+                        "model_config": {},
+                    },
+                    "params": {},
+                },
+            )
+        assert response.status_code == 200
+
+        with mock_webui_user(id="2"):
+            response = self.fast_api_client.get(self.create_url("/"))
+        assert response.status_code == 200
+        assert len(response.json()) == 1
+
+        with mock_webui_user(id="2"):
+            response = self.fast_api_client.get(self.create_url("/my-model"))
+        assert response.status_code == 200
+        data = response.json()
+        assert data["id"] == "my-model"
+        assert data["name"] == "Hello World"
+
+        with mock_webui_user(id="2"):
+            response = self.fast_api_client.delete(
+                self.create_url("/delete?id=my-model")
+            )
+        assert response.status_code == 200
+
+        with mock_webui_user(id="2"):
+            response = self.fast_api_client.get(self.create_url("/"))
+        assert response.status_code == 200
+        assert len(response.json()) == 0

+ 82 - 0
backend/test/apps/webui/routers/test_prompts.py

@@ -0,0 +1,82 @@
+from test.util.abstract_integration_test import AbstractPostgresTest
+from test.util.mock_user import mock_webui_user
+
+
+class TestPrompts(AbstractPostgresTest):
+
+    BASE_PATH = "/api/v1/prompts"
+
+    def test_prompts(self):
+        # Get all prompts
+        with mock_webui_user(id="2"):
+            response = self.fast_api_client.get(self.create_url("/"))
+        assert response.status_code == 200
+        assert len(response.json()) == 0
+
+        # Create a two new prompts
+        with mock_webui_user(id="2"):
+            response = self.fast_api_client.post(
+                self.create_url("/create"),
+                json={
+                    "command": "/my-command",
+                    "title": "Hello World",
+                    "content": "description",
+                },
+            )
+        assert response.status_code == 200
+        with mock_webui_user(id="3"):
+            response = self.fast_api_client.post(
+                self.create_url("/create"),
+                json={
+                    "command": "/my-command2",
+                    "title": "Hello World 2",
+                    "content": "description 2",
+                },
+            )
+        assert response.status_code == 200
+
+        # Get all prompts
+        with mock_webui_user(id="2"):
+            response = self.fast_api_client.get(self.create_url("/"))
+        assert response.status_code == 200
+        assert len(response.json()) == 2
+
+        # Get prompt by command
+        with mock_webui_user(id="2"):
+            response = self.fast_api_client.get(self.create_url("/command/my-command"))
+        assert response.status_code == 200
+        data = response.json()
+        assert data["command"] == "/my-command"
+        assert data["title"] == "Hello World"
+        assert data["content"] == "description"
+        assert data["user_id"] == "2"
+
+        # Update prompt
+        with mock_webui_user(id="2"):
+            response = self.fast_api_client.post(
+                self.create_url("/command/my-command2/update"),
+                json={
+                    "command": "irrelevant for request",
+                    "title": "Hello World Updated",
+                    "content": "description Updated",
+                },
+            )
+        assert response.status_code == 200
+        data = response.json()
+        assert data["command"] == "/my-command2"
+        assert data["title"] == "Hello World Updated"
+        assert data["content"] == "description Updated"
+        assert data["user_id"] == "3"
+
+        # Delete prompt
+        with mock_webui_user(id="2"):
+            response = self.fast_api_client.delete(
+                self.create_url("/command/my-command/delete")
+            )
+        assert response.status_code == 200
+
+        # Get all prompts
+        with mock_webui_user(id="2"):
+            response = self.fast_api_client.get(self.create_url("/"))
+        assert response.status_code == 200
+        assert len(response.json()) == 1

+ 170 - 0
backend/test/apps/webui/routers/test_users.py

@@ -0,0 +1,170 @@
+from test.util.abstract_integration_test import AbstractPostgresTest
+from test.util.mock_user import mock_webui_user
+
+
+def _get_user_by_id(data, param):
+    return next((item for item in data if item["id"] == param), None)
+
+
+def _assert_user(data, id, **kwargs):
+    user = _get_user_by_id(data, id)
+    assert user is not None
+    comparison_data = {
+        "name": f"user {id}",
+        "email": f"user{id}@openwebui.com",
+        "profile_image_url": f"/user{id}.png",
+        "role": "user",
+        **kwargs,
+    }
+    for key, value in comparison_data.items():
+        assert user[key] == value
+
+
+class TestUsers(AbstractPostgresTest):
+
+    BASE_PATH = "/api/v1/users"
+
+    def setup_class(cls):
+        super().setup_class()
+        from apps.webui.models.users import Users
+
+        cls.users = Users
+
+    def setup_method(self):
+        super().setup_method()
+        self.users.insert_new_user(
+            self.db_session,
+            id="1",
+            name="user 1",
+            email="user1@openwebui.com",
+            profile_image_url="/user1.png",
+            role="user",
+        )
+        self.users.insert_new_user(
+            self.db_session,
+            id="2",
+            name="user 2",
+            email="user2@openwebui.com",
+            profile_image_url="/user2.png",
+            role="user",
+        )
+
+    def test_users(self):
+        # Get all users
+        with mock_webui_user(id="3"):
+            response = self.fast_api_client.get(self.create_url(""))
+        assert response.status_code == 200
+        assert len(response.json()) == 2
+        data = response.json()
+        _assert_user(data, "1")
+        _assert_user(data, "2")
+
+        # update role
+        with mock_webui_user(id="3"):
+            response = self.fast_api_client.post(
+                self.create_url("/update/role"), json={"id": "2", "role": "admin"}
+            )
+        assert response.status_code == 200
+        _assert_user([response.json()], "2", role="admin")
+
+        # Get all users
+        with mock_webui_user(id="3"):
+            response = self.fast_api_client.get(self.create_url(""))
+        assert response.status_code == 200
+        assert len(response.json()) == 2
+        data = response.json()
+        _assert_user(data, "1")
+        _assert_user(data, "2", role="admin")
+
+        # Get (empty) user settings
+        with mock_webui_user(id="2"):
+            response = self.fast_api_client.get(self.create_url("/user/settings"))
+        assert response.status_code == 200
+        assert response.json() is None
+
+        # Update user settings
+        with mock_webui_user(id="2"):
+            response = self.fast_api_client.post(
+                self.create_url("/user/settings/update"),
+                json={
+                    "ui": {"attr1": "value1", "attr2": "value2"},
+                    "model_config": {"attr3": "value3", "attr4": "value4"},
+                },
+            )
+        assert response.status_code == 200
+
+        # Get user settings
+        with mock_webui_user(id="2"):
+            response = self.fast_api_client.get(self.create_url("/user/settings"))
+        assert response.status_code == 200
+        assert response.json() == {
+            "ui": {"attr1": "value1", "attr2": "value2"},
+            "model_config": {"attr3": "value3", "attr4": "value4"},
+        }
+
+        # Get (empty) user info
+        with mock_webui_user(id="1"):
+            response = self.fast_api_client.get(self.create_url("/user/info"))
+        assert response.status_code == 200
+        assert response.json() is None
+
+        # Update user info
+        with mock_webui_user(id="1"):
+            response = self.fast_api_client.post(
+                self.create_url("/user/info/update"),
+                json={"attr1": "value1", "attr2": "value2"},
+            )
+        assert response.status_code == 200
+
+        # Get user info
+        with mock_webui_user(id="1"):
+            response = self.fast_api_client.get(self.create_url("/user/info"))
+        assert response.status_code == 200
+        assert response.json() == {"attr1": "value1", "attr2": "value2"}
+
+        # Get user by id
+        with mock_webui_user(id="1"):
+            response = self.fast_api_client.get(self.create_url("/2"))
+        assert response.status_code == 200
+        assert response.json() == {"name": "user 2", "profile_image_url": "/user2.png"}
+
+        # Update user by id
+        with mock_webui_user(id="1"):
+            response = self.fast_api_client.post(
+                self.create_url("/2/update"),
+                json={
+                    "name": "user 2 updated",
+                    "email": "user2-updated@openwebui.com",
+                    "profile_image_url": "/user2-updated.png",
+                },
+            )
+        assert response.status_code == 200
+
+        # Get all users
+        with mock_webui_user(id="3"):
+            response = self.fast_api_client.get(self.create_url(""))
+        assert response.status_code == 200
+        assert len(response.json()) == 2
+        data = response.json()
+        _assert_user(data, "1")
+        _assert_user(
+            data,
+            "2",
+            role="admin",
+            name="user 2 updated",
+            email="user2-updated@openwebui.com",
+            profile_image_url="/user2-updated.png",
+        )
+
+        # Delete user by id
+        with mock_webui_user(id="1"):
+            response = self.fast_api_client.delete(self.create_url("/2"))
+        assert response.status_code == 200
+
+        # Get all users
+        with mock_webui_user(id="3"):
+            response = self.fast_api_client.get(self.create_url(""))
+        assert response.status_code == 200
+        assert len(response.json()) == 1
+        data = response.json()
+        _assert_user(data, "1")

+ 155 - 0
backend/test/util/abstract_integration_test.py

@@ -0,0 +1,155 @@
+import logging
+import os
+import time
+
+import docker
+import pytest
+from docker import DockerClient
+from pytest_docker.plugin import get_docker_ip
+from fastapi.testclient import TestClient
+from sqlalchemy import text, create_engine
+
+log = logging.getLogger(__name__)
+
+
+def get_fast_api_client():
+    from main import app
+
+    with TestClient(app) as c:
+        return c
+
+
+class AbstractIntegrationTest:
+    BASE_PATH = None
+
+    def create_url(self, path):
+        if self.BASE_PATH is None:
+            raise Exception("BASE_PATH is not set")
+        parts = self.BASE_PATH.split("/")
+        parts = [part.strip() for part in parts if part.strip() != ""]
+        path_parts = path.split("/")
+        path_parts = [part.strip() for part in path_parts if part.strip() != ""]
+        return "/".join(parts + path_parts)
+
+    @classmethod
+    def setup_class(cls):
+        pass
+
+    def setup_method(self):
+        pass
+
+    @classmethod
+    def teardown_class(cls):
+        pass
+
+    def teardown_method(self):
+        pass
+
+
+class AbstractPostgresTest(AbstractIntegrationTest):
+    DOCKER_CONTAINER_NAME = "postgres-test-container-will-get-deleted"
+    docker_client: DockerClient
+
+    def get_db(self):
+        from apps.webui.internal.db import SessionLocal
+
+        return SessionLocal()
+
+    @classmethod
+    def _create_db_url(cls, env_vars_postgres: dict) -> str:
+        host = get_docker_ip()
+        user = env_vars_postgres["POSTGRES_USER"]
+        pw = env_vars_postgres["POSTGRES_PASSWORD"]
+        port = 8081
+        db = env_vars_postgres["POSTGRES_DB"]
+        return f"postgresql://{user}:{pw}@{host}:{port}/{db}"
+
+    @classmethod
+    def setup_class(cls):
+        super().setup_class()
+        try:
+            env_vars_postgres = {
+                "POSTGRES_USER": "user",
+                "POSTGRES_PASSWORD": "example",
+                "POSTGRES_DB": "openwebui",
+            }
+            cls.docker_client = docker.from_env()
+            cls.docker_client.containers.run(
+                "postgres:16.2",
+                detach=True,
+                environment=env_vars_postgres,
+                name=cls.DOCKER_CONTAINER_NAME,
+                ports={5432: ("0.0.0.0", 8081)},
+                command="postgres -c log_statement=all",
+            )
+            time.sleep(0.5)
+
+            database_url = cls._create_db_url(env_vars_postgres)
+            os.environ["DATABASE_URL"] = database_url
+            retries = 10
+            db = None
+            while retries > 0:
+                try:
+                    from config import BACKEND_DIR
+                    db = create_engine(database_url, pool_pre_ping=True)
+                    db = db.connect()
+                    log.info("postgres is ready!")
+                    break
+                except Exception as e:
+                    log.warning(e)
+                    time.sleep(3)
+                    retries -= 1
+
+            if db:
+                # import must be after setting env!
+                cls.fast_api_client = get_fast_api_client()
+                db.close()
+            else:
+                raise Exception("Could not connect to Postgres")
+        except Exception as ex:
+            log.error(ex)
+            cls.teardown_class()
+            pytest.fail(f"Could not setup test environment: {ex}")
+
+    def _check_db_connection(self):
+        retries = 10
+        while retries > 0:
+            try:
+                self.db_session.execute(text("SELECT 1"))
+                self.db_session.commit()
+                break
+            except Exception as e:
+                self.db_session.rollback()
+                log.warning(e)
+                time.sleep(3)
+                retries -= 1
+
+    def setup_method(self):
+        super().setup_method()
+        self.db_session = self.get_db()
+        self._check_db_connection()
+
+    @classmethod
+    def teardown_class(cls) -> None:
+        super().teardown_class()
+        cls.docker_client.containers.get(cls.DOCKER_CONTAINER_NAME).remove(force=True)
+
+    def teardown_method(self):
+        # rollback everything not yet committed
+        self.db_session.commit()
+
+        # truncate all tables
+        tables = [
+            "auth",
+            "chat",
+            "chatidtag",
+            "document",
+            "memory",
+            "model",
+            "prompt",
+            "tag",
+            '"user"',
+        ]
+        for table in tables:
+            self.db_session.execute(text(f"TRUNCATE TABLE {table}"))
+        self.db_session.commit()

+ 45 - 0
backend/test/util/mock_user.py

@@ -0,0 +1,45 @@
+from contextlib import contextmanager
+
+from fastapi import FastAPI
+
+
+@contextmanager
+def mock_webui_user(**kwargs):
+    from apps.webui.main import app
+
+    with mock_user(app, **kwargs):
+        yield
+
+
+@contextmanager
+def mock_user(app: FastAPI, **kwargs):
+    from utils.utils import (
+        get_current_user,
+        get_verified_user,
+        get_admin_user,
+        get_current_user_by_api_key,
+    )
+    from apps.webui.models.users import User
+
+    def create_user():
+        user_parameters = {
+            "id": "1",
+            "name": "John Doe",
+            "email": "john.doe@openwebui.com",
+            "role": "user",
+            "profile_image_url": "/user.png",
+            "last_active_at": 1627351200,
+            "updated_at": 1627351200,
+            "created_at": 162735120,
+            **kwargs,
+        }
+        return User(**user_parameters)
+
+    app.dependency_overrides = {
+        get_current_user: create_user,
+        get_verified_user: create_user,
+        get_admin_user: create_user,
+        get_current_user_by_api_key: create_user,
+    }
+    yield
+    app.dependency_overrides = {}

+ 9 - 6
backend/utils/utils.py

@@ -1,6 +1,8 @@
 from fastapi.security import HTTPBearer, HTTPAuthorizationCredentials
 from fastapi import HTTPException, status, Depends, Request
+from sqlalchemy.orm import Session
 
+from apps.webui.internal.db import get_db
 from apps.webui.models.users import Users
 
 from pydantic import BaseModel
@@ -77,6 +79,7 @@ def get_http_authorization_cred(auth_header: str):
 def get_current_user(
     request: Request,
     auth_token: HTTPAuthorizationCredentials = Depends(bearer_security),
+    db=Depends(get_db),
 ):
     token = None
 
@@ -91,19 +94,19 @@ def get_current_user(
 
     # auth by api key
     if token.startswith("sk-"):
-        return get_current_user_by_api_key(token)
+        return get_current_user_by_api_key(db, token)
 
     # auth by jwt token
     data = decode_token(token)
     if data != None and "id" in data:
-        user = Users.get_user_by_id(data["id"])
+        user = Users.get_user_by_id(db, data["id"])
         if user is None:
             raise HTTPException(
                 status_code=status.HTTP_401_UNAUTHORIZED,
                 detail=ERROR_MESSAGES.INVALID_TOKEN,
             )
         else:
-            Users.update_user_last_active_by_id(user.id)
+            Users.update_user_last_active_by_id(db, user.id)
         return user
     else:
         raise HTTPException(
@@ -112,8 +115,8 @@ def get_current_user(
         )
 
 
-def get_current_user_by_api_key(api_key: str):
-    user = Users.get_user_by_api_key(api_key)
+def get_current_user_by_api_key(db: Session, api_key: str):
+    user = Users.get_user_by_api_key(db, api_key)
 
     if user is None:
         raise HTTPException(
@@ -121,7 +124,7 @@ def get_current_user_by_api_key(api_key: str):
             detail=ERROR_MESSAGES.INVALID_TOKEN,
         )
     else:
-        Users.update_user_last_active_by_id(user.id)
+        Users.update_user_last_active_by_id(db, user.id)
 
     return user
 

+ 1 - 4
src/lib/apis/models/index.ts

@@ -63,10 +63,7 @@ export const getModelInfos = async (token: string = '') => {
 export const getModelById = async (token: string, id: string) => {
 	let error = null;
 
-	const searchParams = new URLSearchParams();
-	searchParams.append('id', id);
-
-	const res = await fetch(`${WEBUI_API_BASE_URL}/models?${searchParams.toString()}`, {
+	const res = await fetch(`${WEBUI_API_BASE_URL}/models/${id}`, {
 		method: 'GET',
 		headers: {
 			Accept: 'application/json',