浏览代码

feat(sqlalchemy): use session factory instead of context manager

Jonathan Rohde 10 月之前
父节点
当前提交
da403f3e3c

+ 1 - 11
backend/apps/webui/internal/db.py

@@ -57,14 +57,4 @@ SessionLocal = sessionmaker(
     autocommit=False, autoflush=False, bind=engine, expire_on_commit=False
 )
 Base = declarative_base()
-
-
-@contextmanager
-def get_session():
-    session = scoped_session(SessionLocal)
-    try:
-        yield session
-        session.commit()
-    except Exception as e:
-        session.rollback()
-        raise e
+Session = scoped_session(SessionLocal)

+ 61 - 70
backend/apps/webui/models/auths.py

@@ -3,12 +3,11 @@ from typing import Optional
 import uuid
 import logging
 from sqlalchemy import String, Column, Boolean
-from sqlalchemy.orm import Session
 
 from apps.webui.models.users import UserModel, Users
 from utils.utils import verify_password
 
-from apps.webui.internal.db import Base, get_session
+from apps.webui.internal.db import Base, Session
 
 from config import SRC_LOG_LEVELS
 
@@ -103,101 +102,93 @@ class AuthsTable:
         role: str = "pending",
         oauth_sub: Optional[str] = None,
     ) -> Optional[UserModel]:
-        with get_session() as db:
-            log.info("insert_new_auth")
+        log.info("insert_new_auth")
 
-            id = str(uuid.uuid4())
+        id = str(uuid.uuid4())
 
-            auth = AuthModel(
-                **{"id": id, "email": email, "password": password, "active": True}
-            )
-            result = Auth(**auth.model_dump())
-            db.add(result)
+        auth = AuthModel(
+            **{"id": id, "email": email, "password": password, "active": True}
+        )
+        result = Auth(**auth.model_dump())
+        Session.add(result)
 
-            user = Users.insert_new_user(
-                id, name, email, profile_image_url, role, oauth_sub
-            )
+        user = Users.insert_new_user(
+            id, name, email, profile_image_url, role, oauth_sub)
 
-            db.commit()
-            db.refresh(result)
+        Session.commit()
+        Session.refresh(result)
 
-            if result and user:
-                return user
-            else:
-                return None
+        if result and user:
+            return user
+        else:
+            return None
 
     def authenticate_user(self, email: str, password: str) -> Optional[UserModel]:
         log.info(f"authenticate_user: {email}")
-        with get_session() as db:
-            try:
-                auth = db.query(Auth).filter_by(email=email, active=True).first()
-                if auth:
-                    if verify_password(password, auth.password):
-                        user = Users.get_user_by_id(auth.id)
-                        return user
-                    else:
-                        return None
+        try:
+            auth = Session.query(Auth).filter_by(email=email, active=True).first()
+            if auth:
+                if verify_password(password, auth.password):
+                    user = Users.get_user_by_id(auth.id)
+                    return user
                 else:
                     return None
-            except:
+            else:
                 return None
+        except:
+            return None
 
     def authenticate_user_by_api_key(self, api_key: str) -> Optional[UserModel]:
         log.info(f"authenticate_user_by_api_key: {api_key}")
-        with get_session() as db:
-            # if no api_key, return None
-            if not api_key:
-                return None
+        # if no api_key, return None
+        if not api_key:
+            return None
 
-            try:
-                user = Users.get_user_by_api_key(api_key)
-                return user if user else None
-            except:
-                return False
+        try:
+            user = Users.get_user_by_api_key(api_key)
+            return user if user else None
+        except:
+            return False
 
     def authenticate_user_by_trusted_header(self, email: str) -> Optional[UserModel]:
         log.info(f"authenticate_user_by_trusted_header: {email}")
-        with get_session() as db:
-            try:
-                auth = db.query(Auth).filter(email=email, active=True).first()
-                if auth:
-                    user = Users.get_user_by_id(auth.id)
-                    return user
-            except:
-                return None
+        try:
+            auth = Session.query(Auth).filter(email=email, active=True).first()
+            if auth:
+                user = Users.get_user_by_id(auth.id)
+                return user
+        except:
+            return None
 
     def update_user_password_by_id(self, id: str, new_password: str) -> bool:
-        with get_session() as db:
-            try:
-                result = (
-                    db.query(Auth).filter_by(id=id).update({"password": new_password})
-                )
-                return True if result == 1 else False
-            except:
-                return False
+        try:
+            result = (
+                Session.query(Auth).filter_by(id=id).update({"password": new_password})
+            )
+            return True if result == 1 else False
+        except:
+            return False
 
     def update_email_by_id(self, id: str, email: str) -> bool:
-        with get_session() as db:
-            try:
-                result = db.query(Auth).filter_by(id=id).update({"email": email})
-                return True if result == 1 else False
-            except:
-                return False
+        try:
+            result = Session.query(Auth).filter_by(id=id).update({"email": email})
+            return True if result == 1 else False
+        except:
+            return False
 
     def delete_auth_by_id(self, id: str) -> bool:
-        with get_session() as db:
-            try:
-                # Delete User
-                result = Users.delete_user_by_id(id)
+        try:
+            # Delete User
+            result = Users.delete_user_by_id(id)
 
-                if result:
-                    db.query(Auth).filter_by(id=id).delete()
+            if result:
+                Session.query(Auth).filter_by(id=id).delete()
 
-                    return True
-                else:
-                    return False
-            except:
+                return True
+            else:
                 return False
+        except:
+            return False
 
 
 Auths = AuthsTable()

+ 139 - 161
backend/apps/webui/models/chats.py

@@ -6,9 +6,8 @@ import uuid
 import time
 
 from sqlalchemy import Column, String, BigInteger, Boolean
-from sqlalchemy.orm import Session
 
-from apps.webui.internal.db import Base, get_session
+from apps.webui.internal.db import Base, Session
 
 
 ####################
@@ -80,93 +79,88 @@ class ChatTitleIdResponse(BaseModel):
 class ChatTable:
 
     def insert_new_chat(self, user_id: str, form_data: ChatForm) -> Optional[ChatModel]:
-        with get_session() as db:
-            id = str(uuid.uuid4())
-            chat = ChatModel(
-                **{
-                    "id": id,
-                    "user_id": user_id,
-                    "title": (
-                        form_data.chat["title"]
-                        if "title" in form_data.chat
-                        else "New Chat"
-                    ),
-                    "chat": json.dumps(form_data.chat),
-                    "created_at": int(time.time()),
-                    "updated_at": int(time.time()),
-                }
-            )
-
-            result = Chat(**chat.model_dump())
-            db.add(result)
-            db.commit()
-            db.refresh(result)
-            return ChatModel.model_validate(result) if result else None
+        id = str(uuid.uuid4())
+        chat = ChatModel(
+            **{
+                "id": id,
+                "user_id": user_id,
+                "title": (
+                    form_data.chat["title"]
+                    if "title" in form_data.chat
+                    else "New Chat"
+                ),
+                "chat": json.dumps(form_data.chat),
+                "created_at": int(time.time()),
+                "updated_at": int(time.time()),
+            }
+        )
+
+        result = Chat(**chat.model_dump())
+        Session.add(result)
+        Session.commit()
+        Session.refresh(result)
+        return ChatModel.model_validate(result) if result else None
 
     def update_chat_by_id(self, id: str, chat: dict) -> Optional[ChatModel]:
-        with get_session() as db:
-            try:
-                chat_obj = db.get(Chat, id)
-                chat_obj.chat = json.dumps(chat)
-                chat_obj.title = chat["title"] if "title" in chat else "New Chat"
-                chat_obj.updated_at = int(time.time())
-                db.commit()
-                db.refresh(chat_obj)
-
-                return ChatModel.model_validate(chat_obj)
-            except Exception as e:
-                return None
+        try:
+            chat_obj = Session.get(Chat, id)
+            chat_obj.chat = json.dumps(chat)
+            chat_obj.title = chat["title"] if "title" in chat else "New Chat"
+            chat_obj.updated_at = int(time.time())
+            Session.commit()
+            Session.refresh(chat_obj)
+
+            return ChatModel.model_validate(chat_obj)
+        except Exception as e:
+            return None
 
     def insert_shared_chat_by_chat_id(self, chat_id: str) -> Optional[ChatModel]:
-        with get_session() as db:
-            # Get the existing chat to share
-            chat = db.get(Chat, chat_id)
-            # Check if the chat is already shared
-            if chat.share_id:
-                return self.get_chat_by_id_and_user_id(chat.share_id, "shared")
-            # Create a new chat with the same data, but with a new ID
-            shared_chat = ChatModel(
-                **{
-                    "id": str(uuid.uuid4()),
-                    "user_id": f"shared-{chat_id}",
-                    "title": chat.title,
-                    "chat": chat.chat,
-                    "created_at": chat.created_at,
-                    "updated_at": int(time.time()),
-                }
-            )
-            shared_result = Chat(**shared_chat.model_dump())
-            db.add(shared_result)
-            db.commit()
-            db.refresh(shared_result)
-            # Update the original chat with the share_id
-            result = (
-                db.query(Chat)
-                .filter_by(id=chat_id)
-                .update({"share_id": shared_chat.id})
-            )
-
-            return shared_chat if (shared_result and result) else None
+        # Get the existing chat to share
+        chat = Session.get(Chat, chat_id)
+        # Check if the chat is already shared
+        if chat.share_id:
+            return self.get_chat_by_id_and_user_id(chat.share_id, "shared")
+        # Create a new chat with the same data, but with a new ID
+        shared_chat = ChatModel(
+            **{
+                "id": str(uuid.uuid4()),
+                "user_id": f"shared-{chat_id}",
+                "title": chat.title,
+                "chat": chat.chat,
+                "created_at": chat.created_at,
+                "updated_at": int(time.time()),
+            }
+        )
+        shared_result = Chat(**shared_chat.model_dump())
+        Session.add(shared_result)
+        Session.commit()
+        Session.refresh(shared_result)
+        # Update the original chat with the share_id
+        result = (
+            Session.query(Chat)
+            .filter_by(id=chat_id)
+            .update({"share_id": shared_chat.id})
+        )
+
+        return shared_chat if (shared_result and result) else None
 
     def update_shared_chat_by_chat_id(self, chat_id: str) -> Optional[ChatModel]:
-        with get_session() as db:
-            try:
-                print("update_shared_chat_by_id")
-                chat = db.get(Chat, chat_id)
-                print(chat)
-                chat.title = chat.title
-                chat.chat = chat.chat
-                db.commit()
-                db.refresh(chat)
-
-                return self.get_chat_by_id(chat.share_id)
-            except:
-                return None
+        try:
+            print("update_shared_chat_by_id")
+            chat = Session.get(Chat, chat_id)
+            print(chat)
+            chat.title = chat.title
+            chat.chat = chat.chat
+            Session.commit()
+            Session.refresh(chat)
+
+            return self.get_chat_by_id(chat.share_id)
+        except:
+            return None
 
     def delete_shared_chat_by_chat_id(self, chat_id: str) -> bool:
         try:
-            with get_session() as db:
-                db.query(Chat).filter_by(user_id=f"shared-{chat_id}").delete()
+            Session.query(Chat).filter_by(user_id=f"shared-{chat_id}").delete()
             return True
         except:
             return False
@@ -175,30 +169,27 @@ class ChatTable:
         self, id: str, share_id: Optional[str]
     ) -> Optional[ChatModel]:
         try:
-            with get_session() as db:
-                chat = db.get(Chat, id)
-                chat.share_id = share_id
-                db.commit()
-                db.refresh(chat)
-                return chat
+            chat = Session.get(Chat, id)
+            chat.share_id = share_id
+            Session.commit()
+            Session.refresh(chat)
+            return ChatModel.model_validate(chat)
         except:
             return None
 
     def toggle_chat_archive_by_id(self, id: str) -> Optional[ChatModel]:
         try:
-            with get_session() as db:
-                chat = self.get_chat_by_id(id)
-                db.query(Chat).filter_by(id=id).update({"archived": not chat.archived})
-
-                return self.get_chat_by_id(id)
+            chat = Session.get(Chat, id)
+            chat.archived = not chat.archived
+            Session.commit()
+            Session.refresh(chat)
+            return ChatModel.model_validate(chat)
         except:
             return None
 
     def archive_all_chats_by_user_id(self, user_id: str) -> bool:
         try:
-            with get_session() as db:
-                db.query(Chat).filter_by(user_id=user_id).update({"archived": True})
-
+            Session.query(Chat).filter_by(user_id=user_id).update({"archived": True})
             return True
         except:
             return False
@@ -206,9 +197,8 @@ class ChatTable:
     def get_archived_chat_list_by_user_id(
         self, user_id: str, skip: int = 0, limit: int = 50
     ) -> List[ChatModel]:
-        with get_session() as db:
             all_chats = (
-                db.query(Chat)
+                Session.query(Chat)
                 .filter_by(user_id=user_id, archived=True)
                 .order_by(Chat.updated_at.desc())
                 # .limit(limit).offset(skip)
@@ -223,120 +213,108 @@ class ChatTable:
         skip: int = 0,
         limit: int = 50,
     ) -> List[ChatModel]:
-        with get_session() as db:
-            query = db.query(Chat).filter_by(user_id=user_id)
-            if not include_archived:
-                query = query.filter_by(archived=False)
-            all_chats = (
-                query.order_by(Chat.updated_at.desc())
-                # .limit(limit).offset(skip)
-                .all()
-            )
-            return [ChatModel.model_validate(chat) for chat in all_chats]
+        query = Session.query(Chat).filter_by(user_id=user_id)
+        if not include_archived:
+            query = query.filter_by(archived=False)
+        all_chats = (
+            query.order_by(Chat.updated_at.desc())
+            # .limit(limit).offset(skip)
+            .all()
+        )
+        return [ChatModel.model_validate(chat) for chat in all_chats]
 
     def get_chat_list_by_chat_ids(
         self, chat_ids: List[str], skip: int = 0, limit: int = 50
     ) -> List[ChatModel]:
-        with get_session() as db:
-            all_chats = (
-                db.query(Chat)
-                .filter(Chat.id.in_(chat_ids))
-                .filter_by(archived=False)
-                .order_by(Chat.updated_at.desc())
-                .all()
-            )
-            return [ChatModel.model_validate(chat) for chat in all_chats]
+        all_chats = (
+            Session.query(Chat)
+            .filter(Chat.id.in_(chat_ids))
+            .filter_by(archived=False)
+            .order_by(Chat.updated_at.desc())
+            .all()
+        )
+        return [ChatModel.model_validate(chat) for chat in all_chats]
 
     def get_chat_by_id(self, id: str) -> Optional[ChatModel]:
         try:
-            with get_session() as db:
-                chat = db.get(Chat, id)
-                return ChatModel.model_validate(chat)
+            chat = Session.get(Chat, id)
+            return ChatModel.model_validate(chat)
         except:
             return None
 
     def get_chat_by_share_id(self, id: str) -> Optional[ChatModel]:
         try:
-            with get_session() as db:
-                chat = db.query(Chat).filter_by(share_id=id).first()
+            chat = Session.query(Chat).filter_by(share_id=id).first()
 
-                if chat:
-                    return self.get_chat_by_id(id)
-                else:
-                    return None
+            if chat:
+                return self.get_chat_by_id(id)
+            else:
+                return None
         except Exception as e:
             return None
 
     def get_chat_by_id_and_user_id(self, id: str, user_id: str) -> Optional[ChatModel]:
         try:
-            with get_session() as db:
-                chat = db.query(Chat).filter_by(id=id, user_id=user_id).first()
-                return ChatModel.model_validate(chat)
+            chat = Session.query(Chat).filter_by(id=id, user_id=user_id).first()
+            return ChatModel.model_validate(chat)
         except:
             return None
 
     def get_chats(self, skip: int = 0, limit: int = 50) -> List[ChatModel]:
-        with get_session() as db:
-            all_chats = (
-                db.query(Chat)
-                # .limit(limit).offset(skip)
-                .order_by(Chat.updated_at.desc())
-            )
-            return [ChatModel.model_validate(chat) for chat in all_chats]
+        all_chats = (
+            Session.query(Chat)
+            # .limit(limit).offset(skip)
+            .order_by(Chat.updated_at.desc())
+        )
+        return [ChatModel.model_validate(chat) for chat in all_chats]
 
     def get_chats_by_user_id(self, user_id: str) -> List[ChatModel]:
-        with get_session() as db:
-            all_chats = (
-                db.query(Chat)
-                .filter_by(user_id=user_id)
-                .order_by(Chat.updated_at.desc())
-            )
-            return [ChatModel.model_validate(chat) for chat in all_chats]
+        all_chats = (
+            Session.query(Chat)
+            .filter_by(user_id=user_id)
+            .order_by(Chat.updated_at.desc())
+        )
+        return [ChatModel.model_validate(chat) for chat in all_chats]
 
     def get_archived_chats_by_user_id(self, user_id: str) -> List[ChatModel]:
-        with get_session() as db:
-            all_chats = (
-                db.query(Chat)
-                .filter_by(user_id=user_id, archived=True)
-                .order_by(Chat.updated_at.desc())
-            )
-            return [ChatModel.model_validate(chat) for chat in all_chats]
+        all_chats = (
+            Session.query(Chat)
+            .filter_by(user_id=user_id, archived=True)
+            .order_by(Chat.updated_at.desc())
+        )
+        return [ChatModel.model_validate(chat) for chat in all_chats]
 
     def delete_chat_by_id(self, id: str) -> bool:
         try:
-            with get_session() as db:
-                db.query(Chat).filter_by(id=id).delete()
+            Session.query(Chat).filter_by(id=id).delete()
 
-                return True and self.delete_shared_chat_by_chat_id(id)
+            return True and self.delete_shared_chat_by_chat_id(id)
         except:
             return False
 
     def delete_chat_by_id_and_user_id(self, id: str, user_id: str) -> bool:
         try:
-            with get_session() as db:
-                db.query(Chat).filter_by(id=id, user_id=user_id).delete()
+            Session.query(Chat).filter_by(id=id, user_id=user_id).delete()
 
-                return True and self.delete_shared_chat_by_chat_id(id)
+            return True and self.delete_shared_chat_by_chat_id(id)
         except:
             return False
 
     def delete_chats_by_user_id(self, user_id: str) -> bool:
         try:
-            with get_session() as db:
-                self.delete_shared_chats_by_user_id(user_id)
+            self.delete_shared_chats_by_user_id(user_id)
 
-                db.query(Chat).filter_by(user_id=user_id).delete()
+            Session.query(Chat).filter_by(user_id=user_id).delete()
             return True
         except:
             return False
 
     def delete_shared_chats_by_user_id(self, user_id: str) -> bool:
         try:
-            with get_session() as db:
-                chats_by_user = db.query(Chat).filter_by(user_id=user_id).all()
-                shared_chat_ids = [f"shared-{chat.id}" for chat in chats_by_user]
+            chats_by_user = Session.query(Chat).filter_by(user_id=user_id).all()
+            shared_chat_ids = [f"shared-{chat.id}" for chat in chats_by_user]
 
-                db.query(Chat).filter(Chat.user_id.in_(shared_chat_ids)).delete()
+            Session.query(Chat).filter(Chat.user_id.in_(shared_chat_ids)).delete()
 
             return True
         except:

+ 36 - 43
backend/apps/webui/models/documents.py

@@ -4,9 +4,8 @@ import time
 import logging
 
 from sqlalchemy import String, Column, BigInteger
-from sqlalchemy.orm import Session
 
-from apps.webui.internal.db import Base, get_session
+from apps.webui.internal.db import Base, Session
 
 import json
 
@@ -84,46 +83,42 @@ class DocumentsTable:
         )
 
         try:
-            with get_session() as db:
-                result = Document(**document.model_dump())
-                db.add(result)
-                db.commit()
-                db.refresh(result)
-                if result:
-                    return DocumentModel.model_validate(result)
-                else:
-                    return None
+            result = Document(**document.model_dump())
+            Session.add(result)
+            Session.commit()
+            Session.refresh(result)
+            if result:
+                return DocumentModel.model_validate(result)
+            else:
+                return None
         except:
             return None
 
     def get_doc_by_name(self, name: str) -> Optional[DocumentModel]:
         try:
-            with get_session() as db:
-                document = db.query(Document).filter_by(name=name).first()
-                return DocumentModel.model_validate(document) if document else None
+            document = Session.query(Document).filter_by(name=name).first()
+            return DocumentModel.model_validate(document) if document else None
         except:
             return None
 
     def get_docs(self) -> List[DocumentModel]:
-        with get_session() as db:
-            return [
-                DocumentModel.model_validate(doc) for doc in db.query(Document).all()
-            ]
+        return [
+            DocumentModel.model_validate(doc) for doc in Session.query(Document).all()
+        ]
 
     def update_doc_by_name(
         self, name: str, form_data: DocumentUpdateForm
     ) -> Optional[DocumentModel]:
         try:
-            with get_session() as db:
-                db.query(Document).filter_by(name=name).update(
-                    {
-                        "title": form_data.title,
-                        "name": form_data.name,
-                        "timestamp": int(time.time()),
-                    }
-                )
-                db.commit()
-                return self.get_doc_by_name(form_data.name)
+            Session.query(Document).filter_by(name=name).update(
+                {
+                    "title": form_data.title,
+                    "name": form_data.name,
+                    "timestamp": int(time.time()),
+                }
+            )
+            Session.commit()
+            return self.get_doc_by_name(form_data.name)
         except Exception as e:
             log.exception(e)
             return None
@@ -132,27 +127,25 @@ class DocumentsTable:
         self, name: str, updated: dict
     ) -> Optional[DocumentModel]:
         try:
-            with get_session() as db:
-                doc = self.get_doc_by_name(name)
-                doc_content = json.loads(doc.content if doc.content else "{}")
-                doc_content = {**doc_content, **updated}
-
-                db.query(Document).filter_by(name=name).update(
-                    {
-                        "content": json.dumps(doc_content),
-                        "timestamp": int(time.time()),
-                    }
-                )
-                db.commit()
-                return self.get_doc_by_name(name)
+            doc = self.get_doc_by_name(name)
+            doc_content = json.loads(doc.content if doc.content else "{}")
+            doc_content = {**doc_content, **updated}
+
+            Session.query(Document).filter_by(name=name).update(
+                {
+                    "content": json.dumps(doc_content),
+                    "timestamp": int(time.time()),
+                }
+            )
+            Session.commit()
+            return self.get_doc_by_name(name)
         except Exception as e:
             log.exception(e)
             return None
 
     def delete_doc_by_name(self, name: str) -> bool:
         try:
-            with get_session() as db:
-                db.query(Document).filter_by(name=name).delete()
+            Session.query(Document).filter_by(name=name).delete()
             return True
         except:
             return False

+ 14 - 22
backend/apps/webui/models/files.py

@@ -4,9 +4,8 @@ import time
 import logging
 
 from sqlalchemy import Column, String, BigInteger
-from sqlalchemy.orm import Session
 
-from apps.webui.internal.db import JSONField, Base, get_session
+from apps.webui.internal.db import JSONField, Base, Session
 
 import json
 
@@ -71,45 +70,38 @@ class FilesTable:
         )
 
         try:
-            with get_session() as db:
-                result = File(**file.model_dump())
-                db.add(result)
-                db.commit()
-                db.refresh(result)
-                if result:
-                    return FileModel.model_validate(result)
-                else:
-                    return None
+            result = File(**file.model_dump())
+            Session.add(result)
+            Session.commit()
+            Session.refresh(result)
+            if result:
+                return FileModel.model_validate(result)
+            else:
+                return None
         except Exception as e:
             print(f"Error creating tool: {e}")
             return None
 
     def get_file_by_id(self, id: str) -> Optional[FileModel]:
         try:
-            with get_session() as db:
-                file = db.get(File, id)
-                return FileModel.model_validate(file)
+            file = Session.get(File, id)
+            return FileModel.model_validate(file)
         except:
             return None
 
     def get_files(self) -> List[FileModel]:
-        with get_session() as db:
-            return [FileModel.model_validate(file) for file in db.query(File).all()]
+        return [FileModel.model_validate(file) for file in Session.query(File).all()]
 
     def delete_file_by_id(self, id: str) -> bool:
         try:
-            with get_session() as db:
-                db.query(File).filter_by(id=id).delete()
-                db.commit()
+            Session.query(File).filter_by(id=id).delete()
             return True
         except:
             return False
 
     def delete_all_files(self) -> bool:
         try:
-            with get_session() as db:
-                db.query(File).delete()
-                db.commit()
+            Session.query(File).delete()
             return True
         except:
             return False

+ 53 - 64
backend/apps/webui/models/functions.py

@@ -4,9 +4,8 @@ import time
 import logging
 
 from sqlalchemy import Column, String, Text, BigInteger, Boolean
-from sqlalchemy.orm import Session
 
-from apps.webui.internal.db import JSONField, Base, get_session
+from apps.webui.internal.db import JSONField, Base, Session
 from apps.webui.models.users import Users
 
 import json
@@ -100,64 +99,57 @@ class FunctionsTable:
         )
 
         try:
-            with get_session() as db:
-                result = Function(**function.model_dump())
-                db.add(result)
-                db.commit()
-                db.refresh(result)
-                if result:
-                    return FunctionModel.model_validate(result)
-                else:
-                    return None
+            result = Function(**function.model_dump())
+            Session.add(result)
+            Session.commit()
+            Session.refresh(result)
+            if result:
+                return FunctionModel.model_validate(result)
+            else:
+                return None
         except Exception as e:
             print(f"Error creating tool: {e}")
             return None
 
     def get_function_by_id(self, id: str) -> Optional[FunctionModel]:
         try:
-            with get_session() as db:
-                function = db.get(Function, id)
-                return FunctionModel.model_validate(function)
+            function = Session.get(Function, id)
+            return FunctionModel.model_validate(function)
         except:
             return None
 
     def get_functions(self, active_only=False) -> List[FunctionModel]:
         if active_only:
-            with get_session() as db:
-                return [
-                    FunctionModel.model_validate(function)
-                    for function in db.query(Function).filter_by(is_active=True).all()
-                ]
+            return [
+                FunctionModel.model_validate(function)
+                for function in Session.query(Function).filter_by(is_active=True).all()
+            ]
         else:
-            with get_session() as db:
-                return [
-                    FunctionModel.model_validate(function)
-                    for function in db.query(Function).all()
-                ]
+            return [
+                FunctionModel.model_validate(function)
+                for function in Session.query(Function).all()
+            ]
 
     def get_functions_by_type(
         self, type: str, active_only=False
     ) -> List[FunctionModel]:
         if active_only:
-            with get_session() as db:
-                return [
-                    FunctionModel.model_validate(function)
-                    for function in db.query(Function)
-                    .filter_by(type=type, is_active=True)
-                    .all()
-                ]
+            return [
+                FunctionModel.model_validate(function)
+                for function in Session.query(Function)
+                .filter_by(type=type, is_active=True)
+                .all()
+            ]
         else:
-            with get_session() as db:
-                return [
-                    FunctionModel.model_validate(function)
-                    for function in db.query(Function).filter_by(type=type).all()
-                ]
+            return [
+                FunctionModel.model_validate(function)
+                for function in Session.query(Function).filter_by(type=type).all()
+            ]
 
     def get_function_valves_by_id(self, id: str) -> Optional[dict]:
         try:
-            with get_session() as db:
-                function = db.get(Function, id)
-                return function.valves if function.valves else {}
+            function = Session.get(Function, id)
+            return function.valves if function.valves else {}
         except Exception as e:
             print(f"An error occurred: {e}")
             return None
@@ -166,12 +158,12 @@ class FunctionsTable:
         self, id: str, valves: dict
     ) -> Optional[FunctionValves]:
         try:
-            with get_session() as db:
-                db.query(Function).filter_by(id=id).update(
-                    {"valves": valves, "updated_at": int(time.time())}
-                )
-                db.commit()
-                return self.get_function_by_id(id)
+            function = Session.get(Function, id)
+            function.valves = valves
+            function.updated_at = int(time.time())
+            Session.commit()
+            Session.refresh(function)
+            return self.get_function_by_id(id)
         except:
             return None
 
@@ -219,36 +211,33 @@ class FunctionsTable:
 
     def update_function_by_id(self, id: str, updated: dict) -> Optional[FunctionModel]:
         try:
-            with get_session() as db:
-                db.query(Function).filter_by(id=id).update(
-                    {
-                        **updated,
-                        "updated_at": int(time.time()),
-                    }
-                )
-                db.commit()
-                return self.get_function_by_id(id)
+            Session.query(Function).filter_by(id=id).update(
+                {
+                    **updated,
+                    "updated_at": int(time.time()),
+                }
+            )
+            Session.commit()
+            return self.get_function_by_id(id)
         except:
             return None
 
     def deactivate_all_functions(self) -> Optional[bool]:
         try:
-            with get_session() as db:
-                db.query(Function).update(
-                    {
-                        "is_active": False,
-                        "updated_at": int(time.time()),
-                    }
-                )
-                db.commit()
+            Session.query(Function).update(
+                {
+                    "is_active": False,
+                    "updated_at": int(time.time()),
+                }
+            )
+            Session.commit()
             return True
         except:
             return None
 
     def delete_function_by_id(self, id: str) -> bool:
         try:
-            with get_session() as db:
-                db.query(Function).filter_by(id=id).delete()
+            Session.query(Function).filter_by(id=id).delete()
             return True
         except:
             return False

+ 25 - 35
backend/apps/webui/models/memories.py

@@ -2,10 +2,8 @@ from pydantic import BaseModel, ConfigDict
 from typing import List, Union, Optional
 
 from sqlalchemy import Column, String, BigInteger
-from sqlalchemy.orm import Session
 
-from apps.webui.internal.db import Base, get_session
-from apps.webui.models.chats import Chats
+from apps.webui.internal.db import Base, Session
 
 import time
 import uuid
@@ -58,15 +56,14 @@ class MemoriesTable:
                 "updated_at": int(time.time()),
             }
         )
-        with get_session() as db:
-            result = Memory(**memory.model_dump())
-            db.add(result)
-            db.commit()
-            db.refresh(result)
-            if result:
-                return MemoryModel.model_validate(result)
-            else:
-                return None
+        result = Memory(**memory.model_dump())
+        Session.add(result)
+        Session.commit()
+        Session.refresh(result)
+        if result:
+            return MemoryModel.model_validate(result)
+        else:
+            return None
 
     def update_memory_by_id(
         self,
@@ -74,62 +71,55 @@ class MemoriesTable:
         content: str,
     ) -> Optional[MemoryModel]:
         try:
-            with get_session() as db:
-                db.query(Memory).filter_by(id=id).update(
-                    {"content": content, "updated_at": int(time.time())}
-                )
-                db.commit()
-                return self.get_memory_by_id(id)
+            Session.query(Memory).filter_by(id=id).update(
+                {"content": content, "updated_at": int(time.time())}
+            )
+            Session.commit()
+            return self.get_memory_by_id(id)
         except:
             return None
 
     def get_memories(self) -> List[MemoryModel]:
         try:
-            with get_session() as db:
-                memories = db.query(Memory).all()
-                return [MemoryModel.model_validate(memory) for memory in memories]
+            memories = Session.query(Memory).all()
+            return [MemoryModel.model_validate(memory) for memory in memories]
         except:
             return None
 
     def get_memories_by_user_id(self, user_id: str) -> List[MemoryModel]:
         try:
-            with get_session() as db:
-                memories = db.query(Memory).filter_by(user_id=user_id).all()
-                return [MemoryModel.model_validate(memory) for memory in memories]
+            memories = Session.query(Memory).filter_by(user_id=user_id).all()
+            return [MemoryModel.model_validate(memory) for memory in memories]
         except:
             return None
 
     def get_memory_by_id(self, id: str) -> Optional[MemoryModel]:
         try:
-            with get_session() as db:
-                memory = db.get(Memory, id)
-                return MemoryModel.model_validate(memory)
+            memory = Session.get(Memory, id)
+            return MemoryModel.model_validate(memory)
         except:
             return None
 
     def delete_memory_by_id(self, id: str) -> bool:
         try:
-            with get_session() as db:
-                db.query(Memory).filter_by(id=id).delete()
+            Session.query(Memory).filter_by(id=id).delete()
             return True
 
         except:
             return False
 
-    def delete_memories_by_user_id(self, db: Session, user_id: str) -> bool:
+    def delete_memories_by_user_id(self, user_id: str) -> bool:
         try:
-            with get_session() as db:
-                db.query(Memory).filter_by(user_id=user_id).delete()
+            Session.query(Memory).filter_by(user_id=user_id).delete()
             return True
         except:
             return False
 
     def delete_memory_by_id_and_user_id(
-        self, db: Session, id: str, user_id: str
+        self, id: str, user_id: str
     ) -> bool:
         try:
-            with get_session() as db:
-                db.query(Memory).filter_by(id=id, user_id=user_id).delete()
+            Session.query(Memory).filter_by(id=id, user_id=user_id).delete()
             return True
         except:
             return False

+ 19 - 25
backend/apps/webui/models/models.py

@@ -4,9 +4,8 @@ from typing import Optional
 
 from pydantic import BaseModel, ConfigDict
 from sqlalchemy import String, Column, BigInteger
-from sqlalchemy.orm import Session
 
-from apps.webui.internal.db import Base, JSONField, get_session
+from apps.webui.internal.db import Base, JSONField, Session
 
 from typing import List, Union, Optional
 from config import SRC_LOG_LEVELS
@@ -127,41 +126,37 @@ class ModelsTable:
             }
         )
         try:
-            with get_session() as db:
-                result = Model(**model.model_dump())
-                db.add(result)
-                db.commit()
-                db.refresh(result)
-
-                if result:
-                    return ModelModel.model_validate(result)
-                else:
-                    return None
+            result = Model(**model.model_dump())
+            Session.add(result)
+            Session.commit()
+            Session.refresh(result)
+
+            if result:
+                return ModelModel.model_validate(result)
+            else:
+                return None
         except Exception as e:
             print(e)
             return None
 
     def get_all_models(self) -> List[ModelModel]:
-        with get_session() as db:
-            return [ModelModel.model_validate(model) for model in db.query(Model).all()]
+        return [ModelModel.model_validate(model) for model in Session.query(Model).all()]
 
     def get_model_by_id(self, id: str) -> Optional[ModelModel]:
         try:
-            with get_session() as db:
-                model = db.get(Model, id)
-                return ModelModel.model_validate(model)
+            model = Session.get(Model, id)
+            return ModelModel.model_validate(model)
         except:
             return None
 
     def update_model_by_id(self, id: str, model: ModelForm) -> Optional[ModelModel]:
         try:
             # update only the fields that are present in the model
-            with get_session() as db:
-                model = db.query(Model).get(id)
-                model.update(**model.model_dump())
-                db.commit()
-                db.refresh(model)
-                return ModelModel.model_validate(model)
+            model = Session.query(Model).get(id)
+            model.update(**model.model_dump())
+            Session.commit()
+            Session.refresh(model)
+            return ModelModel.model_validate(model)
         except Exception as e:
             print(e)
 
@@ -169,8 +164,7 @@ class ModelsTable:
 
     def delete_model_by_id(self, id: str) -> bool:
         try:
-            with get_session() as db:
-                db.query(Model).filter_by(id=id).delete()
+            Session.query(Model).filter_by(id=id).delete()
             return True
         except:
             return False

+ 43 - 50
backend/apps/webui/models/prompts.py

@@ -3,9 +3,8 @@ from typing import List, Optional
 import time
 
 from sqlalchemy import String, Column, BigInteger
-from sqlalchemy.orm import Session
 
-from apps.webui.internal.db import Base, get_session
+from apps.webui.internal.db import Base, Session
 
 import json
 
@@ -50,65 +49,59 @@ class PromptsTable:
     def insert_new_prompt(
         self, user_id: str, form_data: PromptForm
     ) -> Optional[PromptModel]:
-        with get_session() as db:
-            prompt = PromptModel(
-                **{
-                    "user_id": user_id,
-                    "command": form_data.command,
-                    "title": form_data.title,
-                    "content": form_data.content,
-                    "timestamp": int(time.time()),
-                }
-            )
-
-            try:
-                result = Prompt(**prompt.dict())
-                db.add(result)
-                db.commit()
-                db.refresh(result)
-                if result:
-                    return PromptModel.model_validate(result)
-                else:
-                    return None
-            except Exception as e:
+        prompt = PromptModel(
+            **{
+                "user_id": user_id,
+                "command": form_data.command,
+                "title": form_data.title,
+                "content": form_data.content,
+                "timestamp": int(time.time()),
+            }
+        )
+
+        try:
+            result = Prompt(**prompt.dict())
+            Session.add(result)
+            Session.commit()
+            Session.refresh(result)
+            if result:
+                return PromptModel.model_validate(result)
+            else:
                 return None
+        except Exception as e:
+            return None
 
     def get_prompt_by_command(self, command: str) -> Optional[PromptModel]:
-        with get_session() as db:
-            try:
-                prompt = db.query(Prompt).filter_by(command=command).first()
-                return PromptModel.model_validate(prompt)
-            except:
-                return None
+        try:
+            prompt = Session.query(Prompt).filter_by(command=command).first()
+            return PromptModel.model_validate(prompt)
+        except:
+            return None
 
     def get_prompts(self) -> List[PromptModel]:
-        with get_session() as db:
-            return [
-                PromptModel.model_validate(prompt) for prompt in db.query(Prompt).all()
-            ]
+        return [
+            PromptModel.model_validate(prompt) for prompt in Session.query(Prompt).all()
+        ]
 
     def update_prompt_by_command(
         self, command: str, form_data: PromptForm
     ) -> Optional[PromptModel]:
-        with get_session() as db:
-            try:
-                prompt = db.query(Prompt).filter_by(command=command).first()
-                prompt.title = form_data.title
-                prompt.content = form_data.content
-                prompt.timestamp = int(time.time())
-                db.commit()
-                return prompt
-                # return self.get_prompt_by_command(command)
-            except:
-                return None
+        try:
+            prompt = Session.query(Prompt).filter_by(command=command).first()
+            prompt.title = form_data.title
+            prompt.content = form_data.content
+            prompt.timestamp = int(time.time())
+            Session.commit()
+            return PromptModel.model_validate(prompt)
+        except:
+            return None
 
     def delete_prompt_by_command(self, command: str) -> bool:
-        with get_session() as db:
-            try:
-                db.query(Prompt).filter_by(command=command).delete()
-                return True
-            except:
-                return False
+        try:
+            Session.query(Prompt).filter_by(command=command).delete()
+            return True
+        except:
+            return False
 
 
 Prompts = PromptsTable()

+ 99 - 109
backend/apps/webui/models/tags.py

@@ -7,9 +7,8 @@ import time
 import logging
 
 from sqlalchemy import String, Column, BigInteger
-from sqlalchemy.orm import Session
 
-from apps.webui.internal.db import Base, get_session
+from apps.webui.internal.db import Base, Session
 
 from config import SRC_LOG_LEVELS
 
@@ -83,15 +82,14 @@ class TagTable:
         id = str(uuid.uuid4())
         tag = TagModel(**{"id": id, "user_id": user_id, "name": name})
         try:
-            with get_session() as db:
-                result = Tag(**tag.model_dump())
-                db.add(result)
-                db.commit()
-                db.refresh(result)
-                if result:
-                    return TagModel.model_validate(result)
-                else:
-                    return None
+            result = Tag(**tag.model_dump())
+            Session.add(result)
+            Session.commit()
+            Session.refresh(result)
+            if result:
+                return TagModel.model_validate(result)
+            else:
+                return None
         except Exception as e:
             return None
 
@@ -99,9 +97,8 @@ class TagTable:
         self, name: str, user_id: str
     ) -> Optional[TagModel]:
         try:
-            with get_session() as db:
-                tag = db.query(Tag).filter(name=name, user_id=user_id).first()
-                return TagModel.model_validate(tag)
+            tag = Session.query(Tag).filter(name=name, user_id=user_id).first()
+            return TagModel.model_validate(tag)
         except Exception as e:
             return None
 
@@ -123,105 +120,99 @@ class TagTable:
             }
         )
         try:
-            with get_session() as db:
-                result = ChatIdTag(**chatIdTag.model_dump())
-                db.add(result)
-                db.commit()
-                db.refresh(result)
-                if result:
-                    return ChatIdTagModel.model_validate(result)
-                else:
-                    return None
+            result = ChatIdTag(**chatIdTag.model_dump())
+            Session.add(result)
+            Session.commit()
+            Session.refresh(result)
+            if result:
+                return ChatIdTagModel.model_validate(result)
+            else:
+                return None
         except:
             return None
 
     def get_tags_by_user_id(self, user_id: str) -> List[TagModel]:
-        with get_session() as db:
-            tag_names = [
-                chat_id_tag.tag_name
-                for chat_id_tag in (
-                    db.query(ChatIdTag)
-                    .filter_by(user_id=user_id)
-                    .order_by(ChatIdTag.timestamp.desc())
-                    .all()
-                )
-            ]
-
-            return [
-                TagModel.model_validate(tag)
-                for tag in (
-                    db.query(Tag)
-                    .filter_by(user_id=user_id)
-                    .filter(Tag.name.in_(tag_names))
-                    .all()
-                )
-            ]
+        tag_names = [
+            chat_id_tag.tag_name
+            for chat_id_tag in (
+                Session.query(ChatIdTag)
+                .filter_by(user_id=user_id)
+                .order_by(ChatIdTag.timestamp.desc())
+                .all()
+            )
+        ]
+
+        return [
+            TagModel.model_validate(tag)
+            for tag in (
+                Session.query(Tag)
+                .filter_by(user_id=user_id)
+                .filter(Tag.name.in_(tag_names))
+                .all()
+            )
+        ]
 
     def get_tags_by_chat_id_and_user_id(
         self, chat_id: str, user_id: str
     ) -> List[TagModel]:
-        with get_session() as db:
-            tag_names = [
-                chat_id_tag.tag_name
-                for chat_id_tag in (
-                    db.query(ChatIdTag)
-                    .filter_by(user_id=user_id, chat_id=chat_id)
-                    .order_by(ChatIdTag.timestamp.desc())
-                    .all()
-                )
-            ]
-
-            return [
-                TagModel.model_validate(tag)
-                for tag in (
-                    db.query(Tag)
-                    .filter_by(user_id=user_id)
-                    .filter(Tag.name.in_(tag_names))
-                    .all()
-                )
-            ]
+        tag_names = [
+            chat_id_tag.tag_name
+            for chat_id_tag in (
+                Session.query(ChatIdTag)
+                .filter_by(user_id=user_id, chat_id=chat_id)
+                .order_by(ChatIdTag.timestamp.desc())
+                .all()
+            )
+        ]
+
+        return [
+            TagModel.model_validate(tag)
+            for tag in (
+                Session.query(Tag)
+                .filter_by(user_id=user_id)
+                .filter(Tag.name.in_(tag_names))
+                .all()
+            )
+        ]
 
     def get_chat_ids_by_tag_name_and_user_id(
         self, tag_name: str, user_id: str
     ) -> List[ChatIdTagModel]:
-        with get_session() as db:
-            return [
-                ChatIdTagModel.model_validate(chat_id_tag)
-                for chat_id_tag in (
-                    db.query(ChatIdTag)
-                    .filter_by(user_id=user_id, tag_name=tag_name)
-                    .order_by(ChatIdTag.timestamp.desc())
-                    .all()
-                )
-            ]
+        return [
+            ChatIdTagModel.model_validate(chat_id_tag)
+            for chat_id_tag in (
+                Session.query(ChatIdTag)
+                .filter_by(user_id=user_id, tag_name=tag_name)
+                .order_by(ChatIdTag.timestamp.desc())
+                .all()
+            )
+        ]
 
     def count_chat_ids_by_tag_name_and_user_id(
         self, tag_name: str, user_id: str
     ) -> int:
-        with get_session() as db:
-            return (
-                db.query(ChatIdTag)
-                .filter_by(tag_name=tag_name, user_id=user_id)
-                .count()
-            )
+        return (
+            Session.query(ChatIdTag)
+            .filter_by(tag_name=tag_name, user_id=user_id)
+            .count()
+        )
 
     def delete_tag_by_tag_name_and_user_id(self, tag_name: str, user_id: str) -> bool:
         try:
-            with get_session() as db:
-                res = (
-                    db.query(ChatIdTag)
-                    .filter_by(tag_name=tag_name, user_id=user_id)
-                    .delete()
-                )
-                log.debug(f"res: {res}")
-                db.commit()
-
-                tag_count = self.count_chat_ids_by_tag_name_and_user_id(
-                    tag_name, user_id
-                )
-                if tag_count == 0:
-                    # Remove tag item from Tag col as well
-                    db.query(Tag).filter_by(name=tag_name, user_id=user_id).delete()
+            res = (
+                Session.query(ChatIdTag)
+                .filter_by(tag_name=tag_name, user_id=user_id)
+                .delete()
+            )
+            log.debug(f"res: {res}")
+            Session.commit()
+
+            tag_count = self.count_chat_ids_by_tag_name_and_user_id(
+                tag_name, user_id
+            )
+            if tag_count == 0:
+                # Remove tag item from Tag col as well
+                Session.query(Tag).filter_by(name=tag_name, user_id=user_id).delete()
             return True
         except Exception as e:
             log.error(f"delete_tag: {e}")
@@ -231,21 +222,20 @@ class TagTable:
         self, tag_name: str, chat_id: str, user_id: str
     ) -> bool:
         try:
-            with get_session() as db:
-                res = (
-                    db.query(ChatIdTag)
-                    .filter_by(tag_name=tag_name, chat_id=chat_id, user_id=user_id)
-                    .delete()
-                )
-                log.debug(f"res: {res}")
-                db.commit()
-
-                tag_count = self.count_chat_ids_by_tag_name_and_user_id(
-                    tag_name, user_id
-                )
-                if tag_count == 0:
-                    # Remove tag item from Tag col as well
-                    db.query(Tag).filter_by(name=tag_name, user_id=user_id).delete()
+            res = (
+                Session.query(ChatIdTag)
+                .filter_by(tag_name=tag_name, chat_id=chat_id, user_id=user_id)
+                .delete()
+            )
+            log.debug(f"res: {res}")
+            Session.commit()
+
+            tag_count = self.count_chat_ids_by_tag_name_and_user_id(
+                tag_name, user_id
+            )
+            if tag_count == 0:
+                # Remove tag item from Tag col as well
+                Session.query(Tag).filter_by(name=tag_name, user_id=user_id).delete()
 
             return True
         except Exception as e:

+ 26 - 33
backend/apps/webui/models/tools.py

@@ -3,9 +3,8 @@ from typing import List, Optional
 import time
 import logging
 from sqlalchemy import String, Column, BigInteger
-from sqlalchemy.orm import Session
 
-from apps.webui.internal.db import Base, JSONField, get_session
+from apps.webui.internal.db import Base, JSONField, Session
 from apps.webui.models.users import Users
 
 import json
@@ -95,48 +94,43 @@ class ToolsTable:
         )
 
         try:
-            with get_session() as db:
-                result = Tool(**tool.model_dump())
-                db.add(result)
-                db.commit()
-                db.refresh(result)
-                if result:
-                    return ToolModel.model_validate(result)
-                else:
-                    return None
+            result = Tool(**tool.model_dump())
+            Session.add(result)
+            Session.commit()
+            Session.refresh(result)
+            if result:
+                return ToolModel.model_validate(result)
+            else:
+                return None
         except Exception as e:
             print(f"Error creating tool: {e}")
             return None
 
     def get_tool_by_id(self, id: str) -> Optional[ToolModel]:
         try:
-            with get_session() as db:
-                tool = db.get(Tool, id)
-                return ToolModel.model_validate(tool)
+            tool = Session.get(Tool, id)
+            return ToolModel.model_validate(tool)
         except:
             return None
 
     def get_tools(self) -> List[ToolModel]:
-        with get_session() as db:
-            return [ToolModel.model_validate(tool) for tool in db.query(Tool).all()]
+        return [ToolModel.model_validate(tool) for tool in Session.query(Tool).all()]
 
     def get_tool_valves_by_id(self, id: str) -> Optional[dict]:
         try:
-            with get_session() as db:
-                tool = db.get(Tool, id)
-                return tool.valves if tool.valves else {}
+            tool = Session.get(Tool, id)
+            return tool.valves if tool.valves else {}
         except Exception as e:
             print(f"An error occurred: {e}")
             return None
 
     def update_tool_valves_by_id(self, id: str, valves: dict) -> Optional[ToolValves]:
         try:
-            with get_session() as db:
-                db.query(Tool).filter_by(id=id).update(
-                    {"valves": valves, "updated_at": int(time.time())}
-                )
-                db.commit()
-                return self.get_tool_by_id(id)
+            Session.query(Tool).filter_by(id=id).update(
+                {"valves": valves, "updated_at": int(time.time())}
+            )
+            Session.commit()
+            return self.get_tool_by_id(id)
         except:
             return None
 
@@ -183,19 +177,18 @@ class ToolsTable:
 
     def update_tool_by_id(self, id: str, updated: dict) -> Optional[ToolModel]:
         try:
-            with get_session() as db:
-                db.query(Tool).filter_by(id=id).update(
-                    {**updated, "updated_at": int(time.time())}
-                )
-                db.commit()
-                return self.get_tool_by_id(id)
+            tool = Session.get(Tool, id)
+            tool.update(**updated)
+            tool.updated_at = int(time.time())
+            Session.commit()
+            Session.refresh(tool)
+            return ToolModel.model_validate(tool)
         except:
             return None
 
     def delete_tool_by_id(self, id: str) -> bool:
         try:
-            with get_session() as db:
-                db.query(Tool).filter_by(id=id).delete()
+            Session.query(Tool).filter_by(id=id).delete()
             return True
         except:
             return False

+ 117 - 134
backend/apps/webui/models/users.py

@@ -3,11 +3,10 @@ from typing import List, Union, Optional
 import time
 
 from sqlalchemy import String, Column, BigInteger, Text
-from sqlalchemy.orm import Session
 
 from utils.misc import get_gravatar_url
 
-from apps.webui.internal.db import Base, JSONField, get_session
+from apps.webui.internal.db import Base, JSONField, Session
 from apps.webui.models.chats import Chats
 
 ####################
@@ -89,177 +88,161 @@ class UsersTable:
         role: str = "pending",
         oauth_sub: Optional[str] = None,
     ) -> Optional[UserModel]:
-        with get_session() as db:
-            user = UserModel(
-                **{
-                    "id": id,
-                    "name": name,
-                    "email": email,
-                    "role": role,
-                    "profile_image_url": profile_image_url,
-                    "last_active_at": int(time.time()),
-                    "created_at": int(time.time()),
-                    "updated_at": int(time.time()),
-                    "oauth_sub": oauth_sub,
-                }
-            )
-            result = User(**user.model_dump())
-            db.add(result)
-            db.commit()
-            db.refresh(result)
-            if result:
-                return user
-            else:
-                return None
+        user = UserModel(
+            **{
+                "id": id,
+                "name": name,
+                "email": email,
+                "role": role,
+                "profile_image_url": profile_image_url,
+                "last_active_at": int(time.time()),
+                "created_at": int(time.time()),
+                "updated_at": int(time.time()),
+                "oauth_sub": oauth_sub,
+            }
+        )
+        result = User(**user.model_dump())
+        Session.add(result)
+        Session.commit()
+        Session.refresh(result)
+        if result:
+            return user
+        else:
+            return None
 
     def get_user_by_id(self, id: str) -> Optional[UserModel]:
-        with get_session() as db:
-            try:
-                user = db.query(User).filter_by(id=id).first()
-                return UserModel.model_validate(user)
-            except Exception as e:
-                return None
+        try:
+            user = Session.query(User).filter_by(id=id).first()
+            return UserModel.model_validate(user)
+        except Exception as e:
+            return None
 
     def get_user_by_api_key(self, api_key: str) -> Optional[UserModel]:
-        with get_session() as db:
-            try:
-                user = db.query(User).filter_by(api_key=api_key).first()
-                return UserModel.model_validate(user)
-            except:
-                return None
+        try:
+            user = Session.query(User).filter_by(api_key=api_key).first()
+            return UserModel.model_validate(user)
+        except:
+            return None
 
     def get_user_by_email(self, email: str) -> Optional[UserModel]:
-        with get_session() as db:
-            try:
-                user = db.query(User).filter_by(email=email).first()
-                return UserModel.model_validate(user)
-            except:
-                return None
+        try:
+            user = Session.query(User).filter_by(email=email).first()
+            return UserModel.model_validate(user)
+        except:
+            return None
 
     def get_user_by_oauth_sub(self, sub: str) -> Optional[UserModel]:
-        with get_session() as db:
-            try:
-                user = db.query(User).filter_by(oauth_sub=sub).first()
-                return UserModel.model_validate(user)
-            except:
-                return None
+        try:
+            user = Session.query(User).filter_by(oauth_sub=sub).first()
+            return UserModel.model_validate(user)
+        except:
+            return None
 
     def get_users(self, skip: int = 0, limit: int = 50) -> List[UserModel]:
-        with get_session() as db:
-            users = (
-                db.query(User)
-                # .offset(skip).limit(limit)
-                .all()
-            )
-            return [UserModel.model_validate(user) for user in users]
+        users = (
+            Session.query(User)
+            # .offset(skip).limit(limit)
+            .all()
+        )
+        return [UserModel.model_validate(user) for user in users]
 
     def get_num_users(self) -> Optional[int]:
-        with get_session() as db:
-            return db.query(User).count()
+        return Session.query(User).count()
 
     def get_first_user(self) -> UserModel:
-        with get_session() as db:
-            try:
-                user = db.query(User).order_by(User.created_at).first()
-                return UserModel.model_validate(user)
-            except:
-                return None
+        try:
+            user = Session.query(User).order_by(User.created_at).first()
+            return UserModel.model_validate(user)
+        except:
+            return None
 
     def update_user_role_by_id(self, id: str, role: str) -> Optional[UserModel]:
-        with get_session() as db:
-            try:
-                db.query(User).filter_by(id=id).update({"role": role})
-                db.commit()
+        try:
+            Session.query(User).filter_by(id=id).update({"role": role})
+            Session.commit()
 
-                user = db.query(User).filter_by(id=id).first()
-                return UserModel.model_validate(user)
-            except:
-                return None
+            user = Session.query(User).filter_by(id=id).first()
+            return UserModel.model_validate(user)
+        except:
+            return None
 
     def update_user_profile_image_url_by_id(
         self, id: str, profile_image_url: str
     ) -> Optional[UserModel]:
-        with get_session() as db:
-            try:
-                db.query(User).filter_by(id=id).update(
-                    {"profile_image_url": profile_image_url}
-                )
-                db.commit()
-
-                user = db.query(User).filter_by(id=id).first()
-                return UserModel.model_validate(user)
-            except:
-                return None
+        try:
+            Session.query(User).filter_by(id=id).update(
+                {"profile_image_url": profile_image_url}
+            )
+            Session.commit()
+
+            user = Session.query(User).filter_by(id=id).first()
+            return UserModel.model_validate(user)
+        except:
+            return None
 
     def update_user_last_active_by_id(self, id: str) -> Optional[UserModel]:
-        with get_session() as db:
-            try:
-                db.query(User).filter_by(id=id).update(
-                    {"last_active_at": int(time.time())}
-                )
+        try:
+            Session.query(User).filter_by(id=id).update(
+                {"last_active_at": int(time.time())}
+            )
 
-                user = db.query(User).filter_by(id=id).first()
-                return UserModel.model_validate(user)
-            except:
-                return None
+            user = Session.query(User).filter_by(id=id).first()
+            return UserModel.model_validate(user)
+        except:
+            return None
 
     def update_user_oauth_sub_by_id(
         self, id: str, oauth_sub: str
     ) -> Optional[UserModel]:
-        with get_session() as db:
-            try:
-                db.query(User).filter_by(id=id).update({"oauth_sub": oauth_sub})
+        try:
+            Session.query(User).filter_by(id=id).update({"oauth_sub": oauth_sub})
 
-                user = db.query(User).filter_by(id=id).first()
-                return UserModel.model_validate(user)
-            except:
-                return None
+            user = Session.query(User).filter_by(id=id).first()
+            return UserModel.model_validate(user)
+        except:
+            return None
 
     def update_user_by_id(self, id: str, updated: dict) -> Optional[UserModel]:
-        with get_session() as db:
-            try:
-                db.query(User).filter_by(id=id).update(updated)
-                db.commit()
+        try:
+            Session.query(User).filter_by(id=id).update(updated)
+            Session.commit()
 
-                user = db.query(User).filter_by(id=id).first()
-                return UserModel.model_validate(user)
-                # return UserModel(**user.dict())
-            except Exception as e:
-                return None
+            user = Session.query(User).filter_by(id=id).first()
+            return UserModel.model_validate(user)
+            # return UserModel(**user.dict())
+        except Exception as e:
+            return None
 
     def delete_user_by_id(self, id: str) -> bool:
-        with get_session() as db:
-            try:
-                # Delete User Chats
-                result = Chats.delete_chats_by_user_id(id)
-
-                if result:
-                    # Delete User
-                    db.query(User).filter_by(id=id).delete()
-                    db.commit()
-
-                    return True
-                else:
-                    return False
-            except:
+        try:
+            # Delete User Chats
+            result = Chats.delete_chats_by_user_id(id)
+
+            if result:
+                # Delete User
+                Session.query(User).filter_by(id=id).delete()
+                Session.commit()
+
+                return True
+            else:
                 return False
+        except:
+            return False
 
     def update_user_api_key_by_id(self, id: str, api_key: str) -> str:
-        with get_session() as db:
-            try:
-                result = db.query(User).filter_by(id=id).update({"api_key": api_key})
-                db.commit()
-                return True if result == 1 else False
-            except:
-                return False
+        try:
+            result = Session.query(User).filter_by(id=id).update({"api_key": api_key})
+            Session.commit()
+            return True if result == 1 else False
+        except:
+            return False
 
     def get_user_api_key_by_id(self, id: str) -> Optional[str]:
-        with get_session() as db:
-            try:
-                user = db.query(User).filter_by(id=id).first()
-                return user.api_key
-            except Exception as e:
-                return None
+        try:
+            user = Session.query(User).filter_by(id=id).first()
+            return user.api_key
+        except Exception as e:
+            return None
 
 
 Users = UsersTable()

+ 10 - 4
backend/main.py

@@ -29,7 +29,6 @@ from fastapi import HTTPException
 from fastapi.middleware.wsgi import WSGIMiddleware
 from fastapi.middleware.cors import CORSMiddleware
 from sqlalchemy import text
-from sqlalchemy.orm import Session
 from starlette.exceptions import HTTPException as StarletteHTTPException
 from starlette.middleware.base import BaseHTTPMiddleware
 from starlette.middleware.sessions import SessionMiddleware
@@ -57,7 +56,7 @@ from apps.webui.main import (
     get_pipe_models,
     generate_function_chat_completion,
 )
-from apps.webui.internal.db import get_session, SessionLocal
+from apps.webui.internal.db import Session, SessionLocal
 
 
 from pydantic import BaseModel
@@ -794,6 +793,14 @@ app.add_middleware(
     allow_headers=["*"],
 )
 
+@app.middleware("http")
+async def remove_session_after_request(request: Request, call_next):
+    response = await call_next(request)
+    log.debug("Removing session after request")
+    Session.commit()
+    Session.remove()
+    return response
+
 
 @app.middleware("http")
 async def check_url(request: Request, call_next):
@@ -2034,8 +2041,7 @@ async def healthcheck():
 
 @app.get("/health/db")
 async def healthcheck_with_db():
-    with get_session() as db:
-        result = db.execute(text("SELECT 1;")).all()
+    Session.execute(text("SELECT 1;")).all()
     return {"status": True}
 
 

+ 2 - 0
backend/test/apps/webui/routers/test_chats.py

@@ -90,6 +90,8 @@ class TestChats(AbstractPostgresTest):
 
     def test_get_user_archived_chats(self):
         self.chats.archive_all_chats_by_user_id("2")
+        from apps.webui.internal.db import Session
+        Session.commit()
         with mock_webui_user(id="2"):
             response = self.fast_api_client.get(self.create_url("/all/archived"))
         assert response.status_code == 200

+ 9 - 12
backend/test/util/abstract_integration_test.py

@@ -9,6 +9,7 @@ from pytest_docker.plugin import get_docker_ip
 from fastapi.testclient import TestClient
 from sqlalchemy import text, create_engine
 
+
 log = logging.getLogger(__name__)
 
 
@@ -50,11 +51,6 @@ class AbstractPostgresTest(AbstractIntegrationTest):
     DOCKER_CONTAINER_NAME = "postgres-test-container-will-get-deleted"
     docker_client: DockerClient
 
-    def get_db(self):
-        from apps.webui.internal.db import SessionLocal
-
-        return SessionLocal()
-
     @classmethod
     def _create_db_url(cls, env_vars_postgres: dict) -> str:
         host = get_docker_ip()
@@ -113,21 +109,21 @@ class AbstractPostgresTest(AbstractIntegrationTest):
             pytest.fail(f"Could not setup test environment: {ex}")
 
     def _check_db_connection(self):
+        from apps.webui.internal.db import Session
         retries = 10
         while retries > 0:
             try:
-                self.db_session.execute(text("SELECT 1"))
-                self.db_session.commit()
+                Session.execute(text("SELECT 1"))
+                Session.commit()
                 break
             except Exception as e:
-                self.db_session.rollback()
+                Session.rollback()
                 log.warning(e)
                 time.sleep(3)
                 retries -= 1
 
     def setup_method(self):
         super().setup_method()
-        self.db_session = self.get_db()
         self._check_db_connection()
 
     @classmethod
@@ -136,8 +132,9 @@ class AbstractPostgresTest(AbstractIntegrationTest):
         cls.docker_client.containers.get(cls.DOCKER_CONTAINER_NAME).remove(force=True)
 
     def teardown_method(self):
+        from apps.webui.internal.db import Session
         # rollback everything not yet committed
-        self.db_session.commit()
+        Session.commit()
 
         # truncate all tables
         tables = [
@@ -152,5 +149,5 @@ class AbstractPostgresTest(AbstractIntegrationTest):
             '"user"',
         ]
         for table in tables:
-            self.db_session.execute(text(f"TRUNCATE TABLE {table}"))
-        self.db_session.commit()
+            Session.execute(text(f"TRUNCATE TABLE {table}"))
+        Session.commit()