소스 검색

feat(sqlalchemy): Replace peewee with sqlalchemy

Jonathan Rohde 10 달 전
부모
커밋
df09d0830a
47개의 변경된 파일2577개의 추가작업 그리고 1000개의 파일을 삭제
  1. 2 2
      .github/workflows/integration-test.yml
  2. 114 0
      backend/alembic.ini
  3. 5 2
      backend/apps/ollama/main.py
  4. 3 1
      backend/apps/openai/main.py
  5. 3 1
      backend/apps/socket/main.py
  6. 40 26
      backend/apps/webui/internal/db.py
  7. 0 72
      backend/apps/webui/internal/wrappers.py
  8. 3 3
      backend/apps/webui/main.py
  9. 40 37
      backend/apps/webui/models/auths.py
  10. 139 162
      backend/apps/webui/models/chats.py
  11. 48 56
      backend/apps/webui/models/documents.py
  12. 30 32
      backend/apps/webui/models/files.py
  13. 36 38
      backend/apps/webui/models/functions.py
  14. 43 44
      backend/apps/webui/models/memories.py
  15. 38 41
      backend/apps/webui/models/models.py
  16. 38 48
      backend/apps/webui/models/prompts.py
  17. 109 89
      backend/apps/webui/models/tags.py
  18. 37 41
      backend/apps/webui/models/tools.py
  19. 92 83
      backend/apps/webui/models/users.py
  20. 40 26
      backend/apps/webui/routers/auths.py
  21. 88 58
      backend/apps/webui/routers/chats.py
  22. 27 13
      backend/apps/webui/routers/documents.py
  23. 14 11
      backend/apps/webui/routers/files.py
  24. 14 13
      backend/apps/webui/routers/functions.py
  25. 18 11
      backend/apps/webui/routers/memories.py
  26. 22 13
      backend/apps/webui/routers/models.py
  27. 21 11
      backend/apps/webui/routers/prompts.py
  28. 22 13
      backend/apps/webui/routers/tools.py
  29. 44 25
      backend/apps/webui/routers/users.py
  30. 4 4
      backend/apps/webui/routers/utils.py
  31. 45 11
      backend/main.py
  32. 4 0
      backend/migrations/README
  33. 93 0
      backend/migrations/env.py
  34. 27 0
      backend/migrations/script.py.mako
  35. 188 0
      backend/migrations/versions/22b5ab2667b8_init.py
  36. 10 3
      backend/requirements.txt
  37. 0 0
      backend/test/__init__.py
  38. 209 0
      backend/test/apps/webui/routers/test_auths.py
  39. 239 0
      backend/test/apps/webui/routers/test_chats.py
  40. 106 0
      backend/test/apps/webui/routers/test_documents.py
  41. 60 0
      backend/test/apps/webui/routers/test_models.py
  42. 82 0
      backend/test/apps/webui/routers/test_prompts.py
  43. 170 0
      backend/test/apps/webui/routers/test_users.py
  44. 155 0
      backend/test/util/abstract_integration_test.py
  45. 45 0
      backend/test/util/mock_user.py
  46. 9 6
      backend/utils/utils.py
  47. 1 4
      src/lib/apis/models/index.ts

+ 2 - 2
.github/workflows/integration-test.yml

@@ -171,7 +171,7 @@ jobs:
           fi
           fi
 
 
           # Check that service will reconnect to postgres when connection will be closed
           # Check that service will reconnect to postgres when connection will be closed
-          status_code=$(curl --write-out %{http_code} -s --output /dev/null http://localhost:8081/health)
+          status_code=$(curl --write-out %{http_code} -s --output /dev/null http://localhost:8081/health/db)
           if [[ "$status_code" -ne 200 ]] ; then
           if [[ "$status_code" -ne 200 ]] ; then
             echo "Server has failed before postgres reconnect check"
             echo "Server has failed before postgres reconnect check"
             exit 1
             exit 1
@@ -183,7 +183,7 @@ jobs:
             cur = conn.cursor(); \
             cur = conn.cursor(); \
             cur.execute('SELECT pg_terminate_backend(psa.pid) FROM pg_stat_activity psa WHERE datname = current_database() AND pid <> pg_backend_pid();')"
             cur.execute('SELECT pg_terminate_backend(psa.pid) FROM pg_stat_activity psa WHERE datname = current_database() AND pid <> pg_backend_pid();')"
 
 
-          status_code=$(curl --write-out %{http_code} -s --output /dev/null http://localhost:8081/health)
+          status_code=$(curl --write-out %{http_code} -s --output /dev/null http://localhost:8081/health/db)
           if [[ "$status_code" -ne 200 ]] ; then
           if [[ "$status_code" -ne 200 ]] ; then
             echo "Server has not reconnected to postgres after connection was closed: returned status $status_code"
             echo "Server has not reconnected to postgres after connection was closed: returned status $status_code"
             exit 1
             exit 1

+ 114 - 0
backend/alembic.ini

@@ -0,0 +1,114 @@
+# A generic, single database configuration.
+
+[alembic]
+# path to migration scripts
+script_location = migrations
+
+# template used to generate migration file names; The default value is %%(rev)s_%%(slug)s
+# Uncomment the line below if you want the files to be prepended with date and time
+# file_template = %%(year)d_%%(month).2d_%%(day).2d_%%(hour).2d%%(minute).2d-%%(rev)s_%%(slug)s
+
+# sys.path path, will be prepended to sys.path if present.
+# defaults to the current working directory.
+prepend_sys_path = .
+
+# timezone to use when rendering the date within the migration file
+# as well as the filename.
+# If specified, requires the python>=3.9 or backports.zoneinfo library.
+# Any required deps can installed by adding `alembic[tz]` to the pip requirements
+# string value is passed to ZoneInfo()
+# leave blank for localtime
+# timezone =
+
+# max length of characters to apply to the
+# "slug" field
+# truncate_slug_length = 40
+
+# set to 'true' to run the environment during
+# the 'revision' command, regardless of autogenerate
+# revision_environment = false
+
+# set to 'true' to allow .pyc and .pyo files without
+# a source .py file to be detected as revisions in the
+# versions/ directory
+# sourceless = false
+
+# version location specification; This defaults
+# to migrations/versions.  When using multiple version
+# directories, initial revisions must be specified with --version-path.
+# The path separator used here should be the separator specified by "version_path_separator" below.
+# version_locations = %(here)s/bar:%(here)s/bat:migrations/versions
+
+# version path separator; As mentioned above, this is the character used to split
+# version_locations. The default within new alembic.ini files is "os", which uses os.pathsep.
+# If this key is omitted entirely, it falls back to the legacy behavior of splitting on spaces and/or commas.
+# Valid values for version_path_separator are:
+#
+# version_path_separator = :
+# version_path_separator = ;
+# version_path_separator = space
+version_path_separator = os  # Use os.pathsep. Default configuration used for new projects.
+
+# set to 'true' to search source files recursively
+# in each "version_locations" directory
+# new in Alembic version 1.10
+# recursive_version_locations = false
+
+# the output encoding used when revision files
+# are written from script.py.mako
+# output_encoding = utf-8
+
+sqlalchemy.url = REPLACE_WITH_DATABASE_URL
+
+
+[post_write_hooks]
+# post_write_hooks defines scripts or Python functions that are run
+# on newly generated revision scripts.  See the documentation for further
+# detail and examples
+
+# format using "black" - use the console_scripts runner, against the "black" entrypoint
+# hooks = black
+# black.type = console_scripts
+# black.entrypoint = black
+# black.options = -l 79 REVISION_SCRIPT_FILENAME
+
+# lint with attempts to fix using "ruff" - use the exec runner, execute a binary
+# hooks = ruff
+# ruff.type = exec
+# ruff.executable = %(here)s/.venv/bin/ruff
+# ruff.options = --fix REVISION_SCRIPT_FILENAME
+
+# Logging configuration
+[loggers]
+keys = root,sqlalchemy,alembic
+
+[handlers]
+keys = console
+
+[formatters]
+keys = generic
+
+[logger_root]
+level = WARN
+handlers = console
+qualname =
+
+[logger_sqlalchemy]
+level = WARN
+handlers =
+qualname = sqlalchemy.engine
+
+[logger_alembic]
+level = INFO
+handlers =
+qualname = alembic
+
+[handler_console]
+class = StreamHandler
+args = (sys.stderr,)
+level = NOTSET
+formatter = generic
+
+[formatter_generic]
+format = %(levelname)-5.5s [%(name)s] %(message)s
+datefmt = %H:%M:%S

+ 5 - 2
backend/apps/ollama/main.py

@@ -31,6 +31,7 @@ from typing import Optional, List, Union
 
 
 from starlette.background import BackgroundTask
 from starlette.background import BackgroundTask
 
 
+from apps.webui.internal.db import get_db
 from apps.webui.models.models import Models
 from apps.webui.models.models import Models
 from apps.webui.models.users import Users
 from apps.webui.models.users import Users
 from constants import ERROR_MESSAGES
 from constants import ERROR_MESSAGES
@@ -711,6 +712,7 @@ async def generate_chat_completion(
     form_data: GenerateChatCompletionForm,
     form_data: GenerateChatCompletionForm,
     url_idx: Optional[int] = None,
     url_idx: Optional[int] = None,
     user=Depends(get_verified_user),
     user=Depends(get_verified_user),
+    db=Depends(get_db),
 ):
 ):
 
 
     log.debug(
     log.debug(
@@ -724,7 +726,7 @@ async def generate_chat_completion(
     }
     }
 
 
     model_id = form_data.model
     model_id = form_data.model
-    model_info = Models.get_model_by_id(model_id)
+    model_info = Models.get_model_by_id(db, model_id)
 
 
     if model_info:
     if model_info:
         if model_info.base_model_id:
         if model_info.base_model_id:
@@ -883,6 +885,7 @@ async def generate_openai_chat_completion(
     form_data: dict,
     form_data: dict,
     url_idx: Optional[int] = None,
     url_idx: Optional[int] = None,
     user=Depends(get_verified_user),
     user=Depends(get_verified_user),
+    db=Depends(get_db),
 ):
 ):
     form_data = OpenAIChatCompletionForm(**form_data)
     form_data = OpenAIChatCompletionForm(**form_data)
 
 
@@ -891,7 +894,7 @@ async def generate_openai_chat_completion(
     }
     }
 
 
     model_id = form_data.model
     model_id = form_data.model
-    model_info = Models.get_model_by_id(model_id)
+    model_info = Models.get_model_by_id(db, model_id)
 
 
     if model_info:
     if model_info:
         if model_info.base_model_id:
         if model_info.base_model_id:

+ 3 - 1
backend/apps/openai/main.py

@@ -11,6 +11,7 @@ import logging
 from pydantic import BaseModel
 from pydantic import BaseModel
 from starlette.background import BackgroundTask
 from starlette.background import BackgroundTask
 
 
+from apps.webui.internal.db import get_db
 from apps.webui.models.models import Models
 from apps.webui.models.models import Models
 from apps.webui.models.users import Users
 from apps.webui.models.users import Users
 from constants import ERROR_MESSAGES
 from constants import ERROR_MESSAGES
@@ -353,12 +354,13 @@ async def generate_chat_completion(
     form_data: dict,
     form_data: dict,
     url_idx: Optional[int] = None,
     url_idx: Optional[int] = None,
     user=Depends(get_verified_user),
     user=Depends(get_verified_user),
+    db=Depends(get_db),
 ):
 ):
     idx = 0
     idx = 0
     payload = {**form_data}
     payload = {**form_data}
 
 
     model_id = form_data.get("model")
     model_id = form_data.get("model")
-    model_info = Models.get_model_by_id(model_id)
+    model_info = Models.get_model_by_id(db, model_id)
 
 
     if model_info:
     if model_info:
         if model_info.base_model_id:
         if model_info.base_model_id:

+ 3 - 1
backend/apps/socket/main.py

@@ -24,7 +24,9 @@ async def connect(sid, environ, auth):
         data = decode_token(auth["token"])
         data = decode_token(auth["token"])
 
 
         if data is not None and "id" in data:
         if data is not None and "id" in data:
-            user = Users.get_user_by_id(data["id"])
+            from apps.webui.internal.db import SessionLocal
+
+            user = Users.get_user_by_id(SessionLocal(), data["id"])
 
 
         if user:
         if user:
             SESSION_POOL[sid] = user.id
             SESSION_POOL[sid] = user.id

+ 40 - 26
backend/apps/webui/internal/db.py

@@ -1,18 +1,34 @@
 import os
 import os
 import logging
 import logging
 import json
 import json
+from typing import Optional, Any
+from typing_extensions import Self
 
 
-from peewee import *
-from peewee_migrate import Router
+from sqlalchemy import create_engine, types, Dialect
+from sqlalchemy.ext.declarative import declarative_base
+from sqlalchemy.orm import sessionmaker
+from sqlalchemy.sql.type_api import _T
 
 
-from apps.webui.internal.wrappers import register_connection
 from config import SRC_LOG_LEVELS, DATA_DIR, DATABASE_URL, BACKEND_DIR
 from config import SRC_LOG_LEVELS, DATA_DIR, DATABASE_URL, BACKEND_DIR
 
 
 log = logging.getLogger(__name__)
 log = logging.getLogger(__name__)
 log.setLevel(SRC_LOG_LEVELS["DB"])
 log.setLevel(SRC_LOG_LEVELS["DB"])
 
 
 
 
-class JSONField(TextField):
+class JSONField(types.TypeDecorator):
+    impl = types.Text
+    cache_ok = True
+
+    def process_bind_param(self, value: Optional[_T], dialect: Dialect) -> Any:
+        return json.dumps(value)
+
+    def process_result_value(self, value: Optional[_T], dialect: Dialect) -> Any:
+        if value is not None:
+            return json.loads(value)
+
+    def copy(self, **kw: Any) -> Self:
+        return JSONField(self.impl.length)
+
     def db_value(self, value):
     def db_value(self, value):
         return json.dumps(value)
         return json.dumps(value)
 
 
@@ -29,26 +45,24 @@ if os.path.exists(f"{DATA_DIR}/ollama.db"):
 else:
 else:
     pass
     pass
 
 
+SQLALCHEMY_DATABASE_URL = DATABASE_URL
+if "sqlite" in SQLALCHEMY_DATABASE_URL:
+    engine = create_engine(
+        SQLALCHEMY_DATABASE_URL, connect_args={"check_same_thread": False}
+    )
+else:
+    engine = create_engine(SQLALCHEMY_DATABASE_URL, pool_pre_ping=True)
+SessionLocal = sessionmaker(autocommit=False, autoflush=False, bind=engine)
+Base = declarative_base()
 
 
-# The `register_connection` function encapsulates the logic for setting up
-# the database connection based on the connection string, while `connect`
-# is a Peewee-specific method to manage the connection state and avoid errors
-# when a connection is already open.
-try:
-    DB = register_connection(DATABASE_URL)
-    log.info(f"Connected to a {DB.__class__.__name__} database.")
-except Exception as e:
-    log.error(f"Failed to initialize the database connection: {e}")
-    raise
-
-router = Router(
-    DB,
-    migrate_dir=BACKEND_DIR / "apps" / "webui" / "internal" / "migrations",
-    logger=log,
-)
-router.run()
-try:
-    DB.connect(reuse_if_open=True)
-except OperationalError as e:
-    log.info(f"Failed to connect to database again due to: {e}")
-    pass
+
+def get_db():
+    db = SessionLocal()
+    try:
+        yield db
+        db.commit()
+    except Exception as e:
+        db.rollback()
+        raise e
+    finally:
+        db.close()

+ 0 - 72
backend/apps/webui/internal/wrappers.py

@@ -1,72 +0,0 @@
-from contextvars import ContextVar
-from peewee import *
-from peewee import PostgresqlDatabase, InterfaceError as PeeWeeInterfaceError
-
-import logging
-from playhouse.db_url import connect, parse
-from playhouse.shortcuts import ReconnectMixin
-
-from config import SRC_LOG_LEVELS
-
-log = logging.getLogger(__name__)
-log.setLevel(SRC_LOG_LEVELS["DB"])
-
-db_state_default = {"closed": None, "conn": None, "ctx": None, "transactions": None}
-db_state = ContextVar("db_state", default=db_state_default.copy())
-
-
-class PeeweeConnectionState(object):
-    def __init__(self, **kwargs):
-        super().__setattr__("_state", db_state)
-        super().__init__(**kwargs)
-
-    def __setattr__(self, name, value):
-        self._state.get()[name] = value
-
-    def __getattr__(self, name):
-        value = self._state.get()[name]
-        return value
-
-
-class CustomReconnectMixin(ReconnectMixin):
-    reconnect_errors = (
-        # psycopg2
-        (OperationalError, "termin"),
-        (InterfaceError, "closed"),
-        # peewee
-        (PeeWeeInterfaceError, "closed"),
-    )
-
-
-class ReconnectingPostgresqlDatabase(CustomReconnectMixin, PostgresqlDatabase):
-    pass
-
-
-def register_connection(db_url):
-    db = connect(db_url)
-    if isinstance(db, PostgresqlDatabase):
-        # Enable autoconnect for SQLite databases, managed by Peewee
-        db.autoconnect = True
-        db.reuse_if_open = True
-        log.info("Connected to PostgreSQL database")
-
-        # Get the connection details
-        connection = parse(db_url)
-
-        # Use our custom database class that supports reconnection
-        db = ReconnectingPostgresqlDatabase(
-            connection["database"],
-            user=connection["user"],
-            password=connection["password"],
-            host=connection["host"],
-            port=connection["port"],
-        )
-        db.connect(reuse_if_open=True)
-    elif isinstance(db, SqliteDatabase):
-        # Enable autoconnect for SQLite databases, managed by Peewee
-        db.autoconnect = True
-        db.reuse_if_open = True
-        log.info("Connected to SQLite database")
-    else:
-        raise ValueError("Unsupported database connection")
-    return db

+ 3 - 3
backend/apps/webui/main.py

@@ -3,7 +3,7 @@ from fastapi.routing import APIRoute
 from fastapi.responses import StreamingResponse
 from fastapi.responses import StreamingResponse
 from fastapi.middleware.cors import CORSMiddleware
 from fastapi.middleware.cors import CORSMiddleware
 from starlette.middleware.sessions import SessionMiddleware
 from starlette.middleware.sessions import SessionMiddleware
-
+from sqlalchemy.orm import Session
 from apps.webui.routers import (
 from apps.webui.routers import (
     auths,
     auths,
     users,
     users,
@@ -114,8 +114,8 @@ async def get_status():
     }
     }
 
 
 
 
-async def get_pipe_models():
-    pipes = Functions.get_functions_by_type("pipe", active_only=True)
+async def get_pipe_models(db: Session):
+    pipes = Functions.get_functions_by_type(db, "pipe", active_only=True)
     pipe_models = []
     pipe_models = []
 
 
     for pipe in pipes:
     for pipe in pipes:

+ 40 - 37
backend/apps/webui/models/auths.py

@@ -1,14 +1,14 @@
 from pydantic import BaseModel
 from pydantic import BaseModel
-from typing import List, Union, Optional
-import time
+from typing import Optional
 import uuid
 import uuid
 import logging
 import logging
-from peewee import *
+from sqlalchemy import String, Column, Boolean
+from sqlalchemy.orm import Session
 
 
 from apps.webui.models.users import UserModel, Users
 from apps.webui.models.users import UserModel, Users
 from utils.utils import verify_password
 from utils.utils import verify_password
 
 
-from apps.webui.internal.db import DB
+from apps.webui.internal.db import Base
 
 
 from config import SRC_LOG_LEVELS
 from config import SRC_LOG_LEVELS
 
 
@@ -20,14 +20,13 @@ log.setLevel(SRC_LOG_LEVELS["MODELS"])
 ####################
 ####################
 
 
 
 
-class Auth(Model):
-    id = CharField(unique=True)
-    email = CharField()
-    password = TextField()
-    active = BooleanField()
+class Auth(Base):
+    __tablename__ = "auth"
 
 
-    class Meta:
-        database = DB
+    id = Column(String, primary_key=True)
+    email = Column(String)
+    password = Column(String)
+    active = Column(Boolean)
 
 
 
 
 class AuthModel(BaseModel):
 class AuthModel(BaseModel):
@@ -94,12 +93,10 @@ class AddUserForm(SignupForm):
 
 
 
 
 class AuthsTable:
 class AuthsTable:
-    def __init__(self, db):
-        self.db = db
-        self.db.create_tables([Auth])
 
 
     def insert_new_auth(
     def insert_new_auth(
         self,
         self,
+        db: Session,
         email: str,
         email: str,
         password: str,
         password: str,
         name: str,
         name: str,
@@ -114,24 +111,30 @@ class AuthsTable:
         auth = AuthModel(
         auth = AuthModel(
             **{"id": id, "email": email, "password": password, "active": True}
             **{"id": id, "email": email, "password": password, "active": True}
         )
         )
-        result = Auth.create(**auth.model_dump())
+        result = Auth(**auth.model_dump())
+        db.add(result)
 
 
         user = Users.insert_new_user(
         user = Users.insert_new_user(
-            id, name, email, profile_image_url, role, oauth_sub
+            db, id, name, email, profile_image_url, role, oauth_sub
         )
         )
 
 
+        db.commit()
+        db.refresh(result)
+
         if result and user:
         if result and user:
             return user
             return user
         else:
         else:
             return None
             return None
 
 
-    def authenticate_user(self, email: str, password: str) -> Optional[UserModel]:
+    def authenticate_user(
+        self, db: Session, email: str, password: str
+    ) -> Optional[UserModel]:
         log.info(f"authenticate_user: {email}")
         log.info(f"authenticate_user: {email}")
         try:
         try:
-            auth = Auth.get(Auth.email == email, Auth.active == True)
+            auth = db.query(Auth).filter_by(email=email, active=True).first()
             if auth:
             if auth:
                 if verify_password(password, auth.password):
                 if verify_password(password, auth.password):
-                    user = Users.get_user_by_id(auth.id)
+                    user = Users.get_user_by_id(db, auth.id)
                     return user
                     return user
                 else:
                 else:
                     return None
                     return None
@@ -140,55 +143,55 @@ class AuthsTable:
         except:
         except:
             return None
             return None
 
 
-    def authenticate_user_by_api_key(self, api_key: str) -> Optional[UserModel]:
+    def authenticate_user_by_api_key(
+        self, db: Session, api_key: str
+    ) -> Optional[UserModel]:
         log.info(f"authenticate_user_by_api_key: {api_key}")
         log.info(f"authenticate_user_by_api_key: {api_key}")
         # if no api_key, return None
         # if no api_key, return None
         if not api_key:
         if not api_key:
             return None
             return None
 
 
         try:
         try:
-            user = Users.get_user_by_api_key(api_key)
+            user = Users.get_user_by_api_key(db, api_key)
             return user if user else None
             return user if user else None
         except:
         except:
             return False
             return False
 
 
-    def authenticate_user_by_trusted_header(self, email: str) -> Optional[UserModel]:
+    def authenticate_user_by_trusted_header(
+        self, db: Session, email: str
+    ) -> Optional[UserModel]:
         log.info(f"authenticate_user_by_trusted_header: {email}")
         log.info(f"authenticate_user_by_trusted_header: {email}")
         try:
         try:
-            auth = Auth.get(Auth.email == email, Auth.active == True)
+            auth = db.query(Auth).filter(email=email, active=True).first()
             if auth:
             if auth:
                 user = Users.get_user_by_id(auth.id)
                 user = Users.get_user_by_id(auth.id)
                 return user
                 return user
         except:
         except:
             return None
             return None
 
 
-    def update_user_password_by_id(self, id: str, new_password: str) -> bool:
+    def update_user_password_by_id(
+        self, db: Session, id: str, new_password: str
+    ) -> bool:
         try:
         try:
-            query = Auth.update(password=new_password).where(Auth.id == id)
-            result = query.execute()
-
+            result = db.query(Auth).filter_by(id=id).update({"password": new_password})
             return True if result == 1 else False
             return True if result == 1 else False
         except:
         except:
             return False
             return False
 
 
-    def update_email_by_id(self, id: str, email: str) -> bool:
+    def update_email_by_id(self, db: Session, id: str, email: str) -> bool:
         try:
         try:
-            query = Auth.update(email=email).where(Auth.id == id)
-            result = query.execute()
-
+            result = db.query(Auth).filter_by(id=id).update({"email": email})
             return True if result == 1 else False
             return True if result == 1 else False
         except:
         except:
             return False
             return False
 
 
-    def delete_auth_by_id(self, id: str) -> bool:
+    def delete_auth_by_id(self, db: Session, id: str) -> bool:
         try:
         try:
             # Delete User
             # Delete User
-            result = Users.delete_user_by_id(id)
+            result = Users.delete_user_by_id(db, id)
 
 
             if result:
             if result:
-                # Delete Auth
-                query = Auth.delete().where(Auth.id == id)
-                query.execute()  # Remove the rows, return number of rows removed.
+                db.query(Auth).filter_by(id=id).delete()
 
 
                 return True
                 return True
             else:
             else:
@@ -197,4 +200,4 @@ class AuthsTable:
             return False
             return False
 
 
 
 
-Auths = AuthsTable(DB)
+Auths = AuthsTable()

+ 139 - 162
backend/apps/webui/models/chats.py

@@ -1,36 +1,39 @@
-from pydantic import BaseModel
+from pydantic import BaseModel, ConfigDict
 from typing import List, Union, Optional
 from typing import List, Union, Optional
-from peewee import *
-from playhouse.shortcuts import model_to_dict
 
 
 import json
 import json
 import uuid
 import uuid
 import time
 import time
 
 
-from apps.webui.internal.db import DB
+from sqlalchemy import Column, String, BigInteger, Boolean
+from sqlalchemy.orm import Session
+
+from apps.webui.internal.db import Base
+
 
 
 ####################
 ####################
 # Chat DB Schema
 # Chat DB Schema
 ####################
 ####################
 
 
 
 
-class Chat(Model):
-    id = CharField(unique=True)
-    user_id = CharField()
-    title = TextField()
-    chat = TextField()  # Save Chat JSON as Text
+class Chat(Base):
+    __tablename__ = "chat"
 
 
-    created_at = BigIntegerField()
-    updated_at = BigIntegerField()
+    id = Column(String, primary_key=True)
+    user_id = Column(String)
+    title = Column(String)
+    chat = Column(String)  # Save Chat JSON as Text
 
 
-    share_id = CharField(null=True, unique=True)
-    archived = BooleanField(default=False)
+    created_at = Column(BigInteger)
+    updated_at = Column(BigInteger)
 
 
-    class Meta:
-        database = DB
+    share_id = Column(String, unique=True, nullable=True)
+    archived = Column(Boolean, default=False)
 
 
 
 
 class ChatModel(BaseModel):
 class ChatModel(BaseModel):
+    model_config = ConfigDict(from_attributes=True)
+
     id: str
     id: str
     user_id: str
     user_id: str
     title: str
     title: str
@@ -75,11 +78,10 @@ class ChatTitleIdResponse(BaseModel):
 
 
 
 
 class ChatTable:
 class ChatTable:
-    def __init__(self, db):
-        self.db = db
-        db.create_tables([Chat])
 
 
-    def insert_new_chat(self, user_id: str, form_data: ChatForm) -> Optional[ChatModel]:
+    def insert_new_chat(
+        self, db: Session, user_id: str, form_data: ChatForm
+    ) -> Optional[ChatModel]:
         id = str(uuid.uuid4())
         id = str(uuid.uuid4())
         chat = ChatModel(
         chat = ChatModel(
             **{
             **{
@@ -94,29 +96,36 @@ class ChatTable:
             }
             }
         )
         )
 
 
-        result = Chat.create(**chat.model_dump())
-        return chat if result else None
+        result = Chat(**chat.model_dump())
+        db.add(result)
+        db.commit()
+        db.refresh(result)
+        return ChatModel.model_validate(result) if result else None
 
 
-    def update_chat_by_id(self, id: str, chat: dict) -> Optional[ChatModel]:
+    def update_chat_by_id(
+        self, db: Session, id: str, chat: dict
+    ) -> Optional[ChatModel]:
         try:
         try:
-            query = Chat.update(
-                chat=json.dumps(chat),
-                title=chat["title"] if "title" in chat else "New Chat",
-                updated_at=int(time.time()),
-            ).where(Chat.id == id)
-            query.execute()
-
-            chat = Chat.get(Chat.id == id)
-            return ChatModel(**model_to_dict(chat))
+            db.query(Chat).filter_by(id=id).update(
+                {
+                    "chat": json.dumps(chat),
+                    "title": chat["title"] if "title" in chat else "New Chat",
+                    "updated_at": int(time.time()),
+                }
+            )
+
+            return self.get_chat_by_id(db, id)
         except:
         except:
             return None
             return None
 
 
-    def insert_shared_chat_by_chat_id(self, chat_id: str) -> Optional[ChatModel]:
+    def insert_shared_chat_by_chat_id(
+        self, db: Session, chat_id: str
+    ) -> Optional[ChatModel]:
         # Get the existing chat to share
         # Get the existing chat to share
-        chat = Chat.get(Chat.id == chat_id)
+        chat = db.get(Chat, chat_id)
         # Check if the chat is already shared
         # Check if the chat is already shared
         if chat.share_id:
         if chat.share_id:
-            return self.get_chat_by_id_and_user_id(chat.share_id, "shared")
+            return self.get_chat_by_id_and_user_id(db, chat.share_id, "shared")
         # Create a new chat with the same data, but with a new ID
         # Create a new chat with the same data, but with a new ID
         shared_chat = ChatModel(
         shared_chat = ChatModel(
             **{
             **{
@@ -128,228 +137,196 @@ class ChatTable:
                 "updated_at": int(time.time()),
                 "updated_at": int(time.time()),
             }
             }
         )
         )
-        shared_result = Chat.create(**shared_chat.model_dump())
+        shared_result = Chat(**shared_chat.model_dump())
+        db.add(shared_result)
+        db.commit()
+        db.refresh(shared_result)
         # Update the original chat with the share_id
         # Update the original chat with the share_id
         result = (
         result = (
-            Chat.update(share_id=shared_chat.id).where(Chat.id == chat_id).execute()
+            db.query(Chat).filter_by(id=chat_id).update({"share_id": shared_chat.id})
         )
         )
 
 
         return shared_chat if (shared_result and result) else None
         return shared_chat if (shared_result and result) else None
 
 
-    def update_shared_chat_by_chat_id(self, chat_id: str) -> Optional[ChatModel]:
+    def update_shared_chat_by_chat_id(
+        self, db: Session, chat_id: str
+    ) -> Optional[ChatModel]:
         try:
         try:
             print("update_shared_chat_by_id")
             print("update_shared_chat_by_id")
-            chat = Chat.get(Chat.id == chat_id)
+            chat = db.get(Chat, chat_id)
             print(chat)
             print(chat)
 
 
-            query = Chat.update(
-                title=chat.title,
-                chat=chat.chat,
-            ).where(Chat.id == chat.share_id)
+            db.query(Chat).filter_by(id=chat.share_id).update(
+                {"title": chat.title, "chat": chat.chat}
+            )
 
 
-            query.execute()
-
-            chat = Chat.get(Chat.id == chat.share_id)
-            return ChatModel(**model_to_dict(chat))
+            return self.get_chat_by_id(db, chat.share_id)
         except:
         except:
             return None
             return None
 
 
-    def delete_shared_chat_by_chat_id(self, chat_id: str) -> bool:
+    def delete_shared_chat_by_chat_id(self, db: Session, chat_id: str) -> bool:
         try:
         try:
-            query = Chat.delete().where(Chat.user_id == f"shared-{chat_id}")
-            query.execute()  # Remove the rows, return number of rows removed.
-
+            db.query(Chat).filter_by(user_id=f"shared-{chat_id}").delete()
             return True
             return True
         except:
         except:
             return False
             return False
 
 
     def update_chat_share_id_by_id(
     def update_chat_share_id_by_id(
-        self, id: str, share_id: Optional[str]
+        self, db: Session, id: str, share_id: Optional[str]
     ) -> Optional[ChatModel]:
     ) -> Optional[ChatModel]:
         try:
         try:
-            query = Chat.update(
-                share_id=share_id,
-            ).where(Chat.id == id)
-            query.execute()
+            db.query(Chat).filter_by(id=id).update({"share_id": share_id})
 
 
-            chat = Chat.get(Chat.id == id)
-            return ChatModel(**model_to_dict(chat))
+            return self.get_chat_by_id(db, id)
         except:
         except:
             return None
             return None
 
 
-    def toggle_chat_archive_by_id(self, id: str) -> Optional[ChatModel]:
+    def toggle_chat_archive_by_id(self, db: Session, id: str) -> Optional[ChatModel]:
         try:
         try:
-            chat = self.get_chat_by_id(id)
-            query = Chat.update(
-                archived=(not chat.archived),
-            ).where(Chat.id == id)
+            chat = self.get_chat_by_id(db, id)
+            db.query(Chat).filter_by(id=id).update({"archived": not chat.archived})
 
 
-            query.execute()
-
-            chat = Chat.get(Chat.id == id)
-            return ChatModel(**model_to_dict(chat))
+            return self.get_chat_by_id(db, id)
         except:
         except:
             return None
             return None
 
 
-    def archive_all_chats_by_user_id(self, user_id: str) -> bool:
+    def archive_all_chats_by_user_id(self, db: Session, user_id: str) -> bool:
         try:
         try:
-            chats = self.get_chats_by_user_id(user_id)
-            for chat in chats:
-                query = Chat.update(
-                    archived=True,
-                ).where(Chat.id == chat.id)
-
-                query.execute()
+            db.query(Chat).filter_by(user_id=user_id).update({"archived": True})
 
 
             return True
             return True
         except:
         except:
             return False
             return False
 
 
     def get_archived_chat_list_by_user_id(
     def get_archived_chat_list_by_user_id(
-        self, user_id: str, skip: int = 0, limit: int = 50
+        self, db: Session, user_id: str, skip: int = 0, limit: int = 50
     ) -> List[ChatModel]:
     ) -> List[ChatModel]:
-        return [
-            ChatModel(**model_to_dict(chat))
-            for chat in Chat.select()
-            .where(Chat.archived == True)
-            .where(Chat.user_id == user_id)
+        all_chats = (
+            db.query(Chat)
+            .filter_by(user_id=user_id, archived=True)
             .order_by(Chat.updated_at.desc())
             .order_by(Chat.updated_at.desc())
-            # .limit(limit)
-            # .offset(skip)
-        ]
+            # .limit(limit).offset(skip)
+            .all()
+        )
+        return [ChatModel.model_validate(chat) for chat in all_chats]
 
 
     def get_chat_list_by_user_id(
     def get_chat_list_by_user_id(
         self,
         self,
+        db: Session,
         user_id: str,
         user_id: str,
         include_archived: bool = False,
         include_archived: bool = False,
         skip: int = 0,
         skip: int = 0,
         limit: int = 50,
         limit: int = 50,
     ) -> List[ChatModel]:
     ) -> List[ChatModel]:
-        if include_archived:
-            return [
-                ChatModel(**model_to_dict(chat))
-                for chat in Chat.select()
-                .where(Chat.user_id == user_id)
-                .order_by(Chat.updated_at.desc())
-                # .limit(limit)
-                # .offset(skip)
-            ]
-        else:
-            return [
-                ChatModel(**model_to_dict(chat))
-                for chat in Chat.select()
-                .where(Chat.archived == False)
-                .where(Chat.user_id == user_id)
-                .order_by(Chat.updated_at.desc())
-                # .limit(limit)
-                # .offset(skip)
-            ]
+        query = db.query(Chat).filter_by(user_id=user_id)
+        if not include_archived:
+            query = query.filter_by(archived=False)
+        all_chats = (
+            query.order_by(Chat.updated_at.desc())
+            # .limit(limit).offset(skip)
+            .all()
+        )
+        return [ChatModel.model_validate(chat) for chat in all_chats]
 
 
     def get_chat_list_by_chat_ids(
     def get_chat_list_by_chat_ids(
-        self, chat_ids: List[str], skip: int = 0, limit: int = 50
+        self, db: Session, chat_ids: List[str], skip: int = 0, limit: int = 50
     ) -> List[ChatModel]:
     ) -> List[ChatModel]:
-        return [
-            ChatModel(**model_to_dict(chat))
-            for chat in Chat.select()
-            .where(Chat.archived == False)
-            .where(Chat.id.in_(chat_ids))
+        all_chats = (
+            db.query(Chat)
+            .filter(Chat.id.in_(chat_ids))
+            .filter_by(archived=False)
             .order_by(Chat.updated_at.desc())
             .order_by(Chat.updated_at.desc())
-        ]
+            .all()
+        )
+        return [ChatModel.model_validate(chat) for chat in all_chats]
 
 
-    def get_chat_by_id(self, id: str) -> Optional[ChatModel]:
+    def get_chat_by_id(self, db: Session, id: str) -> Optional[ChatModel]:
         try:
         try:
-            chat = Chat.get(Chat.id == id)
-            return ChatModel(**model_to_dict(chat))
+            chat = db.get(Chat, id)
+            return ChatModel.model_validate(chat)
         except:
         except:
             return None
             return None
 
 
-    def get_chat_by_share_id(self, id: str) -> Optional[ChatModel]:
+    def get_chat_by_share_id(self, db: Session, id: str) -> Optional[ChatModel]:
         try:
         try:
-            chat = Chat.get(Chat.share_id == id)
+            chat = db.query(Chat).filter_by(share_id=id).first()
 
 
             if chat:
             if chat:
-                chat = Chat.get(Chat.id == id)
-                return ChatModel(**model_to_dict(chat))
+                return self.get_chat_by_id(db, id)
             else:
             else:
                 return None
                 return None
-        except:
+        except Exception as e:
             return None
             return None
 
 
-    def get_chat_by_id_and_user_id(self, id: str, user_id: str) -> Optional[ChatModel]:
+    def get_chat_by_id_and_user_id(
+        self, db: Session, id: str, user_id: str
+    ) -> Optional[ChatModel]:
         try:
         try:
-            chat = Chat.get(Chat.id == id, Chat.user_id == user_id)
-            return ChatModel(**model_to_dict(chat))
+            chat = db.query(Chat).filter_by(id=id, user_id=user_id).first()
+            return ChatModel.model_validate(chat)
         except:
         except:
             return None
             return None
 
 
-    def get_chats(self, skip: int = 0, limit: int = 50) -> List[ChatModel]:
-        return [
-            ChatModel(**model_to_dict(chat))
-            for chat in Chat.select().order_by(Chat.updated_at.desc())
+    def get_chats(self, db: Session, skip: int = 0, limit: int = 50) -> List[ChatModel]:
+        all_chats = (
+            db.query(Chat)
             # .limit(limit).offset(skip)
             # .limit(limit).offset(skip)
-        ]
-
-    def get_chats_by_user_id(self, user_id: str) -> List[ChatModel]:
-        return [
-            ChatModel(**model_to_dict(chat))
-            for chat in Chat.select()
-            .where(Chat.user_id == user_id)
             .order_by(Chat.updated_at.desc())
             .order_by(Chat.updated_at.desc())
-            # .limit(limit).offset(skip)
-        ]
-
-    def get_archived_chats_by_user_id(self, user_id: str) -> List[ChatModel]:
-        return [
-            ChatModel(**model_to_dict(chat))
-            for chat in Chat.select()
-            .where(Chat.archived == True)
-            .where(Chat.user_id == user_id)
+        )
+        return [ChatModel.model_validate(chat) for chat in all_chats]
+
+    def get_chats_by_user_id(self, db: Session, user_id: str) -> List[ChatModel]:
+        all_chats = (
+            db.query(Chat).filter_by(user_id=user_id).order_by(Chat.updated_at.desc())
+        )
+        return [ChatModel.model_validate(chat) for chat in all_chats]
+
+    def get_archived_chats_by_user_id(
+        self, db: Session, user_id: str
+    ) -> List[ChatModel]:
+        all_chats = (
+            db.query(Chat)
+            .filter_by(user_id=user_id, archived=True)
             .order_by(Chat.updated_at.desc())
             .order_by(Chat.updated_at.desc())
-        ]
+        )
+        return [ChatModel.model_validate(chat) for chat in all_chats]
 
 
-    def delete_chat_by_id(self, id: str) -> bool:
+    def delete_chat_by_id(self, db: Session, id: str) -> bool:
         try:
         try:
-            query = Chat.delete().where((Chat.id == id))
-            query.execute()  # Remove the rows, return number of rows removed.
+            db.query(Chat).filter_by(id=id).delete()
 
 
-            return True and self.delete_shared_chat_by_chat_id(id)
+            return True and self.delete_shared_chat_by_chat_id(db, id)
         except:
         except:
             return False
             return False
 
 
-    def delete_chat_by_id_and_user_id(self, id: str, user_id: str) -> bool:
+    def delete_chat_by_id_and_user_id(self, db: Session, id: str, user_id: str) -> bool:
         try:
         try:
-            query = Chat.delete().where((Chat.id == id) & (Chat.user_id == user_id))
-            query.execute()  # Remove the rows, return number of rows removed.
+            db.query(Chat).filter_by(id=id, user_id=user_id).delete()
 
 
-            return True and self.delete_shared_chat_by_chat_id(id)
+            return True and self.delete_shared_chat_by_chat_id(db, id)
         except:
         except:
             return False
             return False
 
 
-    def delete_chats_by_user_id(self, user_id: str) -> bool:
+    def delete_chats_by_user_id(self, db: Session, user_id: str) -> bool:
         try:
         try:
 
 
-            self.delete_shared_chats_by_user_id(user_id)
-
-            query = Chat.delete().where(Chat.user_id == user_id)
-            query.execute()  # Remove the rows, return number of rows removed.
+            self.delete_shared_chats_by_user_id(db, user_id)
 
 
+            db.query(Chat).filter_by(user_id=user_id).delete()
             return True
             return True
         except:
         except:
             return False
             return False
 
 
-    def delete_shared_chats_by_user_id(self, user_id: str) -> bool:
+    def delete_shared_chats_by_user_id(self, db: Session, user_id: str) -> bool:
         try:
         try:
-            shared_chat_ids = [
-                f"shared-{chat.id}"
-                for chat in Chat.select().where(Chat.user_id == user_id)
-            ]
+            chats_by_user = db.query(Chat).filter_by(user_id=user_id).all()
+            shared_chat_ids = [f"shared-{chat.id}" for chat in chats_by_user]
 
 
-            query = Chat.delete().where(Chat.user_id << shared_chat_ids)
-            query.execute()  # Remove the rows, return number of rows removed.
+            db.query(Chat).filter(Chat.user_id.in_(shared_chat_ids)).delete()
 
 
             return True
             return True
         except:
         except:
             return False
             return False
 
 
 
 
-Chats = ChatTable(DB)
+Chats = ChatTable()

+ 48 - 56
backend/apps/webui/models/documents.py

@@ -1,14 +1,12 @@
-from pydantic import BaseModel
-from peewee import *
-from playhouse.shortcuts import model_to_dict
-from typing import List, Union, Optional
+from pydantic import BaseModel, ConfigDict
+from typing import List, Optional
 import time
 import time
 import logging
 import logging
 
 
-from utils.utils import decode_token
-from utils.misc import get_gravatar_url
+from sqlalchemy import String, Column, BigInteger
+from sqlalchemy.orm import Session
 
 
-from apps.webui.internal.db import DB
+from apps.webui.internal.db import Base
 
 
 import json
 import json
 
 
@@ -22,20 +20,21 @@ log.setLevel(SRC_LOG_LEVELS["MODELS"])
 ####################
 ####################
 
 
 
 
-class Document(Model):
-    collection_name = CharField(unique=True)
-    name = CharField(unique=True)
-    title = TextField()
-    filename = TextField()
-    content = TextField(null=True)
-    user_id = CharField()
-    timestamp = BigIntegerField()
+class Document(Base):
+    __tablename__ = "document"
 
 
-    class Meta:
-        database = DB
+    collection_name = Column(String, primary_key=True)
+    name = Column(String, unique=True)
+    title = Column(String)
+    filename = Column(String)
+    content = Column(String, nullable=True)
+    user_id = Column(String)
+    timestamp = Column(BigInteger)
 
 
 
 
 class DocumentModel(BaseModel):
 class DocumentModel(BaseModel):
+    model_config = ConfigDict(from_attributes=True)
+
     collection_name: str
     collection_name: str
     name: str
     name: str
     title: str
     title: str
@@ -72,12 +71,9 @@ class DocumentForm(DocumentUpdateForm):
 
 
 
 
 class DocumentsTable:
 class DocumentsTable:
-    def __init__(self, db):
-        self.db = db
-        self.db.create_tables([Document])
 
 
     def insert_new_doc(
     def insert_new_doc(
-        self, user_id: str, form_data: DocumentForm
+        self, db: Session, user_id: str, form_data: DocumentForm
     ) -> Optional[DocumentModel]:
     ) -> Optional[DocumentModel]:
         document = DocumentModel(
         document = DocumentModel(
             **{
             **{
@@ -88,73 +84,69 @@ class DocumentsTable:
         )
         )
 
 
         try:
         try:
-            result = Document.create(**document.model_dump())
+            result = Document(**document.model_dump())
+            db.add(result)
+            db.commit()
+            db.refresh(result)
             if result:
             if result:
-                return document
+                return DocumentModel.model_validate(result)
             else:
             else:
                 return None
                 return None
         except:
         except:
             return None
             return None
 
 
-    def get_doc_by_name(self, name: str) -> Optional[DocumentModel]:
+    def get_doc_by_name(self, db: Session, name: str) -> Optional[DocumentModel]:
         try:
         try:
-            document = Document.get(Document.name == name)
-            return DocumentModel(**model_to_dict(document))
+            document = db.query(Document).filter_by(name=name).first()
+            return DocumentModel.model_validate(document) if document else None
         except:
         except:
             return None
             return None
 
 
-    def get_docs(self) -> List[DocumentModel]:
-        return [
-            DocumentModel(**model_to_dict(doc))
-            for doc in Document.select()
-            # .limit(limit).offset(skip)
-        ]
+    def get_docs(self, db: Session) -> List[DocumentModel]:
+        return [DocumentModel.model_validate(doc) for doc in db.query(Document).all()]
 
 
     def update_doc_by_name(
     def update_doc_by_name(
-        self, name: str, form_data: DocumentUpdateForm
+        self, db: Session, name: str, form_data: DocumentUpdateForm
     ) -> Optional[DocumentModel]:
     ) -> Optional[DocumentModel]:
         try:
         try:
-            query = Document.update(
-                title=form_data.title,
-                name=form_data.name,
-                timestamp=int(time.time()),
-            ).where(Document.name == name)
-            query.execute()
-
-            doc = Document.get(Document.name == form_data.name)
-            return DocumentModel(**model_to_dict(doc))
+            db.query(Document).filter_by(name=name).update(
+                {
+                    "title": form_data.title,
+                    "name": form_data.name,
+                    "timestamp": int(time.time()),
+                }
+            )
+            return self.get_doc_by_name(db, form_data.name)
         except Exception as e:
         except Exception as e:
             log.exception(e)
             log.exception(e)
             return None
             return None
 
 
     def update_doc_content_by_name(
     def update_doc_content_by_name(
-        self, name: str, updated: dict
+        self, db: Session, name: str, updated: dict
     ) -> Optional[DocumentModel]:
     ) -> Optional[DocumentModel]:
         try:
         try:
-            doc = self.get_doc_by_name(name)
+            doc = self.get_doc_by_name(db, name)
             doc_content = json.loads(doc.content if doc.content else "{}")
             doc_content = json.loads(doc.content if doc.content else "{}")
             doc_content = {**doc_content, **updated}
             doc_content = {**doc_content, **updated}
 
 
-            query = Document.update(
-                content=json.dumps(doc_content),
-                timestamp=int(time.time()),
-            ).where(Document.name == name)
-            query.execute()
+            db.query(Document).filter_by(name=name).update(
+                {
+                    "content": json.dumps(doc_content),
+                    "timestamp": int(time.time()),
+                }
+            )
 
 
-            doc = Document.get(Document.name == name)
-            return DocumentModel(**model_to_dict(doc))
+            return self.get_doc_by_name(db, name)
         except Exception as e:
         except Exception as e:
             log.exception(e)
             log.exception(e)
             return None
             return None
 
 
-    def delete_doc_by_name(self, name: str) -> bool:
+    def delete_doc_by_name(self, db: Session, name: str) -> bool:
         try:
         try:
-            query = Document.delete().where((Document.name == name))
-            query.execute()  # Remove the rows, return number of rows removed.
-
+            db.query(Document).filter_by(name=name).delete()
             return True
             return True
         except:
         except:
             return False
             return False
 
 
 
 
-Documents = DocumentsTable(DB)
+Documents = DocumentsTable()

+ 30 - 32
backend/apps/webui/models/files.py

@@ -1,10 +1,12 @@
-from pydantic import BaseModel
-from peewee import *
-from playhouse.shortcuts import model_to_dict
+from pydantic import BaseModel, ConfigDict
 from typing import List, Union, Optional
 from typing import List, Union, Optional
 import time
 import time
 import logging
 import logging
-from apps.webui.internal.db import DB, JSONField
+
+from sqlalchemy import Column, String, BigInteger
+from sqlalchemy.orm import Session
+
+from apps.webui.internal.db import JSONField, Base
 
 
 import json
 import json
 
 
@@ -18,15 +20,14 @@ log.setLevel(SRC_LOG_LEVELS["MODELS"])
 ####################
 ####################
 
 
 
 
-class File(Model):
-    id = CharField(unique=True)
-    user_id = CharField()
-    filename = TextField()
-    meta = JSONField()
-    created_at = BigIntegerField()
+class File(Base):
+    __tablename__ = "file"
 
 
-    class Meta:
-        database = DB
+    id = Column(String, primary_key=True)
+    user_id = Column(String)
+    filename = Column(String)
+    meta = Column(JSONField)
+    created_at = Column(BigInteger)
 
 
 
 
 class FileModel(BaseModel):
 class FileModel(BaseModel):
@@ -36,6 +37,7 @@ class FileModel(BaseModel):
     meta: dict
     meta: dict
     created_at: int  # timestamp in epoch
     created_at: int  # timestamp in epoch
 
 
+    model_config = ConfigDict(from_attributes=True)
 
 
 ####################
 ####################
 # Forms
 # Forms
@@ -57,11 +59,8 @@ class FileForm(BaseModel):
 
 
 
 
 class FilesTable:
 class FilesTable:
-    def __init__(self, db):
-        self.db = db
-        self.db.create_tables([File])
 
 
-    def insert_new_file(self, user_id: str, form_data: FileForm) -> Optional[FileModel]:
+    def insert_new_file(self, db: Session, user_id: str, form_data: FileForm) -> Optional[FileModel]:
         file = FileModel(
         file = FileModel(
             **{
             **{
                 **form_data.model_dump(),
                 **form_data.model_dump(),
@@ -71,42 +70,41 @@ class FilesTable:
         )
         )
 
 
         try:
         try:
-            result = File.create(**file.model_dump())
+            result = File(**file.model_dump())
+            db.add(result)
+            db.commit()
+            db.refresh(result)
             if result:
             if result:
-                return file
+                return FileModel.model_validate(result)
             else:
             else:
                 return None
                 return None
         except Exception as e:
         except Exception as e:
             print(f"Error creating tool: {e}")
             print(f"Error creating tool: {e}")
             return None
             return None
 
 
-    def get_file_by_id(self, id: str) -> Optional[FileModel]:
+    def get_file_by_id(self, db: Session, id: str) -> Optional[FileModel]:
         try:
         try:
-            file = File.get(File.id == id)
-            return FileModel(**model_to_dict(file))
+            file = db.get(File, id)
+            return FileModel.model_validate(file)
         except:
         except:
             return None
             return None
 
 
-    def get_files(self) -> List[FileModel]:
-        return [FileModel(**model_to_dict(file)) for file in File.select()]
+    def get_files(self, db: Session) -> List[FileModel]:
+        return [FileModel.model_validate(file) for file in db.query(File).all()]
 
 
-    def delete_file_by_id(self, id: str) -> bool:
+    def delete_file_by_id(self, db: Session, id: str) -> bool:
         try:
         try:
-            query = File.delete().where((File.id == id))
-            query.execute()  # Remove the rows, return number of rows removed.
-
+            db.query(File).filter_by(id=id).delete()
             return True
             return True
         except:
         except:
             return False
             return False
 
 
-    def delete_all_files(self) -> bool:
+    def delete_all_files(self, db: Session) -> bool:
         try:
         try:
-            query = File.delete()
-            query.execute()  # Remove the rows, return number of rows removed.
-
+            db.query(File).delete()
             return True
             return True
         except:
         except:
             return False
             return False
 
 
 
 
-Files = FilesTable(DB)
+Files = FilesTable()

+ 36 - 38
backend/apps/webui/models/functions.py

@@ -1,10 +1,12 @@
-from pydantic import BaseModel
-from peewee import *
-from playhouse.shortcuts import model_to_dict
+from pydantic import BaseModel, ConfigDict
 from typing import List, Union, Optional
 from typing import List, Union, Optional
 import time
 import time
 import logging
 import logging
-from apps.webui.internal.db import DB, JSONField
+
+from sqlalchemy import Column, String, Text, BigInteger, Boolean
+from sqlalchemy.orm import Session
+
+from apps.webui.internal.db import JSONField, Base
 from apps.webui.models.users import Users
 from apps.webui.models.users import Users
 
 
 import json
 import json
@@ -21,20 +23,19 @@ log.setLevel(SRC_LOG_LEVELS["MODELS"])
 ####################
 ####################
 
 
 
 
-class Function(Model):
-    id = CharField(unique=True)
-    user_id = CharField()
-    name = TextField()
-    type = TextField()
-    content = TextField()
-    meta = JSONField()
-    valves = JSONField()
-    is_active = BooleanField(default=False)
-    updated_at = BigIntegerField()
-    created_at = BigIntegerField()
+class Function(Base):
+    __tablename__ = "function"
 
 
-    class Meta:
-        database = DB
+    id = Column(String, primary_key=True)
+    user_id = Column(String)
+    name = Column(Text)
+    type = Column(Text)
+    content = Column(Text)
+    meta = Column(JSONField)
+    valves = Column(JSONField)
+    is_active = Column(Boolean)
+    updated_at = Column(BigInteger)
+    created_at = Column(BigInteger)
 
 
 
 
 class FunctionMeta(BaseModel):
 class FunctionMeta(BaseModel):
@@ -53,6 +54,8 @@ class FunctionModel(BaseModel):
     updated_at: int  # timestamp in epoch
     updated_at: int  # timestamp in epoch
     created_at: int  # timestamp in epoch
     created_at: int  # timestamp in epoch
 
 
+    model_config = ConfigDict(from_attributes=True)
+
 
 
 ####################
 ####################
 # Forms
 # Forms
@@ -82,12 +85,9 @@ class FunctionValves(BaseModel):
 
 
 
 
 class FunctionsTable:
 class FunctionsTable:
-    def __init__(self, db):
-        self.db = db
-        self.db.create_tables([Function])
 
 
     def insert_new_function(
     def insert_new_function(
-        self, user_id: str, type: str, form_data: FunctionForm
+        self, db: Session, user_id: str, type: str, form_data: FunctionForm
     ) -> Optional[FunctionModel]:
     ) -> Optional[FunctionModel]:
         function = FunctionModel(
         function = FunctionModel(
             **{
             **{
@@ -100,19 +100,22 @@ class FunctionsTable:
         )
         )
 
 
         try:
         try:
-            result = Function.create(**function.model_dump())
+            result = Function(**function.model_dump())
+            db.add(result)
+            db.commit()
+            db.refresh(result)
             if result:
             if result:
-                return function
+                return FunctionModel.model_validate(result)
             else:
             else:
                 return None
                 return None
         except Exception as e:
         except Exception as e:
             print(f"Error creating tool: {e}")
             print(f"Error creating tool: {e}")
             return None
             return None
 
 
-    def get_function_by_id(self, id: str) -> Optional[FunctionModel]:
+    def get_function_by_id(self, db: Session, id: str) -> Optional[FunctionModel]:
         try:
         try:
-            function = Function.get(Function.id == id)
-            return FunctionModel(**model_to_dict(function))
+            function = db.get(Function, id)
+            return FunctionModel.model_validate(function)
         except:
         except:
             return None
             return None
 
 
@@ -211,14 +214,11 @@ class FunctionsTable:
 
 
     def update_function_by_id(self, id: str, updated: dict) -> Optional[FunctionModel]:
     def update_function_by_id(self, id: str, updated: dict) -> Optional[FunctionModel]:
         try:
         try:
-            query = Function.update(
+            db.query(Function).filter_by(id=id).update({
                 **updated,
                 **updated,
-                updated_at=int(time.time()),
-            ).where(Function.id == id)
-            query.execute()
-
-            function = Function.get(Function.id == id)
-            return FunctionModel(**model_to_dict(function))
+                "updated_at": int(time.time()),
+            })
+            return self.get_function_by_id(db, id)
         except:
         except:
             return None
             return None
 
 
@@ -235,14 +235,12 @@ class FunctionsTable:
         except:
         except:
             return None
             return None
 
 
-    def delete_function_by_id(self, id: str) -> bool:
+    def delete_function_by_id(self, db: Session, id: str) -> bool:
         try:
         try:
-            query = Function.delete().where((Function.id == id))
-            query.execute()  # Remove the rows, return number of rows removed.
-
+            db.query(Function).filter_by(id=id).delete()
             return True
             return True
         except:
         except:
             return False
             return False
 
 
 
 
-Functions = FunctionsTable(DB)
+Functions = FunctionsTable()

+ 43 - 44
backend/apps/webui/models/memories.py

@@ -1,9 +1,10 @@
-from pydantic import BaseModel
-from peewee import *
-from playhouse.shortcuts import model_to_dict
+from pydantic import BaseModel, ConfigDict
 from typing import List, Union, Optional
 from typing import List, Union, Optional
 
 
-from apps.webui.internal.db import DB
+from sqlalchemy import Column, String, BigInteger
+from sqlalchemy.orm import Session
+
+from apps.webui.internal.db import Base
 from apps.webui.models.chats import Chats
 from apps.webui.models.chats import Chats
 
 
 import time
 import time
@@ -14,15 +15,14 @@ import uuid
 ####################
 ####################
 
 
 
 
-class Memory(Model):
-    id = CharField(unique=True)
-    user_id = CharField()
-    content = TextField()
-    updated_at = BigIntegerField()
-    created_at = BigIntegerField()
+class Memory(Base):
+    __tablename__ = "memory"
 
 
-    class Meta:
-        database = DB
+    id = Column(String, primary_key=True)
+    user_id = Column(String)
+    content = Column(String)
+    updated_at = Column(BigInteger)
+    created_at = Column(BigInteger)
 
 
 
 
 class MemoryModel(BaseModel):
 class MemoryModel(BaseModel):
@@ -32,6 +32,8 @@ class MemoryModel(BaseModel):
     updated_at: int  # timestamp in epoch
     updated_at: int  # timestamp in epoch
     created_at: int  # timestamp in epoch
     created_at: int  # timestamp in epoch
 
 
+    model_config = ConfigDict(from_attributes=True)
+
 
 
 ####################
 ####################
 # Forms
 # Forms
@@ -39,12 +41,10 @@ class MemoryModel(BaseModel):
 
 
 
 
 class MemoriesTable:
 class MemoriesTable:
-    def __init__(self, db):
-        self.db = db
-        self.db.create_tables([Memory])
 
 
     def insert_new_memory(
     def insert_new_memory(
         self,
         self,
+        db: Session,
         user_id: str,
         user_id: str,
         content: str,
         content: str,
     ) -> Optional[MemoryModel]:
     ) -> Optional[MemoryModel]:
@@ -59,74 +59,73 @@ class MemoriesTable:
                 "updated_at": int(time.time()),
                 "updated_at": int(time.time()),
             }
             }
         )
         )
-        result = Memory.create(**memory.model_dump())
+        result = Memory(**memory.dict())
+        db.add(result)
+        db.commit()
+        db.refresh(result)
         if result:
         if result:
-            return memory
+            return MemoryModel.model_validate(result)
         else:
         else:
             return None
             return None
 
 
     def update_memory_by_id(
     def update_memory_by_id(
         self,
         self,
+        db: Session,
         id: str,
         id: str,
         content: str,
         content: str,
     ) -> Optional[MemoryModel]:
     ) -> Optional[MemoryModel]:
         try:
         try:
-            memory = Memory.get(Memory.id == id)
-            memory.content = content
-            memory.updated_at = int(time.time())
-            memory.save()
-            return MemoryModel(**model_to_dict(memory))
+            db.query(Memory).filter_by(id=id).update(
+                {"content": content, "updated_at": int(time.time())}
+            )
+            return self.get_memory_by_id(db, id)
         except:
         except:
             return None
             return None
 
 
-    def get_memories(self) -> List[MemoryModel]:
+    def get_memories(self, db: Session) -> List[MemoryModel]:
         try:
         try:
-            memories = Memory.select()
-            return [MemoryModel(**model_to_dict(memory)) for memory in memories]
+            memories = db.query(Memory).all()
+            return [MemoryModel.model_validate(memory) for memory in memories]
         except:
         except:
             return None
             return None
 
 
-    def get_memories_by_user_id(self, user_id: str) -> List[MemoryModel]:
+    def get_memories_by_user_id(self, db: Session, user_id: str) -> List[MemoryModel]:
         try:
         try:
-            memories = Memory.select().where(Memory.user_id == user_id)
-            return [MemoryModel(**model_to_dict(memory)) for memory in memories]
+            memories = db.query(Memory).filter_by(user_id=user_id).all()
+            return [MemoryModel.model_validate(memory) for memory in memories]
         except:
         except:
             return None
             return None
 
 
-    def get_memory_by_id(self, id) -> Optional[MemoryModel]:
+    def get_memory_by_id(self, db: Session, id: str) -> Optional[MemoryModel]:
         try:
         try:
-            memory = Memory.get(Memory.id == id)
-            return MemoryModel(**model_to_dict(memory))
+            memory = db.get(Memory, id)
+            return MemoryModel.model_validate(memory)
         except:
         except:
             return None
             return None
 
 
-    def delete_memory_by_id(self, id: str) -> bool:
+    def delete_memory_by_id(self, db: Session, id: str) -> bool:
         try:
         try:
-            query = Memory.delete().where(Memory.id == id)
-            query.execute()  # Remove the rows, return number of rows removed.
-
+            db.query(Memory).filter_by(id=id).delete()
             return True
             return True
 
 
         except:
         except:
             return False
             return False
 
 
-    def delete_memories_by_user_id(self, user_id: str) -> bool:
+    def delete_memories_by_user_id(self, db: Session, user_id: str) -> bool:
         try:
         try:
-            query = Memory.delete().where(Memory.user_id == user_id)
-            query.execute()
-
+            db.query(Memory).filter_by(user_id=user_id).delete()
             return True
             return True
         except:
         except:
             return False
             return False
 
 
-    def delete_memory_by_id_and_user_id(self, id: str, user_id: str) -> bool:
+    def delete_memory_by_id_and_user_id(
+        self, db: Session, id: str, user_id: str
+    ) -> bool:
         try:
         try:
-            query = Memory.delete().where(Memory.id == id, Memory.user_id == user_id)
-            query.execute()
-
+            db.query(Memory).filter_by(id=id, user_id=user_id).delete()
             return True
             return True
         except:
         except:
             return False
             return False
 
 
 
 
-Memories = MemoriesTable(DB)
+Memories = MemoriesTable()

+ 38 - 41
backend/apps/webui/models/models.py

@@ -2,13 +2,11 @@ import json
 import logging
 import logging
 from typing import Optional
 from typing import Optional
 
 
-import peewee as pw
-from peewee import *
-
-from playhouse.shortcuts import model_to_dict
 from pydantic import BaseModel, ConfigDict
 from pydantic import BaseModel, ConfigDict
+from sqlalchemy import String, Column, BigInteger
+from sqlalchemy.orm import Session
 
 
-from apps.webui.internal.db import DB, JSONField
+from apps.webui.internal.db import Base, JSONField
 
 
 from typing import List, Union, Optional
 from typing import List, Union, Optional
 from config import SRC_LOG_LEVELS
 from config import SRC_LOG_LEVELS
@@ -46,41 +44,42 @@ class ModelMeta(BaseModel):
     pass
     pass
 
 
 
 
-class Model(pw.Model):
-    id = pw.TextField(unique=True)
+class Model(Base):
+    __tablename__ = "model"
+
+    id = Column(String, primary_key=True)
     """
     """
         The model's id as used in the API. If set to an existing model, it will override the model.
         The model's id as used in the API. If set to an existing model, it will override the model.
     """
     """
-    user_id = pw.TextField()
+    user_id = Column(String)
 
 
-    base_model_id = pw.TextField(null=True)
+    base_model_id = Column(String, nullable=True)
     """
     """
         An optional pointer to the actual model that should be used when proxying requests.
         An optional pointer to the actual model that should be used when proxying requests.
     """
     """
 
 
-    name = pw.TextField()
+    name = Column(String)
     """
     """
         The human-readable display name of the model.
         The human-readable display name of the model.
     """
     """
 
 
-    params = JSONField()
+    params = Column(JSONField)
     """
     """
         Holds a JSON encoded blob of parameters, see `ModelParams`.
         Holds a JSON encoded blob of parameters, see `ModelParams`.
     """
     """
 
 
-    meta = JSONField()
+    meta = Column(JSONField)
     """
     """
         Holds a JSON encoded blob of metadata, see `ModelMeta`.
         Holds a JSON encoded blob of metadata, see `ModelMeta`.
     """
     """
 
 
-    updated_at = BigIntegerField()
-    created_at = BigIntegerField()
-
-    class Meta:
-        database = DB
+    updated_at = Column(BigInteger)
+    created_at = Column(BigInteger)
 
 
 
 
 class ModelModel(BaseModel):
 class ModelModel(BaseModel):
+    model_config = ConfigDict(from_attributes=True)
+
     id: str
     id: str
     user_id: str
     user_id: str
     base_model_id: Optional[str] = None
     base_model_id: Optional[str] = None
@@ -115,15 +114,9 @@ class ModelForm(BaseModel):
 
 
 
 
 class ModelsTable:
 class ModelsTable:
-    def __init__(
-        self,
-        db: pw.SqliteDatabase | pw.PostgresqlDatabase,
-    ):
-        self.db = db
-        self.db.create_tables([Model])
 
 
     def insert_new_model(
     def insert_new_model(
-        self, form_data: ModelForm, user_id: str
+        self, db: Session, form_data: ModelForm, user_id: str
     ) -> Optional[ModelModel]:
     ) -> Optional[ModelModel]:
         model = ModelModel(
         model = ModelModel(
             **{
             **{
@@ -134,46 +127,50 @@ class ModelsTable:
             }
             }
         )
         )
         try:
         try:
-            result = Model.create(**model.model_dump())
+            result = Model(**model.dict())
+            db.add(result)
+            db.commit()
+            db.refresh(result)
 
 
             if result:
             if result:
-                return model
+                return ModelModel.model_validate(result)
             else:
             else:
                 return None
                 return None
         except Exception as e:
         except Exception as e:
             print(e)
             print(e)
             return None
             return None
 
 
-    def get_all_models(self) -> List[ModelModel]:
-        return [ModelModel(**model_to_dict(model)) for model in Model.select()]
+    def get_all_models(self, db: Session) -> List[ModelModel]:
+        return [ModelModel.model_validate(model) for model in db.query(Model).all()]
 
 
-    def get_model_by_id(self, id: str) -> Optional[ModelModel]:
+    def get_model_by_id(self, db: Session, id: str) -> Optional[ModelModel]:
         try:
         try:
-            model = Model.get(Model.id == id)
-            return ModelModel(**model_to_dict(model))
+            model = db.get(Model, id)
+            return ModelModel.model_validate(model)
         except:
         except:
             return None
             return None
 
 
-    def update_model_by_id(self, id: str, model: ModelForm) -> Optional[ModelModel]:
+    def update_model_by_id(
+        self, db: Session, id: str, model: ModelForm
+    ) -> Optional[ModelModel]:
         try:
         try:
             # update only the fields that are present in the model
             # update only the fields that are present in the model
-            query = Model.update(**model.model_dump()).where(Model.id == id)
-            query.execute()
-
-            model = Model.get(Model.id == id)
-            return ModelModel(**model_to_dict(model))
+            model = db.query(Model).get(id)
+            model.update(**model.model_dump())
+            db.commit()
+            db.refresh(model)
+            return ModelModel.model_validate(model)
         except Exception as e:
         except Exception as e:
             print(e)
             print(e)
 
 
             return None
             return None
 
 
-    def delete_model_by_id(self, id: str) -> bool:
+    def delete_model_by_id(self, db: Session, id: str) -> bool:
         try:
         try:
-            query = Model.delete().where(Model.id == id)
-            query.execute()
+            db.query(Model).filter_by(id=id).delete()
             return True
             return True
         except:
         except:
             return False
             return False
 
 
 
 
-Models = ModelsTable(DB)
+Models = ModelsTable()

+ 38 - 48
backend/apps/webui/models/prompts.py

@@ -1,13 +1,11 @@
-from pydantic import BaseModel
-from peewee import *
-from playhouse.shortcuts import model_to_dict
-from typing import List, Union, Optional
+from pydantic import BaseModel, ConfigDict
+from typing import List, Optional
 import time
 import time
 
 
-from utils.utils import decode_token
-from utils.misc import get_gravatar_url
+from sqlalchemy import String, Column, BigInteger
+from sqlalchemy.orm import Session
 
 
-from apps.webui.internal.db import DB
+from apps.webui.internal.db import Base
 
 
 import json
 import json
 
 
@@ -16,15 +14,14 @@ import json
 ####################
 ####################
 
 
 
 
-class Prompt(Model):
-    command = CharField(unique=True)
-    user_id = CharField()
-    title = TextField()
-    content = TextField()
-    timestamp = BigIntegerField()
+class Prompt(Base):
+    __tablename__ = "prompt"
 
 
-    class Meta:
-        database = DB
+    command = Column(String, primary_key=True)
+    user_id = Column(String)
+    title = Column(String)
+    content = Column(String)
+    timestamp = Column(BigInteger)
 
 
 
 
 class PromptModel(BaseModel):
 class PromptModel(BaseModel):
@@ -34,6 +31,8 @@ class PromptModel(BaseModel):
     content: str
     content: str
     timestamp: int  # timestamp in epoch
     timestamp: int  # timestamp in epoch
 
 
+    model_config = ConfigDict(from_attributes=True)
+
 
 
 ####################
 ####################
 # Forms
 # Forms
@@ -48,12 +47,8 @@ class PromptForm(BaseModel):
 
 
 class PromptsTable:
 class PromptsTable:
 
 
-    def __init__(self, db):
-        self.db = db
-        self.db.create_tables([Prompt])
-
     def insert_new_prompt(
     def insert_new_prompt(
-        self, user_id: str, form_data: PromptForm
+        self, db: Session, user_id: str, form_data: PromptForm
     ) -> Optional[PromptModel]:
     ) -> Optional[PromptModel]:
         prompt = PromptModel(
         prompt = PromptModel(
             **{
             **{
@@ -66,53 +61,48 @@ class PromptsTable:
         )
         )
 
 
         try:
         try:
-            result = Prompt.create(**prompt.model_dump())
+            result = Prompt(**prompt.dict())
+            db.add(result)
+            db.commit()
+            db.refresh(result)
             if result:
             if result:
-                return prompt
+                return PromptModel.model_validate(result)
             else:
             else:
                 return None
                 return None
-        except:
+        except Exception as e:
             return None
             return None
 
 
-    def get_prompt_by_command(self, command: str) -> Optional[PromptModel]:
+    def get_prompt_by_command(self, db: Session, command: str) -> Optional[PromptModel]:
         try:
         try:
-            prompt = Prompt.get(Prompt.command == command)
-            return PromptModel(**model_to_dict(prompt))
+            prompt = db.query(Prompt).filter_by(command=command).first()
+            return PromptModel.model_validate(prompt)
         except:
         except:
             return None
             return None
 
 
-    def get_prompts(self) -> List[PromptModel]:
-        return [
-            PromptModel(**model_to_dict(prompt))
-            for prompt in Prompt.select()
-            # .limit(limit).offset(skip)
-        ]
+    def get_prompts(self, db: Session) -> List[PromptModel]:
+        return [PromptModel.model_validate(prompt) for prompt in db.query(Prompt).all()]
 
 
     def update_prompt_by_command(
     def update_prompt_by_command(
-        self, command: str, form_data: PromptForm
+        self, db: Session, command: str, form_data: PromptForm
     ) -> Optional[PromptModel]:
     ) -> Optional[PromptModel]:
         try:
         try:
-            query = Prompt.update(
-                title=form_data.title,
-                content=form_data.content,
-                timestamp=int(time.time()),
-            ).where(Prompt.command == command)
-
-            query.execute()
-
-            prompt = Prompt.get(Prompt.command == command)
-            return PromptModel(**model_to_dict(prompt))
+            db.query(Prompt).filter_by(command=command).update(
+                {
+                    "title": form_data.title,
+                    "content": form_data.content,
+                    "timestamp": int(time.time()),
+                }
+            )
+            return self.get_prompt_by_command(db, command)
         except:
         except:
             return None
             return None
 
 
-    def delete_prompt_by_command(self, command: str) -> bool:
+    def delete_prompt_by_command(self, db: Session, command: str) -> bool:
         try:
         try:
-            query = Prompt.delete().where((Prompt.command == command))
-            query.execute()  # Remove the rows, return number of rows removed.
-
+            db.query(Prompt).filter_by(command=command).delete()
             return True
             return True
         except:
         except:
             return False
             return False
 
 
 
 
-Prompts = PromptsTable(DB)
+Prompts = PromptsTable()

+ 109 - 89
backend/apps/webui/models/tags.py

@@ -1,14 +1,15 @@
-from pydantic import BaseModel
-from typing import List, Union, Optional
-from peewee import *
-from playhouse.shortcuts import model_to_dict
+from pydantic import BaseModel, ConfigDict
+from typing import List, Optional
 
 
 import json
 import json
 import uuid
 import uuid
 import time
 import time
 import logging
 import logging
 
 
-from apps.webui.internal.db import DB
+from sqlalchemy import String, Column, BigInteger
+from sqlalchemy.orm import Session
+
+from apps.webui.internal.db import Base
 
 
 from config import SRC_LOG_LEVELS
 from config import SRC_LOG_LEVELS
 
 
@@ -20,25 +21,23 @@ log.setLevel(SRC_LOG_LEVELS["MODELS"])
 ####################
 ####################
 
 
 
 
-class Tag(Model):
-    id = CharField(unique=True)
-    name = CharField()
-    user_id = CharField()
-    data = TextField(null=True)
+class Tag(Base):
+    __tablename__ = "tag"
 
 
-    class Meta:
-        database = DB
+    id = Column(String, primary_key=True)
+    name = Column(String)
+    user_id = Column(String)
+    data = Column(String, nullable=True)
 
 
 
 
-class ChatIdTag(Model):
-    id = CharField(unique=True)
-    tag_name = CharField()
-    chat_id = CharField()
-    user_id = CharField()
-    timestamp = BigIntegerField()
+class ChatIdTag(Base):
+    __tablename__ = "chatidtag"
 
 
-    class Meta:
-        database = DB
+    id = Column(String, primary_key=True)
+    tag_name = Column(String)
+    chat_id = Column(String)
+    user_id = Column(String)
+    timestamp = Column(BigInteger)
 
 
 
 
 class TagModel(BaseModel):
 class TagModel(BaseModel):
@@ -47,6 +46,8 @@ class TagModel(BaseModel):
     user_id: str
     user_id: str
     data: Optional[str] = None
     data: Optional[str] = None
 
 
+    model_config = ConfigDict(from_attributes=True)
+
 
 
 class ChatIdTagModel(BaseModel):
 class ChatIdTagModel(BaseModel):
     id: str
     id: str
@@ -55,6 +56,8 @@ class ChatIdTagModel(BaseModel):
     user_id: str
     user_id: str
     timestamp: int
     timestamp: int
 
 
+    model_config = ConfigDict(from_attributes=True)
+
 
 
 ####################
 ####################
 # Forms
 # Forms
@@ -75,37 +78,39 @@ class ChatTagsResponse(BaseModel):
 
 
 
 
 class TagTable:
 class TagTable:
-    def __init__(self, db):
-        self.db = db
-        db.create_tables([Tag, ChatIdTag])
 
 
-    def insert_new_tag(self, name: str, user_id: str) -> Optional[TagModel]:
+    def insert_new_tag(
+        self, db: Session, name: str, user_id: str
+    ) -> Optional[TagModel]:
         id = str(uuid.uuid4())
         id = str(uuid.uuid4())
         tag = TagModel(**{"id": id, "user_id": user_id, "name": name})
         tag = TagModel(**{"id": id, "user_id": user_id, "name": name})
         try:
         try:
-            result = Tag.create(**tag.model_dump())
+            result = Tag(**tag.dict())
+            db.add(result)
+            db.commit()
+            db.refresh(result)
             if result:
             if result:
-                return tag
+                return TagModel.model_validate(result)
             else:
             else:
                 return None
                 return None
         except Exception as e:
         except Exception as e:
             return None
             return None
 
 
     def get_tag_by_name_and_user_id(
     def get_tag_by_name_and_user_id(
-        self, name: str, user_id: str
+        self, db: Session, name: str, user_id: str
     ) -> Optional[TagModel]:
     ) -> Optional[TagModel]:
         try:
         try:
-            tag = Tag.get(Tag.name == name, Tag.user_id == user_id)
-            return TagModel(**model_to_dict(tag))
+            tag = db.query(Tag).filter(name=name, user_id=user_id).first()
+            return TagModel.model_validate(tag)
         except Exception as e:
         except Exception as e:
             return None
             return None
 
 
     def add_tag_to_chat(
     def add_tag_to_chat(
-        self, user_id: str, form_data: ChatIdTagForm
+        self, db: Session, user_id: str, form_data: ChatIdTagForm
     ) -> Optional[ChatIdTagModel]:
     ) -> Optional[ChatIdTagModel]:
-        tag = self.get_tag_by_name_and_user_id(form_data.tag_name, user_id)
+        tag = self.get_tag_by_name_and_user_id(db, form_data.tag_name, user_id)
         if tag == None:
         if tag == None:
-            tag = self.insert_new_tag(form_data.tag_name, user_id)
+            tag = self.insert_new_tag(db, form_data.tag_name, user_id)
 
 
         id = str(uuid.uuid4())
         id = str(uuid.uuid4())
         chatIdTag = ChatIdTagModel(
         chatIdTag = ChatIdTagModel(
@@ -118,120 +123,135 @@ class TagTable:
             }
             }
         )
         )
         try:
         try:
-            result = ChatIdTag.create(**chatIdTag.model_dump())
+            result = ChatIdTag(**chatIdTag.dict())
+            db.add(result)
+            db.commit()
+            db.refresh(result)
             if result:
             if result:
-                return chatIdTag
+                return ChatIdTagModel.model_validate(result)
             else:
             else:
                 return None
                 return None
         except:
         except:
             return None
             return None
 
 
-    def get_tags_by_user_id(self, user_id: str) -> List[TagModel]:
+    def get_tags_by_user_id(self, db: Session, user_id: str) -> List[TagModel]:
         tag_names = [
         tag_names = [
-            ChatIdTagModel(**model_to_dict(chat_id_tag)).tag_name
-            for chat_id_tag in ChatIdTag.select()
-            .where(ChatIdTag.user_id == user_id)
-            .order_by(ChatIdTag.timestamp.desc())
+            chat_id_tag.tag_name
+            for chat_id_tag in (
+                db.query(ChatIdTag)
+                .filter_by(user_id=user_id)
+                .order_by(ChatIdTag.timestamp.desc())
+                .all()
+            )
         ]
         ]
 
 
         return [
         return [
-            TagModel(**model_to_dict(tag))
-            for tag in Tag.select()
-            .where(Tag.user_id == user_id)
-            .where(Tag.name.in_(tag_names))
+            TagModel.model_validate(tag)
+            for tag in (
+                db.query(Tag)
+                .filter_by(user_id=user_id)
+                .filter(Tag.name.in_(tag_names))
+                .all()
+            )
         ]
         ]
 
 
     def get_tags_by_chat_id_and_user_id(
     def get_tags_by_chat_id_and_user_id(
-        self, chat_id: str, user_id: str
+        self, db: Session, chat_id: str, user_id: str
     ) -> List[TagModel]:
     ) -> List[TagModel]:
         tag_names = [
         tag_names = [
-            ChatIdTagModel(**model_to_dict(chat_id_tag)).tag_name
-            for chat_id_tag in ChatIdTag.select()
-            .where((ChatIdTag.user_id == user_id) & (ChatIdTag.chat_id == chat_id))
-            .order_by(ChatIdTag.timestamp.desc())
+            chat_id_tag.tag_name
+            for chat_id_tag in (
+                db.query(ChatIdTag)
+                .filter_by(user_id=user_id, chat_id=chat_id)
+                .order_by(ChatIdTag.timestamp.desc())
+                .all()
+            )
         ]
         ]
 
 
         return [
         return [
-            TagModel(**model_to_dict(tag))
-            for tag in Tag.select()
-            .where(Tag.user_id == user_id)
-            .where(Tag.name.in_(tag_names))
+            TagModel.model_validate(tag)
+            for tag in (
+                db.query(Tag)
+                .filter_by(user_id=user_id)
+                .filter(Tag.name.in_(tag_names))
+                .all()
+            )
         ]
         ]
 
 
     def get_chat_ids_by_tag_name_and_user_id(
     def get_chat_ids_by_tag_name_and_user_id(
-        self, tag_name: str, user_id: str
-    ) -> Optional[ChatIdTagModel]:
+        self, db: Session, tag_name: str, user_id: str
+    ) -> List[ChatIdTagModel]:
         return [
         return [
-            ChatIdTagModel(**model_to_dict(chat_id_tag))
-            for chat_id_tag in ChatIdTag.select()
-            .where((ChatIdTag.user_id == user_id) & (ChatIdTag.tag_name == tag_name))
-            .order_by(ChatIdTag.timestamp.desc())
+            ChatIdTagModel.model_validate(chat_id_tag)
+            for chat_id_tag in (
+                db.query(ChatIdTag)
+                .filter_by(user_id=user_id, tag_name=tag_name)
+                .order_by(ChatIdTag.timestamp.desc())
+                .all()
+            )
         ]
         ]
 
 
     def count_chat_ids_by_tag_name_and_user_id(
     def count_chat_ids_by_tag_name_and_user_id(
-        self, tag_name: str, user_id: str
+        self, db: Session, tag_name: str, user_id: str
     ) -> int:
     ) -> int:
-        return (
-            ChatIdTag.select()
-            .where((ChatIdTag.tag_name == tag_name) & (ChatIdTag.user_id == user_id))
-            .count()
-        )
+        return db.query(ChatIdTag).filter_by(tag_name=tag_name, user_id=user_id).count()
 
 
-    def delete_tag_by_tag_name_and_user_id(self, tag_name: str, user_id: str) -> bool:
+    def delete_tag_by_tag_name_and_user_id(
+        self, db: Session, tag_name: str, user_id: str
+    ) -> bool:
         try:
         try:
-            query = ChatIdTag.delete().where(
-                (ChatIdTag.tag_name == tag_name) & (ChatIdTag.user_id == user_id)
+            res = (
+                db.query(ChatIdTag)
+                .filter_by(tag_name=tag_name, user_id=user_id)
+                .delete()
             )
             )
-            res = query.execute()  # Remove the rows, return number of rows removed.
             log.debug(f"res: {res}")
             log.debug(f"res: {res}")
 
 
-            tag_count = self.count_chat_ids_by_tag_name_and_user_id(tag_name, user_id)
+            tag_count = self.count_chat_ids_by_tag_name_and_user_id(
+                db, tag_name, user_id
+            )
             if tag_count == 0:
             if tag_count == 0:
                 # Remove tag item from Tag col as well
                 # Remove tag item from Tag col as well
-                query = Tag.delete().where(
-                    (Tag.name == tag_name) & (Tag.user_id == user_id)
-                )
-                query.execute()  # Remove the rows, return number of rows removed.
-
+                db.query(Tag).filter_by(name=tag_name, user_id=user_id).delete()
             return True
             return True
         except Exception as e:
         except Exception as e:
             log.error(f"delete_tag: {e}")
             log.error(f"delete_tag: {e}")
             return False
             return False
 
 
     def delete_tag_by_tag_name_and_chat_id_and_user_id(
     def delete_tag_by_tag_name_and_chat_id_and_user_id(
-        self, tag_name: str, chat_id: str, user_id: str
+        self, db: Session, tag_name: str, chat_id: str, user_id: str
     ) -> bool:
     ) -> bool:
         try:
         try:
-            query = ChatIdTag.delete().where(
-                (ChatIdTag.tag_name == tag_name)
-                & (ChatIdTag.chat_id == chat_id)
-                & (ChatIdTag.user_id == user_id)
+            res = (
+                db.query(ChatIdTag)
+                .filter_by(tag_name=tag_name, chat_id=chat_id, user_id=user_id)
+                .delete()
             )
             )
-            res = query.execute()  # Remove the rows, return number of rows removed.
             log.debug(f"res: {res}")
             log.debug(f"res: {res}")
 
 
-            tag_count = self.count_chat_ids_by_tag_name_and_user_id(tag_name, user_id)
+            tag_count = self.count_chat_ids_by_tag_name_and_user_id(
+                db, tag_name, user_id
+            )
             if tag_count == 0:
             if tag_count == 0:
                 # Remove tag item from Tag col as well
                 # Remove tag item from Tag col as well
-                query = Tag.delete().where(
-                    (Tag.name == tag_name) & (Tag.user_id == user_id)
-                )
-                query.execute()  # Remove the rows, return number of rows removed.
+                db.query(Tag).filter_by(name=tag_name, user_id=user_id).delete()
 
 
             return True
             return True
         except Exception as e:
         except Exception as e:
             log.error(f"delete_tag: {e}")
             log.error(f"delete_tag: {e}")
             return False
             return False
 
 
-    def delete_tags_by_chat_id_and_user_id(self, chat_id: str, user_id: str) -> bool:
-        tags = self.get_tags_by_chat_id_and_user_id(chat_id, user_id)
+    def delete_tags_by_chat_id_and_user_id(
+        self, db: Session, chat_id: str, user_id: str
+    ) -> bool:
+        tags = self.get_tags_by_chat_id_and_user_id(db, chat_id, user_id)
 
 
         for tag in tags:
         for tag in tags:
             self.delete_tag_by_tag_name_and_chat_id_and_user_id(
             self.delete_tag_by_tag_name_and_chat_id_and_user_id(
-                tag.tag_name, chat_id, user_id
+                db, tag.tag_name, chat_id, user_id
             )
             )
 
 
         return True
         return True
 
 
 
 
-Tags = TagTable(DB)
+Tags = TagTable()

+ 37 - 41
backend/apps/webui/models/tools.py

@@ -1,10 +1,11 @@
-from pydantic import BaseModel
-from peewee import *
-from playhouse.shortcuts import model_to_dict
-from typing import List, Union, Optional
+from pydantic import BaseModel, ConfigDict
+from typing import List, Optional
 import time
 import time
 import logging
 import logging
-from apps.webui.internal.db import DB, JSONField
+from sqlalchemy import String, Column, BigInteger
+from sqlalchemy.orm import Session
+
+from apps.webui.internal.db import Base, JSONField
 from apps.webui.models.users import Users
 from apps.webui.models.users import Users
 
 
 import json
 import json
@@ -21,19 +22,18 @@ log.setLevel(SRC_LOG_LEVELS["MODELS"])
 ####################
 ####################
 
 
 
 
-class Tool(Model):
-    id = CharField(unique=True)
-    user_id = CharField()
-    name = TextField()
-    content = TextField()
-    specs = JSONField()
-    meta = JSONField()
-    valves = JSONField()
-    updated_at = BigIntegerField()
-    created_at = BigIntegerField()
+class Tool(Base):
+    __tablename__ = "tool"
 
 
-    class Meta:
-        database = DB
+    id = Column(String, primary_key=True)
+    user_id = Column(String)
+    name = Column(String)
+    content = Column(String)
+    specs = Column(JSONField)
+    meta = Column(JSONField)
+    valves = Column(JSONField)
+    updated_at = Column(BigInteger)
+    created_at = Column(BigInteger)
 
 
 
 
 class ToolMeta(BaseModel):
 class ToolMeta(BaseModel):
@@ -51,6 +51,8 @@ class ToolModel(BaseModel):
     updated_at: int  # timestamp in epoch
     updated_at: int  # timestamp in epoch
     created_at: int  # timestamp in epoch
     created_at: int  # timestamp in epoch
 
 
+    model_config = ConfigDict(from_attributes=True)
+
 
 
 ####################
 ####################
 # Forms
 # Forms
@@ -78,12 +80,9 @@ class ToolValves(BaseModel):
 
 
 
 
 class ToolsTable:
 class ToolsTable:
-    def __init__(self, db):
-        self.db = db
-        self.db.create_tables([Tool])
 
 
     def insert_new_tool(
     def insert_new_tool(
-        self, user_id: str, form_data: ToolForm, specs: List[dict]
+        self, db: Session, user_id: str, form_data: ToolForm, specs: List[dict]
     ) -> Optional[ToolModel]:
     ) -> Optional[ToolModel]:
         tool = ToolModel(
         tool = ToolModel(
             **{
             **{
@@ -96,24 +95,27 @@ class ToolsTable:
         )
         )
 
 
         try:
         try:
-            result = Tool.create(**tool.model_dump())
+            result = Tool(**tool.dict())
+            db.add(result)
+            db.commit()
+            db.refresh(result)
             if result:
             if result:
-                return tool
+                return ToolModel.model_validate(result)
             else:
             else:
                 return None
                 return None
         except Exception as e:
         except Exception as e:
             print(f"Error creating tool: {e}")
             print(f"Error creating tool: {e}")
             return None
             return None
 
 
-    def get_tool_by_id(self, id: str) -> Optional[ToolModel]:
+    def get_tool_by_id(self, db: Session, id: str) -> Optional[ToolModel]:
         try:
         try:
-            tool = Tool.get(Tool.id == id)
-            return ToolModel(**model_to_dict(tool))
+            tool = db.get(Tool, id)
+            return ToolModel.model_validate(tool)
         except:
         except:
             return None
             return None
 
 
-    def get_tools(self) -> List[ToolModel]:
-        return [ToolModel(**model_to_dict(tool)) for tool in Tool.select()]
+    def get_tools(self, db: Session) -> List[ToolModel]:
+        return [ToolModel.model_validate(tool) for tool in db.query(Tool).all()]
 
 
     def get_tool_valves_by_id(self, id: str) -> Optional[dict]:
     def get_tool_valves_by_id(self, id: str) -> Optional[dict]:
         try:
         try:
@@ -180,25 +182,19 @@ class ToolsTable:
 
 
     def update_tool_by_id(self, id: str, updated: dict) -> Optional[ToolModel]:
     def update_tool_by_id(self, id: str, updated: dict) -> Optional[ToolModel]:
         try:
         try:
-            query = Tool.update(
-                **updated,
-                updated_at=int(time.time()),
-            ).where(Tool.id == id)
-            query.execute()
-
-            tool = Tool.get(Tool.id == id)
-            return ToolModel(**model_to_dict(tool))
+            db.query(Tool).filter_by(id=id).update(
+                {**updated, "updated_at": int(time.time())}
+            )
+            return self.get_tool_by_id(db, id)
         except:
         except:
             return None
             return None
 
 
-    def delete_tool_by_id(self, id: str) -> bool:
+    def delete_tool_by_id(self, db: Session, id: str) -> bool:
         try:
         try:
-            query = Tool.delete().where((Tool.id == id))
-            query.execute()  # Remove the rows, return number of rows removed.
-
+            db.query(Tool).filter_by(id=id).delete()
             return True
             return True
         except:
         except:
             return False
             return False
 
 
 
 
-Tools = ToolsTable(DB)
+Tools = ToolsTable()

+ 92 - 83
backend/apps/webui/models/users.py

@@ -1,11 +1,13 @@
-from pydantic import BaseModel, ConfigDict
-from peewee import *
-from playhouse.shortcuts import model_to_dict
+from pydantic import BaseModel, ConfigDict, parse_obj_as
 from typing import List, Union, Optional
 from typing import List, Union, Optional
 import time
 import time
+
+from sqlalchemy import String, Column, BigInteger, Text
+from sqlalchemy.orm import Session
+
 from utils.misc import get_gravatar_url
 from utils.misc import get_gravatar_url
 
 
-from apps.webui.internal.db import DB, JSONField
+from apps.webui.internal.db import Base, JSONField
 from apps.webui.models.chats import Chats
 from apps.webui.models.chats import Chats
 
 
 ####################
 ####################
@@ -13,25 +15,24 @@ from apps.webui.models.chats import Chats
 ####################
 ####################
 
 
 
 
-class User(Model):
-    id = CharField(unique=True)
-    name = CharField()
-    email = CharField()
-    role = CharField()
-    profile_image_url = TextField()
+class User(Base):
+    __tablename__ = "user"
 
 
-    last_active_at = BigIntegerField()
-    updated_at = BigIntegerField()
-    created_at = BigIntegerField()
+    id = Column(String, primary_key=True)
+    name = Column(String)
+    email = Column(String)
+    role = Column(String)
+    profile_image_url = Column(String)
 
 
-    api_key = CharField(null=True, unique=True)
-    settings = JSONField(null=True)
-    info = JSONField(null=True)
+    last_active_at = Column(BigInteger)
+    updated_at = Column(BigInteger)
+    created_at = Column(BigInteger)
 
 
-    oauth_sub = TextField(null=True, unique=True)
+    api_key = Column(String, nullable=True, unique=True)
+    settings = Column(JSONField, nullable=True)
+    info = Column(JSONField, nullable=True)
 
 
-    class Meta:
-        database = DB
+    oauth_sub = Column(Text, unique=True)
 
 
 
 
 class UserSettings(BaseModel):
 class UserSettings(BaseModel):
@@ -41,6 +42,8 @@ class UserSettings(BaseModel):
 
 
 
 
 class UserModel(BaseModel):
 class UserModel(BaseModel):
+    model_config = ConfigDict(from_attributes=True)
+
     id: str
     id: str
     name: str
     name: str
     email: str
     email: str
@@ -76,12 +79,10 @@ class UserUpdateForm(BaseModel):
 
 
 
 
 class UsersTable:
 class UsersTable:
-    def __init__(self, db):
-        self.db = db
-        self.db.create_tables([User])
 
 
     def insert_new_user(
     def insert_new_user(
         self,
         self,
+        db: Session,
         id: str,
         id: str,
         name: str,
         name: str,
         email: str,
         email: str,
@@ -102,30 +103,33 @@ class UsersTable:
                 "oauth_sub": oauth_sub,
                 "oauth_sub": oauth_sub,
             }
             }
         )
         )
-        result = User.create(**user.model_dump())
+        result = User(**user.model_dump())
+        db.add(result)
+        db.commit()
+        db.refresh(result)
         if result:
         if result:
             return user
             return user
         else:
         else:
             return None
             return None
 
 
-    def get_user_by_id(self, id: str) -> Optional[UserModel]:
+    def get_user_by_id(self, db: Session, id: str) -> Optional[UserModel]:
         try:
         try:
-            user = User.get(User.id == id)
-            return UserModel(**model_to_dict(user))
-        except:
+            user = db.query(User).filter_by(id=id).first()
+            return UserModel.model_validate(user)
+        except Exception as e:
             return None
             return None
 
 
-    def get_user_by_api_key(self, api_key: str) -> Optional[UserModel]:
+    def get_user_by_api_key(self, db: Session, api_key: str) -> Optional[UserModel]:
         try:
         try:
-            user = User.get(User.api_key == api_key)
-            return UserModel(**model_to_dict(user))
+            user = db.query(User).filter_by(api_key=api_key).first()
+            return UserModel.model_validate(user)
         except:
         except:
             return None
             return None
 
 
-    def get_user_by_email(self, email: str) -> Optional[UserModel]:
+    def get_user_by_email(self, db: Session, email: str) -> Optional[UserModel]:
         try:
         try:
-            user = User.get(User.email == email)
-            return UserModel(**model_to_dict(user))
+            user = db.query(User).filter_by(email=email).first()
+            return UserModel.model_validate(user)
         except:
         except:
             return None
             return None
 
 
@@ -136,88 +140,94 @@ class UsersTable:
         except:
         except:
             return None
             return None
 
 
-    def get_users(self, skip: int = 0, limit: int = 50) -> List[UserModel]:
-        return [
-            UserModel(**model_to_dict(user))
-            for user in User.select()
-            # .limit(limit).offset(skip)
-        ]
+    def get_users(self, db: Session, skip: int = 0, limit: int = 50) -> List[UserModel]:
+        users = (
+            db.query(User)
+            # .offset(skip).limit(limit)
+            .all()
+        )
+        return [UserModel.model_validate(user) for user in users]
 
 
-    def get_num_users(self) -> Optional[int]:
-        return User.select().count()
+    def get_num_users(self, db: Session) -> Optional[int]:
+        return db.query(User).count()
 
 
-    def get_first_user(self) -> UserModel:
+    def get_first_user(self, db: Session) -> UserModel:
         try:
         try:
-            user = User.select().order_by(User.created_at).first()
-            return UserModel(**model_to_dict(user))
+            user = db.query(User).order_by(User.created_at).first()
+            return UserModel.model_validate(user)
         except:
         except:
             return None
             return None
 
 
-    def update_user_role_by_id(self, id: str, role: str) -> Optional[UserModel]:
+    def update_user_role_by_id(
+        self, db: Session, id: str, role: str
+    ) -> Optional[UserModel]:
         try:
         try:
-            query = User.update(role=role).where(User.id == id)
-            query.execute()
+            db.query(User).filter_by(id=id).update({"role": role})
+            db.commit()
 
 
-            user = User.get(User.id == id)
-            return UserModel(**model_to_dict(user))
+            user = db.query(User).filter_by(id=id).first()
+            return UserModel.model_validate(user)
         except:
         except:
             return None
             return None
 
 
     def update_user_profile_image_url_by_id(
     def update_user_profile_image_url_by_id(
-        self, id: str, profile_image_url: str
+        self, db: Session, id: str, profile_image_url: str
     ) -> Optional[UserModel]:
     ) -> Optional[UserModel]:
         try:
         try:
-            query = User.update(profile_image_url=profile_image_url).where(
-                User.id == id
+            db.query(User).filter_by(id=id).update(
+                {"profile_image_url": profile_image_url}
             )
             )
-            query.execute()
+            db.commit()
 
 
-            user = User.get(User.id == id)
-            return UserModel(**model_to_dict(user))
+            user = db.query(User).filter_by(id=id).first()
+            return UserModel.model_validate(user)
         except:
         except:
             return None
             return None
 
 
-    def update_user_last_active_by_id(self, id: str) -> Optional[UserModel]:
+    def update_user_last_active_by_id(
+        self, db: Session, id: str
+    ) -> Optional[UserModel]:
         try:
         try:
-            query = User.update(last_active_at=int(time.time())).where(User.id == id)
-            query.execute()
+            db.query(User).filter_by(id=id).update({"last_active_at": int(time.time())})
 
 
-            user = User.get(User.id == id)
-            return UserModel(**model_to_dict(user))
+            user = db.query(User).filter_by(id=id).first()
+            return UserModel.model_validate(user)
         except:
         except:
             return None
             return None
 
 
     def update_user_oauth_sub_by_id(
     def update_user_oauth_sub_by_id(
-        self, id: str, oauth_sub: str
+        self, db: Session, id: str, oauth_sub: str
     ) -> Optional[UserModel]:
     ) -> Optional[UserModel]:
         try:
         try:
-            query = User.update(oauth_sub=oauth_sub).where(User.id == id)
-            query.execute()
+            db.query(User).filter_by(id=id).update({"oauth_sub": oauth_sub})
 
 
-            user = User.get(User.id == id)
-            return UserModel(**model_to_dict(user))
+            user = db.query(User).filter_by(id=id).first()
+            return UserModel.model_validate(user)
         except:
         except:
             return None
             return None
 
 
-    def update_user_by_id(self, id: str, updated: dict) -> Optional[UserModel]:
+    def update_user_by_id(
+        self, db: Session, id: str, updated: dict
+    ) -> Optional[UserModel]:
         try:
         try:
-            query = User.update(**updated).where(User.id == id)
-            query.execute()
+            db.query(User).filter_by(id=id).update(updated)
+            db.commit()
 
 
-            user = User.get(User.id == id)
-            return UserModel(**model_to_dict(user))
-        except:
+            user = db.query(User).filter_by(id=id).first()
+            return UserModel.model_validate(user)
+            # return UserModel(**user.dict())
+        except Exception as e:
             return None
             return None
 
 
-    def delete_user_by_id(self, id: str) -> bool:
+    def delete_user_by_id(self, db: Session, id: str) -> bool:
         try:
         try:
             # Delete User Chats
             # Delete User Chats
-            result = Chats.delete_chats_by_user_id(id)
+            result = Chats.delete_chats_by_user_id(db, id)
 
 
             if result:
             if result:
                 # Delete User
                 # Delete User
-                query = User.delete().where(User.id == id)
-                query.execute()  # Remove the rows, return number of rows removed.
+                db.query(User).filter_by(id=id).delete()
+                db.commit()
 
 
                 return True
                 return True
             else:
             else:
@@ -225,21 +235,20 @@ class UsersTable:
         except:
         except:
             return False
             return False
 
 
-    def update_user_api_key_by_id(self, id: str, api_key: str) -> str:
+    def update_user_api_key_by_id(self, db: Session, id: str, api_key: str) -> str:
         try:
         try:
-            query = User.update(api_key=api_key).where(User.id == id)
-            result = query.execute()
-
+            result = db.query(User).filter_by(id=id).update({"api_key": api_key})
+            db.commit()
             return True if result == 1 else False
             return True if result == 1 else False
         except:
         except:
             return False
             return False
 
 
-    def get_user_api_key_by_id(self, id: str) -> Optional[str]:
+    def get_user_api_key_by_id(self, db: Session, id: str) -> Optional[str]:
         try:
         try:
-            user = User.get(User.id == id)
+            user = db.query(User).filter_by(id=id).first()
             return user.api_key
             return user.api_key
-        except:
+        except Exception as e:
             return None
             return None
 
 
 
 
-Users = UsersTable(DB)
+Users = UsersTable()

+ 40 - 26
backend/apps/webui/routers/auths.py

@@ -10,6 +10,7 @@ import re
 import uuid
 import uuid
 import csv
 import csv
 
 
+from apps.webui.internal.db import get_db
 from apps.webui.models.auths import (
 from apps.webui.models.auths import (
     SigninForm,
     SigninForm,
     SignupForm,
     SignupForm,
@@ -78,10 +79,13 @@ async def get_session_user(
 
 
 @router.post("/update/profile", response_model=UserResponse)
 @router.post("/update/profile", response_model=UserResponse)
 async def update_profile(
 async def update_profile(
-    form_data: UpdateProfileForm, session_user=Depends(get_current_user)
+    form_data: UpdateProfileForm,
+    session_user=Depends(get_current_user),
+    db=Depends(get_db),
 ):
 ):
     if session_user:
     if session_user:
         user = Users.update_user_by_id(
         user = Users.update_user_by_id(
+            db,
             session_user.id,
             session_user.id,
             {"profile_image_url": form_data.profile_image_url, "name": form_data.name},
             {"profile_image_url": form_data.profile_image_url, "name": form_data.name},
         )
         )
@@ -100,16 +104,18 @@ async def update_profile(
 
 
 @router.post("/update/password", response_model=bool)
 @router.post("/update/password", response_model=bool)
 async def update_password(
 async def update_password(
-    form_data: UpdatePasswordForm, session_user=Depends(get_current_user)
+    form_data: UpdatePasswordForm,
+    session_user=Depends(get_current_user),
+    db=Depends(get_db),
 ):
 ):
     if WEBUI_AUTH_TRUSTED_EMAIL_HEADER:
     if WEBUI_AUTH_TRUSTED_EMAIL_HEADER:
         raise HTTPException(400, detail=ERROR_MESSAGES.ACTION_PROHIBITED)
         raise HTTPException(400, detail=ERROR_MESSAGES.ACTION_PROHIBITED)
     if session_user:
     if session_user:
-        user = Auths.authenticate_user(session_user.email, form_data.password)
+        user = Auths.authenticate_user(db, session_user.email, form_data.password)
 
 
         if user:
         if user:
             hashed = get_password_hash(form_data.new_password)
             hashed = get_password_hash(form_data.new_password)
-            return Auths.update_user_password_by_id(user.id, hashed)
+            return Auths.update_user_password_by_id(db, user.id, hashed)
         else:
         else:
             raise HTTPException(400, detail=ERROR_MESSAGES.INVALID_PASSWORD)
             raise HTTPException(400, detail=ERROR_MESSAGES.INVALID_PASSWORD)
     else:
     else:
@@ -122,7 +128,7 @@ async def update_password(
 
 
 
 
 @router.post("/signin", response_model=SigninResponse)
 @router.post("/signin", response_model=SigninResponse)
-async def signin(request: Request, response: Response, form_data: SigninForm):
+async def signin(request: Request, response: Response, form_data: SigninForm, db=Depends(get_db)):
     if WEBUI_AUTH_TRUSTED_EMAIL_HEADER:
     if WEBUI_AUTH_TRUSTED_EMAIL_HEADER:
         if WEBUI_AUTH_TRUSTED_EMAIL_HEADER not in request.headers:
         if WEBUI_AUTH_TRUSTED_EMAIL_HEADER not in request.headers:
             raise HTTPException(400, detail=ERROR_MESSAGES.INVALID_TRUSTED_HEADER)
             raise HTTPException(400, detail=ERROR_MESSAGES.INVALID_TRUSTED_HEADER)
@@ -133,32 +139,34 @@ async def signin(request: Request, response: Response, form_data: SigninForm):
             trusted_name = request.headers.get(
             trusted_name = request.headers.get(
                 WEBUI_AUTH_TRUSTED_NAME_HEADER, trusted_email
                 WEBUI_AUTH_TRUSTED_NAME_HEADER, trusted_email
             )
             )
-        if not Users.get_user_by_email(trusted_email.lower()):
+        if not Users.get_user_by_email(db, trusted_email.lower()):
             await signup(
             await signup(
                 request,
                 request,
                 SignupForm(
                 SignupForm(
                     email=trusted_email, password=str(uuid.uuid4()), name=trusted_name
                     email=trusted_email, password=str(uuid.uuid4()), name=trusted_name
                 ),
                 ),
+                db,
             )
             )
-        user = Auths.authenticate_user_by_trusted_header(trusted_email)
+        user = Auths.authenticate_user_by_trusted_header(db, trusted_email)
     elif WEBUI_AUTH == False:
     elif WEBUI_AUTH == False:
         admin_email = "admin@localhost"
         admin_email = "admin@localhost"
         admin_password = "admin"
         admin_password = "admin"
 
 
-        if Users.get_user_by_email(admin_email.lower()):
-            user = Auths.authenticate_user(admin_email.lower(), admin_password)
+        if Users.get_user_by_email(db, admin_email.lower()):
+            user = Auths.authenticate_user(db, admin_email.lower(), admin_password)
         else:
         else:
-            if Users.get_num_users() != 0:
+            if Users.get_num_users(db) != 0:
                 raise HTTPException(400, detail=ERROR_MESSAGES.EXISTING_USERS)
                 raise HTTPException(400, detail=ERROR_MESSAGES.EXISTING_USERS)
 
 
             await signup(
             await signup(
                 request,
                 request,
                 SignupForm(email=admin_email, password=admin_password, name="User"),
                 SignupForm(email=admin_email, password=admin_password, name="User"),
+                db,
             )
             )
 
 
-            user = Auths.authenticate_user(admin_email.lower(), admin_password)
+            user = Auths.authenticate_user(db, admin_email.lower(), admin_password)
     else:
     else:
-        user = Auths.authenticate_user(form_data.email.lower(), form_data.password)
+        user = Auths.authenticate_user(db, form_data.email.lower(), form_data.password)
 
 
     if user:
     if user:
         token = create_token(
         token = create_token(
@@ -192,7 +200,7 @@ async def signin(request: Request, response: Response, form_data: SigninForm):
 
 
 
 
 @router.post("/signup", response_model=SigninResponse)
 @router.post("/signup", response_model=SigninResponse)
-async def signup(request: Request, response: Response, form_data: SignupForm):
+async def signup(request: Request, response: Response, form_data: SignupForm, db=Depends(get_db)):
     if not request.app.state.config.ENABLE_SIGNUP and WEBUI_AUTH:
     if not request.app.state.config.ENABLE_SIGNUP and WEBUI_AUTH:
         raise HTTPException(
         raise HTTPException(
             status.HTTP_403_FORBIDDEN, detail=ERROR_MESSAGES.ACCESS_PROHIBITED
             status.HTTP_403_FORBIDDEN, detail=ERROR_MESSAGES.ACCESS_PROHIBITED
@@ -203,17 +211,18 @@ async def signup(request: Request, response: Response, form_data: SignupForm):
             status.HTTP_400_BAD_REQUEST, detail=ERROR_MESSAGES.INVALID_EMAIL_FORMAT
             status.HTTP_400_BAD_REQUEST, detail=ERROR_MESSAGES.INVALID_EMAIL_FORMAT
         )
         )
 
 
-    if Users.get_user_by_email(form_data.email.lower()):
+    if Users.get_user_by_email(db, form_data.email.lower()):
         raise HTTPException(400, detail=ERROR_MESSAGES.EMAIL_TAKEN)
         raise HTTPException(400, detail=ERROR_MESSAGES.EMAIL_TAKEN)
 
 
     try:
     try:
         role = (
         role = (
             "admin"
             "admin"
-            if Users.get_num_users() == 0
+            if Users.get_num_users(db) == 0
             else request.app.state.config.DEFAULT_USER_ROLE
             else request.app.state.config.DEFAULT_USER_ROLE
         )
         )
         hashed = get_password_hash(form_data.password)
         hashed = get_password_hash(form_data.password)
         user = Auths.insert_new_auth(
         user = Auths.insert_new_auth(
+            db,
             form_data.email.lower(),
             form_data.email.lower(),
             hashed,
             hashed,
             form_data.name,
             form_data.name,
@@ -267,14 +276,16 @@ async def signup(request: Request, response: Response, form_data: SignupForm):
 
 
 
 
 @router.post("/add", response_model=SigninResponse)
 @router.post("/add", response_model=SigninResponse)
-async def add_user(form_data: AddUserForm, user=Depends(get_admin_user)):
+async def add_user(
+    form_data: AddUserForm, user=Depends(get_admin_user), db=Depends(get_db)
+):
 
 
     if not validate_email_format(form_data.email.lower()):
     if not validate_email_format(form_data.email.lower()):
         raise HTTPException(
         raise HTTPException(
             status.HTTP_400_BAD_REQUEST, detail=ERROR_MESSAGES.INVALID_EMAIL_FORMAT
             status.HTTP_400_BAD_REQUEST, detail=ERROR_MESSAGES.INVALID_EMAIL_FORMAT
         )
         )
 
 
-    if Users.get_user_by_email(form_data.email.lower()):
+    if Users.get_user_by_email(db, form_data.email.lower()):
         raise HTTPException(400, detail=ERROR_MESSAGES.EMAIL_TAKEN)
         raise HTTPException(400, detail=ERROR_MESSAGES.EMAIL_TAKEN)
 
 
     try:
     try:
@@ -282,6 +293,7 @@ async def add_user(form_data: AddUserForm, user=Depends(get_admin_user)):
         print(form_data)
         print(form_data)
         hashed = get_password_hash(form_data.password)
         hashed = get_password_hash(form_data.password)
         user = Auths.insert_new_auth(
         user = Auths.insert_new_auth(
+            db,
             form_data.email.lower(),
             form_data.email.lower(),
             hashed,
             hashed,
             form_data.name,
             form_data.name,
@@ -312,7 +324,9 @@ async def add_user(form_data: AddUserForm, user=Depends(get_admin_user)):
 
 
 
 
 @router.get("/admin/details")
 @router.get("/admin/details")
-async def get_admin_details(request: Request, user=Depends(get_current_user)):
+async def get_admin_details(
+    request: Request, user=Depends(get_current_user), db=Depends(get_db)
+):
     if request.app.state.config.SHOW_ADMIN_DETAILS:
     if request.app.state.config.SHOW_ADMIN_DETAILS:
         admin_email = request.app.state.config.ADMIN_EMAIL
         admin_email = request.app.state.config.ADMIN_EMAIL
         admin_name = None
         admin_name = None
@@ -320,11 +334,11 @@ async def get_admin_details(request: Request, user=Depends(get_current_user)):
         print(admin_email, admin_name)
         print(admin_email, admin_name)
 
 
         if admin_email:
         if admin_email:
-            admin = Users.get_user_by_email(admin_email)
+            admin = Users.get_user_by_email(db, admin_email)
             if admin:
             if admin:
                 admin_name = admin.name
                 admin_name = admin.name
         else:
         else:
-            admin = Users.get_first_user()
+            admin = Users.get_first_user(db)
             if admin:
             if admin:
                 admin_email = admin.email
                 admin_email = admin.email
                 admin_name = admin.name
                 admin_name = admin.name
@@ -397,9 +411,9 @@ async def update_admin_config(
 
 
 # create api key
 # create api key
 @router.post("/api_key", response_model=ApiKey)
 @router.post("/api_key", response_model=ApiKey)
-async def create_api_key_(user=Depends(get_current_user)):
+async def create_api_key_(user=Depends(get_current_user), db=Depends(get_db)):
     api_key = create_api_key()
     api_key = create_api_key()
-    success = Users.update_user_api_key_by_id(user.id, api_key)
+    success = Users.update_user_api_key_by_id(db, user.id, api_key)
     if success:
     if success:
         return {
         return {
             "api_key": api_key,
             "api_key": api_key,
@@ -410,15 +424,15 @@ async def create_api_key_(user=Depends(get_current_user)):
 
 
 # delete api key
 # delete api key
 @router.delete("/api_key", response_model=bool)
 @router.delete("/api_key", response_model=bool)
-async def delete_api_key(user=Depends(get_current_user)):
-    success = Users.update_user_api_key_by_id(user.id, None)
+async def delete_api_key(user=Depends(get_current_user), db=Depends(get_db)):
+    success = Users.update_user_api_key_by_id(db, user.id, None)
     return success
     return success
 
 
 
 
 # get api key
 # get api key
 @router.get("/api_key", response_model=ApiKey)
 @router.get("/api_key", response_model=ApiKey)
-async def get_api_key(user=Depends(get_current_user)):
-    api_key = Users.get_user_api_key_by_id(user.id)
+async def get_api_key(user=Depends(get_current_user), db=Depends(get_db)):
+    api_key = Users.get_user_api_key_by_id(db, user.id)
     if api_key:
     if api_key:
         return {
         return {
             "api_key": api_key,
             "api_key": api_key,

+ 88 - 58
backend/apps/webui/routers/chats.py

@@ -1,6 +1,8 @@
 from fastapi import Depends, Request, HTTPException, status
 from fastapi import Depends, Request, HTTPException, status
 from datetime import datetime, timedelta
 from datetime import datetime, timedelta
 from typing import List, Union, Optional
 from typing import List, Union, Optional
+
+from apps.webui.internal.db import get_db
 from utils.utils import get_current_user, get_admin_user
 from utils.utils import get_current_user, get_admin_user
 from fastapi import APIRouter
 from fastapi import APIRouter
 from pydantic import BaseModel
 from pydantic import BaseModel
@@ -43,9 +45,9 @@ router = APIRouter()
 @router.get("/", response_model=List[ChatTitleIdResponse])
 @router.get("/", response_model=List[ChatTitleIdResponse])
 @router.get("/list", response_model=List[ChatTitleIdResponse])
 @router.get("/list", response_model=List[ChatTitleIdResponse])
 async def get_session_user_chat_list(
 async def get_session_user_chat_list(
-    user=Depends(get_current_user), skip: int = 0, limit: int = 50
+    user=Depends(get_current_user), skip: int = 0, limit: int = 50, db=Depends(get_db)
 ):
 ):
-    return Chats.get_chat_list_by_user_id(user.id, skip, limit)
+    return Chats.get_chat_list_by_user_id(db, user.id, skip, limit)
 
 
 
 
 ############################
 ############################
@@ -54,7 +56,9 @@ async def get_session_user_chat_list(
 
 
 
 
 @router.delete("/", response_model=bool)
 @router.delete("/", response_model=bool)
-async def delete_all_user_chats(request: Request, user=Depends(get_current_user)):
+async def delete_all_user_chats(
+    request: Request, user=Depends(get_current_user), db=Depends(get_db)
+):
 
 
     if (
     if (
         user.role == "user"
         user.role == "user"
@@ -65,7 +69,7 @@ async def delete_all_user_chats(request: Request, user=Depends(get_current_user)
             detail=ERROR_MESSAGES.ACCESS_PROHIBITED,
             detail=ERROR_MESSAGES.ACCESS_PROHIBITED,
         )
         )
 
 
-    result = Chats.delete_chats_by_user_id(user.id)
+    result = Chats.delete_chats_by_user_id(db, user.id)
     return result
     return result
 
 
 
 
@@ -76,10 +80,14 @@ async def delete_all_user_chats(request: Request, user=Depends(get_current_user)
 
 
 @router.get("/list/user/{user_id}", response_model=List[ChatTitleIdResponse])
 @router.get("/list/user/{user_id}", response_model=List[ChatTitleIdResponse])
 async def get_user_chat_list_by_user_id(
 async def get_user_chat_list_by_user_id(
-    user_id: str, user=Depends(get_admin_user), skip: int = 0, limit: int = 50
+    user_id: str,
+    user=Depends(get_admin_user),
+    skip: int = 0,
+    limit: int = 50,
+    db=Depends(get_db),
 ):
 ):
     return Chats.get_chat_list_by_user_id(
     return Chats.get_chat_list_by_user_id(
-        user_id, include_archived=True, skip=skip, limit=limit
+        db, user_id, include_archived=True, skip=skip, limit=limit
     )
     )
 
 
 
 
@@ -89,9 +97,11 @@ async def get_user_chat_list_by_user_id(
 
 
 
 
 @router.post("/new", response_model=Optional[ChatResponse])
 @router.post("/new", response_model=Optional[ChatResponse])
-async def create_new_chat(form_data: ChatForm, user=Depends(get_current_user)):
+async def create_new_chat(
+    form_data: ChatForm, user=Depends(get_current_user), db=Depends(get_db)
+):
     try:
     try:
-        chat = Chats.insert_new_chat(user.id, form_data)
+        chat = Chats.insert_new_chat(db, user.id, form_data)
         return ChatResponse(**{**chat.model_dump(), "chat": json.loads(chat.chat)})
         return ChatResponse(**{**chat.model_dump(), "chat": json.loads(chat.chat)})
     except Exception as e:
     except Exception as e:
         log.exception(e)
         log.exception(e)
@@ -106,10 +116,10 @@ async def create_new_chat(form_data: ChatForm, user=Depends(get_current_user)):
 
 
 
 
 @router.get("/all", response_model=List[ChatResponse])
 @router.get("/all", response_model=List[ChatResponse])
-async def get_user_chats(user=Depends(get_current_user)):
+async def get_user_chats(user=Depends(get_current_user), db=Depends(get_db)):
     return [
     return [
         ChatResponse(**{**chat.model_dump(), "chat": json.loads(chat.chat)})
         ChatResponse(**{**chat.model_dump(), "chat": json.loads(chat.chat)})
-        for chat in Chats.get_chats_by_user_id(user.id)
+        for chat in Chats.get_chats_by_user_id(db, user.id)
     ]
     ]
 
 
 
 
@@ -119,10 +129,10 @@ async def get_user_chats(user=Depends(get_current_user)):
 
 
 
 
 @router.get("/all/archived", response_model=List[ChatResponse])
 @router.get("/all/archived", response_model=List[ChatResponse])
-async def get_user_chats(user=Depends(get_current_user)):
+async def get_user_archived_chats(user=Depends(get_current_user), db=Depends(get_db)):
     return [
     return [
         ChatResponse(**{**chat.model_dump(), "chat": json.loads(chat.chat)})
         ChatResponse(**{**chat.model_dump(), "chat": json.loads(chat.chat)})
-        for chat in Chats.get_archived_chats_by_user_id(user.id)
+        for chat in Chats.get_archived_chats_by_user_id(db, user.id)
     ]
     ]
 
 
 
 
@@ -132,7 +142,7 @@ async def get_user_chats(user=Depends(get_current_user)):
 
 
 
 
 @router.get("/all/db", response_model=List[ChatResponse])
 @router.get("/all/db", response_model=List[ChatResponse])
-async def get_all_user_chats_in_db(user=Depends(get_admin_user)):
+async def get_all_user_chats_in_db(user=Depends(get_admin_user), db=Depends(get_db)):
     if not ENABLE_ADMIN_EXPORT:
     if not ENABLE_ADMIN_EXPORT:
         raise HTTPException(
         raise HTTPException(
             status_code=status.HTTP_401_UNAUTHORIZED,
             status_code=status.HTTP_401_UNAUTHORIZED,
@@ -140,7 +150,7 @@ async def get_all_user_chats_in_db(user=Depends(get_admin_user)):
         )
         )
     return [
     return [
         ChatResponse(**{**chat.model_dump(), "chat": json.loads(chat.chat)})
         ChatResponse(**{**chat.model_dump(), "chat": json.loads(chat.chat)})
-        for chat in Chats.get_chats()
+        for chat in Chats.get_chats(db)
     ]
     ]
 
 
 
 
@@ -151,9 +161,9 @@ async def get_all_user_chats_in_db(user=Depends(get_admin_user)):
 
 
 @router.get("/archived", response_model=List[ChatTitleIdResponse])
 @router.get("/archived", response_model=List[ChatTitleIdResponse])
 async def get_archived_session_user_chat_list(
 async def get_archived_session_user_chat_list(
-    user=Depends(get_current_user), skip: int = 0, limit: int = 50
+    user=Depends(get_current_user), skip: int = 0, limit: int = 50, db=Depends(get_db)
 ):
 ):
-    return Chats.get_archived_chat_list_by_user_id(user.id, skip, limit)
+    return Chats.get_archived_chat_list_by_user_id(db, user.id, skip, limit)
 
 
 
 
 ############################
 ############################
@@ -162,8 +172,8 @@ async def get_archived_session_user_chat_list(
 
 
 
 
 @router.post("/archive/all", response_model=bool)
 @router.post("/archive/all", response_model=bool)
-async def archive_all_chats(user=Depends(get_current_user)):
-    return Chats.archive_all_chats_by_user_id(user.id)
+async def archive_all_chats(user=Depends(get_current_user), db=Depends(get_db)):
+    return Chats.archive_all_chats_by_user_id(db, user.id)
 
 
 
 
 ############################
 ############################
@@ -172,16 +182,18 @@ async def archive_all_chats(user=Depends(get_current_user)):
 
 
 
 
 @router.get("/share/{share_id}", response_model=Optional[ChatResponse])
 @router.get("/share/{share_id}", response_model=Optional[ChatResponse])
-async def get_shared_chat_by_id(share_id: str, user=Depends(get_current_user)):
+async def get_shared_chat_by_id(
+    share_id: str, user=Depends(get_current_user), db=Depends(get_db)
+):
     if user.role == "pending":
     if user.role == "pending":
         raise HTTPException(
         raise HTTPException(
             status_code=status.HTTP_401_UNAUTHORIZED, detail=ERROR_MESSAGES.NOT_FOUND
             status_code=status.HTTP_401_UNAUTHORIZED, detail=ERROR_MESSAGES.NOT_FOUND
         )
         )
 
 
     if user.role == "user":
     if user.role == "user":
-        chat = Chats.get_chat_by_share_id(share_id)
+        chat = Chats.get_chat_by_share_id(db, share_id)
     elif user.role == "admin":
     elif user.role == "admin":
-        chat = Chats.get_chat_by_id(share_id)
+        chat = Chats.get_chat_by_id(db, share_id)
 
 
     if chat:
     if chat:
         return ChatResponse(**{**chat.model_dump(), "chat": json.loads(chat.chat)})
         return ChatResponse(**{**chat.model_dump(), "chat": json.loads(chat.chat)})
@@ -204,21 +216,23 @@ class TagNameForm(BaseModel):
 
 
 @router.post("/tags", response_model=List[ChatTitleIdResponse])
 @router.post("/tags", response_model=List[ChatTitleIdResponse])
 async def get_user_chat_list_by_tag_name(
 async def get_user_chat_list_by_tag_name(
-    form_data: TagNameForm, user=Depends(get_current_user)
+    form_data: TagNameForm, user=Depends(get_current_user), db=Depends(get_db)
 ):
 ):
 
 
     print(form_data)
     print(form_data)
     chat_ids = [
     chat_ids = [
         chat_id_tag.chat_id
         chat_id_tag.chat_id
         for chat_id_tag in Tags.get_chat_ids_by_tag_name_and_user_id(
         for chat_id_tag in Tags.get_chat_ids_by_tag_name_and_user_id(
-            form_data.name, user.id
+            db, form_data.name, user.id
         )
         )
     ]
     ]
 
 
-    chats = Chats.get_chat_list_by_chat_ids(chat_ids, form_data.skip, form_data.limit)
+    chats = Chats.get_chat_list_by_chat_ids(
+        db, chat_ids, form_data.skip, form_data.limit
+    )
 
 
     if len(chats) == 0:
     if len(chats) == 0:
-        Tags.delete_tag_by_tag_name_and_user_id(form_data.name, user.id)
+        Tags.delete_tag_by_tag_name_and_user_id(db, form_data.name, user.id)
 
 
     return chats
     return chats
 
 
@@ -229,9 +243,9 @@ async def get_user_chat_list_by_tag_name(
 
 
 
 
 @router.get("/tags/all", response_model=List[TagModel])
 @router.get("/tags/all", response_model=List[TagModel])
-async def get_all_tags(user=Depends(get_current_user)):
+async def get_all_tags(user=Depends(get_current_user), db=Depends(get_db)):
     try:
     try:
-        tags = Tags.get_tags_by_user_id(user.id)
+        tags = Tags.get_tags_by_user_id(db, user.id)
         return tags
         return tags
     except Exception as e:
     except Exception as e:
         log.exception(e)
         log.exception(e)
@@ -246,8 +260,8 @@ async def get_all_tags(user=Depends(get_current_user)):
 
 
 
 
 @router.get("/{id}", response_model=Optional[ChatResponse])
 @router.get("/{id}", response_model=Optional[ChatResponse])
-async def get_chat_by_id(id: str, user=Depends(get_current_user)):
-    chat = Chats.get_chat_by_id_and_user_id(id, user.id)
+async def get_chat_by_id(id: str, user=Depends(get_current_user), db=Depends(get_db)):
+    chat = Chats.get_chat_by_id_and_user_id(db, id, user.id)
 
 
     if chat:
     if chat:
         return ChatResponse(**{**chat.model_dump(), "chat": json.loads(chat.chat)})
         return ChatResponse(**{**chat.model_dump(), "chat": json.loads(chat.chat)})
@@ -264,13 +278,13 @@ async def get_chat_by_id(id: str, user=Depends(get_current_user)):
 
 
 @router.post("/{id}", response_model=Optional[ChatResponse])
 @router.post("/{id}", response_model=Optional[ChatResponse])
 async def update_chat_by_id(
 async def update_chat_by_id(
-    id: str, form_data: ChatForm, user=Depends(get_current_user)
+    id: str, form_data: ChatForm, user=Depends(get_current_user), db=Depends(get_db)
 ):
 ):
-    chat = Chats.get_chat_by_id_and_user_id(id, user.id)
+    chat = Chats.get_chat_by_id_and_user_id(db, id, user.id)
     if chat:
     if chat:
         updated_chat = {**json.loads(chat.chat), **form_data.chat}
         updated_chat = {**json.loads(chat.chat), **form_data.chat}
 
 
-        chat = Chats.update_chat_by_id(id, updated_chat)
+        chat = Chats.update_chat_by_id(db, id, updated_chat)
         return ChatResponse(**{**chat.model_dump(), "chat": json.loads(chat.chat)})
         return ChatResponse(**{**chat.model_dump(), "chat": json.loads(chat.chat)})
     else:
     else:
         raise HTTPException(
         raise HTTPException(
@@ -285,10 +299,12 @@ async def update_chat_by_id(
 
 
 
 
 @router.delete("/{id}", response_model=bool)
 @router.delete("/{id}", response_model=bool)
-async def delete_chat_by_id(request: Request, id: str, user=Depends(get_current_user)):
+async def delete_chat_by_id(
+    request: Request, id: str, user=Depends(get_current_user), db=Depends(get_db)
+):
 
 
     if user.role == "admin":
     if user.role == "admin":
-        result = Chats.delete_chat_by_id(id)
+        result = Chats.delete_chat_by_id(db, id)
         return result
         return result
     else:
     else:
         if not request.app.state.config.USER_PERMISSIONS["chat"]["deletion"]:
         if not request.app.state.config.USER_PERMISSIONS["chat"]["deletion"]:
@@ -297,7 +313,7 @@ async def delete_chat_by_id(request: Request, id: str, user=Depends(get_current_
                 detail=ERROR_MESSAGES.ACCESS_PROHIBITED,
                 detail=ERROR_MESSAGES.ACCESS_PROHIBITED,
             )
             )
 
 
-        result = Chats.delete_chat_by_id_and_user_id(id, user.id)
+        result = Chats.delete_chat_by_id_and_user_id(db, id, user.id)
         return result
         return result
 
 
 
 
@@ -307,8 +323,8 @@ async def delete_chat_by_id(request: Request, id: str, user=Depends(get_current_
 
 
 
 
 @router.get("/{id}/clone", response_model=Optional[ChatResponse])
 @router.get("/{id}/clone", response_model=Optional[ChatResponse])
-async def clone_chat_by_id(id: str, user=Depends(get_current_user)):
-    chat = Chats.get_chat_by_id_and_user_id(id, user.id)
+async def clone_chat_by_id(id: str, user=Depends(get_current_user), db=Depends(get_db)):
+    chat = Chats.get_chat_by_id_and_user_id(db, id, user.id)
     if chat:
     if chat:
 
 
         chat_body = json.loads(chat.chat)
         chat_body = json.loads(chat.chat)
@@ -319,7 +335,7 @@ async def clone_chat_by_id(id: str, user=Depends(get_current_user)):
             "title": f"Clone of {chat.title}",
             "title": f"Clone of {chat.title}",
         }
         }
 
 
-        chat = Chats.insert_new_chat(user.id, ChatForm(**{"chat": updated_chat}))
+        chat = Chats.insert_new_chat(db, user.id, ChatForm(**{"chat": updated_chat}))
         return ChatResponse(**{**chat.model_dump(), "chat": json.loads(chat.chat)})
         return ChatResponse(**{**chat.model_dump(), "chat": json.loads(chat.chat)})
     else:
     else:
         raise HTTPException(
         raise HTTPException(
@@ -333,10 +349,12 @@ async def clone_chat_by_id(id: str, user=Depends(get_current_user)):
 
 
 
 
 @router.get("/{id}/archive", response_model=Optional[ChatResponse])
 @router.get("/{id}/archive", response_model=Optional[ChatResponse])
-async def archive_chat_by_id(id: str, user=Depends(get_current_user)):
-    chat = Chats.get_chat_by_id_and_user_id(id, user.id)
+async def archive_chat_by_id(
+    id: str, user=Depends(get_current_user), db=Depends(get_db)
+):
+    chat = Chats.get_chat_by_id_and_user_id(db, id, user.id)
     if chat:
     if chat:
-        chat = Chats.toggle_chat_archive_by_id(id)
+        chat = Chats.toggle_chat_archive_by_id(db, id)
         return ChatResponse(**{**chat.model_dump(), "chat": json.loads(chat.chat)})
         return ChatResponse(**{**chat.model_dump(), "chat": json.loads(chat.chat)})
     else:
     else:
         raise HTTPException(
         raise HTTPException(
@@ -350,16 +368,16 @@ async def archive_chat_by_id(id: str, user=Depends(get_current_user)):
 
 
 
 
 @router.post("/{id}/share", response_model=Optional[ChatResponse])
 @router.post("/{id}/share", response_model=Optional[ChatResponse])
-async def share_chat_by_id(id: str, user=Depends(get_current_user)):
-    chat = Chats.get_chat_by_id_and_user_id(id, user.id)
+async def share_chat_by_id(id: str, user=Depends(get_current_user), db=Depends(get_db)):
+    chat = Chats.get_chat_by_id_and_user_id(db, id, user.id)
     if chat:
     if chat:
         if chat.share_id:
         if chat.share_id:
-            shared_chat = Chats.update_shared_chat_by_chat_id(chat.id)
+            shared_chat = Chats.update_shared_chat_by_chat_id(db, chat.id)
             return ChatResponse(
             return ChatResponse(
                 **{**shared_chat.model_dump(), "chat": json.loads(shared_chat.chat)}
                 **{**shared_chat.model_dump(), "chat": json.loads(shared_chat.chat)}
             )
             )
 
 
-        shared_chat = Chats.insert_shared_chat_by_chat_id(chat.id)
+        shared_chat = Chats.insert_shared_chat_by_chat_id(db, chat.id)
         if not shared_chat:
         if not shared_chat:
             raise HTTPException(
             raise HTTPException(
                 status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
                 status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
@@ -382,14 +400,16 @@ async def share_chat_by_id(id: str, user=Depends(get_current_user)):
 
 
 
 
 @router.delete("/{id}/share", response_model=Optional[bool])
 @router.delete("/{id}/share", response_model=Optional[bool])
-async def delete_shared_chat_by_id(id: str, user=Depends(get_current_user)):
-    chat = Chats.get_chat_by_id_and_user_id(id, user.id)
+async def delete_shared_chat_by_id(
+    id: str, user=Depends(get_current_user), db=Depends(get_db)
+):
+    chat = Chats.get_chat_by_id_and_user_id(db, id, user.id)
     if chat:
     if chat:
         if not chat.share_id:
         if not chat.share_id:
             return False
             return False
 
 
-        result = Chats.delete_shared_chat_by_chat_id(id)
-        update_result = Chats.update_chat_share_id_by_id(id, None)
+        result = Chats.delete_shared_chat_by_chat_id(db, id)
+        update_result = Chats.update_chat_share_id_by_id(db, id, None)
 
 
         return result and update_result != None
         return result and update_result != None
     else:
     else:
@@ -405,8 +425,10 @@ async def delete_shared_chat_by_id(id: str, user=Depends(get_current_user)):
 
 
 
 
 @router.get("/{id}/tags", response_model=List[TagModel])
 @router.get("/{id}/tags", response_model=List[TagModel])
-async def get_chat_tags_by_id(id: str, user=Depends(get_current_user)):
-    tags = Tags.get_tags_by_chat_id_and_user_id(id, user.id)
+async def get_chat_tags_by_id(
+    id: str, user=Depends(get_current_user), db=Depends(get_db)
+):
+    tags = Tags.get_tags_by_chat_id_and_user_id(db, id, user.id)
 
 
     if tags != None:
     if tags != None:
         return tags
         return tags
@@ -423,12 +445,15 @@ async def get_chat_tags_by_id(id: str, user=Depends(get_current_user)):
 
 
 @router.post("/{id}/tags", response_model=Optional[ChatIdTagModel])
 @router.post("/{id}/tags", response_model=Optional[ChatIdTagModel])
 async def add_chat_tag_by_id(
 async def add_chat_tag_by_id(
-    id: str, form_data: ChatIdTagForm, user=Depends(get_current_user)
+    id: str,
+    form_data: ChatIdTagForm,
+    user=Depends(get_current_user),
+    db=Depends(get_db),
 ):
 ):
-    tags = Tags.get_tags_by_chat_id_and_user_id(id, user.id)
+    tags = Tags.get_tags_by_chat_id_and_user_id(db, id, user.id)
 
 
     if form_data.tag_name not in tags:
     if form_data.tag_name not in tags:
-        tag = Tags.add_tag_to_chat(user.id, form_data)
+        tag = Tags.add_tag_to_chat(db, user.id, form_data)
 
 
         if tag:
         if tag:
             return tag
             return tag
@@ -450,10 +475,13 @@ async def add_chat_tag_by_id(
 
 
 @router.delete("/{id}/tags", response_model=Optional[bool])
 @router.delete("/{id}/tags", response_model=Optional[bool])
 async def delete_chat_tag_by_id(
 async def delete_chat_tag_by_id(
-    id: str, form_data: ChatIdTagForm, user=Depends(get_current_user)
+    id: str,
+    form_data: ChatIdTagForm,
+    user=Depends(get_current_user),
+    db=Depends(get_db),
 ):
 ):
     result = Tags.delete_tag_by_tag_name_and_chat_id_and_user_id(
     result = Tags.delete_tag_by_tag_name_and_chat_id_and_user_id(
-        form_data.tag_name, id, user.id
+        db, form_data.tag_name, id, user.id
     )
     )
 
 
     if result:
     if result:
@@ -470,8 +498,10 @@ async def delete_chat_tag_by_id(
 
 
 
 
 @router.delete("/{id}/tags/all", response_model=Optional[bool])
 @router.delete("/{id}/tags/all", response_model=Optional[bool])
-async def delete_all_chat_tags_by_id(id: str, user=Depends(get_current_user)):
-    result = Tags.delete_tags_by_chat_id_and_user_id(id, user.id)
+async def delete_all_chat_tags_by_id(
+    id: str, user=Depends(get_current_user), db=Depends(get_db)
+):
+    result = Tags.delete_tags_by_chat_id_and_user_id(db, id, user.id)
 
 
     if result:
     if result:
         return result
         return result

+ 27 - 13
backend/apps/webui/routers/documents.py

@@ -6,6 +6,7 @@ from fastapi import APIRouter
 from pydantic import BaseModel
 from pydantic import BaseModel
 import json
 import json
 
 
+from apps.webui.internal.db import get_db
 from apps.webui.models.documents import (
 from apps.webui.models.documents import (
     Documents,
     Documents,
     DocumentForm,
     DocumentForm,
@@ -25,7 +26,7 @@ router = APIRouter()
 
 
 
 
 @router.get("/", response_model=List[DocumentResponse])
 @router.get("/", response_model=List[DocumentResponse])
-async def get_documents(user=Depends(get_current_user)):
+async def get_documents(user=Depends(get_current_user), db=Depends(get_db)):
     docs = [
     docs = [
         DocumentResponse(
         DocumentResponse(
             **{
             **{
@@ -33,7 +34,7 @@ async def get_documents(user=Depends(get_current_user)):
                 "content": json.loads(doc.content if doc.content else "{}"),
                 "content": json.loads(doc.content if doc.content else "{}"),
             }
             }
         )
         )
-        for doc in Documents.get_docs()
+        for doc in Documents.get_docs(db)
     ]
     ]
     return docs
     return docs
 
 
@@ -44,10 +45,12 @@ async def get_documents(user=Depends(get_current_user)):
 
 
 
 
 @router.post("/create", response_model=Optional[DocumentResponse])
 @router.post("/create", response_model=Optional[DocumentResponse])
-async def create_new_doc(form_data: DocumentForm, user=Depends(get_admin_user)):
-    doc = Documents.get_doc_by_name(form_data.name)
+async def create_new_doc(
+    form_data: DocumentForm, user=Depends(get_admin_user), db=Depends(get_db)
+):
+    doc = Documents.get_doc_by_name(db, form_data.name)
     if doc == None:
     if doc == None:
-        doc = Documents.insert_new_doc(user.id, form_data)
+        doc = Documents.insert_new_doc(db, user.id, form_data)
 
 
         if doc:
         if doc:
             return DocumentResponse(
             return DocumentResponse(
@@ -74,8 +77,10 @@ async def create_new_doc(form_data: DocumentForm, user=Depends(get_admin_user)):
 
 
 
 
 @router.get("/doc", response_model=Optional[DocumentResponse])
 @router.get("/doc", response_model=Optional[DocumentResponse])
-async def get_doc_by_name(name: str, user=Depends(get_current_user)):
-    doc = Documents.get_doc_by_name(name)
+async def get_doc_by_name(
+    name: str, user=Depends(get_current_user), db=Depends(get_db)
+):
+    doc = Documents.get_doc_by_name(db, name)
 
 
     if doc:
     if doc:
         return DocumentResponse(
         return DocumentResponse(
@@ -106,8 +111,12 @@ class TagDocumentForm(BaseModel):
 
 
 
 
 @router.post("/doc/tags", response_model=Optional[DocumentResponse])
 @router.post("/doc/tags", response_model=Optional[DocumentResponse])
-async def tag_doc_by_name(form_data: TagDocumentForm, user=Depends(get_current_user)):
-    doc = Documents.update_doc_content_by_name(form_data.name, {"tags": form_data.tags})
+async def tag_doc_by_name(
+    form_data: TagDocumentForm, user=Depends(get_current_user), db=Depends(get_db)
+):
+    doc = Documents.update_doc_content_by_name(
+        db, form_data.name, {"tags": form_data.tags}
+    )
 
 
     if doc:
     if doc:
         return DocumentResponse(
         return DocumentResponse(
@@ -130,9 +139,12 @@ async def tag_doc_by_name(form_data: TagDocumentForm, user=Depends(get_current_u
 
 
 @router.post("/doc/update", response_model=Optional[DocumentResponse])
 @router.post("/doc/update", response_model=Optional[DocumentResponse])
 async def update_doc_by_name(
 async def update_doc_by_name(
-    name: str, form_data: DocumentUpdateForm, user=Depends(get_admin_user)
+    name: str,
+    form_data: DocumentUpdateForm,
+    user=Depends(get_admin_user),
+    db=Depends(get_db),
 ):
 ):
-    doc = Documents.update_doc_by_name(name, form_data)
+    doc = Documents.update_doc_by_name(db, name, form_data)
     if doc:
     if doc:
         return DocumentResponse(
         return DocumentResponse(
             **{
             **{
@@ -153,6 +165,8 @@ async def update_doc_by_name(
 
 
 
 
 @router.delete("/doc/delete", response_model=bool)
 @router.delete("/doc/delete", response_model=bool)
-async def delete_doc_by_name(name: str, user=Depends(get_admin_user)):
-    result = Documents.delete_doc_by_name(name)
+async def delete_doc_by_name(
+    name: str, user=Depends(get_admin_user), db=Depends(get_db)
+):
+    result = Documents.delete_doc_by_name(db, name)
     return result
     return result

+ 14 - 11
backend/apps/webui/routers/files.py

@@ -20,6 +20,7 @@ from fastapi.responses import StreamingResponse, JSONResponse, FileResponse
 from pydantic import BaseModel
 from pydantic import BaseModel
 import json
 import json
 
 
+from apps.webui.internal.db import get_db
 from apps.webui.models.files import (
 from apps.webui.models.files import (
     Files,
     Files,
     FileForm,
     FileForm,
@@ -53,6 +54,7 @@ router = APIRouter()
 def upload_file(
 def upload_file(
     file: UploadFile = File(...),
     file: UploadFile = File(...),
     user=Depends(get_verified_user),
     user=Depends(get_verified_user),
+    db=Depends(get_db)
 ):
 ):
     log.info(f"file.content_type: {file.content_type}")
     log.info(f"file.content_type: {file.content_type}")
     try:
     try:
@@ -70,6 +72,7 @@ def upload_file(
             f.close()
             f.close()
 
 
         file = Files.insert_new_file(
         file = Files.insert_new_file(
+            db,
             user.id,
             user.id,
             FileForm(
             FileForm(
                 **{
                 **{
@@ -106,8 +109,8 @@ def upload_file(
 
 
 
 
 @router.get("/", response_model=List[FileModel])
 @router.get("/", response_model=List[FileModel])
-async def list_files(user=Depends(get_verified_user)):
-    files = Files.get_files()
+async def list_files(user=Depends(get_verified_user), db=Depends(get_db)):
+    files = Files.get_files(db)
     return files
     return files
 
 
 
 
@@ -117,8 +120,8 @@ async def list_files(user=Depends(get_verified_user)):
 
 
 
 
 @router.delete("/all")
 @router.delete("/all")
-async def delete_all_files(user=Depends(get_admin_user)):
-    result = Files.delete_all_files()
+async def delete_all_files(user=Depends(get_admin_user), db=Depends(get_db)):
+    result = Files.delete_all_files(db)
 
 
     if result:
     if result:
         folder = f"{UPLOAD_DIR}"
         folder = f"{UPLOAD_DIR}"
@@ -154,8 +157,8 @@ async def delete_all_files(user=Depends(get_admin_user)):
 
 
 
 
 @router.get("/{id}", response_model=Optional[FileModel])
 @router.get("/{id}", response_model=Optional[FileModel])
-async def get_file_by_id(id: str, user=Depends(get_verified_user)):
-    file = Files.get_file_by_id(id)
+async def get_file_by_id(id: str, user=Depends(get_verified_user), db=Depends(get_db)):
+    file = Files.get_file_by_id(db, id)
 
 
     if file:
     if file:
         return file
         return file
@@ -172,8 +175,8 @@ async def get_file_by_id(id: str, user=Depends(get_verified_user)):
 
 
 
 
 @router.get("/{id}/content", response_model=Optional[FileModel])
 @router.get("/{id}/content", response_model=Optional[FileModel])
-async def get_file_content_by_id(id: str, user=Depends(get_verified_user)):
-    file = Files.get_file_by_id(id)
+async def get_file_content_by_id(id: str, user=Depends(get_verified_user), db=Depends(get_db)):
+    file = Files.get_file_by_id(db, id)
 
 
     if file:
     if file:
         file_path = Path(file.meta["path"])
         file_path = Path(file.meta["path"])
@@ -223,11 +226,11 @@ async def get_file_content_by_id(id: str, user=Depends(get_verified_user)):
 
 
 
 
 @router.delete("/{id}")
 @router.delete("/{id}")
-async def delete_file_by_id(id: str, user=Depends(get_verified_user)):
-    file = Files.get_file_by_id(id)
+async def delete_file_by_id(id: str, user=Depends(get_verified_user), db=Depends(get_db)):
+    file = Files.get_file_by_id(db, id)
 
 
     if file:
     if file:
-        result = Files.delete_file_by_id(id)
+        result = Files.delete_file_by_id(db, id)
         if result:
         if result:
             return {"message": "File deleted successfully"}
             return {"message": "File deleted successfully"}
         else:
         else:

+ 14 - 13
backend/apps/webui/routers/functions.py

@@ -6,6 +6,7 @@ from fastapi import APIRouter
 from pydantic import BaseModel
 from pydantic import BaseModel
 import json
 import json
 
 
+from apps.webui.internal.db import get_db
 from apps.webui.models.functions import (
 from apps.webui.models.functions import (
     Functions,
     Functions,
     FunctionForm,
     FunctionForm,
@@ -31,8 +32,8 @@ router = APIRouter()
 
 
 
 
 @router.get("/", response_model=List[FunctionResponse])
 @router.get("/", response_model=List[FunctionResponse])
-async def get_functions(user=Depends(get_verified_user)):
-    return Functions.get_functions()
+async def get_functions(user=Depends(get_verified_user), db=Depends(get_db)):
+    return Functions.get_functions(db)
 
 
 
 
 ############################
 ############################
@@ -41,8 +42,8 @@ async def get_functions(user=Depends(get_verified_user)):
 
 
 
 
 @router.get("/export", response_model=List[FunctionModel])
 @router.get("/export", response_model=List[FunctionModel])
-async def get_functions(user=Depends(get_admin_user)):
-    return Functions.get_functions()
+async def get_functions(user=Depends(get_admin_user), db=Depends(get_db)):
+    return Functions.get_functions(db)
 
 
 
 
 ############################
 ############################
@@ -52,7 +53,7 @@ async def get_functions(user=Depends(get_admin_user)):
 
 
 @router.post("/create", response_model=Optional[FunctionResponse])
 @router.post("/create", response_model=Optional[FunctionResponse])
 async def create_new_function(
 async def create_new_function(
-    request: Request, form_data: FunctionForm, user=Depends(get_admin_user)
+    request: Request, form_data: FunctionForm, user=Depends(get_admin_user), db=Depends(get_db)
 ):
 ):
     if not form_data.id.isidentifier():
     if not form_data.id.isidentifier():
         raise HTTPException(
         raise HTTPException(
@@ -62,7 +63,7 @@ async def create_new_function(
 
 
     form_data.id = form_data.id.lower()
     form_data.id = form_data.id.lower()
 
 
-    function = Functions.get_function_by_id(form_data.id)
+    function = Functions.get_function_by_id(db, form_data.id)
     if function == None:
     if function == None:
         function_path = os.path.join(FUNCTIONS_DIR, f"{form_data.id}.py")
         function_path = os.path.join(FUNCTIONS_DIR, f"{form_data.id}.py")
         try:
         try:
@@ -77,7 +78,7 @@ async def create_new_function(
             FUNCTIONS = request.app.state.FUNCTIONS
             FUNCTIONS = request.app.state.FUNCTIONS
             FUNCTIONS[form_data.id] = function_module
             FUNCTIONS[form_data.id] = function_module
 
 
-            function = Functions.insert_new_function(user.id, function_type, form_data)
+            function = Functions.insert_new_function(db, user.id, function_type, form_data)
 
 
             function_cache_dir = Path(CACHE_DIR) / "functions" / form_data.id
             function_cache_dir = Path(CACHE_DIR) / "functions" / form_data.id
             function_cache_dir.mkdir(parents=True, exist_ok=True)
             function_cache_dir.mkdir(parents=True, exist_ok=True)
@@ -108,8 +109,8 @@ async def create_new_function(
 
 
 
 
 @router.get("/id/{id}", response_model=Optional[FunctionModel])
 @router.get("/id/{id}", response_model=Optional[FunctionModel])
-async def get_function_by_id(id: str, user=Depends(get_admin_user)):
-    function = Functions.get_function_by_id(id)
+async def get_function_by_id(id: str, user=Depends(get_admin_user), db=Depends(get_db)):
+    function = Functions.get_function_by_id(db, id)
 
 
     if function:
     if function:
         return function
         return function
@@ -154,7 +155,7 @@ async def toggle_function_by_id(id: str, user=Depends(get_admin_user)):
 
 
 @router.post("/id/{id}/update", response_model=Optional[FunctionModel])
 @router.post("/id/{id}/update", response_model=Optional[FunctionModel])
 async def update_function_by_id(
 async def update_function_by_id(
-    request: Request, id: str, form_data: FunctionForm, user=Depends(get_admin_user)
+    request: Request, id: str, form_data: FunctionForm, user=Depends(get_admin_user), db=Depends(get_db)
 ):
 ):
     function_path = os.path.join(FUNCTIONS_DIR, f"{id}.py")
     function_path = os.path.join(FUNCTIONS_DIR, f"{id}.py")
 
 
@@ -171,7 +172,7 @@ async def update_function_by_id(
         updated = {**form_data.model_dump(exclude={"id"}), "type": function_type}
         updated = {**form_data.model_dump(exclude={"id"}), "type": function_type}
         print(updated)
         print(updated)
 
 
-        function = Functions.update_function_by_id(id, updated)
+        function = Functions.update_function_by_id(db, id, updated)
 
 
         if function:
         if function:
             return function
             return function
@@ -195,9 +196,9 @@ async def update_function_by_id(
 
 
 @router.delete("/id/{id}/delete", response_model=bool)
 @router.delete("/id/{id}/delete", response_model=bool)
 async def delete_function_by_id(
 async def delete_function_by_id(
-    request: Request, id: str, user=Depends(get_admin_user)
+    request: Request, id: str, user=Depends(get_admin_user), db=Depends(get_db)
 ):
 ):
-    result = Functions.delete_function_by_id(id)
+    result = Functions.delete_function_by_id(db, id)
 
 
     if result:
     if result:
         FUNCTIONS = request.app.state.FUNCTIONS
         FUNCTIONS = request.app.state.FUNCTIONS

+ 18 - 11
backend/apps/webui/routers/memories.py

@@ -7,6 +7,7 @@ from fastapi import APIRouter
 from pydantic import BaseModel
 from pydantic import BaseModel
 import logging
 import logging
 
 
+from apps.webui.internal.db import get_db
 from apps.webui.models.memories import Memories, MemoryModel
 from apps.webui.models.memories import Memories, MemoryModel
 
 
 from utils.utils import get_verified_user
 from utils.utils import get_verified_user
@@ -31,8 +32,8 @@ async def get_embeddings(request: Request):
 
 
 
 
 @router.get("/", response_model=List[MemoryModel])
 @router.get("/", response_model=List[MemoryModel])
-async def get_memories(user=Depends(get_verified_user)):
-    return Memories.get_memories_by_user_id(user.id)
+async def get_memories(user=Depends(get_verified_user), db=Depends(get_db)):
+    return Memories.get_memories_by_user_id(db, user.id)
 
 
 
 
 ############################
 ############################
@@ -50,9 +51,12 @@ class MemoryUpdateModel(BaseModel):
 
 
 @router.post("/add", response_model=Optional[MemoryModel])
 @router.post("/add", response_model=Optional[MemoryModel])
 async def add_memory(
 async def add_memory(
-    request: Request, form_data: AddMemoryForm, user=Depends(get_verified_user)
+    request: Request,
+    form_data: AddMemoryForm,
+    user=Depends(get_verified_user),
+    db=Depends(get_db),
 ):
 ):
-    memory = Memories.insert_new_memory(user.id, form_data.content)
+    memory = Memories.insert_new_memory(db, user.id, form_data.content)
     memory_embedding = request.app.state.EMBEDDING_FUNCTION(memory.content)
     memory_embedding = request.app.state.EMBEDDING_FUNCTION(memory.content)
 
 
     collection = CHROMA_CLIENT.get_or_create_collection(name=f"user-memory-{user.id}")
     collection = CHROMA_CLIENT.get_or_create_collection(name=f"user-memory-{user.id}")
@@ -72,8 +76,9 @@ async def update_memory_by_id(
     request: Request,
     request: Request,
     form_data: MemoryUpdateModel,
     form_data: MemoryUpdateModel,
     user=Depends(get_verified_user),
     user=Depends(get_verified_user),
+    db=Depends(get_db),
 ):
 ):
-    memory = Memories.update_memory_by_id(memory_id, form_data.content)
+    memory = Memories.update_memory_by_id(db, memory_id, form_data.content)
     if memory is None:
     if memory is None:
         raise HTTPException(status_code=404, detail="Memory not found")
         raise HTTPException(status_code=404, detail="Memory not found")
 
 
@@ -124,12 +129,12 @@ async def query_memory(
 ############################
 ############################
 @router.get("/reset", response_model=bool)
 @router.get("/reset", response_model=bool)
 async def reset_memory_from_vector_db(
 async def reset_memory_from_vector_db(
-    request: Request, user=Depends(get_verified_user)
+    request: Request, user=Depends(get_verified_user), db=Depends(get_db)
 ):
 ):
     CHROMA_CLIENT.delete_collection(f"user-memory-{user.id}")
     CHROMA_CLIENT.delete_collection(f"user-memory-{user.id}")
     collection = CHROMA_CLIENT.get_or_create_collection(name=f"user-memory-{user.id}")
     collection = CHROMA_CLIENT.get_or_create_collection(name=f"user-memory-{user.id}")
 
 
-    memories = Memories.get_memories_by_user_id(user.id)
+    memories = Memories.get_memories_by_user_id(db, user.id)
     for memory in memories:
     for memory in memories:
         memory_embedding = request.app.state.EMBEDDING_FUNCTION(memory.content)
         memory_embedding = request.app.state.EMBEDDING_FUNCTION(memory.content)
         collection.upsert(
         collection.upsert(
@@ -146,8 +151,8 @@ async def reset_memory_from_vector_db(
 
 
 
 
 @router.delete("/user", response_model=bool)
 @router.delete("/user", response_model=bool)
-async def delete_memory_by_user_id(user=Depends(get_verified_user)):
-    result = Memories.delete_memories_by_user_id(user.id)
+async def delete_memory_by_user_id(user=Depends(get_verified_user), db=Depends(get_db)):
+    result = Memories.delete_memories_by_user_id(db, user.id)
 
 
     if result:
     if result:
         try:
         try:
@@ -165,8 +170,10 @@ async def delete_memory_by_user_id(user=Depends(get_verified_user)):
 
 
 
 
 @router.delete("/{memory_id}", response_model=bool)
 @router.delete("/{memory_id}", response_model=bool)
-async def delete_memory_by_id(memory_id: str, user=Depends(get_verified_user)):
-    result = Memories.delete_memory_by_id_and_user_id(memory_id, user.id)
+async def delete_memory_by_id(
+    memory_id: str, user=Depends(get_verified_user), db=Depends(get_db)
+):
+    result = Memories.delete_memory_by_id_and_user_id(db, memory_id, user.id)
 
 
     if result:
     if result:
         collection = CHROMA_CLIENT.get_or_create_collection(
         collection = CHROMA_CLIENT.get_or_create_collection(

+ 22 - 13
backend/apps/webui/routers/models.py

@@ -5,6 +5,8 @@ from typing import List, Union, Optional
 from fastapi import APIRouter
 from fastapi import APIRouter
 from pydantic import BaseModel
 from pydantic import BaseModel
 import json
 import json
+
+from apps.webui.internal.db import get_db
 from apps.webui.models.models import Models, ModelModel, ModelForm, ModelResponse
 from apps.webui.models.models import Models, ModelModel, ModelForm, ModelResponse
 
 
 from utils.utils import get_verified_user, get_admin_user
 from utils.utils import get_verified_user, get_admin_user
@@ -18,8 +20,8 @@ router = APIRouter()
 
 
 
 
 @router.get("/", response_model=List[ModelResponse])
 @router.get("/", response_model=List[ModelResponse])
-async def get_models(user=Depends(get_verified_user)):
-    return Models.get_all_models()
+async def get_models(user=Depends(get_verified_user), db=Depends(get_db)):
+    return Models.get_all_models(db)
 
 
 
 
 ############################
 ############################
@@ -29,7 +31,10 @@ async def get_models(user=Depends(get_verified_user)):
 
 
 @router.post("/add", response_model=Optional[ModelModel])
 @router.post("/add", response_model=Optional[ModelModel])
 async def add_new_model(
 async def add_new_model(
-    request: Request, form_data: ModelForm, user=Depends(get_admin_user)
+    request: Request,
+    form_data: ModelForm,
+    user=Depends(get_admin_user),
+    db=Depends(get_db),
 ):
 ):
     if form_data.id in request.app.state.MODELS:
     if form_data.id in request.app.state.MODELS:
         raise HTTPException(
         raise HTTPException(
@@ -37,7 +42,7 @@ async def add_new_model(
             detail=ERROR_MESSAGES.MODEL_ID_TAKEN,
             detail=ERROR_MESSAGES.MODEL_ID_TAKEN,
         )
         )
     else:
     else:
-        model = Models.insert_new_model(form_data, user.id)
+        model = Models.insert_new_model(db, form_data, user.id)
 
 
         if model:
         if model:
             return model
             return model
@@ -53,9 +58,9 @@ async def add_new_model(
 ############################
 ############################
 
 
 
 
-@router.get("/", response_model=Optional[ModelModel])
-async def get_model_by_id(id: str, user=Depends(get_verified_user)):
-    model = Models.get_model_by_id(id)
+@router.get("/{id}", response_model=Optional[ModelModel])
+async def get_model_by_id(id: str, user=Depends(get_verified_user), db=Depends(get_db)):
+    model = Models.get_model_by_id(db, id)
 
 
     if model:
     if model:
         return model
         return model
@@ -73,15 +78,19 @@ async def get_model_by_id(id: str, user=Depends(get_verified_user)):
 
 
 @router.post("/update", response_model=Optional[ModelModel])
 @router.post("/update", response_model=Optional[ModelModel])
 async def update_model_by_id(
 async def update_model_by_id(
-    request: Request, id: str, form_data: ModelForm, user=Depends(get_admin_user)
+    request: Request,
+    id: str,
+    form_data: ModelForm,
+    user=Depends(get_admin_user),
+    db=Depends(get_db),
 ):
 ):
-    model = Models.get_model_by_id(id)
+    model = Models.get_model_by_id(db, id)
     if model:
     if model:
-        model = Models.update_model_by_id(id, form_data)
+        model = Models.update_model_by_id(db, id, form_data)
         return model
         return model
     else:
     else:
         if form_data.id in request.app.state.MODELS:
         if form_data.id in request.app.state.MODELS:
-            model = Models.insert_new_model(form_data, user.id)
+            model = Models.insert_new_model(db, form_data, user.id)
             if model:
             if model:
                 return model
                 return model
             else:
             else:
@@ -102,6 +111,6 @@ async def update_model_by_id(
 
 
 
 
 @router.delete("/delete", response_model=bool)
 @router.delete("/delete", response_model=bool)
-async def delete_model_by_id(id: str, user=Depends(get_admin_user)):
-    result = Models.delete_model_by_id(id)
+async def delete_model_by_id(id: str, user=Depends(get_admin_user), db=Depends(get_db)):
+    result = Models.delete_model_by_id(db, id)
     return result
     return result

+ 21 - 11
backend/apps/webui/routers/prompts.py

@@ -6,6 +6,7 @@ from fastapi import APIRouter
 from pydantic import BaseModel
 from pydantic import BaseModel
 import json
 import json
 
 
+from apps.webui.internal.db import get_db
 from apps.webui.models.prompts import Prompts, PromptForm, PromptModel
 from apps.webui.models.prompts import Prompts, PromptForm, PromptModel
 
 
 from utils.utils import get_current_user, get_admin_user
 from utils.utils import get_current_user, get_admin_user
@@ -19,8 +20,8 @@ router = APIRouter()
 
 
 
 
 @router.get("/", response_model=List[PromptModel])
 @router.get("/", response_model=List[PromptModel])
-async def get_prompts(user=Depends(get_current_user)):
-    return Prompts.get_prompts()
+async def get_prompts(user=Depends(get_current_user), db=Depends(get_db)):
+    return Prompts.get_prompts(db)
 
 
 
 
 ############################
 ############################
@@ -29,10 +30,12 @@ async def get_prompts(user=Depends(get_current_user)):
 
 
 
 
 @router.post("/create", response_model=Optional[PromptModel])
 @router.post("/create", response_model=Optional[PromptModel])
-async def create_new_prompt(form_data: PromptForm, user=Depends(get_admin_user)):
-    prompt = Prompts.get_prompt_by_command(form_data.command)
+async def create_new_prompt(
+    form_data: PromptForm, user=Depends(get_admin_user), db=Depends(get_db)
+):
+    prompt = Prompts.get_prompt_by_command(db, form_data.command)
     if prompt == None:
     if prompt == None:
-        prompt = Prompts.insert_new_prompt(user.id, form_data)
+        prompt = Prompts.insert_new_prompt(db, user.id, form_data)
 
 
         if prompt:
         if prompt:
             return prompt
             return prompt
@@ -52,8 +55,10 @@ async def create_new_prompt(form_data: PromptForm, user=Depends(get_admin_user))
 
 
 
 
 @router.get("/command/{command}", response_model=Optional[PromptModel])
 @router.get("/command/{command}", response_model=Optional[PromptModel])
-async def get_prompt_by_command(command: str, user=Depends(get_current_user)):
-    prompt = Prompts.get_prompt_by_command(f"/{command}")
+async def get_prompt_by_command(
+    command: str, user=Depends(get_current_user), db=Depends(get_db)
+):
+    prompt = Prompts.get_prompt_by_command(db, f"/{command}")
 
 
     if prompt:
     if prompt:
         return prompt
         return prompt
@@ -71,9 +76,12 @@ async def get_prompt_by_command(command: str, user=Depends(get_current_user)):
 
 
 @router.post("/command/{command}/update", response_model=Optional[PromptModel])
 @router.post("/command/{command}/update", response_model=Optional[PromptModel])
 async def update_prompt_by_command(
 async def update_prompt_by_command(
-    command: str, form_data: PromptForm, user=Depends(get_admin_user)
+    command: str,
+    form_data: PromptForm,
+    user=Depends(get_admin_user),
+    db=Depends(get_db),
 ):
 ):
-    prompt = Prompts.update_prompt_by_command(f"/{command}", form_data)
+    prompt = Prompts.update_prompt_by_command(db, f"/{command}", form_data)
     if prompt:
     if prompt:
         return prompt
         return prompt
     else:
     else:
@@ -89,6 +97,8 @@ async def update_prompt_by_command(
 
 
 
 
 @router.delete("/command/{command}/delete", response_model=bool)
 @router.delete("/command/{command}/delete", response_model=bool)
-async def delete_prompt_by_command(command: str, user=Depends(get_admin_user)):
-    result = Prompts.delete_prompt_by_command(f"/{command}")
+async def delete_prompt_by_command(
+    command: str, user=Depends(get_admin_user), db=Depends(get_db)
+):
+    result = Prompts.delete_prompt_by_command(db, f"/{command}")
     return result
     return result

+ 22 - 13
backend/apps/webui/routers/tools.py

@@ -6,7 +6,7 @@ from fastapi import APIRouter
 from pydantic import BaseModel
 from pydantic import BaseModel
 import json
 import json
 
 
-
+from apps.webui.internal.db import get_db
 from apps.webui.models.users import Users
 from apps.webui.models.users import Users
 from apps.webui.models.tools import Tools, ToolForm, ToolModel, ToolResponse
 from apps.webui.models.tools import Tools, ToolForm, ToolModel, ToolResponse
 from apps.webui.utils import load_toolkit_module_by_id
 from apps.webui.utils import load_toolkit_module_by_id
@@ -34,7 +34,7 @@ router = APIRouter()
 
 
 
 
 @router.get("/", response_model=List[ToolResponse])
 @router.get("/", response_model=List[ToolResponse])
-async def get_toolkits(user=Depends(get_verified_user)):
+async def get_toolkits(user=Depends(get_verified_user), db=Depends(get_db)):
     toolkits = [toolkit for toolkit in Tools.get_tools()]
     toolkits = [toolkit for toolkit in Tools.get_tools()]
     return toolkits
     return toolkits
 
 
@@ -45,8 +45,8 @@ async def get_toolkits(user=Depends(get_verified_user)):
 
 
 
 
 @router.get("/export", response_model=List[ToolModel])
 @router.get("/export", response_model=List[ToolModel])
-async def get_toolkits(user=Depends(get_admin_user)):
-    toolkits = [toolkit for toolkit in Tools.get_tools()]
+async def get_toolkits(user=Depends(get_admin_user), db=Depends(get_db)):
+    toolkits = [toolkit for toolkit in Tools.get_tools(db)]
     return toolkits
     return toolkits
 
 
 
 
@@ -57,7 +57,10 @@ async def get_toolkits(user=Depends(get_admin_user)):
 
 
 @router.post("/create", response_model=Optional[ToolResponse])
 @router.post("/create", response_model=Optional[ToolResponse])
 async def create_new_toolkit(
 async def create_new_toolkit(
-    request: Request, form_data: ToolForm, user=Depends(get_admin_user)
+    request: Request,
+    form_data: ToolForm,
+    user=Depends(get_admin_user),
+    db=Depends(get_db),
 ):
 ):
     if not form_data.id.isidentifier():
     if not form_data.id.isidentifier():
         raise HTTPException(
         raise HTTPException(
@@ -67,7 +70,7 @@ async def create_new_toolkit(
 
 
     form_data.id = form_data.id.lower()
     form_data.id = form_data.id.lower()
 
 
-    toolkit = Tools.get_tool_by_id(form_data.id)
+    toolkit = Tools.get_tool_by_id(db, form_data.id)
     if toolkit == None:
     if toolkit == None:
         toolkit_path = os.path.join(TOOLS_DIR, f"{form_data.id}.py")
         toolkit_path = os.path.join(TOOLS_DIR, f"{form_data.id}.py")
         try:
         try:
@@ -81,7 +84,7 @@ async def create_new_toolkit(
             TOOLS[form_data.id] = toolkit_module
             TOOLS[form_data.id] = toolkit_module
 
 
             specs = get_tools_specs(TOOLS[form_data.id])
             specs = get_tools_specs(TOOLS[form_data.id])
-            toolkit = Tools.insert_new_tool(user.id, form_data, specs)
+            toolkit = Tools.insert_new_tool(db, user.id, form_data, specs)
 
 
             tool_cache_dir = Path(CACHE_DIR) / "tools" / form_data.id
             tool_cache_dir = Path(CACHE_DIR) / "tools" / form_data.id
             tool_cache_dir.mkdir(parents=True, exist_ok=True)
             tool_cache_dir.mkdir(parents=True, exist_ok=True)
@@ -112,8 +115,8 @@ async def create_new_toolkit(
 
 
 
 
 @router.get("/id/{id}", response_model=Optional[ToolModel])
 @router.get("/id/{id}", response_model=Optional[ToolModel])
-async def get_toolkit_by_id(id: str, user=Depends(get_admin_user)):
-    toolkit = Tools.get_tool_by_id(id)
+async def get_toolkit_by_id(id: str, user=Depends(get_admin_user), db=Depends(get_db)):
+    toolkit = Tools.get_tool_by_id(db, id)
 
 
     if toolkit:
     if toolkit:
         return toolkit
         return toolkit
@@ -131,7 +134,11 @@ async def get_toolkit_by_id(id: str, user=Depends(get_admin_user)):
 
 
 @router.post("/id/{id}/update", response_model=Optional[ToolModel])
 @router.post("/id/{id}/update", response_model=Optional[ToolModel])
 async def update_toolkit_by_id(
 async def update_toolkit_by_id(
-    request: Request, id: str, form_data: ToolForm, user=Depends(get_admin_user)
+    request: Request,
+    id: str,
+    form_data: ToolForm,
+    user=Depends(get_admin_user),
+    db=Depends(get_db),
 ):
 ):
     toolkit_path = os.path.join(TOOLS_DIR, f"{id}.py")
     toolkit_path = os.path.join(TOOLS_DIR, f"{id}.py")
 
 
@@ -153,7 +160,7 @@ async def update_toolkit_by_id(
         }
         }
 
 
         print(updated)
         print(updated)
-        toolkit = Tools.update_tool_by_id(id, updated)
+        toolkit = Tools.update_tool_by_id(db, id, updated)
 
 
         if toolkit:
         if toolkit:
             return toolkit
             return toolkit
@@ -176,8 +183,10 @@ async def update_toolkit_by_id(
 
 
 
 
 @router.delete("/id/{id}/delete", response_model=bool)
 @router.delete("/id/{id}/delete", response_model=bool)
-async def delete_toolkit_by_id(request: Request, id: str, user=Depends(get_admin_user)):
-    result = Tools.delete_tool_by_id(id)
+async def delete_toolkit_by_id(
+    request: Request, id: str, user=Depends(get_admin_user), db=Depends(get_db)
+):
+    result = Tools.delete_tool_by_id(db, id)
 
 
     if result:
     if result:
         TOOLS = request.app.state.TOOLS
         TOOLS = request.app.state.TOOLS

+ 44 - 25
backend/apps/webui/routers/users.py

@@ -9,6 +9,7 @@ import time
 import uuid
 import uuid
 import logging
 import logging
 
 
+from apps.webui.internal.db import get_db
 from apps.webui.models.users import (
 from apps.webui.models.users import (
     UserModel,
     UserModel,
     UserUpdateForm,
     UserUpdateForm,
@@ -40,8 +41,10 @@ router = APIRouter()
 
 
 
 
 @router.get("/", response_model=List[UserModel])
 @router.get("/", response_model=List[UserModel])
-async def get_users(skip: int = 0, limit: int = 50, user=Depends(get_admin_user)):
-    return Users.get_users(skip, limit)
+async def get_users(
+    skip: int = 0, limit: int = 50, user=Depends(get_admin_user), db=Depends(get_db)
+):
+    return Users.get_users(db, skip, limit)
 
 
 
 
 ############################
 ############################
@@ -68,10 +71,12 @@ async def update_user_permissions(
 
 
 
 
 @router.post("/update/role", response_model=Optional[UserModel])
 @router.post("/update/role", response_model=Optional[UserModel])
-async def update_user_role(form_data: UserRoleUpdateForm, user=Depends(get_admin_user)):
+async def update_user_role(
+    form_data: UserRoleUpdateForm, user=Depends(get_admin_user), db=Depends(get_db)
+):
 
 
-    if user.id != form_data.id and form_data.id != Users.get_first_user().id:
-        return Users.update_user_role_by_id(form_data.id, form_data.role)
+    if user.id != form_data.id and form_data.id != Users.get_first_user(db).id:
+        return Users.update_user_role_by_id(db, form_data.id, form_data.role)
 
 
     raise HTTPException(
     raise HTTPException(
         status_code=status.HTTP_403_FORBIDDEN,
         status_code=status.HTTP_403_FORBIDDEN,
@@ -85,8 +90,10 @@ async def update_user_role(form_data: UserRoleUpdateForm, user=Depends(get_admin
 
 
 
 
 @router.get("/user/settings", response_model=Optional[UserSettings])
 @router.get("/user/settings", response_model=Optional[UserSettings])
-async def get_user_settings_by_session_user(user=Depends(get_verified_user)):
-    user = Users.get_user_by_id(user.id)
+async def get_user_settings_by_session_user(
+    user=Depends(get_verified_user), db=Depends(get_db)
+):
+    user = Users.get_user_by_id(db, user.id)
     if user:
     if user:
         return user.settings
         return user.settings
     else:
     else:
@@ -103,9 +110,9 @@ async def get_user_settings_by_session_user(user=Depends(get_verified_user)):
 
 
 @router.post("/user/settings/update", response_model=UserSettings)
 @router.post("/user/settings/update", response_model=UserSettings)
 async def update_user_settings_by_session_user(
 async def update_user_settings_by_session_user(
-    form_data: UserSettings, user=Depends(get_verified_user)
+    form_data: UserSettings, user=Depends(get_verified_user), db=Depends(get_db)
 ):
 ):
-    user = Users.update_user_by_id(user.id, {"settings": form_data.model_dump()})
+    user = Users.update_user_by_id(db, user.id, {"settings": form_data.model_dump()})
     if user:
     if user:
         return user.settings
         return user.settings
     else:
     else:
@@ -121,8 +128,10 @@ async def update_user_settings_by_session_user(
 
 
 
 
 @router.get("/user/info", response_model=Optional[dict])
 @router.get("/user/info", response_model=Optional[dict])
-async def get_user_info_by_session_user(user=Depends(get_verified_user)):
-    user = Users.get_user_by_id(user.id)
+async def get_user_info_by_session_user(
+    user=Depends(get_verified_user), db=Depends(get_db)
+):
+    user = Users.get_user_by_id(db, user.id)
     if user:
     if user:
         return user.info
         return user.info
     else:
     else:
@@ -138,15 +147,17 @@ async def get_user_info_by_session_user(user=Depends(get_verified_user)):
 
 
 
 
 @router.post("/user/info/update", response_model=Optional[dict])
 @router.post("/user/info/update", response_model=Optional[dict])
-async def update_user_settings_by_session_user(
-    form_data: dict, user=Depends(get_verified_user)
+async def update_user_info_by_session_user(
+    form_data: dict, user=Depends(get_verified_user), db=Depends(get_db)
 ):
 ):
-    user = Users.get_user_by_id(user.id)
+    user = Users.get_user_by_id(db, user.id)
     if user:
     if user:
         if user.info is None:
         if user.info is None:
             user.info = {}
             user.info = {}
 
 
-        user = Users.update_user_by_id(user.id, {"info": {**user.info, **form_data}})
+        user = Users.update_user_by_id(
+            db, user.id, {"info": {**user.info, **form_data}}
+        )
         if user:
         if user:
             return user.info
             return user.info
         else:
         else:
@@ -172,13 +183,15 @@ class UserResponse(BaseModel):
 
 
 
 
 @router.get("/{user_id}", response_model=UserResponse)
 @router.get("/{user_id}", response_model=UserResponse)
-async def get_user_by_id(user_id: str, user=Depends(get_verified_user)):
+async def get_user_by_id(
+    user_id: str, user=Depends(get_verified_user), db=Depends(get_db)
+):
 
 
     # Check if user_id is a shared chat
     # Check if user_id is a shared chat
     # If it is, get the user_id from the chat
     # If it is, get the user_id from the chat
     if user_id.startswith("shared-"):
     if user_id.startswith("shared-"):
         chat_id = user_id.replace("shared-", "")
         chat_id = user_id.replace("shared-", "")
-        chat = Chats.get_chat_by_id(chat_id)
+        chat = Chats.get_chat_by_id(db, chat_id)
         if chat:
         if chat:
             user_id = chat.user_id
             user_id = chat.user_id
         else:
         else:
@@ -187,7 +200,7 @@ async def get_user_by_id(user_id: str, user=Depends(get_verified_user)):
                 detail=ERROR_MESSAGES.USER_NOT_FOUND,
                 detail=ERROR_MESSAGES.USER_NOT_FOUND,
             )
             )
 
 
-    user = Users.get_user_by_id(user_id)
+    user = Users.get_user_by_id(db, user_id)
 
 
     if user:
     if user:
         return UserResponse(name=user.name, profile_image_url=user.profile_image_url)
         return UserResponse(name=user.name, profile_image_url=user.profile_image_url)
@@ -205,13 +218,16 @@ async def get_user_by_id(user_id: str, user=Depends(get_verified_user)):
 
 
 @router.post("/{user_id}/update", response_model=Optional[UserModel])
 @router.post("/{user_id}/update", response_model=Optional[UserModel])
 async def update_user_by_id(
 async def update_user_by_id(
-    user_id: str, form_data: UserUpdateForm, session_user=Depends(get_admin_user)
+    user_id: str,
+    form_data: UserUpdateForm,
+    session_user=Depends(get_admin_user),
+    db=Depends(get_db),
 ):
 ):
-    user = Users.get_user_by_id(user_id)
+    user = Users.get_user_by_id(db, user_id)
 
 
     if user:
     if user:
         if form_data.email.lower() != user.email:
         if form_data.email.lower() != user.email:
-            email_user = Users.get_user_by_email(form_data.email.lower())
+            email_user = Users.get_user_by_email(db, form_data.email.lower())
             if email_user:
             if email_user:
                 raise HTTPException(
                 raise HTTPException(
                     status_code=status.HTTP_400_BAD_REQUEST,
                     status_code=status.HTTP_400_BAD_REQUEST,
@@ -221,10 +237,11 @@ async def update_user_by_id(
         if form_data.password:
         if form_data.password:
             hashed = get_password_hash(form_data.password)
             hashed = get_password_hash(form_data.password)
             log.debug(f"hashed: {hashed}")
             log.debug(f"hashed: {hashed}")
-            Auths.update_user_password_by_id(user_id, hashed)
+            Auths.update_user_password_by_id(db, user_id, hashed)
 
 
-        Auths.update_email_by_id(user_id, form_data.email.lower())
+        Auths.update_email_by_id(db, user_id, form_data.email.lower())
         updated_user = Users.update_user_by_id(
         updated_user = Users.update_user_by_id(
+            db,
             user_id,
             user_id,
             {
             {
                 "name": form_data.name,
                 "name": form_data.name,
@@ -253,9 +270,11 @@ async def update_user_by_id(
 
 
 
 
 @router.delete("/{user_id}", response_model=bool)
 @router.delete("/{user_id}", response_model=bool)
-async def delete_user_by_id(user_id: str, user=Depends(get_admin_user)):
+async def delete_user_by_id(
+    user_id: str, user=Depends(get_admin_user), db=Depends(get_db)
+):
     if user.id != user_id:
     if user.id != user_id:
-        result = Auths.delete_auth_by_id(user_id)
+        result = Auths.delete_auth_by_id(db, user_id)
 
 
         if result:
         if result:
             return True
             return True

+ 4 - 4
backend/apps/webui/routers/utils.py

@@ -1,6 +1,5 @@
 from fastapi import APIRouter, UploadFile, File, Response
 from fastapi import APIRouter, UploadFile, File, Response
 from fastapi import Depends, HTTPException, status
 from fastapi import Depends, HTTPException, status
-from peewee import SqliteDatabase
 from starlette.responses import StreamingResponse, FileResponse
 from starlette.responses import StreamingResponse, FileResponse
 from pydantic import BaseModel
 from pydantic import BaseModel
 
 
@@ -10,7 +9,6 @@ import markdown
 import black
 import black
 
 
 
 
-from apps.webui.internal.db import DB
 from utils.utils import get_admin_user
 from utils.utils import get_admin_user
 from utils.misc import calculate_sha256, get_gravatar_url
 from utils.misc import calculate_sha256, get_gravatar_url
 
 
@@ -114,13 +112,15 @@ async def download_db(user=Depends(get_admin_user)):
             status_code=status.HTTP_401_UNAUTHORIZED,
             status_code=status.HTTP_401_UNAUTHORIZED,
             detail=ERROR_MESSAGES.ACCESS_PROHIBITED,
             detail=ERROR_MESSAGES.ACCESS_PROHIBITED,
         )
         )
-    if not isinstance(DB, SqliteDatabase):
+    from apps.webui.internal.db import engine
+
+    if engine.name != "sqlite":
         raise HTTPException(
         raise HTTPException(
             status_code=status.HTTP_400_BAD_REQUEST,
             status_code=status.HTTP_400_BAD_REQUEST,
             detail=ERROR_MESSAGES.DB_NOT_SQLITE,
             detail=ERROR_MESSAGES.DB_NOT_SQLITE,
         )
         )
     return FileResponse(
     return FileResponse(
-        DB.database,
+        engine.url.database,
         media_type="application/octet-stream",
         media_type="application/octet-stream",
         filename="webui.db",
         filename="webui.db",
     )
     )

+ 45 - 11
backend/main.py

@@ -1,5 +1,6 @@
 import base64
 import base64
 import uuid
 import uuid
+import subprocess
 from contextlib import asynccontextmanager
 from contextlib import asynccontextmanager
 
 
 from authlib.integrations.starlette_client import OAuth
 from authlib.integrations.starlette_client import OAuth
@@ -27,6 +28,8 @@ from fastapi.responses import JSONResponse
 from fastapi import HTTPException
 from fastapi import HTTPException
 from fastapi.middleware.wsgi import WSGIMiddleware
 from fastapi.middleware.wsgi import WSGIMiddleware
 from fastapi.middleware.cors import CORSMiddleware
 from fastapi.middleware.cors import CORSMiddleware
+from sqlalchemy import text
+from sqlalchemy.orm import Session
 from starlette.exceptions import HTTPException as StarletteHTTPException
 from starlette.exceptions import HTTPException as StarletteHTTPException
 from starlette.middleware.base import BaseHTTPMiddleware
 from starlette.middleware.base import BaseHTTPMiddleware
 from starlette.middleware.sessions import SessionMiddleware
 from starlette.middleware.sessions import SessionMiddleware
@@ -54,6 +57,7 @@ from apps.webui.main import (
     get_pipe_models,
     get_pipe_models,
     generate_function_chat_completion,
     generate_function_chat_completion,
 )
 )
+from apps.webui.internal.db import get_db, SessionLocal
 
 
 
 
 from pydantic import BaseModel
 from pydantic import BaseModel
@@ -124,6 +128,8 @@ from config import (
     WEBUI_SESSION_COOKIE_SAME_SITE,
     WEBUI_SESSION_COOKIE_SAME_SITE,
     WEBUI_SESSION_COOKIE_SECURE,
     WEBUI_SESSION_COOKIE_SECURE,
     AppConfig,
     AppConfig,
+    BACKEND_DIR,
+    DATABASE_URL,
 )
 )
 from constants import ERROR_MESSAGES, WEBHOOK_MESSAGES
 from constants import ERROR_MESSAGES, WEBHOOK_MESSAGES
 from utils.webhook import post_webhook
 from utils.webhook import post_webhook
@@ -166,8 +172,19 @@ https://github.com/open-webui/open-webui
 )
 )
 
 
 
 
+def run_migrations():
+    from alembic.config import Config
+    from alembic import command
+
+    alembic_cfg = Config(f"{BACKEND_DIR}/alembic.ini")
+    alembic_cfg.set_main_option("sqlalchemy.url", DATABASE_URL)
+    alembic_cfg.set_main_option("script_location", f"{BACKEND_DIR}/migrations")
+    command.upgrade(alembic_cfg, "head")
+
+
 @asynccontextmanager
 @asynccontextmanager
 async def lifespan(app: FastAPI):
 async def lifespan(app: FastAPI):
+    run_migrations()
     yield
     yield
 
 
 
 
@@ -393,6 +410,7 @@ class ChatCompletionMiddleware(BaseHTTPMiddleware):
             user = get_current_user(
             user = get_current_user(
                 request,
                 request,
                 get_http_authorization_cred(request.headers.get("Authorization")),
                 get_http_authorization_cred(request.headers.get("Authorization")),
+                SessionLocal(),
             )
             )
             # Flag to skip RAG completions if file_handler is present in tools/functions
             # Flag to skip RAG completions if file_handler is present in tools/functions
             skip_files = False
             skip_files = False
@@ -736,6 +754,7 @@ class PipelineMiddleware(BaseHTTPMiddleware):
             user = get_current_user(
             user = get_current_user(
                 request,
                 request,
                 get_http_authorization_cred(request.headers.get("Authorization")),
                 get_http_authorization_cred(request.headers.get("Authorization")),
+                SessionLocal(),
             )
             )
 
 
             try:
             try:
@@ -781,7 +800,9 @@ app.add_middleware(
 @app.middleware("http")
 @app.middleware("http")
 async def check_url(request: Request, call_next):
 async def check_url(request: Request, call_next):
     if len(app.state.MODELS) == 0:
     if len(app.state.MODELS) == 0:
-        await get_all_models()
+        db = SessionLocal()
+        await get_all_models(db)
+        db.commit()
     else:
     else:
         pass
         pass
 
 
@@ -815,12 +836,12 @@ app.mount("/api/v1", webui_app)
 webui_app.state.EMBEDDING_FUNCTION = rag_app.state.EMBEDDING_FUNCTION
 webui_app.state.EMBEDDING_FUNCTION = rag_app.state.EMBEDDING_FUNCTION
 
 
 
 
-async def get_all_models():
+async def get_all_models(db: Session):
     pipe_models = []
     pipe_models = []
     openai_models = []
     openai_models = []
     ollama_models = []
     ollama_models = []
 
 
-    pipe_models = await get_pipe_models()
+    pipe_models = await get_pipe_models(db)
 
 
     if app.state.config.ENABLE_OPENAI_API:
     if app.state.config.ENABLE_OPENAI_API:
         openai_models = await get_openai_models()
         openai_models = await get_openai_models()
@@ -842,7 +863,7 @@ async def get_all_models():
 
 
     models = pipe_models + openai_models + ollama_models
     models = pipe_models + openai_models + ollama_models
 
 
-    custom_models = Models.get_all_models()
+    custom_models = Models.get_all_models(db)
     for custom_model in custom_models:
     for custom_model in custom_models:
         if custom_model.base_model_id == None:
         if custom_model.base_model_id == None:
             for model in models:
             for model in models:
@@ -882,8 +903,8 @@ async def get_all_models():
 
 
 
 
 @app.get("/api/models")
 @app.get("/api/models")
-async def get_models(user=Depends(get_verified_user)):
-    models = await get_all_models()
+async def get_models(user=Depends(get_verified_user), db=Depends(get_db)):
+    models = await get_all_models(db)
 
 
     # Filter out filter pipelines
     # Filter out filter pipelines
     models = [
     models = [
@@ -1584,9 +1605,12 @@ async def get_pipelines(urlIdx: Optional[int] = None, user=Depends(get_admin_use
 
 
 @app.get("/api/pipelines/{pipeline_id}/valves")
 @app.get("/api/pipelines/{pipeline_id}/valves")
 async def get_pipeline_valves(
 async def get_pipeline_valves(
-    urlIdx: Optional[int], pipeline_id: str, user=Depends(get_admin_user)
+    urlIdx: Optional[int],
+    pipeline_id: str,
+    user=Depends(get_admin_user),
+    db=Depends(get_db),
 ):
 ):
-    models = await get_all_models()
+    models = await get_all_models(db)
     r = None
     r = None
     try:
     try:
 
 
@@ -1622,9 +1646,12 @@ async def get_pipeline_valves(
 
 
 @app.get("/api/pipelines/{pipeline_id}/valves/spec")
 @app.get("/api/pipelines/{pipeline_id}/valves/spec")
 async def get_pipeline_valves_spec(
 async def get_pipeline_valves_spec(
-    urlIdx: Optional[int], pipeline_id: str, user=Depends(get_admin_user)
+    urlIdx: Optional[int],
+    pipeline_id: str,
+    user=Depends(get_admin_user),
+    db=Depends(get_db),
 ):
 ):
-    models = await get_all_models()
+    models = await get_all_models(db)
 
 
     r = None
     r = None
     try:
     try:
@@ -1663,8 +1690,9 @@ async def update_pipeline_valves(
     pipeline_id: str,
     pipeline_id: str,
     form_data: dict,
     form_data: dict,
     user=Depends(get_admin_user),
     user=Depends(get_admin_user),
+    db=Depends(get_db),
 ):
 ):
-    models = await get_all_models()
+    models = await get_all_models(db)
 
 
     r = None
     r = None
     try:
     try:
@@ -2011,6 +2039,12 @@ async def healthcheck():
     return {"status": True}
     return {"status": True}
 
 
 
 
+@app.get("/health/db")
+async def healthcheck_with_db(db: Session = Depends(get_db)):
+    result = db.execute(text("SELECT 1;")).all()
+    return {"status": True}
+
+
 app.mount("/static", StaticFiles(directory=STATIC_DIR), name="static")
 app.mount("/static", StaticFiles(directory=STATIC_DIR), name="static")
 app.mount("/cache", StaticFiles(directory=CACHE_DIR), name="cache")
 app.mount("/cache", StaticFiles(directory=CACHE_DIR), name="cache")
 
 

+ 4 - 0
backend/migrations/README

@@ -0,0 +1,4 @@
+Generic single-database configuration.
+
+Create new migrations with
+DATABASE_URL=<replace with actual url> alembic revision --autogenerate -m "a description"

+ 93 - 0
backend/migrations/env.py

@@ -0,0 +1,93 @@
+import os
+from logging.config import fileConfig
+
+from sqlalchemy import engine_from_config
+from sqlalchemy import pool
+
+from alembic import context
+
+from apps.webui.models.auths import Auth
+from apps.webui.models.chats import Chat
+from apps.webui.models.documents import Document
+from apps.webui.models.memories import Memory
+from apps.webui.models.models import Model
+from apps.webui.models.prompts import Prompt
+from apps.webui.models.tags import Tag, ChatIdTag
+from apps.webui.models.tools import Tool
+from apps.webui.models.users import User
+from apps.webui.models.files import File
+from apps.webui.models.functions import Function
+
+# this is the Alembic Config object, which provides
+# access to the values within the .ini file in use.
+config = context.config
+
+# Interpret the config file for Python logging.
+# This line sets up loggers basically.
+if config.config_file_name is not None:
+    fileConfig(config.config_file_name)
+
+# add your model's MetaData object here
+# for 'autogenerate' support
+# from myapp import mymodel
+# target_metadata = mymodel.Base.metadata
+target_metadata = Auth.metadata
+
+# other values from the config, defined by the needs of env.py,
+# can be acquired:
+# my_important_option = config.get_main_option("my_important_option")
+# ... etc.
+
+database_url = os.getenv("DATABASE_URL", None)
+if database_url:
+    config.set_main_option("sqlalchemy.url", database_url)
+
+
+def run_migrations_offline() -> None:
+    """Run migrations in 'offline' mode.
+
+    This configures the context with just a URL
+    and not an Engine, though an Engine is acceptable
+    here as well.  By skipping the Engine creation
+    we don't even need a DBAPI to be available.
+
+    Calls to context.execute() here emit the given string to the
+    script output.
+
+    """
+    url = config.get_main_option("sqlalchemy.url")
+    context.configure(
+        url=url,
+        target_metadata=target_metadata,
+        literal_binds=True,
+        dialect_opts={"paramstyle": "named"},
+    )
+
+    with context.begin_transaction():
+        context.run_migrations()
+
+
+def run_migrations_online() -> None:
+    """Run migrations in 'online' mode.
+
+    In this scenario we need to create an Engine
+    and associate a connection with the context.
+
+    """
+    connectable = engine_from_config(
+        config.get_section(config.config_ini_section, {}),
+        prefix="sqlalchemy.",
+        poolclass=pool.NullPool,
+    )
+
+    with connectable.connect() as connection:
+        context.configure(connection=connection, target_metadata=target_metadata)
+
+        with context.begin_transaction():
+            context.run_migrations()
+
+
+if context.is_offline_mode():
+    run_migrations_offline()
+else:
+    run_migrations_online()

+ 27 - 0
backend/migrations/script.py.mako

@@ -0,0 +1,27 @@
+"""${message}
+
+Revision ID: ${up_revision}
+Revises: ${down_revision | comma,n}
+Create Date: ${create_date}
+
+"""
+from typing import Sequence, Union
+
+from alembic import op
+import sqlalchemy as sa
+import apps.webui.internal.db
+${imports if imports else ""}
+
+# revision identifiers, used by Alembic.
+revision: str = ${repr(up_revision)}
+down_revision: Union[str, None] = ${repr(down_revision)}
+branch_labels: Union[str, Sequence[str], None] = ${repr(branch_labels)}
+depends_on: Union[str, Sequence[str], None] = ${repr(depends_on)}
+
+
+def upgrade() -> None:
+    ${upgrades if upgrades else "pass"}
+
+
+def downgrade() -> None:
+    ${downgrades if downgrades else "pass"}

+ 188 - 0
backend/migrations/versions/22b5ab2667b8_init.py

@@ -0,0 +1,188 @@
+"""init
+
+Revision ID: 22b5ab2667b8
+Revises: 
+Create Date: 2024-06-20 13:22:40.397002
+
+"""
+
+from typing import Sequence, Union
+
+from alembic import op
+import sqlalchemy as sa
+from sqlalchemy.engine.reflection import Inspector
+
+import apps.webui.internal.db
+
+
+# revision identifiers, used by Alembic.
+revision: str = "22b5ab2667b8"
+down_revision: Union[str, None] = None
+branch_labels: Union[str, Sequence[str], None] = None
+depends_on: Union[str, Sequence[str], None] = None
+
+
+def upgrade() -> None:
+    con = op.get_bind()
+    inspector = Inspector.from_engine(con)
+    tables = set(inspector.get_table_names())
+
+    # ### commands auto generated by Alembic - please adjust! ###
+    if not "auth" in tables:
+        op.create_table(
+            "auth",
+            sa.Column("id", sa.String(), nullable=False),
+            sa.Column("email", sa.String(), nullable=True),
+            sa.Column("password", sa.String(), nullable=True),
+            sa.Column("active", sa.Boolean(), nullable=True),
+            sa.PrimaryKeyConstraint("id"),
+        )
+
+    if not "chat" in tables:
+        op.create_table(
+            "chat",
+            sa.Column("id", sa.String(), nullable=False),
+            sa.Column("user_id", sa.String(), nullable=True),
+            sa.Column("title", sa.String(), nullable=True),
+            sa.Column("chat", sa.String(), nullable=True),
+            sa.Column("created_at", sa.BigInteger(), nullable=True),
+            sa.Column("updated_at", sa.BigInteger(), nullable=True),
+            sa.Column("share_id", sa.String(), nullable=True),
+            sa.Column("archived", sa.Boolean(), nullable=True),
+            sa.PrimaryKeyConstraint("id"),
+            sa.UniqueConstraint("share_id"),
+        )
+
+    if not "chatidtag" in tables:
+        op.create_table(
+            "chatidtag",
+            sa.Column("id", sa.String(), nullable=False),
+            sa.Column("tag_name", sa.String(), nullable=True),
+            sa.Column("chat_id", sa.String(), nullable=True),
+            sa.Column("user_id", sa.String(), nullable=True),
+            sa.Column("timestamp", sa.BigInteger(), nullable=True),
+            sa.PrimaryKeyConstraint("id"),
+        )
+
+    if not "document" in tables:
+        op.create_table(
+            "document",
+            sa.Column("collection_name", sa.String(), nullable=False),
+            sa.Column("name", sa.String(), nullable=True),
+            sa.Column("title", sa.String(), nullable=True),
+            sa.Column("filename", sa.String(), nullable=True),
+            sa.Column("content", sa.String(), nullable=True),
+            sa.Column("user_id", sa.String(), nullable=True),
+            sa.Column("timestamp", sa.BigInteger(), nullable=True),
+            sa.PrimaryKeyConstraint("collection_name"),
+            sa.UniqueConstraint("name"),
+        )
+
+    if not "memory" in tables:
+        op.create_table(
+            "memory",
+            sa.Column("id", sa.String(), nullable=False),
+            sa.Column("user_id", sa.String(), nullable=True),
+            sa.Column("content", sa.String(), nullable=True),
+            sa.Column("updated_at", sa.BigInteger(), nullable=True),
+            sa.Column("created_at", sa.BigInteger(), nullable=True),
+            sa.PrimaryKeyConstraint("id"),
+        )
+
+    if not "model" in tables:
+        op.create_table(
+            "model",
+            sa.Column("id", sa.String(), nullable=False),
+            sa.Column("user_id", sa.String(), nullable=True),
+            sa.Column("base_model_id", sa.String(), nullable=True),
+            sa.Column("name", sa.String(), nullable=True),
+            sa.Column("params", apps.webui.internal.db.JSONField(), nullable=True),
+            sa.Column("meta", apps.webui.internal.db.JSONField(), nullable=True),
+            sa.Column("updated_at", sa.BigInteger(), nullable=True),
+            sa.Column("created_at", sa.BigInteger(), nullable=True),
+            sa.PrimaryKeyConstraint("id"),
+        )
+
+    if not "prompt" in tables:
+        op.create_table(
+            "prompt",
+            sa.Column("command", sa.String(), nullable=False),
+            sa.Column("user_id", sa.String(), nullable=True),
+            sa.Column("title", sa.String(), nullable=True),
+            sa.Column("content", sa.String(), nullable=True),
+            sa.Column("timestamp", sa.BigInteger(), nullable=True),
+            sa.PrimaryKeyConstraint("command"),
+        )
+
+    if not "tag" in tables:
+        op.create_table(
+            "tag",
+            sa.Column("id", sa.String(), nullable=False),
+            sa.Column("name", sa.String(), nullable=True),
+            sa.Column("user_id", sa.String(), nullable=True),
+            sa.Column("data", sa.String(), nullable=True),
+            sa.PrimaryKeyConstraint("id"),
+        )
+
+    if not "tool" in tables:
+        op.create_table(
+            "tool",
+            sa.Column("id", sa.String(), nullable=False),
+            sa.Column("user_id", sa.String(), nullable=True),
+            sa.Column("name", sa.String(), nullable=True),
+            sa.Column("content", sa.String(), nullable=True),
+            sa.Column("specs", apps.webui.internal.db.JSONField(), nullable=True),
+            sa.Column("meta", apps.webui.internal.db.JSONField(), nullable=True),
+            sa.Column("updated_at", sa.BigInteger(), nullable=True),
+            sa.Column("created_at", sa.BigInteger(), nullable=True),
+            sa.PrimaryKeyConstraint("id"),
+        )
+
+    if not "user" in tables:
+        op.create_table(
+            "user",
+            sa.Column("id", sa.String(), nullable=False),
+            sa.Column("name", sa.String(), nullable=True),
+            sa.Column("email", sa.String(), nullable=True),
+            sa.Column("role", sa.String(), nullable=True),
+            sa.Column("profile_image_url", sa.String(), nullable=True),
+            sa.Column("last_active_at", sa.BigInteger(), nullable=True),
+            sa.Column("updated_at", sa.BigInteger(), nullable=True),
+            sa.Column("created_at", sa.BigInteger(), nullable=True),
+            sa.Column("api_key", sa.String(), nullable=True),
+            sa.Column("settings", apps.webui.internal.db.JSONField(), nullable=True),
+            sa.Column("info", apps.webui.internal.db.JSONField(), nullable=True),
+            sa.PrimaryKeyConstraint("id"),
+            sa.UniqueConstraint("api_key"),
+        )
+
+    if not "file" in tables:
+        op.create_table('file',
+                        sa.Column('id', sa.String(), nullable=False),
+                        sa.Column('user_id', sa.String(), nullable=True),
+                        sa.Column('filename', sa.String(), nullable=True),
+                        sa.Column('meta', apps.webui.internal.db.JSONField(), nullable=True),
+                        sa.Column('created_at', sa.BigInteger(), nullable=True),
+                        sa.PrimaryKeyConstraint('id')
+                        )
+
+    if not "function" in tables:
+        op.create_table('function',
+                        sa.Column('id', sa.String(), nullable=False),
+                        sa.Column('user_id', sa.String(), nullable=True),
+                        sa.Column('name', sa.Text(), nullable=True),
+                        sa.Column('type', sa.Text(), nullable=True),
+                        sa.Column('content', sa.Text(), nullable=True),
+                        sa.Column('meta', apps.webui.internal.db.JSONField(), nullable=True),
+                        sa.Column('updated_at', sa.BigInteger(), nullable=True),
+                        sa.Column('created_at', sa.BigInteger(), nullable=True),
+                        sa.PrimaryKeyConstraint('id')
+                        )
+    # ### end Alembic commands ###
+
+
+def downgrade() -> None:
+    # ### commands auto generated by Alembic - please adjust! ###
+    # do nothing as we assume we had previous migrations from peewee-migrate
+    pass
+    # ### end Alembic commands ###

+ 10 - 3
backend/requirements.txt

@@ -12,8 +12,10 @@ passlib[bcrypt]==1.7.4
 
 
 requests==2.32.2
 requests==2.32.2
 aiohttp==3.9.5
 aiohttp==3.9.5
-peewee==3.17.5
-peewee-migrate==1.12.2
+sqlalchemy==2.0.30
+alembic==1.13.1
+# peewee==3.17.5
+# peewee-migrate==1.12.2
 psycopg2-binary==2.9.9
 psycopg2-binary==2.9.9
 PyMySQL==1.1.1
 PyMySQL==1.1.1
 bcrypt==4.1.3
 bcrypt==4.1.3
@@ -67,4 +69,9 @@ pytube==15.0.0
 
 
 extract_msg
 extract_msg
 pydub
 pydub
-duckduckgo-search~=6.1.5
+duckduckgo-search~=6.1.5
+
+## Tests
+docker~=7.1.0
+pytest~=8.2.1
+pytest-docker~=3.1.1

+ 0 - 0
backend/test/__init__.py


+ 209 - 0
backend/test/apps/webui/routers/test_auths.py

@@ -0,0 +1,209 @@
+import pytest
+
+from test.util.abstract_integration_test import AbstractPostgresTest
+from test.util.mock_user import mock_webui_user
+
+
+class TestAuths(AbstractPostgresTest):
+    BASE_PATH = "/api/v1/auths"
+
+    def setup_class(cls):
+        super().setup_class()
+        from apps.webui.models.users import Users
+        from apps.webui.models.auths import Auths
+
+        cls.users = Users
+        cls.auths = Auths
+
+    def test_get_session_user(self):
+        with mock_webui_user():
+            response = self.fast_api_client.get(self.create_url(""))
+        assert response.status_code == 200
+        assert response.json() == {
+            "id": "1",
+            "name": "John Doe",
+            "email": "john.doe@openwebui.com",
+            "role": "user",
+            "profile_image_url": "/user.png",
+        }
+
+    def test_update_profile(self):
+        from utils.utils import get_password_hash
+
+        user = self.auths.insert_new_auth(
+            self.db_session,
+            email="john.doe@openwebui.com",
+            password=get_password_hash("old_password"),
+            name="John Doe",
+            profile_image_url="/user.png",
+            role="user",
+        )
+
+        with mock_webui_user(id=user.id):
+            response = self.fast_api_client.post(
+                self.create_url("/update/profile"),
+                json={"name": "John Doe 2", "profile_image_url": "/user2.png"},
+            )
+        assert response.status_code == 200
+        db_user = self.users.get_user_by_id(self.db_session, user.id)
+        assert db_user.name == "John Doe 2"
+        assert db_user.profile_image_url == "/user2.png"
+
+    def test_update_password(self):
+        from utils.utils import get_password_hash
+
+        user = self.auths.insert_new_auth(
+            self.db_session,
+            email="john.doe@openwebui.com",
+            password=get_password_hash("old_password"),
+            name="John Doe",
+            profile_image_url="/user.png",
+            role="user",
+        )
+
+        with mock_webui_user(id=user.id):
+            response = self.fast_api_client.post(
+                self.create_url("/update/password"),
+                json={"password": "old_password", "new_password": "new_password"},
+            )
+        assert response.status_code == 200
+
+        old_auth = self.auths.authenticate_user(
+            self.db_session, "john.doe@openwebui.com", "old_password"
+        )
+        assert old_auth is None
+        new_auth = self.auths.authenticate_user(
+            self.db_session, "john.doe@openwebui.com", "new_password"
+        )
+        assert new_auth is not None
+
+    def test_signin(self):
+        from utils.utils import get_password_hash
+
+        user = self.auths.insert_new_auth(
+            self.db_session,
+            email="john.doe@openwebui.com",
+            password=get_password_hash("password"),
+            name="John Doe",
+            profile_image_url="/user.png",
+            role="user",
+        )
+        response = self.fast_api_client.post(
+            self.create_url("/signin"),
+            json={"email": "john.doe@openwebui.com", "password": "password"},
+        )
+        assert response.status_code == 200
+        data = response.json()
+        assert data["id"] == user.id
+        assert data["name"] == "John Doe"
+        assert data["email"] == "john.doe@openwebui.com"
+        assert data["role"] == "user"
+        assert data["profile_image_url"] == "/user.png"
+        assert data["token"] is not None and len(data["token"]) > 0
+        assert data["token_type"] == "Bearer"
+
+    def test_signup(self):
+        response = self.fast_api_client.post(
+            self.create_url("/signup"),
+            json={
+                "name": "John Doe",
+                "email": "john.doe@openwebui.com",
+                "password": "password",
+            },
+        )
+        assert response.status_code == 200
+        data = response.json()
+        assert data["id"] is not None and len(data["id"]) > 0
+        assert data["name"] == "John Doe"
+        assert data["email"] == "john.doe@openwebui.com"
+        assert data["role"] in ["admin", "user", "pending"]
+        assert data["profile_image_url"] == "/user.png"
+        assert data["token"] is not None and len(data["token"]) > 0
+        assert data["token_type"] == "Bearer"
+
+    def test_add_user(self):
+        with mock_webui_user():
+            response = self.fast_api_client.post(
+                self.create_url("/add"),
+                json={
+                    "name": "John Doe 2",
+                    "email": "john.doe2@openwebui.com",
+                    "password": "password2",
+                    "role": "admin",
+                },
+            )
+        assert response.status_code == 200
+        data = response.json()
+        assert data["id"] is not None and len(data["id"]) > 0
+        assert data["name"] == "John Doe 2"
+        assert data["email"] == "john.doe2@openwebui.com"
+        assert data["role"] == "admin"
+        assert data["profile_image_url"] == "/user.png"
+        assert data["token"] is not None and len(data["token"]) > 0
+        assert data["token_type"] == "Bearer"
+
+    def test_get_admin_details(self):
+        self.auths.insert_new_auth(
+            self.db_session,
+            email="john.doe@openwebui.com",
+            password="password",
+            name="John Doe",
+            profile_image_url="/user.png",
+            role="admin",
+        )
+        with mock_webui_user():
+            response = self.fast_api_client.get(self.create_url("/admin/details"))
+
+        assert response.status_code == 200
+        assert response.json() == {
+            "name": "John Doe",
+            "email": "john.doe@openwebui.com",
+        }
+
+    def test_create_api_key_(self):
+        user = self.auths.insert_new_auth(
+            self.db_session,
+            email="john.doe@openwebui.com",
+            password="password",
+            name="John Doe",
+            profile_image_url="/user.png",
+            role="admin",
+        )
+        with mock_webui_user(id=user.id):
+            response = self.fast_api_client.post(self.create_url("/api_key"))
+        assert response.status_code == 200
+        data = response.json()
+        assert data["api_key"] is not None
+        assert len(data["api_key"]) > 0
+
+    def test_delete_api_key(self):
+        user = self.auths.insert_new_auth(
+            self.db_session,
+            email="john.doe@openwebui.com",
+            password="password",
+            name="John Doe",
+            profile_image_url="/user.png",
+            role="admin",
+        )
+        self.users.update_user_api_key_by_id(self.db_session, user.id, "abc")
+        with mock_webui_user(id=user.id):
+            response = self.fast_api_client.delete(self.create_url("/api_key"))
+        assert response.status_code == 200
+        assert response.json() == True
+        db_user = self.users.get_user_by_id(self.db_session, user.id)
+        assert db_user.api_key is None
+
+    def test_get_api_key(self):
+        user = self.auths.insert_new_auth(
+            self.db_session,
+            email="john.doe@openwebui.com",
+            password="password",
+            name="John Doe",
+            profile_image_url="/user.png",
+            role="admin",
+        )
+        self.users.update_user_api_key_by_id(self.db_session, user.id, "abc")
+        with mock_webui_user(id=user.id):
+            response = self.fast_api_client.get(self.create_url("/api_key"))
+        assert response.status_code == 200
+        assert response.json() == {"api_key": "abc"}

+ 239 - 0
backend/test/apps/webui/routers/test_chats.py

@@ -0,0 +1,239 @@
+import uuid
+
+from test.util.abstract_integration_test import AbstractPostgresTest
+from test.util.mock_user import mock_webui_user
+
+
+class TestChats(AbstractPostgresTest):
+
+    BASE_PATH = "/api/v1/chats"
+
+    def setup_class(cls):
+        super().setup_class()
+
+    def setup_method(self):
+        super().setup_method()
+        from apps.webui.models.chats import ChatForm
+        from apps.webui.models.chats import Chats
+
+        self.chats = Chats
+        self.chats.insert_new_chat(
+            self.db_session,
+            "2",
+            ChatForm(
+                **{
+                    "chat": {
+                        "name": "chat1",
+                        "description": "chat1 description",
+                        "tags": ["tag1", "tag2"],
+                        "history": {"currentId": "1", "messages": []},
+                    }
+                }
+            ),
+        )
+
+    def test_get_session_user_chat_list(self):
+        with mock_webui_user(id="2"):
+            response = self.fast_api_client.get(self.create_url("/"))
+        assert response.status_code == 200
+        first_chat = response.json()[0]
+        assert first_chat["id"] is not None
+        assert first_chat["title"] == "New Chat"
+        assert first_chat["created_at"] is not None
+        assert first_chat["updated_at"] is not None
+
+    def test_delete_all_user_chats(self):
+        with mock_webui_user(id="2"):
+            response = self.fast_api_client.delete(self.create_url("/"))
+        assert response.status_code == 200
+        assert len(self.chats.get_chats(self.db_session)) == 0
+
+    def test_get_user_chat_list_by_user_id(self):
+        with mock_webui_user(id="3"):
+            response = self.fast_api_client.get(self.create_url("/list/user/2"))
+        assert response.status_code == 200
+        first_chat = response.json()[0]
+        assert first_chat["id"] is not None
+        assert first_chat["title"] == "New Chat"
+        assert first_chat["created_at"] is not None
+        assert first_chat["updated_at"] is not None
+
+    def test_create_new_chat(self):
+        with mock_webui_user(id="2"):
+            response = self.fast_api_client.post(
+                self.create_url("/new"),
+                json={
+                    "chat": {
+                        "name": "chat2",
+                        "description": "chat2 description",
+                        "tags": ["tag1", "tag2"],
+                    }
+                },
+            )
+        assert response.status_code == 200
+        data = response.json()
+        assert data["archived"] is False
+        assert data["chat"] == {
+            "name": "chat2",
+            "description": "chat2 description",
+            "tags": ["tag1", "tag2"],
+        }
+        assert data["user_id"] == "2"
+        assert data["id"] is not None
+        assert data["share_id"] is None
+        assert data["title"] == "New Chat"
+        assert data["updated_at"] is not None
+        assert data["created_at"] is not None
+        assert len(self.chats.get_chats(self.db_session)) == 2
+
+    def test_get_user_chats(self):
+        self.test_get_session_user_chat_list()
+
+    def test_get_user_archived_chats(self):
+        self.chats.archive_all_chats_by_user_id(self.db_session, "2")
+        self.db_session.commit()
+        with mock_webui_user(id="2"):
+            response = self.fast_api_client.get(self.create_url("/all/archived"))
+        assert response.status_code == 200
+        first_chat = response.json()[0]
+        assert first_chat["id"] is not None
+        assert first_chat["title"] == "New Chat"
+        assert first_chat["created_at"] is not None
+        assert first_chat["updated_at"] is not None
+
+    def test_get_all_user_chats_in_db(self):
+        with mock_webui_user(id="4"):
+            response = self.fast_api_client.get(self.create_url("/all/db"))
+        assert response.status_code == 200
+        assert len(response.json()) == 1
+
+    def test_get_archived_session_user_chat_list(self):
+        self.test_get_user_archived_chats()
+
+    def test_archive_all_chats(self):
+        with mock_webui_user(id="2"):
+            response = self.fast_api_client.post(self.create_url("/archive/all"))
+        assert response.status_code == 200
+        assert len(self.chats.get_archived_chats_by_user_id(self.db_session, "2")) == 1
+
+    def test_get_shared_chat_by_id(self):
+        chat_id = self.chats.get_chats(self.db_session)[0].id
+        self.chats.update_chat_share_id_by_id(self.db_session, chat_id, chat_id)
+        self.db_session.commit()
+        with mock_webui_user(id="2"):
+            response = self.fast_api_client.get(self.create_url(f"/share/{chat_id}"))
+        assert response.status_code == 200
+        data = response.json()
+        assert data["id"] == chat_id
+        assert data["chat"] == {
+            "name": "chat1",
+            "description": "chat1 description",
+            "tags": ["tag1", "tag2"],
+            "history": {"currentId": "1", "messages": []},
+        }
+        assert data["id"] == chat_id
+        assert data["share_id"] == chat_id
+        assert data["title"] == "New Chat"
+
+    def test_get_chat_by_id(self):
+        chat_id = self.chats.get_chats(self.db_session)[0].id
+        with mock_webui_user(id="2"):
+            response = self.fast_api_client.get(self.create_url(f"/{chat_id}"))
+        assert response.status_code == 200
+        data = response.json()
+        assert data["id"] == chat_id
+        assert data["chat"] == {
+            "name": "chat1",
+            "description": "chat1 description",
+            "tags": ["tag1", "tag2"],
+            "history": {"currentId": "1", "messages": []},
+        }
+        assert data["share_id"] is None
+        assert data["title"] == "New Chat"
+        assert data["user_id"] == "2"
+
+    def test_update_chat_by_id(self):
+        chat_id = self.chats.get_chats(self.db_session)[0].id
+        with mock_webui_user(id="2"):
+            response = self.fast_api_client.post(
+                self.create_url(f"/{chat_id}"),
+                json={
+                    "chat": {
+                        "name": "chat2",
+                        "description": "chat2 description",
+                        "tags": ["tag2", "tag4"],
+                        "title": "Just another title",
+                    }
+                },
+            )
+        assert response.status_code == 200
+        data = response.json()
+        assert data["id"] == chat_id
+        assert data["chat"] == {
+            "name": "chat2",
+            "title": "Just another title",
+            "description": "chat2 description",
+            "tags": ["tag2", "tag4"],
+            "history": {"currentId": "1", "messages": []},
+        }
+        assert data["share_id"] is None
+        assert data["title"] == "Just another title"
+        assert data["user_id"] == "2"
+
+    def test_delete_chat_by_id(self):
+        chat_id = self.chats.get_chats(self.db_session)[0].id
+        with mock_webui_user(id="2"):
+            response = self.fast_api_client.delete(self.create_url(f"/{chat_id}"))
+        assert response.status_code == 200
+        assert response.json() is True
+
+    def test_clone_chat_by_id(self):
+        chat_id = self.chats.get_chats(self.db_session)[0].id
+        with mock_webui_user(id="2"):
+            response = self.fast_api_client.get(self.create_url(f"/{chat_id}/clone"))
+
+        assert response.status_code == 200
+        data = response.json()
+        assert data["id"] != chat_id
+        assert data["chat"] == {
+            "branchPointMessageId": "1",
+            "description": "chat1 description",
+            "history": {"currentId": "1", "messages": []},
+            "name": "chat1",
+            "originalChatId": chat_id,
+            "tags": ["tag1", "tag2"],
+            "title": "Clone of New Chat",
+        }
+        assert data["share_id"] is None
+        assert data["title"] == "Clone of New Chat"
+        assert data["user_id"] == "2"
+
+    def test_archive_chat_by_id(self):
+        chat_id = self.chats.get_chats(self.db_session)[0].id
+        with mock_webui_user(id="2"):
+            response = self.fast_api_client.get(self.create_url(f"/{chat_id}/archive"))
+        assert response.status_code == 200
+
+        chat = self.chats.get_chat_by_id(self.db_session, chat_id)
+        assert chat.archived is True
+
+    def test_share_chat_by_id(self):
+        chat_id = self.chats.get_chats(self.db_session)[0].id
+        with mock_webui_user(id="2"):
+            response = self.fast_api_client.post(self.create_url(f"/{chat_id}/share"))
+        assert response.status_code == 200
+
+        chat = self.chats.get_chat_by_id(self.db_session, chat_id)
+        assert chat.share_id is not None
+
+    def test_delete_shared_chat_by_id(self):
+        chat_id = self.chats.get_chats(self.db_session)[0].id
+        share_id = str(uuid.uuid4())
+        self.chats.update_chat_share_id_by_id(self.db_session, chat_id, share_id)
+        self.db_session.commit()
+        with mock_webui_user(id="2"):
+            response = self.fast_api_client.delete(self.create_url(f"/{chat_id}/share"))
+        assert response.status_code
+
+        chat = self.chats.get_chat_by_id(self.db_session, chat_id)
+        assert chat.share_id is None

+ 106 - 0
backend/test/apps/webui/routers/test_documents.py

@@ -0,0 +1,106 @@
+from test.util.abstract_integration_test import AbstractPostgresTest
+from test.util.mock_user import mock_webui_user
+
+
+class TestDocuments(AbstractPostgresTest):
+
+    BASE_PATH = "/api/v1/documents"
+
+    def setup_class(cls):
+        super().setup_class()
+        from apps.webui.models.documents import Documents
+
+        cls.documents = Documents
+
+    def test_documents(self):
+        # Empty database
+        assert len(self.documents.get_docs(self.db_session)) == 0
+        with mock_webui_user(id="2"):
+            response = self.fast_api_client.get(self.create_url("/"))
+        assert response.status_code == 200
+        assert len(response.json()) == 0
+
+        # Create a new document
+        with mock_webui_user(id="2"):
+            response = self.fast_api_client.post(
+                self.create_url("/create"),
+                json={
+                    "name": "doc_name",
+                    "title": "doc title",
+                    "collection_name": "custom collection",
+                    "filename": "doc_name.pdf",
+                    "content": "",
+                },
+            )
+        assert response.status_code == 200
+        assert response.json()["name"] == "doc_name"
+        assert len(self.documents.get_docs(self.db_session)) == 1
+
+        # Get the document
+        with mock_webui_user(id="2"):
+            response = self.fast_api_client.get(self.create_url("/doc?name=doc_name"))
+        assert response.status_code == 200
+        data = response.json()
+        assert data["collection_name"] == "custom collection"
+        assert data["name"] == "doc_name"
+        assert data["title"] == "doc title"
+        assert data["filename"] == "doc_name.pdf"
+        assert data["content"] == {}
+
+        # Create another document
+        with mock_webui_user(id="2"):
+            response = self.fast_api_client.post(
+                self.create_url("/create"),
+                json={
+                    "name": "doc_name 2",
+                    "title": "doc title 2",
+                    "collection_name": "custom collection 2",
+                    "filename": "doc_name2.pdf",
+                    "content": "",
+                },
+            )
+        assert response.status_code == 200
+        assert response.json()["name"] == "doc_name 2"
+        assert len(self.documents.get_docs(self.db_session)) == 2
+
+        # Get all documents
+        with mock_webui_user(id="2"):
+            response = self.fast_api_client.get(self.create_url("/"))
+        assert response.status_code == 200
+        assert len(response.json()) == 2
+
+        # Update the first document
+        with mock_webui_user(id="2"):
+            response = self.fast_api_client.post(
+                self.create_url("/doc/update?name=doc_name"),
+                json={"name": "doc_name rework", "title": "updated title"},
+            )
+        assert response.status_code == 200
+        data = response.json()
+        assert data["name"] == "doc_name rework"
+        assert data["title"] == "updated title"
+
+        # Tag the first document
+        with mock_webui_user(id="2"):
+            response = self.fast_api_client.post(
+                self.create_url("/doc/tags"),
+                json={
+                    "name": "doc_name rework",
+                    "tags": [{"name": "testing-tag"}, {"name": "another-tag"}],
+                },
+            )
+        assert response.status_code == 200
+        data = response.json()
+        assert data["name"] == "doc_name rework"
+        assert data["content"] == {
+            "tags": [{"name": "testing-tag"}, {"name": "another-tag"}]
+        }
+        assert len(self.documents.get_docs(self.db_session)) == 2
+
+        # Delete the first document
+        with mock_webui_user(id="2"):
+            response = self.fast_api_client.delete(
+                self.create_url("/doc/delete?name=doc_name rework")
+            )
+        assert response.status_code == 200
+        assert len(self.documents.get_docs(self.db_session)) == 1

+ 60 - 0
backend/test/apps/webui/routers/test_models.py

@@ -0,0 +1,60 @@
+from test.util.abstract_integration_test import AbstractPostgresTest
+from test.util.mock_user import mock_webui_user
+
+
+class TestModels(AbstractPostgresTest):
+
+    BASE_PATH = "/api/v1/models"
+
+    def setup_class(cls):
+        super().setup_class()
+        from apps.webui.models.models import Model
+
+        cls.models = Model
+
+    def test_models(self):
+        with mock_webui_user(id="2"):
+            response = self.fast_api_client.get(self.create_url("/"))
+        assert response.status_code == 200
+        assert len(response.json()) == 0
+
+        with mock_webui_user(id="2"):
+            response = self.fast_api_client.post(
+                self.create_url("/add"),
+                json={
+                    "id": "my-model",
+                    "base_model_id": "base-model-id",
+                    "name": "Hello World",
+                    "meta": {
+                        "profile_image_url": "/favicon.png",
+                        "description": "description",
+                        "capabilities": None,
+                        "model_config": {},
+                    },
+                    "params": {},
+                },
+            )
+        assert response.status_code == 200
+
+        with mock_webui_user(id="2"):
+            response = self.fast_api_client.get(self.create_url("/"))
+        assert response.status_code == 200
+        assert len(response.json()) == 1
+
+        with mock_webui_user(id="2"):
+            response = self.fast_api_client.get(self.create_url("/my-model"))
+        assert response.status_code == 200
+        data = response.json()
+        assert data["id"] == "my-model"
+        assert data["name"] == "Hello World"
+
+        with mock_webui_user(id="2"):
+            response = self.fast_api_client.delete(
+                self.create_url("/delete?id=my-model")
+            )
+        assert response.status_code == 200
+
+        with mock_webui_user(id="2"):
+            response = self.fast_api_client.get(self.create_url("/"))
+        assert response.status_code == 200
+        assert len(response.json()) == 0

+ 82 - 0
backend/test/apps/webui/routers/test_prompts.py

@@ -0,0 +1,82 @@
+from test.util.abstract_integration_test import AbstractPostgresTest
+from test.util.mock_user import mock_webui_user
+
+
+class TestPrompts(AbstractPostgresTest):
+
+    BASE_PATH = "/api/v1/prompts"
+
+    def test_prompts(self):
+        # Get all prompts
+        with mock_webui_user(id="2"):
+            response = self.fast_api_client.get(self.create_url("/"))
+        assert response.status_code == 200
+        assert len(response.json()) == 0
+
+        # Create a two new prompts
+        with mock_webui_user(id="2"):
+            response = self.fast_api_client.post(
+                self.create_url("/create"),
+                json={
+                    "command": "/my-command",
+                    "title": "Hello World",
+                    "content": "description",
+                },
+            )
+        assert response.status_code == 200
+        with mock_webui_user(id="3"):
+            response = self.fast_api_client.post(
+                self.create_url("/create"),
+                json={
+                    "command": "/my-command2",
+                    "title": "Hello World 2",
+                    "content": "description 2",
+                },
+            )
+        assert response.status_code == 200
+
+        # Get all prompts
+        with mock_webui_user(id="2"):
+            response = self.fast_api_client.get(self.create_url("/"))
+        assert response.status_code == 200
+        assert len(response.json()) == 2
+
+        # Get prompt by command
+        with mock_webui_user(id="2"):
+            response = self.fast_api_client.get(self.create_url("/command/my-command"))
+        assert response.status_code == 200
+        data = response.json()
+        assert data["command"] == "/my-command"
+        assert data["title"] == "Hello World"
+        assert data["content"] == "description"
+        assert data["user_id"] == "2"
+
+        # Update prompt
+        with mock_webui_user(id="2"):
+            response = self.fast_api_client.post(
+                self.create_url("/command/my-command2/update"),
+                json={
+                    "command": "irrelevant for request",
+                    "title": "Hello World Updated",
+                    "content": "description Updated",
+                },
+            )
+        assert response.status_code == 200
+        data = response.json()
+        assert data["command"] == "/my-command2"
+        assert data["title"] == "Hello World Updated"
+        assert data["content"] == "description Updated"
+        assert data["user_id"] == "3"
+
+        # Delete prompt
+        with mock_webui_user(id="2"):
+            response = self.fast_api_client.delete(
+                self.create_url("/command/my-command/delete")
+            )
+        assert response.status_code == 200
+
+        # Get all prompts
+        with mock_webui_user(id="2"):
+            response = self.fast_api_client.get(self.create_url("/"))
+        assert response.status_code == 200
+        assert len(response.json()) == 1

+ 170 - 0
backend/test/apps/webui/routers/test_users.py

@@ -0,0 +1,170 @@
+from test.util.abstract_integration_test import AbstractPostgresTest
+from test.util.mock_user import mock_webui_user
+
+
+def _get_user_by_id(data, param):
+    return next((item for item in data if item["id"] == param), None)
+
+
+def _assert_user(data, id, **kwargs):
+    user = _get_user_by_id(data, id)
+    assert user is not None
+    comparison_data = {
+        "name": f"user {id}",
+        "email": f"user{id}@openwebui.com",
+        "profile_image_url": f"/user{id}.png",
+        "role": "user",
+        **kwargs,
+    }
+    for key, value in comparison_data.items():
+        assert user[key] == value
+
+
+class TestUsers(AbstractPostgresTest):
+
+    BASE_PATH = "/api/v1/users"
+
+    def setup_class(cls):
+        super().setup_class()
+        from apps.webui.models.users import Users
+
+        cls.users = Users
+
+    def setup_method(self):
+        super().setup_method()
+        self.users.insert_new_user(
+            self.db_session,
+            id="1",
+            name="user 1",
+            email="user1@openwebui.com",
+            profile_image_url="/user1.png",
+            role="user",
+        )
+        self.users.insert_new_user(
+            self.db_session,
+            id="2",
+            name="user 2",
+            email="user2@openwebui.com",
+            profile_image_url="/user2.png",
+            role="user",
+        )
+
+    def test_users(self):
+        # Get all users
+        with mock_webui_user(id="3"):
+            response = self.fast_api_client.get(self.create_url(""))
+        assert response.status_code == 200
+        assert len(response.json()) == 2
+        data = response.json()
+        _assert_user(data, "1")
+        _assert_user(data, "2")
+
+        # update role
+        with mock_webui_user(id="3"):
+            response = self.fast_api_client.post(
+                self.create_url("/update/role"), json={"id": "2", "role": "admin"}
+            )
+        assert response.status_code == 200
+        _assert_user([response.json()], "2", role="admin")
+
+        # Get all users
+        with mock_webui_user(id="3"):
+            response = self.fast_api_client.get(self.create_url(""))
+        assert response.status_code == 200
+        assert len(response.json()) == 2
+        data = response.json()
+        _assert_user(data, "1")
+        _assert_user(data, "2", role="admin")
+
+        # Get (empty) user settings
+        with mock_webui_user(id="2"):
+            response = self.fast_api_client.get(self.create_url("/user/settings"))
+        assert response.status_code == 200
+        assert response.json() is None
+
+        # Update user settings
+        with mock_webui_user(id="2"):
+            response = self.fast_api_client.post(
+                self.create_url("/user/settings/update"),
+                json={
+                    "ui": {"attr1": "value1", "attr2": "value2"},
+                    "model_config": {"attr3": "value3", "attr4": "value4"},
+                },
+            )
+        assert response.status_code == 200
+
+        # Get user settings
+        with mock_webui_user(id="2"):
+            response = self.fast_api_client.get(self.create_url("/user/settings"))
+        assert response.status_code == 200
+        assert response.json() == {
+            "ui": {"attr1": "value1", "attr2": "value2"},
+            "model_config": {"attr3": "value3", "attr4": "value4"},
+        }
+
+        # Get (empty) user info
+        with mock_webui_user(id="1"):
+            response = self.fast_api_client.get(self.create_url("/user/info"))
+        assert response.status_code == 200
+        assert response.json() is None
+
+        # Update user info
+        with mock_webui_user(id="1"):
+            response = self.fast_api_client.post(
+                self.create_url("/user/info/update"),
+                json={"attr1": "value1", "attr2": "value2"},
+            )
+        assert response.status_code == 200
+
+        # Get user info
+        with mock_webui_user(id="1"):
+            response = self.fast_api_client.get(self.create_url("/user/info"))
+        assert response.status_code == 200
+        assert response.json() == {"attr1": "value1", "attr2": "value2"}
+
+        # Get user by id
+        with mock_webui_user(id="1"):
+            response = self.fast_api_client.get(self.create_url("/2"))
+        assert response.status_code == 200
+        assert response.json() == {"name": "user 2", "profile_image_url": "/user2.png"}
+
+        # Update user by id
+        with mock_webui_user(id="1"):
+            response = self.fast_api_client.post(
+                self.create_url("/2/update"),
+                json={
+                    "name": "user 2 updated",
+                    "email": "user2-updated@openwebui.com",
+                    "profile_image_url": "/user2-updated.png",
+                },
+            )
+        assert response.status_code == 200
+
+        # Get all users
+        with mock_webui_user(id="3"):
+            response = self.fast_api_client.get(self.create_url(""))
+        assert response.status_code == 200
+        assert len(response.json()) == 2
+        data = response.json()
+        _assert_user(data, "1")
+        _assert_user(
+            data,
+            "2",
+            role="admin",
+            name="user 2 updated",
+            email="user2-updated@openwebui.com",
+            profile_image_url="/user2-updated.png",
+        )
+
+        # Delete user by id
+        with mock_webui_user(id="1"):
+            response = self.fast_api_client.delete(self.create_url("/2"))
+        assert response.status_code == 200
+
+        # Get all users
+        with mock_webui_user(id="3"):
+            response = self.fast_api_client.get(self.create_url(""))
+        assert response.status_code == 200
+        assert len(response.json()) == 1
+        data = response.json()
+        _assert_user(data, "1")

+ 155 - 0
backend/test/util/abstract_integration_test.py

@@ -0,0 +1,155 @@
+import logging
+import os
+import time
+
+import docker
+import pytest
+from docker import DockerClient
+from pytest_docker.plugin import get_docker_ip
+from fastapi.testclient import TestClient
+from sqlalchemy import text, create_engine
+
+log = logging.getLogger(__name__)
+
+
+def get_fast_api_client():
+    from main import app
+
+    with TestClient(app) as c:
+        return c
+
+
+class AbstractIntegrationTest:
+    BASE_PATH = None
+
+    def create_url(self, path):
+        if self.BASE_PATH is None:
+            raise Exception("BASE_PATH is not set")
+        parts = self.BASE_PATH.split("/")
+        parts = [part.strip() for part in parts if part.strip() != ""]
+        path_parts = path.split("/")
+        path_parts = [part.strip() for part in path_parts if part.strip() != ""]
+        return "/".join(parts + path_parts)
+
+    @classmethod
+    def setup_class(cls):
+        pass
+
+    def setup_method(self):
+        pass
+
+    @classmethod
+    def teardown_class(cls):
+        pass
+
+    def teardown_method(self):
+        pass
+
+
+class AbstractPostgresTest(AbstractIntegrationTest):
+    DOCKER_CONTAINER_NAME = "postgres-test-container-will-get-deleted"
+    docker_client: DockerClient
+
+    def get_db(self):
+        from apps.webui.internal.db import SessionLocal
+
+        return SessionLocal()
+
+    @classmethod
+    def _create_db_url(cls, env_vars_postgres: dict) -> str:
+        host = get_docker_ip()
+        user = env_vars_postgres["POSTGRES_USER"]
+        pw = env_vars_postgres["POSTGRES_PASSWORD"]
+        port = 8081
+        db = env_vars_postgres["POSTGRES_DB"]
+        return f"postgresql://{user}:{pw}@{host}:{port}/{db}"
+
+    @classmethod
+    def setup_class(cls):
+        super().setup_class()
+        try:
+            env_vars_postgres = {
+                "POSTGRES_USER": "user",
+                "POSTGRES_PASSWORD": "example",
+                "POSTGRES_DB": "openwebui",
+            }
+            cls.docker_client = docker.from_env()
+            cls.docker_client.containers.run(
+                "postgres:16.2",
+                detach=True,
+                environment=env_vars_postgres,
+                name=cls.DOCKER_CONTAINER_NAME,
+                ports={5432: ("0.0.0.0", 8081)},
+                command="postgres -c log_statement=all",
+            )
+            time.sleep(0.5)
+
+            database_url = cls._create_db_url(env_vars_postgres)
+            os.environ["DATABASE_URL"] = database_url
+            retries = 10
+            db = None
+            while retries > 0:
+                try:
+                    from config import BACKEND_DIR
+                    db = create_engine(database_url, pool_pre_ping=True)
+                    db = db.connect()
+                    log.info("postgres is ready!")
+                    break
+                except Exception as e:
+                    log.warning(e)
+                    time.sleep(3)
+                    retries -= 1
+
+            if db:
+                # import must be after setting env!
+                cls.fast_api_client = get_fast_api_client()
+                db.close()
+            else:
+                raise Exception("Could not connect to Postgres")
+        except Exception as ex:
+            log.error(ex)
+            cls.teardown_class()
+            pytest.fail(f"Could not setup test environment: {ex}")
+
+    def _check_db_connection(self):
+        retries = 10
+        while retries > 0:
+            try:
+                self.db_session.execute(text("SELECT 1"))
+                self.db_session.commit()
+                break
+            except Exception as e:
+                self.db_session.rollback()
+                log.warning(e)
+                time.sleep(3)
+                retries -= 1
+
+    def setup_method(self):
+        super().setup_method()
+        self.db_session = self.get_db()
+        self._check_db_connection()
+
+    @classmethod
+    def teardown_class(cls) -> None:
+        super().teardown_class()
+        cls.docker_client.containers.get(cls.DOCKER_CONTAINER_NAME).remove(force=True)
+
+    def teardown_method(self):
+        # rollback everything not yet committed
+        self.db_session.commit()
+
+        # truncate all tables
+        tables = [
+            "auth",
+            "chat",
+            "chatidtag",
+            "document",
+            "memory",
+            "model",
+            "prompt",
+            "tag",
+            '"user"',
+        ]
+        for table in tables:
+            self.db_session.execute(text(f"TRUNCATE TABLE {table}"))
+        self.db_session.commit()

+ 45 - 0
backend/test/util/mock_user.py

@@ -0,0 +1,45 @@
+from contextlib import contextmanager
+
+from fastapi import FastAPI
+
+
+@contextmanager
+def mock_webui_user(**kwargs):
+    from apps.webui.main import app
+
+    with mock_user(app, **kwargs):
+        yield
+
+
+@contextmanager
+def mock_user(app: FastAPI, **kwargs):
+    from utils.utils import (
+        get_current_user,
+        get_verified_user,
+        get_admin_user,
+        get_current_user_by_api_key,
+    )
+    from apps.webui.models.users import User
+
+    def create_user():
+        user_parameters = {
+            "id": "1",
+            "name": "John Doe",
+            "email": "john.doe@openwebui.com",
+            "role": "user",
+            "profile_image_url": "/user.png",
+            "last_active_at": 1627351200,
+            "updated_at": 1627351200,
+            "created_at": 162735120,
+            **kwargs,
+        }
+        return User(**user_parameters)
+
+    app.dependency_overrides = {
+        get_current_user: create_user,
+        get_verified_user: create_user,
+        get_admin_user: create_user,
+        get_current_user_by_api_key: create_user,
+    }
+    yield
+    app.dependency_overrides = {}

+ 9 - 6
backend/utils/utils.py

@@ -1,6 +1,8 @@
 from fastapi.security import HTTPBearer, HTTPAuthorizationCredentials
 from fastapi.security import HTTPBearer, HTTPAuthorizationCredentials
 from fastapi import HTTPException, status, Depends, Request
 from fastapi import HTTPException, status, Depends, Request
+from sqlalchemy.orm import Session
 
 
+from apps.webui.internal.db import get_db
 from apps.webui.models.users import Users
 from apps.webui.models.users import Users
 
 
 from pydantic import BaseModel
 from pydantic import BaseModel
@@ -77,6 +79,7 @@ def get_http_authorization_cred(auth_header: str):
 def get_current_user(
 def get_current_user(
     request: Request,
     request: Request,
     auth_token: HTTPAuthorizationCredentials = Depends(bearer_security),
     auth_token: HTTPAuthorizationCredentials = Depends(bearer_security),
+    db=Depends(get_db),
 ):
 ):
     token = None
     token = None
 
 
@@ -91,19 +94,19 @@ def get_current_user(
 
 
     # auth by api key
     # auth by api key
     if token.startswith("sk-"):
     if token.startswith("sk-"):
-        return get_current_user_by_api_key(token)
+        return get_current_user_by_api_key(db, token)
 
 
     # auth by jwt token
     # auth by jwt token
     data = decode_token(token)
     data = decode_token(token)
     if data != None and "id" in data:
     if data != None and "id" in data:
-        user = Users.get_user_by_id(data["id"])
+        user = Users.get_user_by_id(db, data["id"])
         if user is None:
         if user is None:
             raise HTTPException(
             raise HTTPException(
                 status_code=status.HTTP_401_UNAUTHORIZED,
                 status_code=status.HTTP_401_UNAUTHORIZED,
                 detail=ERROR_MESSAGES.INVALID_TOKEN,
                 detail=ERROR_MESSAGES.INVALID_TOKEN,
             )
             )
         else:
         else:
-            Users.update_user_last_active_by_id(user.id)
+            Users.update_user_last_active_by_id(db, user.id)
         return user
         return user
     else:
     else:
         raise HTTPException(
         raise HTTPException(
@@ -112,8 +115,8 @@ def get_current_user(
         )
         )
 
 
 
 
-def get_current_user_by_api_key(api_key: str):
-    user = Users.get_user_by_api_key(api_key)
+def get_current_user_by_api_key(db: Session, api_key: str):
+    user = Users.get_user_by_api_key(db, api_key)
 
 
     if user is None:
     if user is None:
         raise HTTPException(
         raise HTTPException(
@@ -121,7 +124,7 @@ def get_current_user_by_api_key(api_key: str):
             detail=ERROR_MESSAGES.INVALID_TOKEN,
             detail=ERROR_MESSAGES.INVALID_TOKEN,
         )
         )
     else:
     else:
-        Users.update_user_last_active_by_id(user.id)
+        Users.update_user_last_active_by_id(db, user.id)
 
 
     return user
     return user
 
 

+ 1 - 4
src/lib/apis/models/index.ts

@@ -63,10 +63,7 @@ export const getModelInfos = async (token: string = '') => {
 export const getModelById = async (token: string, id: string) => {
 export const getModelById = async (token: string, id: string) => {
 	let error = null;
 	let error = null;
 
 
-	const searchParams = new URLSearchParams();
-	searchParams.append('id', id);
-
-	const res = await fetch(`${WEBUI_API_BASE_URL}/models?${searchParams.toString()}`, {
+	const res = await fetch(`${WEBUI_API_BASE_URL}/models/${id}`, {
 		method: 'GET',
 		method: 'GET',
 		headers: {
 		headers: {
 			Accept: 'application/json',
 			Accept: 'application/json',