Browse Source

feat(sqlalchemy): use session factory instead of context manager

Jonathan Rohde 10 tháng trước cách đây
mục cha
commit
da403f3e3c

+ 1 - 11
backend/apps/webui/internal/db.py

@@ -57,14 +57,4 @@ SessionLocal = sessionmaker(
     autocommit=False, autoflush=False, bind=engine, expire_on_commit=False
     autocommit=False, autoflush=False, bind=engine, expire_on_commit=False
 )
 )
 Base = declarative_base()
 Base = declarative_base()
-
-
-@contextmanager
-def get_session():
-    session = scoped_session(SessionLocal)
-    try:
-        yield session
-        session.commit()
-    except Exception as e:
-        session.rollback()
-        raise e
+Session = scoped_session(SessionLocal)

+ 61 - 70
backend/apps/webui/models/auths.py

@@ -3,12 +3,11 @@ from typing import Optional
 import uuid
 import uuid
 import logging
 import logging
 from sqlalchemy import String, Column, Boolean
 from sqlalchemy import String, Column, Boolean
-from sqlalchemy.orm import Session
 
 
 from apps.webui.models.users import UserModel, Users
 from apps.webui.models.users import UserModel, Users
 from utils.utils import verify_password
 from utils.utils import verify_password
 
 
-from apps.webui.internal.db import Base, get_session
+from apps.webui.internal.db import Base, Session
 
 
 from config import SRC_LOG_LEVELS
 from config import SRC_LOG_LEVELS
 
 
@@ -103,101 +102,93 @@ class AuthsTable:
         role: str = "pending",
         role: str = "pending",
         oauth_sub: Optional[str] = None,
         oauth_sub: Optional[str] = None,
     ) -> Optional[UserModel]:
     ) -> Optional[UserModel]:
-        with get_session() as db:
-            log.info("insert_new_auth")
+        log.info("insert_new_auth")
 
 
-            id = str(uuid.uuid4())
+        id = str(uuid.uuid4())
 
 
-            auth = AuthModel(
-                **{"id": id, "email": email, "password": password, "active": True}
-            )
-            result = Auth(**auth.model_dump())
-            db.add(result)
+        auth = AuthModel(
+            **{"id": id, "email": email, "password": password, "active": True}
+        )
+        result = Auth(**auth.model_dump())
+        Session.add(result)
 
 
-            user = Users.insert_new_user(
-                id, name, email, profile_image_url, role, oauth_sub
-            )
+        user = Users.insert_new_user(
+            id, name, email, profile_image_url, role, oauth_sub)
 
 
-            db.commit()
-            db.refresh(result)
+        Session.commit()
+        Session.refresh(result)
 
 
-            if result and user:
-                return user
-            else:
-                return None
+        if result and user:
+            return user
+        else:
+            return None
 
 
     def authenticate_user(self, email: str, password: str) -> Optional[UserModel]:
     def authenticate_user(self, email: str, password: str) -> Optional[UserModel]:
         log.info(f"authenticate_user: {email}")
         log.info(f"authenticate_user: {email}")
-        with get_session() as db:
-            try:
-                auth = db.query(Auth).filter_by(email=email, active=True).first()
-                if auth:
-                    if verify_password(password, auth.password):
-                        user = Users.get_user_by_id(auth.id)
-                        return user
-                    else:
-                        return None
+        try:
+            auth = Session.query(Auth).filter_by(email=email, active=True).first()
+            if auth:
+                if verify_password(password, auth.password):
+                    user = Users.get_user_by_id(auth.id)
+                    return user
                 else:
                 else:
                     return None
                     return None
-            except:
+            else:
                 return None
                 return None
+        except:
+            return None
 
 
     def authenticate_user_by_api_key(self, api_key: str) -> Optional[UserModel]:
     def authenticate_user_by_api_key(self, api_key: str) -> Optional[UserModel]:
         log.info(f"authenticate_user_by_api_key: {api_key}")
         log.info(f"authenticate_user_by_api_key: {api_key}")
-        with get_session() as db:
-            # if no api_key, return None
-            if not api_key:
-                return None
+        # if no api_key, return None
+        if not api_key:
+            return None
 
 
-            try:
-                user = Users.get_user_by_api_key(api_key)
-                return user if user else None
-            except:
-                return False
+        try:
+            user = Users.get_user_by_api_key(api_key)
+            return user if user else None
+        except:
+            return False
 
 
     def authenticate_user_by_trusted_header(self, email: str) -> Optional[UserModel]:
     def authenticate_user_by_trusted_header(self, email: str) -> Optional[UserModel]:
         log.info(f"authenticate_user_by_trusted_header: {email}")
         log.info(f"authenticate_user_by_trusted_header: {email}")
-        with get_session() as db:
-            try:
-                auth = db.query(Auth).filter(email=email, active=True).first()
-                if auth:
-                    user = Users.get_user_by_id(auth.id)
-                    return user
-            except:
-                return None
+        try:
+            auth = Session.query(Auth).filter(email=email, active=True).first()
+            if auth:
+                user = Users.get_user_by_id(auth.id)
+                return user
+        except:
+            return None
 
 
     def update_user_password_by_id(self, id: str, new_password: str) -> bool:
     def update_user_password_by_id(self, id: str, new_password: str) -> bool:
-        with get_session() as db:
-            try:
-                result = (
-                    db.query(Auth).filter_by(id=id).update({"password": new_password})
-                )
-                return True if result == 1 else False
-            except:
-                return False
+        try:
+            result = (
+                Session.query(Auth).filter_by(id=id).update({"password": new_password})
+            )
+            return True if result == 1 else False
+        except:
+            return False
 
 
     def update_email_by_id(self, id: str, email: str) -> bool:
     def update_email_by_id(self, id: str, email: str) -> bool:
-        with get_session() as db:
-            try:
-                result = db.query(Auth).filter_by(id=id).update({"email": email})
-                return True if result == 1 else False
-            except:
-                return False
+        try:
+            result = Session.query(Auth).filter_by(id=id).update({"email": email})
+            return True if result == 1 else False
+        except:
+            return False
 
 
     def delete_auth_by_id(self, id: str) -> bool:
     def delete_auth_by_id(self, id: str) -> bool:
-        with get_session() as db:
-            try:
-                # Delete User
-                result = Users.delete_user_by_id(id)
+        try:
+            # Delete User
+            result = Users.delete_user_by_id(id)
 
 
-                if result:
-                    db.query(Auth).filter_by(id=id).delete()
+            if result:
+                Session.query(Auth).filter_by(id=id).delete()
 
 
-                    return True
-                else:
-                    return False
-            except:
+                return True
+            else:
                 return False
                 return False
+        except:
+            return False
 
 
 
 
 Auths = AuthsTable()
 Auths = AuthsTable()

+ 139 - 161
backend/apps/webui/models/chats.py

@@ -6,9 +6,8 @@ import uuid
 import time
 import time
 
 
 from sqlalchemy import Column, String, BigInteger, Boolean
 from sqlalchemy import Column, String, BigInteger, Boolean
-from sqlalchemy.orm import Session
 
 
-from apps.webui.internal.db import Base, get_session
+from apps.webui.internal.db import Base, Session
 
 
 
 
 ####################
 ####################
@@ -80,93 +79,88 @@ class ChatTitleIdResponse(BaseModel):
 class ChatTable:
 class ChatTable:
 
 
     def insert_new_chat(self, user_id: str, form_data: ChatForm) -> Optional[ChatModel]:
     def insert_new_chat(self, user_id: str, form_data: ChatForm) -> Optional[ChatModel]:
-        with get_session() as db:
-            id = str(uuid.uuid4())
-            chat = ChatModel(
-                **{
-                    "id": id,
-                    "user_id": user_id,
-                    "title": (
-                        form_data.chat["title"]
-                        if "title" in form_data.chat
-                        else "New Chat"
-                    ),
-                    "chat": json.dumps(form_data.chat),
-                    "created_at": int(time.time()),
-                    "updated_at": int(time.time()),
-                }
-            )
-
-            result = Chat(**chat.model_dump())
-            db.add(result)
-            db.commit()
-            db.refresh(result)
-            return ChatModel.model_validate(result) if result else None
+        id = str(uuid.uuid4())
+        chat = ChatModel(
+            **{
+                "id": id,
+                "user_id": user_id,
+                "title": (
+                    form_data.chat["title"]
+                    if "title" in form_data.chat
+                    else "New Chat"
+                ),
+                "chat": json.dumps(form_data.chat),
+                "created_at": int(time.time()),
+                "updated_at": int(time.time()),
+            }
+        )
+
+        result = Chat(**chat.model_dump())
+        Session.add(result)
+        Session.commit()
+        Session.refresh(result)
+        return ChatModel.model_validate(result) if result else None
 
 
     def update_chat_by_id(self, id: str, chat: dict) -> Optional[ChatModel]:
     def update_chat_by_id(self, id: str, chat: dict) -> Optional[ChatModel]:
-        with get_session() as db:
-            try:
-                chat_obj = db.get(Chat, id)
-                chat_obj.chat = json.dumps(chat)
-                chat_obj.title = chat["title"] if "title" in chat else "New Chat"
-                chat_obj.updated_at = int(time.time())
-                db.commit()
-                db.refresh(chat_obj)
-
-                return ChatModel.model_validate(chat_obj)
-            except Exception as e:
-                return None
+        try:
+            chat_obj = Session.get(Chat, id)
+            chat_obj.chat = json.dumps(chat)
+            chat_obj.title = chat["title"] if "title" in chat else "New Chat"
+            chat_obj.updated_at = int(time.time())
+            Session.commit()
+            Session.refresh(chat_obj)
+
+            return ChatModel.model_validate(chat_obj)
+        except Exception as e:
+            return None
 
 
     def insert_shared_chat_by_chat_id(self, chat_id: str) -> Optional[ChatModel]:
     def insert_shared_chat_by_chat_id(self, chat_id: str) -> Optional[ChatModel]:
-        with get_session() as db:
-            # Get the existing chat to share
-            chat = db.get(Chat, chat_id)
-            # Check if the chat is already shared
-            if chat.share_id:
-                return self.get_chat_by_id_and_user_id(chat.share_id, "shared")
-            # Create a new chat with the same data, but with a new ID
-            shared_chat = ChatModel(
-                **{
-                    "id": str(uuid.uuid4()),
-                    "user_id": f"shared-{chat_id}",
-                    "title": chat.title,
-                    "chat": chat.chat,
-                    "created_at": chat.created_at,
-                    "updated_at": int(time.time()),
-                }
-            )
-            shared_result = Chat(**shared_chat.model_dump())
-            db.add(shared_result)
-            db.commit()
-            db.refresh(shared_result)
-            # Update the original chat with the share_id
-            result = (
-                db.query(Chat)
-                .filter_by(id=chat_id)
-                .update({"share_id": shared_chat.id})
-            )
-
-            return shared_chat if (shared_result and result) else None
+        # Get the existing chat to share
+        chat = Session.get(Chat, chat_id)
+        # Check if the chat is already shared
+        if chat.share_id:
+            return self.get_chat_by_id_and_user_id(chat.share_id, "shared")
+        # Create a new chat with the same data, but with a new ID
+        shared_chat = ChatModel(
+            **{
+                "id": str(uuid.uuid4()),
+                "user_id": f"shared-{chat_id}",
+                "title": chat.title,
+                "chat": chat.chat,
+                "created_at": chat.created_at,
+                "updated_at": int(time.time()),
+            }
+        )
+        shared_result = Chat(**shared_chat.model_dump())
+        Session.add(shared_result)
+        Session.commit()
+        Session.refresh(shared_result)
+        # Update the original chat with the share_id
+        result = (
+            Session.query(Chat)
+            .filter_by(id=chat_id)
+            .update({"share_id": shared_chat.id})
+        )
+
+        return shared_chat if (shared_result and result) else None
 
 
     def update_shared_chat_by_chat_id(self, chat_id: str) -> Optional[ChatModel]:
     def update_shared_chat_by_chat_id(self, chat_id: str) -> Optional[ChatModel]:
-        with get_session() as db:
-            try:
-                print("update_shared_chat_by_id")
-                chat = db.get(Chat, chat_id)
-                print(chat)
-                chat.title = chat.title
-                chat.chat = chat.chat
-                db.commit()
-                db.refresh(chat)
-
-                return self.get_chat_by_id(chat.share_id)
-            except:
-                return None
+        try:
+            print("update_shared_chat_by_id")
+            chat = Session.get(Chat, chat_id)
+            print(chat)
+            chat.title = chat.title
+            chat.chat = chat.chat
+            Session.commit()
+            Session.refresh(chat)
+
+            return self.get_chat_by_id(chat.share_id)
+        except:
+            return None
 
 
     def delete_shared_chat_by_chat_id(self, chat_id: str) -> bool:
     def delete_shared_chat_by_chat_id(self, chat_id: str) -> bool:
         try:
         try:
-            with get_session() as db:
-                db.query(Chat).filter_by(user_id=f"shared-{chat_id}").delete()
+            Session.query(Chat).filter_by(user_id=f"shared-{chat_id}").delete()
             return True
             return True
         except:
         except:
             return False
             return False
@@ -175,30 +169,27 @@ class ChatTable:
         self, id: str, share_id: Optional[str]
         self, id: str, share_id: Optional[str]
     ) -> Optional[ChatModel]:
     ) -> Optional[ChatModel]:
         try:
         try:
-            with get_session() as db:
-                chat = db.get(Chat, id)
-                chat.share_id = share_id
-                db.commit()
-                db.refresh(chat)
-                return chat
+            chat = Session.get(Chat, id)
+            chat.share_id = share_id
+            Session.commit()
+            Session.refresh(chat)
+            return ChatModel.model_validate(chat)
         except:
         except:
             return None
             return None
 
 
     def toggle_chat_archive_by_id(self, id: str) -> Optional[ChatModel]:
     def toggle_chat_archive_by_id(self, id: str) -> Optional[ChatModel]:
         try:
         try:
-            with get_session() as db:
-                chat = self.get_chat_by_id(id)
-                db.query(Chat).filter_by(id=id).update({"archived": not chat.archived})
-
-                return self.get_chat_by_id(id)
+            chat = Session.get(Chat, id)
+            chat.archived = not chat.archived
+            Session.commit()
+            Session.refresh(chat)
+            return ChatModel.model_validate(chat)
         except:
         except:
             return None
             return None
 
 
     def archive_all_chats_by_user_id(self, user_id: str) -> bool:
     def archive_all_chats_by_user_id(self, user_id: str) -> bool:
         try:
         try:
-            with get_session() as db:
-                db.query(Chat).filter_by(user_id=user_id).update({"archived": True})
-
+            Session.query(Chat).filter_by(user_id=user_id).update({"archived": True})
             return True
             return True
         except:
         except:
             return False
             return False
@@ -206,9 +197,8 @@ class ChatTable:
     def get_archived_chat_list_by_user_id(
     def get_archived_chat_list_by_user_id(
         self, user_id: str, skip: int = 0, limit: int = 50
         self, user_id: str, skip: int = 0, limit: int = 50
     ) -> List[ChatModel]:
     ) -> List[ChatModel]:
-        with get_session() as db:
             all_chats = (
             all_chats = (
-                db.query(Chat)
+                Session.query(Chat)
                 .filter_by(user_id=user_id, archived=True)
                 .filter_by(user_id=user_id, archived=True)
                 .order_by(Chat.updated_at.desc())
                 .order_by(Chat.updated_at.desc())
                 # .limit(limit).offset(skip)
                 # .limit(limit).offset(skip)
@@ -223,120 +213,108 @@ class ChatTable:
         skip: int = 0,
         skip: int = 0,
         limit: int = 50,
         limit: int = 50,
     ) -> List[ChatModel]:
     ) -> List[ChatModel]:
-        with get_session() as db:
-            query = db.query(Chat).filter_by(user_id=user_id)
-            if not include_archived:
-                query = query.filter_by(archived=False)
-            all_chats = (
-                query.order_by(Chat.updated_at.desc())
-                # .limit(limit).offset(skip)
-                .all()
-            )
-            return [ChatModel.model_validate(chat) for chat in all_chats]
+        query = Session.query(Chat).filter_by(user_id=user_id)
+        if not include_archived:
+            query = query.filter_by(archived=False)
+        all_chats = (
+            query.order_by(Chat.updated_at.desc())
+            # .limit(limit).offset(skip)
+            .all()
+        )
+        return [ChatModel.model_validate(chat) for chat in all_chats]
 
 
     def get_chat_list_by_chat_ids(
     def get_chat_list_by_chat_ids(
         self, chat_ids: List[str], skip: int = 0, limit: int = 50
         self, chat_ids: List[str], skip: int = 0, limit: int = 50
     ) -> List[ChatModel]:
     ) -> List[ChatModel]:
-        with get_session() as db:
-            all_chats = (
-                db.query(Chat)
-                .filter(Chat.id.in_(chat_ids))
-                .filter_by(archived=False)
-                .order_by(Chat.updated_at.desc())
-                .all()
-            )
-            return [ChatModel.model_validate(chat) for chat in all_chats]
+        all_chats = (
+            Session.query(Chat)
+            .filter(Chat.id.in_(chat_ids))
+            .filter_by(archived=False)
+            .order_by(Chat.updated_at.desc())
+            .all()
+        )
+        return [ChatModel.model_validate(chat) for chat in all_chats]
 
 
     def get_chat_by_id(self, id: str) -> Optional[ChatModel]:
     def get_chat_by_id(self, id: str) -> Optional[ChatModel]:
         try:
         try:
-            with get_session() as db:
-                chat = db.get(Chat, id)
-                return ChatModel.model_validate(chat)
+            chat = Session.get(Chat, id)
+            return ChatModel.model_validate(chat)
         except:
         except:
             return None
             return None
 
 
     def get_chat_by_share_id(self, id: str) -> Optional[ChatModel]:
     def get_chat_by_share_id(self, id: str) -> Optional[ChatModel]:
         try:
         try:
-            with get_session() as db:
-                chat = db.query(Chat).filter_by(share_id=id).first()
+            chat = Session.query(Chat).filter_by(share_id=id).first()
 
 
-                if chat:
-                    return self.get_chat_by_id(id)
-                else:
-                    return None
+            if chat:
+                return self.get_chat_by_id(id)
+            else:
+                return None
         except Exception as e:
         except Exception as e:
             return None
             return None
 
 
     def get_chat_by_id_and_user_id(self, id: str, user_id: str) -> Optional[ChatModel]:
     def get_chat_by_id_and_user_id(self, id: str, user_id: str) -> Optional[ChatModel]:
         try:
         try:
-            with get_session() as db:
-                chat = db.query(Chat).filter_by(id=id, user_id=user_id).first()
-                return ChatModel.model_validate(chat)
+            chat = Session.query(Chat).filter_by(id=id, user_id=user_id).first()
+            return ChatModel.model_validate(chat)
         except:
         except:
             return None
             return None
 
 
     def get_chats(self, skip: int = 0, limit: int = 50) -> List[ChatModel]:
     def get_chats(self, skip: int = 0, limit: int = 50) -> List[ChatModel]:
-        with get_session() as db:
-            all_chats = (
-                db.query(Chat)
-                # .limit(limit).offset(skip)
-                .order_by(Chat.updated_at.desc())
-            )
-            return [ChatModel.model_validate(chat) for chat in all_chats]
+        all_chats = (
+            Session.query(Chat)
+            # .limit(limit).offset(skip)
+            .order_by(Chat.updated_at.desc())
+        )
+        return [ChatModel.model_validate(chat) for chat in all_chats]
 
 
     def get_chats_by_user_id(self, user_id: str) -> List[ChatModel]:
     def get_chats_by_user_id(self, user_id: str) -> List[ChatModel]:
-        with get_session() as db:
-            all_chats = (
-                db.query(Chat)
-                .filter_by(user_id=user_id)
-                .order_by(Chat.updated_at.desc())
-            )
-            return [ChatModel.model_validate(chat) for chat in all_chats]
+        all_chats = (
+            Session.query(Chat)
+            .filter_by(user_id=user_id)
+            .order_by(Chat.updated_at.desc())
+        )
+        return [ChatModel.model_validate(chat) for chat in all_chats]
 
 
     def get_archived_chats_by_user_id(self, user_id: str) -> List[ChatModel]:
     def get_archived_chats_by_user_id(self, user_id: str) -> List[ChatModel]:
-        with get_session() as db:
-            all_chats = (
-                db.query(Chat)
-                .filter_by(user_id=user_id, archived=True)
-                .order_by(Chat.updated_at.desc())
-            )
-            return [ChatModel.model_validate(chat) for chat in all_chats]
+        all_chats = (
+            Session.query(Chat)
+            .filter_by(user_id=user_id, archived=True)
+            .order_by(Chat.updated_at.desc())
+        )
+        return [ChatModel.model_validate(chat) for chat in all_chats]
 
 
     def delete_chat_by_id(self, id: str) -> bool:
     def delete_chat_by_id(self, id: str) -> bool:
         try:
         try:
-            with get_session() as db:
-                db.query(Chat).filter_by(id=id).delete()
+            Session.query(Chat).filter_by(id=id).delete()
 
 
-                return True and self.delete_shared_chat_by_chat_id(id)
+            return True and self.delete_shared_chat_by_chat_id(id)
         except:
         except:
             return False
             return False
 
 
     def delete_chat_by_id_and_user_id(self, id: str, user_id: str) -> bool:
     def delete_chat_by_id_and_user_id(self, id: str, user_id: str) -> bool:
         try:
         try:
-            with get_session() as db:
-                db.query(Chat).filter_by(id=id, user_id=user_id).delete()
+            Session.query(Chat).filter_by(id=id, user_id=user_id).delete()
 
 
-                return True and self.delete_shared_chat_by_chat_id(id)
+            return True and self.delete_shared_chat_by_chat_id(id)
         except:
         except:
             return False
             return False
 
 
     def delete_chats_by_user_id(self, user_id: str) -> bool:
     def delete_chats_by_user_id(self, user_id: str) -> bool:
         try:
         try:
-            with get_session() as db:
-                self.delete_shared_chats_by_user_id(user_id)
+            self.delete_shared_chats_by_user_id(user_id)
 
 
-                db.query(Chat).filter_by(user_id=user_id).delete()
+            Session.query(Chat).filter_by(user_id=user_id).delete()
             return True
             return True
         except:
         except:
             return False
             return False
 
 
     def delete_shared_chats_by_user_id(self, user_id: str) -> bool:
     def delete_shared_chats_by_user_id(self, user_id: str) -> bool:
         try:
         try:
-            with get_session() as db:
-                chats_by_user = db.query(Chat).filter_by(user_id=user_id).all()
-                shared_chat_ids = [f"shared-{chat.id}" for chat in chats_by_user]
+            chats_by_user = Session.query(Chat).filter_by(user_id=user_id).all()
+            shared_chat_ids = [f"shared-{chat.id}" for chat in chats_by_user]
 
 
-                db.query(Chat).filter(Chat.user_id.in_(shared_chat_ids)).delete()
+            Session.query(Chat).filter(Chat.user_id.in_(shared_chat_ids)).delete()
 
 
             return True
             return True
         except:
         except:

+ 36 - 43
backend/apps/webui/models/documents.py

@@ -4,9 +4,8 @@ import time
 import logging
 import logging
 
 
 from sqlalchemy import String, Column, BigInteger
 from sqlalchemy import String, Column, BigInteger
-from sqlalchemy.orm import Session
 
 
-from apps.webui.internal.db import Base, get_session
+from apps.webui.internal.db import Base, Session
 
 
 import json
 import json
 
 
@@ -84,46 +83,42 @@ class DocumentsTable:
         )
         )
 
 
         try:
         try:
-            with get_session() as db:
-                result = Document(**document.model_dump())
-                db.add(result)
-                db.commit()
-                db.refresh(result)
-                if result:
-                    return DocumentModel.model_validate(result)
-                else:
-                    return None
+            result = Document(**document.model_dump())
+            Session.add(result)
+            Session.commit()
+            Session.refresh(result)
+            if result:
+                return DocumentModel.model_validate(result)
+            else:
+                return None
         except:
         except:
             return None
             return None
 
 
     def get_doc_by_name(self, name: str) -> Optional[DocumentModel]:
     def get_doc_by_name(self, name: str) -> Optional[DocumentModel]:
         try:
         try:
-            with get_session() as db:
-                document = db.query(Document).filter_by(name=name).first()
-                return DocumentModel.model_validate(document) if document else None
+            document = Session.query(Document).filter_by(name=name).first()
+            return DocumentModel.model_validate(document) if document else None
         except:
         except:
             return None
             return None
 
 
     def get_docs(self) -> List[DocumentModel]:
     def get_docs(self) -> List[DocumentModel]:
-        with get_session() as db:
-            return [
-                DocumentModel.model_validate(doc) for doc in db.query(Document).all()
-            ]
+        return [
+            DocumentModel.model_validate(doc) for doc in Session.query(Document).all()
+        ]
 
 
     def update_doc_by_name(
     def update_doc_by_name(
         self, name: str, form_data: DocumentUpdateForm
         self, name: str, form_data: DocumentUpdateForm
     ) -> Optional[DocumentModel]:
     ) -> Optional[DocumentModel]:
         try:
         try:
-            with get_session() as db:
-                db.query(Document).filter_by(name=name).update(
-                    {
-                        "title": form_data.title,
-                        "name": form_data.name,
-                        "timestamp": int(time.time()),
-                    }
-                )
-                db.commit()
-                return self.get_doc_by_name(form_data.name)
+            Session.query(Document).filter_by(name=name).update(
+                {
+                    "title": form_data.title,
+                    "name": form_data.name,
+                    "timestamp": int(time.time()),
+                }
+            )
+            Session.commit()
+            return self.get_doc_by_name(form_data.name)
         except Exception as e:
         except Exception as e:
             log.exception(e)
             log.exception(e)
             return None
             return None
@@ -132,27 +127,25 @@ class DocumentsTable:
         self, name: str, updated: dict
         self, name: str, updated: dict
     ) -> Optional[DocumentModel]:
     ) -> Optional[DocumentModel]:
         try:
         try:
-            with get_session() as db:
-                doc = self.get_doc_by_name(name)
-                doc_content = json.loads(doc.content if doc.content else "{}")
-                doc_content = {**doc_content, **updated}
-
-                db.query(Document).filter_by(name=name).update(
-                    {
-                        "content": json.dumps(doc_content),
-                        "timestamp": int(time.time()),
-                    }
-                )
-                db.commit()
-                return self.get_doc_by_name(name)
+            doc = self.get_doc_by_name(name)
+            doc_content = json.loads(doc.content if doc.content else "{}")
+            doc_content = {**doc_content, **updated}
+
+            Session.query(Document).filter_by(name=name).update(
+                {
+                    "content": json.dumps(doc_content),
+                    "timestamp": int(time.time()),
+                }
+            )
+            Session.commit()
+            return self.get_doc_by_name(name)
         except Exception as e:
         except Exception as e:
             log.exception(e)
             log.exception(e)
             return None
             return None
 
 
     def delete_doc_by_name(self, name: str) -> bool:
     def delete_doc_by_name(self, name: str) -> bool:
         try:
         try:
-            with get_session() as db:
-                db.query(Document).filter_by(name=name).delete()
+            Session.query(Document).filter_by(name=name).delete()
             return True
             return True
         except:
         except:
             return False
             return False

+ 14 - 22
backend/apps/webui/models/files.py

@@ -4,9 +4,8 @@ import time
 import logging
 import logging
 
 
 from sqlalchemy import Column, String, BigInteger
 from sqlalchemy import Column, String, BigInteger
-from sqlalchemy.orm import Session
 
 
-from apps.webui.internal.db import JSONField, Base, get_session
+from apps.webui.internal.db import JSONField, Base, Session
 
 
 import json
 import json
 
 
@@ -71,45 +70,38 @@ class FilesTable:
         )
         )
 
 
         try:
         try:
-            with get_session() as db:
-                result = File(**file.model_dump())
-                db.add(result)
-                db.commit()
-                db.refresh(result)
-                if result:
-                    return FileModel.model_validate(result)
-                else:
-                    return None
+            result = File(**file.model_dump())
+            Session.add(result)
+            Session.commit()
+            Session.refresh(result)
+            if result:
+                return FileModel.model_validate(result)
+            else:
+                return None
         except Exception as e:
         except Exception as e:
             print(f"Error creating tool: {e}")
             print(f"Error creating tool: {e}")
             return None
             return None
 
 
     def get_file_by_id(self, id: str) -> Optional[FileModel]:
     def get_file_by_id(self, id: str) -> Optional[FileModel]:
         try:
         try:
-            with get_session() as db:
-                file = db.get(File, id)
-                return FileModel.model_validate(file)
+            file = Session.get(File, id)
+            return FileModel.model_validate(file)
         except:
         except:
             return None
             return None
 
 
     def get_files(self) -> List[FileModel]:
     def get_files(self) -> List[FileModel]:
-        with get_session() as db:
-            return [FileModel.model_validate(file) for file in db.query(File).all()]
+        return [FileModel.model_validate(file) for file in Session.query(File).all()]
 
 
     def delete_file_by_id(self, id: str) -> bool:
     def delete_file_by_id(self, id: str) -> bool:
         try:
         try:
-            with get_session() as db:
-                db.query(File).filter_by(id=id).delete()
-                db.commit()
+            Session.query(File).filter_by(id=id).delete()
             return True
             return True
         except:
         except:
             return False
             return False
 
 
     def delete_all_files(self) -> bool:
     def delete_all_files(self) -> bool:
         try:
         try:
-            with get_session() as db:
-                db.query(File).delete()
-                db.commit()
+            Session.query(File).delete()
             return True
             return True
         except:
         except:
             return False
             return False

+ 53 - 64
backend/apps/webui/models/functions.py

@@ -4,9 +4,8 @@ import time
 import logging
 import logging
 
 
 from sqlalchemy import Column, String, Text, BigInteger, Boolean
 from sqlalchemy import Column, String, Text, BigInteger, Boolean
-from sqlalchemy.orm import Session
 
 
-from apps.webui.internal.db import JSONField, Base, get_session
+from apps.webui.internal.db import JSONField, Base, Session
 from apps.webui.models.users import Users
 from apps.webui.models.users import Users
 
 
 import json
 import json
@@ -100,64 +99,57 @@ class FunctionsTable:
         )
         )
 
 
         try:
         try:
-            with get_session() as db:
-                result = Function(**function.model_dump())
-                db.add(result)
-                db.commit()
-                db.refresh(result)
-                if result:
-                    return FunctionModel.model_validate(result)
-                else:
-                    return None
+            result = Function(**function.model_dump())
+            Session.add(result)
+            Session.commit()
+            Session.refresh(result)
+            if result:
+                return FunctionModel.model_validate(result)
+            else:
+                return None
         except Exception as e:
         except Exception as e:
             print(f"Error creating tool: {e}")
             print(f"Error creating tool: {e}")
             return None
             return None
 
 
     def get_function_by_id(self, id: str) -> Optional[FunctionModel]:
     def get_function_by_id(self, id: str) -> Optional[FunctionModel]:
         try:
         try:
-            with get_session() as db:
-                function = db.get(Function, id)
-                return FunctionModel.model_validate(function)
+            function = Session.get(Function, id)
+            return FunctionModel.model_validate(function)
         except:
         except:
             return None
             return None
 
 
     def get_functions(self, active_only=False) -> List[FunctionModel]:
     def get_functions(self, active_only=False) -> List[FunctionModel]:
         if active_only:
         if active_only:
-            with get_session() as db:
-                return [
-                    FunctionModel.model_validate(function)
-                    for function in db.query(Function).filter_by(is_active=True).all()
-                ]
+            return [
+                FunctionModel.model_validate(function)
+                for function in Session.query(Function).filter_by(is_active=True).all()
+            ]
         else:
         else:
-            with get_session() as db:
-                return [
-                    FunctionModel.model_validate(function)
-                    for function in db.query(Function).all()
-                ]
+            return [
+                FunctionModel.model_validate(function)
+                for function in Session.query(Function).all()
+            ]
 
 
     def get_functions_by_type(
     def get_functions_by_type(
         self, type: str, active_only=False
         self, type: str, active_only=False
     ) -> List[FunctionModel]:
     ) -> List[FunctionModel]:
         if active_only:
         if active_only:
-            with get_session() as db:
-                return [
-                    FunctionModel.model_validate(function)
-                    for function in db.query(Function)
-                    .filter_by(type=type, is_active=True)
-                    .all()
-                ]
+            return [
+                FunctionModel.model_validate(function)
+                for function in Session.query(Function)
+                .filter_by(type=type, is_active=True)
+                .all()
+            ]
         else:
         else:
-            with get_session() as db:
-                return [
-                    FunctionModel.model_validate(function)
-                    for function in db.query(Function).filter_by(type=type).all()
-                ]
+            return [
+                FunctionModel.model_validate(function)
+                for function in Session.query(Function).filter_by(type=type).all()
+            ]
 
 
     def get_function_valves_by_id(self, id: str) -> Optional[dict]:
     def get_function_valves_by_id(self, id: str) -> Optional[dict]:
         try:
         try:
-            with get_session() as db:
-                function = db.get(Function, id)
-                return function.valves if function.valves else {}
+            function = Session.get(Function, id)
+            return function.valves if function.valves else {}
         except Exception as e:
         except Exception as e:
             print(f"An error occurred: {e}")
             print(f"An error occurred: {e}")
             return None
             return None
@@ -166,12 +158,12 @@ class FunctionsTable:
         self, id: str, valves: dict
         self, id: str, valves: dict
     ) -> Optional[FunctionValves]:
     ) -> Optional[FunctionValves]:
         try:
         try:
-            with get_session() as db:
-                db.query(Function).filter_by(id=id).update(
-                    {"valves": valves, "updated_at": int(time.time())}
-                )
-                db.commit()
-                return self.get_function_by_id(id)
+            function = Session.get(Function, id)
+            function.valves = valves
+            function.updated_at = int(time.time())
+            Session.commit()
+            Session.refresh(function)
+            return self.get_function_by_id(id)
         except:
         except:
             return None
             return None
 
 
@@ -219,36 +211,33 @@ class FunctionsTable:
 
 
     def update_function_by_id(self, id: str, updated: dict) -> Optional[FunctionModel]:
     def update_function_by_id(self, id: str, updated: dict) -> Optional[FunctionModel]:
         try:
         try:
-            with get_session() as db:
-                db.query(Function).filter_by(id=id).update(
-                    {
-                        **updated,
-                        "updated_at": int(time.time()),
-                    }
-                )
-                db.commit()
-                return self.get_function_by_id(id)
+            Session.query(Function).filter_by(id=id).update(
+                {
+                    **updated,
+                    "updated_at": int(time.time()),
+                }
+            )
+            Session.commit()
+            return self.get_function_by_id(id)
         except:
         except:
             return None
             return None
 
 
     def deactivate_all_functions(self) -> Optional[bool]:
     def deactivate_all_functions(self) -> Optional[bool]:
         try:
         try:
-            with get_session() as db:
-                db.query(Function).update(
-                    {
-                        "is_active": False,
-                        "updated_at": int(time.time()),
-                    }
-                )
-                db.commit()
+            Session.query(Function).update(
+                {
+                    "is_active": False,
+                    "updated_at": int(time.time()),
+                }
+            )
+            Session.commit()
             return True
             return True
         except:
         except:
             return None
             return None
 
 
     def delete_function_by_id(self, id: str) -> bool:
     def delete_function_by_id(self, id: str) -> bool:
         try:
         try:
-            with get_session() as db:
-                db.query(Function).filter_by(id=id).delete()
+            Session.query(Function).filter_by(id=id).delete()
             return True
             return True
         except:
         except:
             return False
             return False

+ 25 - 35
backend/apps/webui/models/memories.py

@@ -2,10 +2,8 @@ from pydantic import BaseModel, ConfigDict
 from typing import List, Union, Optional
 from typing import List, Union, Optional
 
 
 from sqlalchemy import Column, String, BigInteger
 from sqlalchemy import Column, String, BigInteger
-from sqlalchemy.orm import Session
 
 
-from apps.webui.internal.db import Base, get_session
-from apps.webui.models.chats import Chats
+from apps.webui.internal.db import Base, Session
 
 
 import time
 import time
 import uuid
 import uuid
@@ -58,15 +56,14 @@ class MemoriesTable:
                 "updated_at": int(time.time()),
                 "updated_at": int(time.time()),
             }
             }
         )
         )
-        with get_session() as db:
-            result = Memory(**memory.model_dump())
-            db.add(result)
-            db.commit()
-            db.refresh(result)
-            if result:
-                return MemoryModel.model_validate(result)
-            else:
-                return None
+        result = Memory(**memory.model_dump())
+        Session.add(result)
+        Session.commit()
+        Session.refresh(result)
+        if result:
+            return MemoryModel.model_validate(result)
+        else:
+            return None
 
 
     def update_memory_by_id(
     def update_memory_by_id(
         self,
         self,
@@ -74,62 +71,55 @@ class MemoriesTable:
         content: str,
         content: str,
     ) -> Optional[MemoryModel]:
     ) -> Optional[MemoryModel]:
         try:
         try:
-            with get_session() as db:
-                db.query(Memory).filter_by(id=id).update(
-                    {"content": content, "updated_at": int(time.time())}
-                )
-                db.commit()
-                return self.get_memory_by_id(id)
+            Session.query(Memory).filter_by(id=id).update(
+                {"content": content, "updated_at": int(time.time())}
+            )
+            Session.commit()
+            return self.get_memory_by_id(id)
         except:
         except:
             return None
             return None
 
 
     def get_memories(self) -> List[MemoryModel]:
     def get_memories(self) -> List[MemoryModel]:
         try:
         try:
-            with get_session() as db:
-                memories = db.query(Memory).all()
-                return [MemoryModel.model_validate(memory) for memory in memories]
+            memories = Session.query(Memory).all()
+            return [MemoryModel.model_validate(memory) for memory in memories]
         except:
         except:
             return None
             return None
 
 
     def get_memories_by_user_id(self, user_id: str) -> List[MemoryModel]:
     def get_memories_by_user_id(self, user_id: str) -> List[MemoryModel]:
         try:
         try:
-            with get_session() as db:
-                memories = db.query(Memory).filter_by(user_id=user_id).all()
-                return [MemoryModel.model_validate(memory) for memory in memories]
+            memories = Session.query(Memory).filter_by(user_id=user_id).all()
+            return [MemoryModel.model_validate(memory) for memory in memories]
         except:
         except:
             return None
             return None
 
 
     def get_memory_by_id(self, id: str) -> Optional[MemoryModel]:
     def get_memory_by_id(self, id: str) -> Optional[MemoryModel]:
         try:
         try:
-            with get_session() as db:
-                memory = db.get(Memory, id)
-                return MemoryModel.model_validate(memory)
+            memory = Session.get(Memory, id)
+            return MemoryModel.model_validate(memory)
         except:
         except:
             return None
             return None
 
 
     def delete_memory_by_id(self, id: str) -> bool:
     def delete_memory_by_id(self, id: str) -> bool:
         try:
         try:
-            with get_session() as db:
-                db.query(Memory).filter_by(id=id).delete()
+            Session.query(Memory).filter_by(id=id).delete()
             return True
             return True
 
 
         except:
         except:
             return False
             return False
 
 
-    def delete_memories_by_user_id(self, db: Session, user_id: str) -> bool:
+    def delete_memories_by_user_id(self, user_id: str) -> bool:
         try:
         try:
-            with get_session() as db:
-                db.query(Memory).filter_by(user_id=user_id).delete()
+            Session.query(Memory).filter_by(user_id=user_id).delete()
             return True
             return True
         except:
         except:
             return False
             return False
 
 
     def delete_memory_by_id_and_user_id(
     def delete_memory_by_id_and_user_id(
-        self, db: Session, id: str, user_id: str
+        self, id: str, user_id: str
     ) -> bool:
     ) -> bool:
         try:
         try:
-            with get_session() as db:
-                db.query(Memory).filter_by(id=id, user_id=user_id).delete()
+            Session.query(Memory).filter_by(id=id, user_id=user_id).delete()
             return True
             return True
         except:
         except:
             return False
             return False

+ 19 - 25
backend/apps/webui/models/models.py

@@ -4,9 +4,8 @@ from typing import Optional
 
 
 from pydantic import BaseModel, ConfigDict
 from pydantic import BaseModel, ConfigDict
 from sqlalchemy import String, Column, BigInteger
 from sqlalchemy import String, Column, BigInteger
-from sqlalchemy.orm import Session
 
 
-from apps.webui.internal.db import Base, JSONField, get_session
+from apps.webui.internal.db import Base, JSONField, Session
 
 
 from typing import List, Union, Optional
 from typing import List, Union, Optional
 from config import SRC_LOG_LEVELS
 from config import SRC_LOG_LEVELS
@@ -127,41 +126,37 @@ class ModelsTable:
             }
             }
         )
         )
         try:
         try:
-            with get_session() as db:
-                result = Model(**model.model_dump())
-                db.add(result)
-                db.commit()
-                db.refresh(result)
-
-                if result:
-                    return ModelModel.model_validate(result)
-                else:
-                    return None
+            result = Model(**model.model_dump())
+            Session.add(result)
+            Session.commit()
+            Session.refresh(result)
+
+            if result:
+                return ModelModel.model_validate(result)
+            else:
+                return None
         except Exception as e:
         except Exception as e:
             print(e)
             print(e)
             return None
             return None
 
 
     def get_all_models(self) -> List[ModelModel]:
     def get_all_models(self) -> List[ModelModel]:
-        with get_session() as db:
-            return [ModelModel.model_validate(model) for model in db.query(Model).all()]
+        return [ModelModel.model_validate(model) for model in Session.query(Model).all()]
 
 
     def get_model_by_id(self, id: str) -> Optional[ModelModel]:
     def get_model_by_id(self, id: str) -> Optional[ModelModel]:
         try:
         try:
-            with get_session() as db:
-                model = db.get(Model, id)
-                return ModelModel.model_validate(model)
+            model = Session.get(Model, id)
+            return ModelModel.model_validate(model)
         except:
         except:
             return None
             return None
 
 
     def update_model_by_id(self, id: str, model: ModelForm) -> Optional[ModelModel]:
     def update_model_by_id(self, id: str, model: ModelForm) -> Optional[ModelModel]:
         try:
         try:
             # update only the fields that are present in the model
             # update only the fields that are present in the model
-            with get_session() as db:
-                model = db.query(Model).get(id)
-                model.update(**model.model_dump())
-                db.commit()
-                db.refresh(model)
-                return ModelModel.model_validate(model)
+            model = Session.query(Model).get(id)
+            model.update(**model.model_dump())
+            Session.commit()
+            Session.refresh(model)
+            return ModelModel.model_validate(model)
         except Exception as e:
         except Exception as e:
             print(e)
             print(e)
 
 
@@ -169,8 +164,7 @@ class ModelsTable:
 
 
     def delete_model_by_id(self, id: str) -> bool:
     def delete_model_by_id(self, id: str) -> bool:
         try:
         try:
-            with get_session() as db:
-                db.query(Model).filter_by(id=id).delete()
+            Session.query(Model).filter_by(id=id).delete()
             return True
             return True
         except:
         except:
             return False
             return False

+ 43 - 50
backend/apps/webui/models/prompts.py

@@ -3,9 +3,8 @@ from typing import List, Optional
 import time
 import time
 
 
 from sqlalchemy import String, Column, BigInteger
 from sqlalchemy import String, Column, BigInteger
-from sqlalchemy.orm import Session
 
 
-from apps.webui.internal.db import Base, get_session
+from apps.webui.internal.db import Base, Session
 
 
 import json
 import json
 
 
@@ -50,65 +49,59 @@ class PromptsTable:
     def insert_new_prompt(
     def insert_new_prompt(
         self, user_id: str, form_data: PromptForm
         self, user_id: str, form_data: PromptForm
     ) -> Optional[PromptModel]:
     ) -> Optional[PromptModel]:
-        with get_session() as db:
-            prompt = PromptModel(
-                **{
-                    "user_id": user_id,
-                    "command": form_data.command,
-                    "title": form_data.title,
-                    "content": form_data.content,
-                    "timestamp": int(time.time()),
-                }
-            )
-
-            try:
-                result = Prompt(**prompt.dict())
-                db.add(result)
-                db.commit()
-                db.refresh(result)
-                if result:
-                    return PromptModel.model_validate(result)
-                else:
-                    return None
-            except Exception as e:
+        prompt = PromptModel(
+            **{
+                "user_id": user_id,
+                "command": form_data.command,
+                "title": form_data.title,
+                "content": form_data.content,
+                "timestamp": int(time.time()),
+            }
+        )
+
+        try:
+            result = Prompt(**prompt.dict())
+            Session.add(result)
+            Session.commit()
+            Session.refresh(result)
+            if result:
+                return PromptModel.model_validate(result)
+            else:
                 return None
                 return None
+        except Exception as e:
+            return None
 
 
     def get_prompt_by_command(self, command: str) -> Optional[PromptModel]:
     def get_prompt_by_command(self, command: str) -> Optional[PromptModel]:
-        with get_session() as db:
-            try:
-                prompt = db.query(Prompt).filter_by(command=command).first()
-                return PromptModel.model_validate(prompt)
-            except:
-                return None
+        try:
+            prompt = Session.query(Prompt).filter_by(command=command).first()
+            return PromptModel.model_validate(prompt)
+        except:
+            return None
 
 
     def get_prompts(self) -> List[PromptModel]:
     def get_prompts(self) -> List[PromptModel]:
-        with get_session() as db:
-            return [
-                PromptModel.model_validate(prompt) for prompt in db.query(Prompt).all()
-            ]
+        return [
+            PromptModel.model_validate(prompt) for prompt in Session.query(Prompt).all()
+        ]
 
 
     def update_prompt_by_command(
     def update_prompt_by_command(
         self, command: str, form_data: PromptForm
         self, command: str, form_data: PromptForm
     ) -> Optional[PromptModel]:
     ) -> Optional[PromptModel]:
-        with get_session() as db:
-            try:
-                prompt = db.query(Prompt).filter_by(command=command).first()
-                prompt.title = form_data.title
-                prompt.content = form_data.content
-                prompt.timestamp = int(time.time())
-                db.commit()
-                return prompt
-                # return self.get_prompt_by_command(command)
-            except:
-                return None
+        try:
+            prompt = Session.query(Prompt).filter_by(command=command).first()
+            prompt.title = form_data.title
+            prompt.content = form_data.content
+            prompt.timestamp = int(time.time())
+            Session.commit()
+            return PromptModel.model_validate(prompt)
+        except:
+            return None
 
 
     def delete_prompt_by_command(self, command: str) -> bool:
     def delete_prompt_by_command(self, command: str) -> bool:
-        with get_session() as db:
-            try:
-                db.query(Prompt).filter_by(command=command).delete()
-                return True
-            except:
-                return False
+        try:
+            Session.query(Prompt).filter_by(command=command).delete()
+            return True
+        except:
+            return False
 
 
 
 
 Prompts = PromptsTable()
 Prompts = PromptsTable()

+ 99 - 109
backend/apps/webui/models/tags.py

@@ -7,9 +7,8 @@ import time
 import logging
 import logging
 
 
 from sqlalchemy import String, Column, BigInteger
 from sqlalchemy import String, Column, BigInteger
-from sqlalchemy.orm import Session
 
 
-from apps.webui.internal.db import Base, get_session
+from apps.webui.internal.db import Base, Session
 
 
 from config import SRC_LOG_LEVELS
 from config import SRC_LOG_LEVELS
 
 
@@ -83,15 +82,14 @@ class TagTable:
         id = str(uuid.uuid4())
         id = str(uuid.uuid4())
         tag = TagModel(**{"id": id, "user_id": user_id, "name": name})
         tag = TagModel(**{"id": id, "user_id": user_id, "name": name})
         try:
         try:
-            with get_session() as db:
-                result = Tag(**tag.model_dump())
-                db.add(result)
-                db.commit()
-                db.refresh(result)
-                if result:
-                    return TagModel.model_validate(result)
-                else:
-                    return None
+            result = Tag(**tag.model_dump())
+            Session.add(result)
+            Session.commit()
+            Session.refresh(result)
+            if result:
+                return TagModel.model_validate(result)
+            else:
+                return None
         except Exception as e:
         except Exception as e:
             return None
             return None
 
 
@@ -99,9 +97,8 @@ class TagTable:
         self, name: str, user_id: str
         self, name: str, user_id: str
     ) -> Optional[TagModel]:
     ) -> Optional[TagModel]:
         try:
         try:
-            with get_session() as db:
-                tag = db.query(Tag).filter(name=name, user_id=user_id).first()
-                return TagModel.model_validate(tag)
+            tag = Session.query(Tag).filter(name=name, user_id=user_id).first()
+            return TagModel.model_validate(tag)
         except Exception as e:
         except Exception as e:
             return None
             return None
 
 
@@ -123,105 +120,99 @@ class TagTable:
             }
             }
         )
         )
         try:
         try:
-            with get_session() as db:
-                result = ChatIdTag(**chatIdTag.model_dump())
-                db.add(result)
-                db.commit()
-                db.refresh(result)
-                if result:
-                    return ChatIdTagModel.model_validate(result)
-                else:
-                    return None
+            result = ChatIdTag(**chatIdTag.model_dump())
+            Session.add(result)
+            Session.commit()
+            Session.refresh(result)
+            if result:
+                return ChatIdTagModel.model_validate(result)
+            else:
+                return None
         except:
         except:
             return None
             return None
 
 
     def get_tags_by_user_id(self, user_id: str) -> List[TagModel]:
     def get_tags_by_user_id(self, user_id: str) -> List[TagModel]:
-        with get_session() as db:
-            tag_names = [
-                chat_id_tag.tag_name
-                for chat_id_tag in (
-                    db.query(ChatIdTag)
-                    .filter_by(user_id=user_id)
-                    .order_by(ChatIdTag.timestamp.desc())
-                    .all()
-                )
-            ]
-
-            return [
-                TagModel.model_validate(tag)
-                for tag in (
-                    db.query(Tag)
-                    .filter_by(user_id=user_id)
-                    .filter(Tag.name.in_(tag_names))
-                    .all()
-                )
-            ]
+        tag_names = [
+            chat_id_tag.tag_name
+            for chat_id_tag in (
+                Session.query(ChatIdTag)
+                .filter_by(user_id=user_id)
+                .order_by(ChatIdTag.timestamp.desc())
+                .all()
+            )
+        ]
+
+        return [
+            TagModel.model_validate(tag)
+            for tag in (
+                Session.query(Tag)
+                .filter_by(user_id=user_id)
+                .filter(Tag.name.in_(tag_names))
+                .all()
+            )
+        ]
 
 
     def get_tags_by_chat_id_and_user_id(
     def get_tags_by_chat_id_and_user_id(
         self, chat_id: str, user_id: str
         self, chat_id: str, user_id: str
     ) -> List[TagModel]:
     ) -> List[TagModel]:
-        with get_session() as db:
-            tag_names = [
-                chat_id_tag.tag_name
-                for chat_id_tag in (
-                    db.query(ChatIdTag)
-                    .filter_by(user_id=user_id, chat_id=chat_id)
-                    .order_by(ChatIdTag.timestamp.desc())
-                    .all()
-                )
-            ]
-
-            return [
-                TagModel.model_validate(tag)
-                for tag in (
-                    db.query(Tag)
-                    .filter_by(user_id=user_id)
-                    .filter(Tag.name.in_(tag_names))
-                    .all()
-                )
-            ]
+        tag_names = [
+            chat_id_tag.tag_name
+            for chat_id_tag in (
+                Session.query(ChatIdTag)
+                .filter_by(user_id=user_id, chat_id=chat_id)
+                .order_by(ChatIdTag.timestamp.desc())
+                .all()
+            )
+        ]
+
+        return [
+            TagModel.model_validate(tag)
+            for tag in (
+                Session.query(Tag)
+                .filter_by(user_id=user_id)
+                .filter(Tag.name.in_(tag_names))
+                .all()
+            )
+        ]
 
 
     def get_chat_ids_by_tag_name_and_user_id(
     def get_chat_ids_by_tag_name_and_user_id(
         self, tag_name: str, user_id: str
         self, tag_name: str, user_id: str
     ) -> List[ChatIdTagModel]:
     ) -> List[ChatIdTagModel]:
-        with get_session() as db:
-            return [
-                ChatIdTagModel.model_validate(chat_id_tag)
-                for chat_id_tag in (
-                    db.query(ChatIdTag)
-                    .filter_by(user_id=user_id, tag_name=tag_name)
-                    .order_by(ChatIdTag.timestamp.desc())
-                    .all()
-                )
-            ]
+        return [
+            ChatIdTagModel.model_validate(chat_id_tag)
+            for chat_id_tag in (
+                Session.query(ChatIdTag)
+                .filter_by(user_id=user_id, tag_name=tag_name)
+                .order_by(ChatIdTag.timestamp.desc())
+                .all()
+            )
+        ]
 
 
     def count_chat_ids_by_tag_name_and_user_id(
     def count_chat_ids_by_tag_name_and_user_id(
         self, tag_name: str, user_id: str
         self, tag_name: str, user_id: str
     ) -> int:
     ) -> int:
-        with get_session() as db:
-            return (
-                db.query(ChatIdTag)
-                .filter_by(tag_name=tag_name, user_id=user_id)
-                .count()
-            )
+        return (
+            Session.query(ChatIdTag)
+            .filter_by(tag_name=tag_name, user_id=user_id)
+            .count()
+        )
 
 
     def delete_tag_by_tag_name_and_user_id(self, tag_name: str, user_id: str) -> bool:
     def delete_tag_by_tag_name_and_user_id(self, tag_name: str, user_id: str) -> bool:
         try:
         try:
-            with get_session() as db:
-                res = (
-                    db.query(ChatIdTag)
-                    .filter_by(tag_name=tag_name, user_id=user_id)
-                    .delete()
-                )
-                log.debug(f"res: {res}")
-                db.commit()
-
-                tag_count = self.count_chat_ids_by_tag_name_and_user_id(
-                    tag_name, user_id
-                )
-                if tag_count == 0:
-                    # Remove tag item from Tag col as well
-                    db.query(Tag).filter_by(name=tag_name, user_id=user_id).delete()
+            res = (
+                Session.query(ChatIdTag)
+                .filter_by(tag_name=tag_name, user_id=user_id)
+                .delete()
+            )
+            log.debug(f"res: {res}")
+            Session.commit()
+
+            tag_count = self.count_chat_ids_by_tag_name_and_user_id(
+                tag_name, user_id
+            )
+            if tag_count == 0:
+                # Remove tag item from Tag col as well
+                Session.query(Tag).filter_by(name=tag_name, user_id=user_id).delete()
             return True
             return True
         except Exception as e:
         except Exception as e:
             log.error(f"delete_tag: {e}")
             log.error(f"delete_tag: {e}")
@@ -231,21 +222,20 @@ class TagTable:
         self, tag_name: str, chat_id: str, user_id: str
         self, tag_name: str, chat_id: str, user_id: str
     ) -> bool:
     ) -> bool:
         try:
         try:
-            with get_session() as db:
-                res = (
-                    db.query(ChatIdTag)
-                    .filter_by(tag_name=tag_name, chat_id=chat_id, user_id=user_id)
-                    .delete()
-                )
-                log.debug(f"res: {res}")
-                db.commit()
-
-                tag_count = self.count_chat_ids_by_tag_name_and_user_id(
-                    tag_name, user_id
-                )
-                if tag_count == 0:
-                    # Remove tag item from Tag col as well
-                    db.query(Tag).filter_by(name=tag_name, user_id=user_id).delete()
+            res = (
+                Session.query(ChatIdTag)
+                .filter_by(tag_name=tag_name, chat_id=chat_id, user_id=user_id)
+                .delete()
+            )
+            log.debug(f"res: {res}")
+            Session.commit()
+
+            tag_count = self.count_chat_ids_by_tag_name_and_user_id(
+                tag_name, user_id
+            )
+            if tag_count == 0:
+                # Remove tag item from Tag col as well
+                Session.query(Tag).filter_by(name=tag_name, user_id=user_id).delete()
 
 
             return True
             return True
         except Exception as e:
         except Exception as e:

+ 26 - 33
backend/apps/webui/models/tools.py

@@ -3,9 +3,8 @@ from typing import List, Optional
 import time
 import time
 import logging
 import logging
 from sqlalchemy import String, Column, BigInteger
 from sqlalchemy import String, Column, BigInteger
-from sqlalchemy.orm import Session
 
 
-from apps.webui.internal.db import Base, JSONField, get_session
+from apps.webui.internal.db import Base, JSONField, Session
 from apps.webui.models.users import Users
 from apps.webui.models.users import Users
 
 
 import json
 import json
@@ -95,48 +94,43 @@ class ToolsTable:
         )
         )
 
 
         try:
         try:
-            with get_session() as db:
-                result = Tool(**tool.model_dump())
-                db.add(result)
-                db.commit()
-                db.refresh(result)
-                if result:
-                    return ToolModel.model_validate(result)
-                else:
-                    return None
+            result = Tool(**tool.model_dump())
+            Session.add(result)
+            Session.commit()
+            Session.refresh(result)
+            if result:
+                return ToolModel.model_validate(result)
+            else:
+                return None
         except Exception as e:
         except Exception as e:
             print(f"Error creating tool: {e}")
             print(f"Error creating tool: {e}")
             return None
             return None
 
 
     def get_tool_by_id(self, id: str) -> Optional[ToolModel]:
     def get_tool_by_id(self, id: str) -> Optional[ToolModel]:
         try:
         try:
-            with get_session() as db:
-                tool = db.get(Tool, id)
-                return ToolModel.model_validate(tool)
+            tool = Session.get(Tool, id)
+            return ToolModel.model_validate(tool)
         except:
         except:
             return None
             return None
 
 
     def get_tools(self) -> List[ToolModel]:
     def get_tools(self) -> List[ToolModel]:
-        with get_session() as db:
-            return [ToolModel.model_validate(tool) for tool in db.query(Tool).all()]
+        return [ToolModel.model_validate(tool) for tool in Session.query(Tool).all()]
 
 
     def get_tool_valves_by_id(self, id: str) -> Optional[dict]:
     def get_tool_valves_by_id(self, id: str) -> Optional[dict]:
         try:
         try:
-            with get_session() as db:
-                tool = db.get(Tool, id)
-                return tool.valves if tool.valves else {}
+            tool = Session.get(Tool, id)
+            return tool.valves if tool.valves else {}
         except Exception as e:
         except Exception as e:
             print(f"An error occurred: {e}")
             print(f"An error occurred: {e}")
             return None
             return None
 
 
     def update_tool_valves_by_id(self, id: str, valves: dict) -> Optional[ToolValves]:
     def update_tool_valves_by_id(self, id: str, valves: dict) -> Optional[ToolValves]:
         try:
         try:
-            with get_session() as db:
-                db.query(Tool).filter_by(id=id).update(
-                    {"valves": valves, "updated_at": int(time.time())}
-                )
-                db.commit()
-                return self.get_tool_by_id(id)
+            Session.query(Tool).filter_by(id=id).update(
+                {"valves": valves, "updated_at": int(time.time())}
+            )
+            Session.commit()
+            return self.get_tool_by_id(id)
         except:
         except:
             return None
             return None
 
 
@@ -183,19 +177,18 @@ class ToolsTable:
 
 
     def update_tool_by_id(self, id: str, updated: dict) -> Optional[ToolModel]:
     def update_tool_by_id(self, id: str, updated: dict) -> Optional[ToolModel]:
         try:
         try:
-            with get_session() as db:
-                db.query(Tool).filter_by(id=id).update(
-                    {**updated, "updated_at": int(time.time())}
-                )
-                db.commit()
-                return self.get_tool_by_id(id)
+            tool = Session.get(Tool, id)
+            tool.update(**updated)
+            tool.updated_at = int(time.time())
+            Session.commit()
+            Session.refresh(tool)
+            return ToolModel.model_validate(tool)
         except:
         except:
             return None
             return None
 
 
     def delete_tool_by_id(self, id: str) -> bool:
     def delete_tool_by_id(self, id: str) -> bool:
         try:
         try:
-            with get_session() as db:
-                db.query(Tool).filter_by(id=id).delete()
+            Session.query(Tool).filter_by(id=id).delete()
             return True
             return True
         except:
         except:
             return False
             return False

+ 117 - 134
backend/apps/webui/models/users.py

@@ -3,11 +3,10 @@ from typing import List, Union, Optional
 import time
 import time
 
 
 from sqlalchemy import String, Column, BigInteger, Text
 from sqlalchemy import String, Column, BigInteger, Text
-from sqlalchemy.orm import Session
 
 
 from utils.misc import get_gravatar_url
 from utils.misc import get_gravatar_url
 
 
-from apps.webui.internal.db import Base, JSONField, get_session
+from apps.webui.internal.db import Base, JSONField, Session
 from apps.webui.models.chats import Chats
 from apps.webui.models.chats import Chats
 
 
 ####################
 ####################
@@ -89,177 +88,161 @@ class UsersTable:
         role: str = "pending",
         role: str = "pending",
         oauth_sub: Optional[str] = None,
         oauth_sub: Optional[str] = None,
     ) -> Optional[UserModel]:
     ) -> Optional[UserModel]:
-        with get_session() as db:
-            user = UserModel(
-                **{
-                    "id": id,
-                    "name": name,
-                    "email": email,
-                    "role": role,
-                    "profile_image_url": profile_image_url,
-                    "last_active_at": int(time.time()),
-                    "created_at": int(time.time()),
-                    "updated_at": int(time.time()),
-                    "oauth_sub": oauth_sub,
-                }
-            )
-            result = User(**user.model_dump())
-            db.add(result)
-            db.commit()
-            db.refresh(result)
-            if result:
-                return user
-            else:
-                return None
+        user = UserModel(
+            **{
+                "id": id,
+                "name": name,
+                "email": email,
+                "role": role,
+                "profile_image_url": profile_image_url,
+                "last_active_at": int(time.time()),
+                "created_at": int(time.time()),
+                "updated_at": int(time.time()),
+                "oauth_sub": oauth_sub,
+            }
+        )
+        result = User(**user.model_dump())
+        Session.add(result)
+        Session.commit()
+        Session.refresh(result)
+        if result:
+            return user
+        else:
+            return None
 
 
     def get_user_by_id(self, id: str) -> Optional[UserModel]:
     def get_user_by_id(self, id: str) -> Optional[UserModel]:
-        with get_session() as db:
-            try:
-                user = db.query(User).filter_by(id=id).first()
-                return UserModel.model_validate(user)
-            except Exception as e:
-                return None
+        try:
+            user = Session.query(User).filter_by(id=id).first()
+            return UserModel.model_validate(user)
+        except Exception as e:
+            return None
 
 
     def get_user_by_api_key(self, api_key: str) -> Optional[UserModel]:
     def get_user_by_api_key(self, api_key: str) -> Optional[UserModel]:
-        with get_session() as db:
-            try:
-                user = db.query(User).filter_by(api_key=api_key).first()
-                return UserModel.model_validate(user)
-            except:
-                return None
+        try:
+            user = Session.query(User).filter_by(api_key=api_key).first()
+            return UserModel.model_validate(user)
+        except:
+            return None
 
 
     def get_user_by_email(self, email: str) -> Optional[UserModel]:
     def get_user_by_email(self, email: str) -> Optional[UserModel]:
-        with get_session() as db:
-            try:
-                user = db.query(User).filter_by(email=email).first()
-                return UserModel.model_validate(user)
-            except:
-                return None
+        try:
+            user = Session.query(User).filter_by(email=email).first()
+            return UserModel.model_validate(user)
+        except:
+            return None
 
 
     def get_user_by_oauth_sub(self, sub: str) -> Optional[UserModel]:
     def get_user_by_oauth_sub(self, sub: str) -> Optional[UserModel]:
-        with get_session() as db:
-            try:
-                user = db.query(User).filter_by(oauth_sub=sub).first()
-                return UserModel.model_validate(user)
-            except:
-                return None
+        try:
+            user = Session.query(User).filter_by(oauth_sub=sub).first()
+            return UserModel.model_validate(user)
+        except:
+            return None
 
 
     def get_users(self, skip: int = 0, limit: int = 50) -> List[UserModel]:
     def get_users(self, skip: int = 0, limit: int = 50) -> List[UserModel]:
-        with get_session() as db:
-            users = (
-                db.query(User)
-                # .offset(skip).limit(limit)
-                .all()
-            )
-            return [UserModel.model_validate(user) for user in users]
+        users = (
+            Session.query(User)
+            # .offset(skip).limit(limit)
+            .all()
+        )
+        return [UserModel.model_validate(user) for user in users]
 
 
     def get_num_users(self) -> Optional[int]:
     def get_num_users(self) -> Optional[int]:
-        with get_session() as db:
-            return db.query(User).count()
+        return Session.query(User).count()
 
 
     def get_first_user(self) -> UserModel:
     def get_first_user(self) -> UserModel:
-        with get_session() as db:
-            try:
-                user = db.query(User).order_by(User.created_at).first()
-                return UserModel.model_validate(user)
-            except:
-                return None
+        try:
+            user = Session.query(User).order_by(User.created_at).first()
+            return UserModel.model_validate(user)
+        except:
+            return None
 
 
     def update_user_role_by_id(self, id: str, role: str) -> Optional[UserModel]:
     def update_user_role_by_id(self, id: str, role: str) -> Optional[UserModel]:
-        with get_session() as db:
-            try:
-                db.query(User).filter_by(id=id).update({"role": role})
-                db.commit()
+        try:
+            Session.query(User).filter_by(id=id).update({"role": role})
+            Session.commit()
 
 
-                user = db.query(User).filter_by(id=id).first()
-                return UserModel.model_validate(user)
-            except:
-                return None
+            user = Session.query(User).filter_by(id=id).first()
+            return UserModel.model_validate(user)
+        except:
+            return None
 
 
     def update_user_profile_image_url_by_id(
     def update_user_profile_image_url_by_id(
         self, id: str, profile_image_url: str
         self, id: str, profile_image_url: str
     ) -> Optional[UserModel]:
     ) -> Optional[UserModel]:
-        with get_session() as db:
-            try:
-                db.query(User).filter_by(id=id).update(
-                    {"profile_image_url": profile_image_url}
-                )
-                db.commit()
-
-                user = db.query(User).filter_by(id=id).first()
-                return UserModel.model_validate(user)
-            except:
-                return None
+        try:
+            Session.query(User).filter_by(id=id).update(
+                {"profile_image_url": profile_image_url}
+            )
+            Session.commit()
+
+            user = Session.query(User).filter_by(id=id).first()
+            return UserModel.model_validate(user)
+        except:
+            return None
 
 
     def update_user_last_active_by_id(self, id: str) -> Optional[UserModel]:
     def update_user_last_active_by_id(self, id: str) -> Optional[UserModel]:
-        with get_session() as db:
-            try:
-                db.query(User).filter_by(id=id).update(
-                    {"last_active_at": int(time.time())}
-                )
+        try:
+            Session.query(User).filter_by(id=id).update(
+                {"last_active_at": int(time.time())}
+            )
 
 
-                user = db.query(User).filter_by(id=id).first()
-                return UserModel.model_validate(user)
-            except:
-                return None
+            user = Session.query(User).filter_by(id=id).first()
+            return UserModel.model_validate(user)
+        except:
+            return None
 
 
     def update_user_oauth_sub_by_id(
     def update_user_oauth_sub_by_id(
         self, id: str, oauth_sub: str
         self, id: str, oauth_sub: str
     ) -> Optional[UserModel]:
     ) -> Optional[UserModel]:
-        with get_session() as db:
-            try:
-                db.query(User).filter_by(id=id).update({"oauth_sub": oauth_sub})
+        try:
+            Session.query(User).filter_by(id=id).update({"oauth_sub": oauth_sub})
 
 
-                user = db.query(User).filter_by(id=id).first()
-                return UserModel.model_validate(user)
-            except:
-                return None
+            user = Session.query(User).filter_by(id=id).first()
+            return UserModel.model_validate(user)
+        except:
+            return None
 
 
     def update_user_by_id(self, id: str, updated: dict) -> Optional[UserModel]:
     def update_user_by_id(self, id: str, updated: dict) -> Optional[UserModel]:
-        with get_session() as db:
-            try:
-                db.query(User).filter_by(id=id).update(updated)
-                db.commit()
+        try:
+            Session.query(User).filter_by(id=id).update(updated)
+            Session.commit()
 
 
-                user = db.query(User).filter_by(id=id).first()
-                return UserModel.model_validate(user)
-                # return UserModel(**user.dict())
-            except Exception as e:
-                return None
+            user = Session.query(User).filter_by(id=id).first()
+            return UserModel.model_validate(user)
+            # return UserModel(**user.dict())
+        except Exception as e:
+            return None
 
 
     def delete_user_by_id(self, id: str) -> bool:
     def delete_user_by_id(self, id: str) -> bool:
-        with get_session() as db:
-            try:
-                # Delete User Chats
-                result = Chats.delete_chats_by_user_id(id)
-
-                if result:
-                    # Delete User
-                    db.query(User).filter_by(id=id).delete()
-                    db.commit()
-
-                    return True
-                else:
-                    return False
-            except:
+        try:
+            # Delete User Chats
+            result = Chats.delete_chats_by_user_id(id)
+
+            if result:
+                # Delete User
+                Session.query(User).filter_by(id=id).delete()
+                Session.commit()
+
+                return True
+            else:
                 return False
                 return False
+        except:
+            return False
 
 
     def update_user_api_key_by_id(self, id: str, api_key: str) -> str:
     def update_user_api_key_by_id(self, id: str, api_key: str) -> str:
-        with get_session() as db:
-            try:
-                result = db.query(User).filter_by(id=id).update({"api_key": api_key})
-                db.commit()
-                return True if result == 1 else False
-            except:
-                return False
+        try:
+            result = Session.query(User).filter_by(id=id).update({"api_key": api_key})
+            Session.commit()
+            return True if result == 1 else False
+        except:
+            return False
 
 
     def get_user_api_key_by_id(self, id: str) -> Optional[str]:
     def get_user_api_key_by_id(self, id: str) -> Optional[str]:
-        with get_session() as db:
-            try:
-                user = db.query(User).filter_by(id=id).first()
-                return user.api_key
-            except Exception as e:
-                return None
+        try:
+            user = Session.query(User).filter_by(id=id).first()
+            return user.api_key
+        except Exception as e:
+            return None
 
 
 
 
 Users = UsersTable()
 Users = UsersTable()

+ 10 - 4
backend/main.py

@@ -29,7 +29,6 @@ from fastapi import HTTPException
 from fastapi.middleware.wsgi import WSGIMiddleware
 from fastapi.middleware.wsgi import WSGIMiddleware
 from fastapi.middleware.cors import CORSMiddleware
 from fastapi.middleware.cors import CORSMiddleware
 from sqlalchemy import text
 from sqlalchemy import text
-from sqlalchemy.orm import Session
 from starlette.exceptions import HTTPException as StarletteHTTPException
 from starlette.exceptions import HTTPException as StarletteHTTPException
 from starlette.middleware.base import BaseHTTPMiddleware
 from starlette.middleware.base import BaseHTTPMiddleware
 from starlette.middleware.sessions import SessionMiddleware
 from starlette.middleware.sessions import SessionMiddleware
@@ -57,7 +56,7 @@ from apps.webui.main import (
     get_pipe_models,
     get_pipe_models,
     generate_function_chat_completion,
     generate_function_chat_completion,
 )
 )
-from apps.webui.internal.db import get_session, SessionLocal
+from apps.webui.internal.db import Session, SessionLocal
 
 
 
 
 from pydantic import BaseModel
 from pydantic import BaseModel
@@ -794,6 +793,14 @@ app.add_middleware(
     allow_headers=["*"],
     allow_headers=["*"],
 )
 )
 
 
+@app.middleware("http")
+async def remove_session_after_request(request: Request, call_next):
+    response = await call_next(request)
+    log.debug("Removing session after request")
+    Session.commit()
+    Session.remove()
+    return response
+
 
 
 @app.middleware("http")
 @app.middleware("http")
 async def check_url(request: Request, call_next):
 async def check_url(request: Request, call_next):
@@ -2034,8 +2041,7 @@ async def healthcheck():
 
 
 @app.get("/health/db")
 @app.get("/health/db")
 async def healthcheck_with_db():
 async def healthcheck_with_db():
-    with get_session() as db:
-        result = db.execute(text("SELECT 1;")).all()
+    Session.execute(text("SELECT 1;")).all()
     return {"status": True}
     return {"status": True}
 
 
 
 

+ 2 - 0
backend/test/apps/webui/routers/test_chats.py

@@ -90,6 +90,8 @@ class TestChats(AbstractPostgresTest):
 
 
     def test_get_user_archived_chats(self):
     def test_get_user_archived_chats(self):
         self.chats.archive_all_chats_by_user_id("2")
         self.chats.archive_all_chats_by_user_id("2")
+        from apps.webui.internal.db import Session
+        Session.commit()
         with mock_webui_user(id="2"):
         with mock_webui_user(id="2"):
             response = self.fast_api_client.get(self.create_url("/all/archived"))
             response = self.fast_api_client.get(self.create_url("/all/archived"))
         assert response.status_code == 200
         assert response.status_code == 200

+ 9 - 12
backend/test/util/abstract_integration_test.py

@@ -9,6 +9,7 @@ from pytest_docker.plugin import get_docker_ip
 from fastapi.testclient import TestClient
 from fastapi.testclient import TestClient
 from sqlalchemy import text, create_engine
 from sqlalchemy import text, create_engine
 
 
+
 log = logging.getLogger(__name__)
 log = logging.getLogger(__name__)
 
 
 
 
@@ -50,11 +51,6 @@ class AbstractPostgresTest(AbstractIntegrationTest):
     DOCKER_CONTAINER_NAME = "postgres-test-container-will-get-deleted"
     DOCKER_CONTAINER_NAME = "postgres-test-container-will-get-deleted"
     docker_client: DockerClient
     docker_client: DockerClient
 
 
-    def get_db(self):
-        from apps.webui.internal.db import SessionLocal
-
-        return SessionLocal()
-
     @classmethod
     @classmethod
     def _create_db_url(cls, env_vars_postgres: dict) -> str:
     def _create_db_url(cls, env_vars_postgres: dict) -> str:
         host = get_docker_ip()
         host = get_docker_ip()
@@ -113,21 +109,21 @@ class AbstractPostgresTest(AbstractIntegrationTest):
             pytest.fail(f"Could not setup test environment: {ex}")
             pytest.fail(f"Could not setup test environment: {ex}")
 
 
     def _check_db_connection(self):
     def _check_db_connection(self):
+        from apps.webui.internal.db import Session
         retries = 10
         retries = 10
         while retries > 0:
         while retries > 0:
             try:
             try:
-                self.db_session.execute(text("SELECT 1"))
-                self.db_session.commit()
+                Session.execute(text("SELECT 1"))
+                Session.commit()
                 break
                 break
             except Exception as e:
             except Exception as e:
-                self.db_session.rollback()
+                Session.rollback()
                 log.warning(e)
                 log.warning(e)
                 time.sleep(3)
                 time.sleep(3)
                 retries -= 1
                 retries -= 1
 
 
     def setup_method(self):
     def setup_method(self):
         super().setup_method()
         super().setup_method()
-        self.db_session = self.get_db()
         self._check_db_connection()
         self._check_db_connection()
 
 
     @classmethod
     @classmethod
@@ -136,8 +132,9 @@ class AbstractPostgresTest(AbstractIntegrationTest):
         cls.docker_client.containers.get(cls.DOCKER_CONTAINER_NAME).remove(force=True)
         cls.docker_client.containers.get(cls.DOCKER_CONTAINER_NAME).remove(force=True)
 
 
     def teardown_method(self):
     def teardown_method(self):
+        from apps.webui.internal.db import Session
         # rollback everything not yet committed
         # rollback everything not yet committed
-        self.db_session.commit()
+        Session.commit()
 
 
         # truncate all tables
         # truncate all tables
         tables = [
         tables = [
@@ -152,5 +149,5 @@ class AbstractPostgresTest(AbstractIntegrationTest):
             '"user"',
             '"user"',
         ]
         ]
         for table in tables:
         for table in tables:
-            self.db_session.execute(text(f"TRUNCATE TABLE {table}"))
-        self.db_session.commit()
+            Session.execute(text(f"TRUNCATE TABLE {table}"))
+        Session.commit()