Timothy J. Baek hai 10 meses
pai
achega
864646094e

+ 9 - 11
backend/apps/webui/models/auths.py

@@ -7,7 +7,7 @@ from sqlalchemy import String, Column, Boolean, Text
 from apps.webui.models.users import UserModel, Users
 from utils.utils import verify_password
 
-from apps.webui.internal.db import Base, Session
+from apps.webui.internal.db import Base, get_db
 
 from config import SRC_LOG_LEVELS
 
@@ -110,14 +110,14 @@ class AuthsTable:
             **{"id": id, "email": email, "password": password, "active": True}
         )
         result = Auth(**auth.model_dump())
-        Session.add(result)
+        db.add(result)
 
         user = Users.insert_new_user(
             id, name, email, profile_image_url, role, oauth_sub
         )
 
-        Session.commit()
-        Session.refresh(result)
+        db.commit()
+        db.refresh(result)
 
         if result and user:
             return user
@@ -127,7 +127,7 @@ class AuthsTable:
     def authenticate_user(self, email: str, password: str) -> Optional[UserModel]:
         log.info(f"authenticate_user: {email}")
         try:
-            auth = Session.query(Auth).filter_by(email=email, active=True).first()
+            auth = db.query(Auth).filter_by(email=email, active=True).first()
             if auth:
                 if verify_password(password, auth.password):
                     user = Users.get_user_by_id(auth.id)
@@ -154,7 +154,7 @@ class AuthsTable:
     def authenticate_user_by_trusted_header(self, email: str) -> Optional[UserModel]:
         log.info(f"authenticate_user_by_trusted_header: {email}")
         try:
-            auth = Session.query(Auth).filter(email=email, active=True).first()
+            auth = db.query(Auth).filter(email=email, active=True).first()
             if auth:
                 user = Users.get_user_by_id(auth.id)
                 return user
@@ -163,16 +163,14 @@ class AuthsTable:
 
     def update_user_password_by_id(self, id: str, new_password: str) -> bool:
         try:
-            result = (
-                Session.query(Auth).filter_by(id=id).update({"password": new_password})
-            )
+            result = db.query(Auth).filter_by(id=id).update({"password": new_password})
             return True if result == 1 else False
         except:
             return False
 
     def update_email_by_id(self, id: str, email: str) -> bool:
         try:
-            result = Session.query(Auth).filter_by(id=id).update({"email": email})
+            result = db.query(Auth).filter_by(id=id).update({"email": email})
             return True if result == 1 else False
         except:
             return False
@@ -183,7 +181,7 @@ class AuthsTable:
             result = Users.delete_user_by_id(id)
 
             if result:
-                Session.query(Auth).filter_by(id=id).delete()
+                db.query(Auth).filter_by(id=id).delete()
 
                 return True
             else:

+ 188 - 142
backend/apps/webui/models/chats.py

@@ -7,7 +7,7 @@ import time
 
 from sqlalchemy import Column, String, BigInteger, Boolean, Text
 
-from apps.webui.internal.db import Base, Session
+from apps.webui.internal.db import Base, get_db
 
 
 ####################
@@ -79,87 +79,99 @@ class ChatTitleIdResponse(BaseModel):
 class ChatTable:
 
     def insert_new_chat(self, user_id: str, form_data: ChatForm) -> Optional[ChatModel]:
-        id = str(uuid.uuid4())
-        chat = ChatModel(
-            **{
-                "id": id,
-                "user_id": user_id,
-                "title": (
-                    form_data.chat["title"] if "title" in form_data.chat else "New Chat"
-                ),
-                "chat": json.dumps(form_data.chat),
-                "created_at": int(time.time()),
-                "updated_at": int(time.time()),
-            }
-        )
-
-        result = Chat(**chat.model_dump())
-        Session.add(result)
-        Session.commit()
-        Session.refresh(result)
-        return ChatModel.model_validate(result) if result else None
+        with get_db() as db:
+
+            id = str(uuid.uuid4())
+            chat = ChatModel(
+                **{
+                    "id": id,
+                    "user_id": user_id,
+                    "title": (
+                        form_data.chat["title"]
+                        if "title" in form_data.chat
+                        else "New Chat"
+                    ),
+                    "chat": json.dumps(form_data.chat),
+                    "created_at": int(time.time()),
+                    "updated_at": int(time.time()),
+                }
+            )
+
+            result = Chat(**chat.model_dump())
+            db.add(result)
+            db.commit()
+            db.refresh(result)
+            return ChatModel.model_validate(result) if result else None
 
     def update_chat_by_id(self, id: str, chat: dict) -> Optional[ChatModel]:
         try:
-            chat_obj = Session.get(Chat, id)
-            chat_obj.chat = json.dumps(chat)
-            chat_obj.title = chat["title"] if "title" in chat else "New Chat"
-            chat_obj.updated_at = int(time.time())
-            Session.commit()
-            Session.refresh(chat_obj)
-
-            return ChatModel.model_validate(chat_obj)
+            with get_db() as db:
+
+                chat_obj = db.get(Chat, id)
+                chat_obj.chat = json.dumps(chat)
+                chat_obj.title = chat["title"] if "title" in chat else "New Chat"
+                chat_obj.updated_at = int(time.time())
+                db.commit()
+                db.refresh(chat_obj)
+
+                return ChatModel.model_validate(chat_obj)
         except Exception as e:
             return None
 
     def insert_shared_chat_by_chat_id(self, chat_id: str) -> Optional[ChatModel]:
-        # Get the existing chat to share
-        chat = Session.get(Chat, chat_id)
-        # Check if the chat is already shared
-        if chat.share_id:
-            return self.get_chat_by_id_and_user_id(chat.share_id, "shared")
-        # Create a new chat with the same data, but with a new ID
-        shared_chat = ChatModel(
-            **{
-                "id": str(uuid.uuid4()),
-                "user_id": f"shared-{chat_id}",
-                "title": chat.title,
-                "chat": chat.chat,
-                "created_at": chat.created_at,
-                "updated_at": int(time.time()),
-            }
-        )
-        shared_result = Chat(**shared_chat.model_dump())
-        Session.add(shared_result)
-        Session.commit()
-        Session.refresh(shared_result)
-        # Update the original chat with the share_id
-        result = (
-            Session.query(Chat)
-            .filter_by(id=chat_id)
-            .update({"share_id": shared_chat.id})
-        )
-
-        return shared_chat if (shared_result and result) else None
+        with get_db() as db:
+
+            # Get the existing chat to share
+            chat = db.get(Chat, chat_id)
+            # Check if the chat is already shared
+            if chat.share_id:
+                return self.get_chat_by_id_and_user_id(chat.share_id, "shared")
+            # Create a new chat with the same data, but with a new ID
+            shared_chat = ChatModel(
+                **{
+                    "id": str(uuid.uuid4()),
+                    "user_id": f"shared-{chat_id}",
+                    "title": chat.title,
+                    "chat": chat.chat,
+                    "created_at": chat.created_at,
+                    "updated_at": int(time.time()),
+                }
+            )
+            shared_result = Chat(**shared_chat.model_dump())
+            db.add(shared_result)
+            db.commit()
+            db.refresh(shared_result)
+            # Update the original chat with the share_id
+            result = (
+                db.query(Chat)
+                .filter_by(id=chat_id)
+                .update({"share_id": shared_chat.id})
+            )
+
+            return shared_chat if (shared_result and result) else None
 
     def update_shared_chat_by_chat_id(self, chat_id: str) -> Optional[ChatModel]:
         try:
-            print("update_shared_chat_by_id")
-            chat = Session.get(Chat, chat_id)
-            print(chat)
-            chat.title = chat.title
-            chat.chat = chat.chat
-            Session.commit()
-            Session.refresh(chat)
-
-            return self.get_chat_by_id(chat.share_id)
+            with get_db() as db:
+
+                print("update_shared_chat_by_id")
+                chat = db.get(Chat, chat_id)
+                print(chat)
+                chat.title = chat.title
+                chat.chat = chat.chat
+                db.commit()
+                db.refresh(chat)
+
+                return self.get_chat_by_id(chat.share_id)
         except:
             return None
 
     def delete_shared_chat_by_chat_id(self, chat_id: str) -> bool:
         try:
-            Session.query(Chat).filter_by(user_id=f"shared-{chat_id}").delete()
-            return True
+            with get_db() as db:
+
+                db.query(Chat).filter_by(user_id=f"shared-{chat_id}").delete()
+                return True
         except:
             return False
 
@@ -167,42 +179,50 @@ class ChatTable:
         self, id: str, share_id: Optional[str]
     ) -> Optional[ChatModel]:
         try:
-            chat = Session.get(Chat, id)
-            chat.share_id = share_id
-            Session.commit()
-            Session.refresh(chat)
-            return ChatModel.model_validate(chat)
+            with get_db() as db:
+
+                chat = db.get(Chat, id)
+                chat.share_id = share_id
+                db.commit()
+                db.refresh(chat)
+                return ChatModel.model_validate(chat)
         except:
             return None
 
     def toggle_chat_archive_by_id(self, id: str) -> Optional[ChatModel]:
         try:
-            chat = Session.get(Chat, id)
-            chat.archived = not chat.archived
-            Session.commit()
-            Session.refresh(chat)
-            return ChatModel.model_validate(chat)
+            with get_db() as db:
+
+                chat = db.get(Chat, id)
+                chat.archived = not chat.archived
+                db.commit()
+                db.refresh(chat)
+                return ChatModel.model_validate(chat)
         except:
             return None
 
     def archive_all_chats_by_user_id(self, user_id: str) -> bool:
         try:
-            Session.query(Chat).filter_by(user_id=user_id).update({"archived": True})
-            return True
+            with get_db() as db:
+
+                db.query(Chat).filter_by(user_id=user_id).update({"archived": True})
+                return True
         except:
             return False
 
     def get_archived_chat_list_by_user_id(
         self, user_id: str, skip: int = 0, limit: int = 50
     ) -> List[ChatModel]:
-        all_chats = (
-            Session.query(Chat)
-            .filter_by(user_id=user_id, archived=True)
-            .order_by(Chat.updated_at.desc())
-            # .limit(limit).offset(skip)
-            .all()
-        )
-        return [ChatModel.model_validate(chat) for chat in all_chats]
+        with get_db() as db:
+
+            all_chats = (
+                db.query(Chat)
+                .filter_by(user_id=user_id, archived=True)
+                .order_by(Chat.updated_at.desc())
+                # .limit(limit).offset(skip)
+                .all()
+            )
+            return [ChatModel.model_validate(chat) for chat in all_chats]
 
     def get_chat_list_by_user_id(
         self,
@@ -211,110 +231,136 @@ class ChatTable:
         skip: int = 0,
         limit: int = 50,
     ) -> List[ChatModel]:
-        query = Session.query(Chat).filter_by(user_id=user_id)
-        if not include_archived:
-            query = query.filter_by(archived=False)
-        all_chats = (
-            query.order_by(Chat.updated_at.desc())
-            # .limit(limit).offset(skip)
-            .all()
-        )
-        return [ChatModel.model_validate(chat) for chat in all_chats]
+        with get_db() as db:
+            query = db.query(Chat).filter_by(user_id=user_id)
+            if not include_archived:
+                query = query.filter_by(archived=False)
+            all_chats = (
+                query.order_by(Chat.updated_at.desc())
+                # .limit(limit).offset(skip)
+                .all()
+            )
+            return [ChatModel.model_validate(chat) for chat in all_chats]
 
     def get_chat_list_by_chat_ids(
         self, chat_ids: List[str], skip: int = 0, limit: int = 50
     ) -> List[ChatModel]:
-        all_chats = (
-            Session.query(Chat)
-            .filter(Chat.id.in_(chat_ids))
-            .filter_by(archived=False)
-            .order_by(Chat.updated_at.desc())
-            .all()
-        )
-        return [ChatModel.model_validate(chat) for chat in all_chats]
+
+        with get_db() as db:
+
+            all_chats = (
+                db.query(Chat)
+                .filter(Chat.id.in_(chat_ids))
+                .filter_by(archived=False)
+                .order_by(Chat.updated_at.desc())
+                .all()
+            )
+            return [ChatModel.model_validate(chat) for chat in all_chats]
 
     def get_chat_by_id(self, id: str) -> Optional[ChatModel]:
         try:
-            chat = Session.get(Chat, id)
-            return ChatModel.model_validate(chat)
+            with get_db() as db:
+
+                chat = db.get(Chat, id)
+                return ChatModel.model_validate(chat)
         except:
             return None
 
     def get_chat_by_share_id(self, id: str) -> Optional[ChatModel]:
         try:
-            chat = Session.query(Chat).filter_by(share_id=id).first()
+            with get_db() as db:
 
-            if chat:
-                return self.get_chat_by_id(id)
-            else:
-                return None
+                chat = db.query(Chat).filter_by(share_id=id).first()
+
+                if chat:
+                    return self.get_chat_by_id(id)
+                else:
+                    return None
         except Exception as e:
             return None
 
     def get_chat_by_id_and_user_id(self, id: str, user_id: str) -> Optional[ChatModel]:
         try:
-            chat = Session.query(Chat).filter_by(id=id, user_id=user_id).first()
-            return ChatModel.model_validate(chat)
+            with get_db() as db:
+
+                chat = db.query(Chat).filter_by(id=id, user_id=user_id).first()
+                return ChatModel.model_validate(chat)
         except:
             return None
 
     def get_chats(self, skip: int = 0, limit: int = 50) -> List[ChatModel]:
-        all_chats = (
-            Session.query(Chat)
-            # .limit(limit).offset(skip)
-            .order_by(Chat.updated_at.desc())
-        )
-        return [ChatModel.model_validate(chat) for chat in all_chats]
+        with get_db() as db:
+
+            all_chats = (
+                db.query(Chat)
+                # .limit(limit).offset(skip)
+                .order_by(Chat.updated_at.desc())
+            )
+            return [ChatModel.model_validate(chat) for chat in all_chats]
 
     def get_chats_by_user_id(self, user_id: str) -> List[ChatModel]:
-        all_chats = (
-            Session.query(Chat)
-            .filter_by(user_id=user_id)
-            .order_by(Chat.updated_at.desc())
-        )
-        return [ChatModel.model_validate(chat) for chat in all_chats]
+        with get_db() as db:
+
+            all_chats = (
+                db.query(Chat)
+                .filter_by(user_id=user_id)
+                .order_by(Chat.updated_at.desc())
+            )
+            return [ChatModel.model_validate(chat) for chat in all_chats]
 
     def get_archived_chats_by_user_id(self, user_id: str) -> List[ChatModel]:
-        all_chats = (
-            Session.query(Chat)
-            .filter_by(user_id=user_id, archived=True)
-            .order_by(Chat.updated_at.desc())
-        )
-        return [ChatModel.model_validate(chat) for chat in all_chats]
+        with get_db() as db:
+
+            all_chats = (
+                db.query(Chat)
+                .filter_by(user_id=user_id, archived=True)
+                .order_by(Chat.updated_at.desc())
+            )
+            return [ChatModel.model_validate(chat) for chat in all_chats]
 
     def delete_chat_by_id(self, id: str) -> bool:
         try:
-            Session.query(Chat).filter_by(id=id).delete()
+            with get_db() as db:
+
+                db.query(Chat).filter_by(id=id).delete()
 
-            return True and self.delete_shared_chat_by_chat_id(id)
+                return True and self.delete_shared_chat_by_chat_id(id)
         except:
             return False
 
     def delete_chat_by_id_and_user_id(self, id: str, user_id: str) -> bool:
         try:
-            Session.query(Chat).filter_by(id=id, user_id=user_id).delete()
+            with get_db() as db:
 
-            return True and self.delete_shared_chat_by_chat_id(id)
+                db.query(Chat).filter_by(id=id, user_id=user_id).delete()
+
+                return True and self.delete_shared_chat_by_chat_id(id)
         except:
             return False
 
     def delete_chats_by_user_id(self, user_id: str) -> bool:
         try:
-            self.delete_shared_chats_by_user_id(user_id)
 
-            Session.query(Chat).filter_by(user_id=user_id).delete()
-            return True
+            with get_db() as db:
+
+                self.delete_shared_chats_by_user_id(user_id)
+
+                db.query(Chat).filter_by(user_id=user_id).delete()
+                return True
         except:
             return False
 
     def delete_shared_chats_by_user_id(self, user_id: str) -> bool:
         try:
-            chats_by_user = Session.query(Chat).filter_by(user_id=user_id).all()
-            shared_chat_ids = [f"shared-{chat.id}" for chat in chats_by_user]
 
-            Session.query(Chat).filter(Chat.user_id.in_(shared_chat_ids)).delete()
+            with get_db() as db:
+
+                chats_by_user = db.query(Chat).filter_by(user_id=user_id).all()
+                shared_chat_ids = [f"shared-{chat.id}" for chat in chats_by_user]
+
+                db.query(Chat).filter(Chat.user_id.in_(shared_chat_ids)).delete()
 
-            return True
+                return True
         except:
             return False
 

+ 54 - 42
backend/apps/webui/models/documents.py

@@ -5,7 +5,7 @@ import logging
 
 from sqlalchemy import String, Column, BigInteger, Text
 
-from apps.webui.internal.db import Base, Session
+from apps.webui.internal.db import Base, get_db
 
 import json
 
@@ -74,51 +74,59 @@ class DocumentsTable:
     def insert_new_doc(
         self, user_id: str, form_data: DocumentForm
     ) -> Optional[DocumentModel]:
-        document = DocumentModel(
-            **{
-                **form_data.model_dump(),
-                "user_id": user_id,
-                "timestamp": int(time.time()),
-            }
-        )
+        with get_db() as db:
 
-        try:
-            result = Document(**document.model_dump())
-            Session.add(result)
-            Session.commit()
-            Session.refresh(result)
-            if result:
-                return DocumentModel.model_validate(result)
-            else:
+            document = DocumentModel(
+                **{
+                    **form_data.model_dump(),
+                    "user_id": user_id,
+                    "timestamp": int(time.time()),
+                }
+            )
+
+            try:
+                result = Document(**document.model_dump())
+                db.add(result)
+                db.commit()
+                db.refresh(result)
+                if result:
+                    return DocumentModel.model_validate(result)
+                else:
+                    return None
+            except:
                 return None
-        except:
-            return None
 
     def get_doc_by_name(self, name: str) -> Optional[DocumentModel]:
         try:
-            document = Session.query(Document).filter_by(name=name).first()
-            return DocumentModel.model_validate(document) if document else None
+            with get_db() as db:
+
+                document = db.query(Document).filter_by(name=name).first()
+                return DocumentModel.model_validate(document) if document else None
         except:
             return None
 
     def get_docs(self) -> List[DocumentModel]:
-        return [
-            DocumentModel.model_validate(doc) for doc in Session.query(Document).all()
-        ]
+        with get_db() as db:
+
+            return [
+                DocumentModel.model_validate(doc) for doc in db.query(Document).all()
+            ]
 
     def update_doc_by_name(
         self, name: str, form_data: DocumentUpdateForm
     ) -> Optional[DocumentModel]:
         try:
-            Session.query(Document).filter_by(name=name).update(
-                {
-                    "title": form_data.title,
-                    "name": form_data.name,
-                    "timestamp": int(time.time()),
-                }
-            )
-            Session.commit()
-            return self.get_doc_by_name(form_data.name)
+            with get_db() as db:
+
+                db.query(Document).filter_by(name=name).update(
+                    {
+                        "title": form_data.title,
+                        "name": form_data.name,
+                        "timestamp": int(time.time()),
+                    }
+                )
+                db.commit()
+                return self.get_doc_by_name(form_data.name)
         except Exception as e:
             log.exception(e)
             return None
@@ -131,22 +139,26 @@ class DocumentsTable:
             doc_content = json.loads(doc.content if doc.content else "{}")
             doc_content = {**doc_content, **updated}
 
-            Session.query(Document).filter_by(name=name).update(
-                {
-                    "content": json.dumps(doc_content),
-                    "timestamp": int(time.time()),
-                }
-            )
-            Session.commit()
-            return self.get_doc_by_name(name)
+            with get_db() as db:
+
+                db.query(Document).filter_by(name=name).update(
+                    {
+                        "content": json.dumps(doc_content),
+                        "timestamp": int(time.time()),
+                    }
+                )
+                db.commit()
+                return self.get_doc_by_name(name)
         except Exception as e:
             log.exception(e)
             return None
 
     def delete_doc_by_name(self, name: str) -> bool:
         try:
-            Session.query(Document).filter_by(name=name).delete()
-            return True
+            with get_db() as db:
+
+                db.query(Document).filter_by(name=name).delete()
+                return True
         except:
             return False
 

+ 48 - 36
backend/apps/webui/models/files.py

@@ -5,7 +5,7 @@ import logging
 
 from sqlalchemy import Column, String, BigInteger, Text
 
-from apps.webui.internal.db import JSONField, Base, Session
+from apps.webui.internal.db import JSONField, Base, get_db
 
 import json
 
@@ -61,50 +61,62 @@ class FileForm(BaseModel):
 class FilesTable:
 
     def insert_new_file(self, user_id: str, form_data: FileForm) -> Optional[FileModel]:
-        file = FileModel(
-            **{
-                **form_data.model_dump(),
-                "user_id": user_id,
-                "created_at": int(time.time()),
-            }
-        )
-
-        try:
-            result = File(**file.model_dump())
-            Session.add(result)
-            Session.commit()
-            Session.refresh(result)
-            if result:
-                return FileModel.model_validate(result)
-            else:
+        with get_db() as db:
+
+            file = FileModel(
+                **{
+                    **form_data.model_dump(),
+                    "user_id": user_id,
+                    "created_at": int(time.time()),
+                }
+            )
+
+            try:
+                result = File(**file.model_dump())
+                db.add(result)
+                db.commit()
+                db.refresh(result)
+                if result:
+                    return FileModel.model_validate(result)
+                else:
+                    return None
+            except Exception as e:
+                print(f"Error creating tool: {e}")
                 return None
-        except Exception as e:
-            print(f"Error creating tool: {e}")
-            return None
 
     def get_file_by_id(self, id: str) -> Optional[FileModel]:
-        try:
-            file = Session.get(File, id)
-            return FileModel.model_validate(file)
-        except:
-            return None
+        with get_db() as db:
+
+            try:
+                file = db.get(File, id)
+                return FileModel.model_validate(file)
+            except:
+                return None
 
     def get_files(self) -> List[FileModel]:
-        return [FileModel.model_validate(file) for file in Session.query(File).all()]
+        with get_db() as db:
+
+            return [FileModel.model_validate(file) for file in db.query(File).all()]
 
     def delete_file_by_id(self, id: str) -> bool:
-        try:
-            Session.query(File).filter_by(id=id).delete()
-            return True
-        except:
-            return False
+
+        with get_db() as db:
+
+            try:
+                db.query(File).filter_by(id=id).delete()
+                return True
+            except:
+                return False
 
     def delete_all_files(self) -> bool:
-        try:
-            Session.query(File).delete()
-            return True
-        except:
-            return False
+
+        with get_db() as db:
+
+            try:
+                db.query(File).delete()
+                return True
+            except:
+                return False
 
 
 Files = FilesTable()

+ 101 - 79
backend/apps/webui/models/functions.py

@@ -5,7 +5,7 @@ import logging
 
 from sqlalchemy import Column, String, Text, BigInteger, Boolean
 
-from apps.webui.internal.db import JSONField, Base, Session
+from apps.webui.internal.db import JSONField, Base, get_db
 from apps.webui.models.users import Users
 
 import json
@@ -91,6 +91,7 @@ class FunctionsTable:
     def insert_new_function(
         self, user_id: str, type: str, form_data: FunctionForm
     ) -> Optional[FunctionModel]:
+
         function = FunctionModel(
             **{
                 **form_data.model_dump(),
@@ -102,85 +103,99 @@ class FunctionsTable:
         )
 
         try:
-            result = Function(**function.model_dump())
-            Session.add(result)
-            Session.commit()
-            Session.refresh(result)
-            if result:
-                return FunctionModel.model_validate(result)
-            else:
-                return None
+            with get_db() as db:
+                result = Function(**function.model_dump())
+                db.add(result)
+                db.commit()
+                db.refresh(result)
+                if result:
+                    return FunctionModel.model_validate(result)
+                else:
+                    return None
         except Exception as e:
             print(f"Error creating tool: {e}")
             return None
 
     def get_function_by_id(self, id: str) -> Optional[FunctionModel]:
         try:
-            function = Session.get(Function, id)
-            return FunctionModel.model_validate(function)
+            with get_db() as db:
+
+                function = db.get(Function, id)
+                return FunctionModel.model_validate(function)
         except:
             return None
 
     def get_functions(self, active_only=False) -> List[FunctionModel]:
-        if active_only:
-            return [
-                FunctionModel.model_validate(function)
-                for function in Session.query(Function).filter_by(is_active=True).all()
-            ]
-        else:
-            return [
-                FunctionModel.model_validate(function)
-                for function in Session.query(Function).all()
-            ]
+        with get_db() as db:
+
+            if active_only:
+                return [
+                    FunctionModel.model_validate(function)
+                    for function in db.query(Function).filter_by(is_active=True).all()
+                ]
+            else:
+                return [
+                    FunctionModel.model_validate(function)
+                    for function in db.query(Function).all()
+                ]
 
     def get_functions_by_type(
         self, type: str, active_only=False
     ) -> List[FunctionModel]:
-        if active_only:
+        with get_db() as db:
+
+            if active_only:
+                return [
+                    FunctionModel.model_validate(function)
+                    for function in db.query(Function)
+                    .filter_by(type=type, is_active=True)
+                    .all()
+                ]
+            else:
+                return [
+                    FunctionModel.model_validate(function)
+                    for function in db.query(Function).filter_by(type=type).all()
+                ]
+
+    def get_global_filter_functions(self) -> List[FunctionModel]:
+        with get_db() as db:
+
             return [
                 FunctionModel.model_validate(function)
-                for function in Session.query(Function)
-                .filter_by(type=type, is_active=True)
+                for function in db.query(Function)
+                .filter_by(type="filter", is_active=True, is_global=True)
                 .all()
             ]
-        else:
-            return [
-                FunctionModel.model_validate(function)
-                for function in Session.query(Function).filter_by(type=type).all()
-            ]
-
-    def get_global_filter_functions(self) -> List[FunctionModel]:
-        return [
-            FunctionModel.model_validate(function)
-            for function in Session.query(Function)
-            .filter_by(type="filter", is_active=True, is_global=True)
-            .all()
-        ]
 
     def get_function_valves_by_id(self, id: str) -> Optional[dict]:
-        try:
-            function = Session.get(Function, id)
-            return function.valves if function.valves else {}
-        except Exception as e:
-            print(f"An error occurred: {e}")
-            return None
+        with get_db() as db:
+
+            try:
+                function = db.get(Function, id)
+                return function.valves if function.valves else {}
+            except Exception as e:
+                print(f"An error occurred: {e}")
+                return None
 
     def update_function_valves_by_id(
         self, id: str, valves: dict
     ) -> Optional[FunctionValves]:
-        try:
-            function = Session.get(Function, id)
-            function.valves = valves
-            function.updated_at = int(time.time())
-            Session.commit()
-            Session.refresh(function)
-            return self.get_function_by_id(id)
-        except:
-            return None
+        with get_db() as db:
+
+            try:
+                function = db.get(Function, id)
+                function.valves = valves
+                function.updated_at = int(time.time())
+                db.commit()
+                db.refresh(function)
+                return self.get_function_by_id(id)
+            except:
+                return None
 
     def get_user_valves_by_id_and_user_id(
         self, id: str, user_id: str
     ) -> Optional[dict]:
+
         try:
             user = Users.get_user_by_id(user_id)
             user_settings = user.settings.model_dump()
@@ -199,6 +214,7 @@ class FunctionsTable:
     def update_user_valves_by_id_and_user_id(
         self, id: str, user_id: str, valves: dict
     ) -> Optional[dict]:
+
         try:
             user = Users.get_user_by_id(user_id)
             user_settings = user.settings.model_dump()
@@ -220,37 +236,43 @@ class FunctionsTable:
             return None
 
     def update_function_by_id(self, id: str, updated: dict) -> Optional[FunctionModel]:
-        try:
-            Session.query(Function).filter_by(id=id).update(
-                {
-                    **updated,
-                    "updated_at": int(time.time()),
-                }
-            )
-            Session.commit()
-            return self.get_function_by_id(id)
-        except:
-            return None
+        with get_db() as db:
+
+            try:
+                db.query(Function).filter_by(id=id).update(
+                    {
+                        **updated,
+                        "updated_at": int(time.time()),
+                    }
+                )
+                db.commit()
+                return self.get_function_by_id(id)
+            except:
+                return None
 
     def deactivate_all_functions(self) -> Optional[bool]:
-        try:
-            Session.query(Function).update(
-                {
-                    "is_active": False,
-                    "updated_at": int(time.time()),
-                }
-            )
-            Session.commit()
-            return True
-        except:
-            return None
+        with get_db() as db:
+
+            try:
+                db.query(Function).update(
+                    {
+                        "is_active": False,
+                        "updated_at": int(time.time()),
+                    }
+                )
+                db.commit()
+                return True
+            except:
+                return None
 
     def delete_function_by_id(self, id: str) -> bool:
-        try:
-            Session.query(Function).filter_by(id=id).delete()
-            return True
-        except:
-            return False
+        with get_db() as db:
+
+            try:
+                db.query(Function).filter_by(id=id).delete()
+                return True
+            except:
+                return False
 
 
 Functions = FunctionsTable()

+ 74 - 58
backend/apps/webui/models/memories.py

@@ -3,7 +3,7 @@ from typing import List, Union, Optional
 
 from sqlalchemy import Column, String, BigInteger, Text
 
-from apps.webui.internal.db import Base, Session
+from apps.webui.internal.db import Base, get_db
 
 import time
 import uuid
@@ -45,82 +45,98 @@ class MemoriesTable:
         user_id: str,
         content: str,
     ) -> Optional[MemoryModel]:
-        id = str(uuid.uuid4())
-
-        memory = MemoryModel(
-            **{
-                "id": id,
-                "user_id": user_id,
-                "content": content,
-                "created_at": int(time.time()),
-                "updated_at": int(time.time()),
-            }
-        )
-        result = Memory(**memory.model_dump())
-        Session.add(result)
-        Session.commit()
-        Session.refresh(result)
-        if result:
-            return MemoryModel.model_validate(result)
-        else:
-            return None
+
+        with get_db() as db:
+            id = str(uuid.uuid4())
+
+            memory = MemoryModel(
+                **{
+                    "id": id,
+                    "user_id": user_id,
+                    "content": content,
+                    "created_at": int(time.time()),
+                    "updated_at": int(time.time()),
+                }
+            )
+            result = Memory(**memory.model_dump())
+            db.add(result)
+            db.commit()
+            db.refresh(result)
+            if result:
+                return MemoryModel.model_validate(result)
+            else:
+                return None
 
     def update_memory_by_id(
         self,
         id: str,
         content: str,
     ) -> Optional[MemoryModel]:
-        try:
-            Session.query(Memory).filter_by(id=id).update(
-                {"content": content, "updated_at": int(time.time())}
-            )
-            Session.commit()
-            return self.get_memory_by_id(id)
-        except:
-            return None
+        with get_db() as db:
+
+            try:
+                db.query(Memory).filter_by(id=id).update(
+                    {"content": content, "updated_at": int(time.time())}
+                )
+                db.commit()
+                return self.get_memory_by_id(id)
+            except:
+                return None
 
     def get_memories(self) -> List[MemoryModel]:
-        try:
-            memories = Session.query(Memory).all()
-            return [MemoryModel.model_validate(memory) for memory in memories]
-        except:
-            return None
+        with get_db() as db:
+
+            try:
+                memories = db.query(Memory).all()
+                return [MemoryModel.model_validate(memory) for memory in memories]
+            except:
+                return None
 
     def get_memories_by_user_id(self, user_id: str) -> List[MemoryModel]:
-        try:
-            memories = Session.query(Memory).filter_by(user_id=user_id).all()
-            return [MemoryModel.model_validate(memory) for memory in memories]
-        except:
-            return None
+        with get_db() as db:
+
+            try:
+                memories = db.query(Memory).filter_by(user_id=user_id).all()
+                return [MemoryModel.model_validate(memory) for memory in memories]
+            except:
+                return None
 
     def get_memory_by_id(self, id: str) -> Optional[MemoryModel]:
-        try:
-            memory = Session.get(Memory, id)
-            return MemoryModel.model_validate(memory)
-        except:
-            return None
+        with get_db() as db:
+
+            try:
+                memory = db.get(Memory, id)
+                return MemoryModel.model_validate(memory)
+            except:
+                return None
 
     def delete_memory_by_id(self, id: str) -> bool:
-        try:
-            Session.query(Memory).filter_by(id=id).delete()
-            return True
+        with get_db() as db:
+
+            try:
+                db.query(Memory).filter_by(id=id).delete()
+                return True
 
-        except:
-            return False
+            except:
+                return False
 
     def delete_memories_by_user_id(self, user_id: str) -> bool:
-        try:
-            Session.query(Memory).filter_by(user_id=user_id).delete()
-            return True
-        except:
-            return False
+        with get_db() as db:
+
+            try:
+                db.query(Memory).filter_by(user_id=user_id).delete()
+                return True
+            except:
+                return False
 
     def delete_memory_by_id_and_user_id(self, id: str, user_id: str) -> bool:
-        try:
-            Session.query(Memory).filter_by(id=id, user_id=user_id).delete()
-            return True
-        except:
-            return False
+        with get_db() as db:
+
+            try:
+                db.query(Memory).filter_by(id=id, user_id=user_id).delete()
+                return True
+            except:
+                return False
 
 
 Memories = MemoriesTable()

+ 32 - 23
backend/apps/webui/models/models.py

@@ -5,7 +5,7 @@ from typing import Optional
 from pydantic import BaseModel, ConfigDict
 from sqlalchemy import String, Column, BigInteger, Text
 
-from apps.webui.internal.db import Base, JSONField, Session
+from apps.webui.internal.db import Base, JSONField, get_db
 
 from typing import List, Union, Optional
 from config import SRC_LOG_LEVELS
@@ -126,39 +126,46 @@ class ModelsTable:
             }
         )
         try:
-            result = Model(**model.model_dump())
-            Session.add(result)
-            Session.commit()
-            Session.refresh(result)
-
-            if result:
-                return ModelModel.model_validate(result)
-            else:
-                return None
+
+            with get_db() as db:
+
+                result = Model(**model.model_dump())
+                db.add(result)
+                db.commit()
+                db.refresh(result)
+
+                if result:
+                    return ModelModel.model_validate(result)
+                else:
+                    return None
         except Exception as e:
             print(e)
             return None
 
     def get_all_models(self) -> List[ModelModel]:
-        return [
-            ModelModel.model_validate(model) for model in Session.query(Model).all()
-        ]
+        with get_db() as db:
+
+            return [ModelModel.model_validate(model) for model in db.query(Model).all()]
 
     def get_model_by_id(self, id: str) -> Optional[ModelModel]:
         try:
-            model = Session.get(Model, id)
-            return ModelModel.model_validate(model)
+            with get_db() as db:
+
+                model = db.get(Model, id)
+                return ModelModel.model_validate(model)
         except:
             return None
 
     def update_model_by_id(self, id: str, model: ModelForm) -> Optional[ModelModel]:
         try:
-            # update only the fields that are present in the model
-            model = Session.query(Model).get(id)
-            model.update(**model.model_dump())
-            Session.commit()
-            Session.refresh(model)
-            return ModelModel.model_validate(model)
+            with get_db() as db:
+
+                # update only the fields that are present in the model
+                model = db.query(Model).get(id)
+                model.update(**model.model_dump())
+                db.commit()
+                db.refresh(model)
+                return ModelModel.model_validate(model)
         except Exception as e:
             print(e)
 
@@ -166,8 +173,10 @@ class ModelsTable:
 
     def delete_model_by_id(self, id: str) -> bool:
         try:
-            Session.query(Model).filter_by(id=id).delete()
-            return True
+            with get_db() as db:
+
+                db.query(Model).filter_by(id=id).delete()
+                return True
         except:
             return False
 

+ 32 - 22
backend/apps/webui/models/prompts.py

@@ -4,7 +4,7 @@ import time
 
 from sqlalchemy import String, Column, BigInteger, Text
 
-from apps.webui.internal.db import Base, Session
+from apps.webui.internal.db import Base, get_db
 
 import json
 
@@ -60,46 +60,56 @@ class PromptsTable:
         )
 
         try:
-            result = Prompt(**prompt.dict())
-            Session.add(result)
-            Session.commit()
-            Session.refresh(result)
-            if result:
-                return PromptModel.model_validate(result)
-            else:
-                return None
+            with get_db() as db:
+
+                result = Prompt(**prompt.dict())
+                db.add(result)
+                db.commit()
+                db.refresh(result)
+                if result:
+                    return PromptModel.model_validate(result)
+                else:
+                    return None
         except Exception as e:
             return None
 
     def get_prompt_by_command(self, command: str) -> Optional[PromptModel]:
         try:
-            prompt = Session.query(Prompt).filter_by(command=command).first()
-            return PromptModel.model_validate(prompt)
+            with get_db() as db:
+
+                prompt = db.query(Prompt).filter_by(command=command).first()
+                return PromptModel.model_validate(prompt)
         except:
             return None
 
     def get_prompts(self) -> List[PromptModel]:
-        return [
-            PromptModel.model_validate(prompt) for prompt in Session.query(Prompt).all()
-        ]
+        with get_db() as db:
+
+            return [
+                PromptModel.model_validate(prompt) for prompt in db.query(Prompt).all()
+            ]
 
     def update_prompt_by_command(
         self, command: str, form_data: PromptForm
     ) -> Optional[PromptModel]:
         try:
-            prompt = Session.query(Prompt).filter_by(command=command).first()
-            prompt.title = form_data.title
-            prompt.content = form_data.content
-            prompt.timestamp = int(time.time())
-            Session.commit()
-            return PromptModel.model_validate(prompt)
+            with get_db() as db:
+
+                prompt = db.query(Prompt).filter_by(command=command).first()
+                prompt.title = form_data.title
+                prompt.content = form_data.content
+                prompt.timestamp = int(time.time())
+                db.commit()
+                return PromptModel.model_validate(prompt)
         except:
             return None
 
     def delete_prompt_by_command(self, command: str) -> bool:
         try:
-            Session.query(Prompt).filter_by(command=command).delete()
-            return True
+            with get_db() as db:
+
+                db.query(Prompt).filter_by(command=command).delete()
+                return True
         except:
             return False
 

+ 120 - 102
backend/apps/webui/models/tags.py

@@ -8,7 +8,7 @@ import logging
 
 from sqlalchemy import String, Column, BigInteger, Text
 
-from apps.webui.internal.db import Base, Session
+from apps.webui.internal.db import Base, get_db
 
 from config import SRC_LOG_LEVELS
 
@@ -79,26 +79,29 @@ class ChatTagsResponse(BaseModel):
 class TagTable:
 
     def insert_new_tag(self, name: str, user_id: str) -> Optional[TagModel]:
-        id = str(uuid.uuid4())
-        tag = TagModel(**{"id": id, "user_id": user_id, "name": name})
-        try:
-            result = Tag(**tag.model_dump())
-            Session.add(result)
-            Session.commit()
-            Session.refresh(result)
-            if result:
-                return TagModel.model_validate(result)
-            else:
+        with get_db() as db:
+
+            id = str(uuid.uuid4())
+            tag = TagModel(**{"id": id, "user_id": user_id, "name": name})
+            try:
+                result = Tag(**tag.model_dump())
+                db.add(result)
+                db.commit()
+                db.refresh(result)
+                if result:
+                    return TagModel.model_validate(result)
+                else:
+                    return None
+            except Exception as e:
                 return None
-        except Exception as e:
-            return None
 
     def get_tag_by_name_and_user_id(
         self, name: str, user_id: str
     ) -> Optional[TagModel]:
         try:
-            tag = Session.query(Tag).filter(name=name, user_id=user_id).first()
-            return TagModel.model_validate(tag)
+            with get_db() as db:
+                tag = db.query(Tag).filter(name=name, user_id=user_id).first()
+                return TagModel.model_validate(tag)
         except Exception as e:
             return None
 
@@ -120,98 +123,109 @@ class TagTable:
             }
         )
         try:
-            result = ChatIdTag(**chatIdTag.model_dump())
-            Session.add(result)
-            Session.commit()
-            Session.refresh(result)
-            if result:
-                return ChatIdTagModel.model_validate(result)
-            else:
-                return None
+            with get_db() as db:
+                result = ChatIdTag(**chatIdTag.model_dump())
+                db.add(result)
+                db.commit()
+                db.refresh(result)
+                if result:
+                    return ChatIdTagModel.model_validate(result)
+                else:
+                    return None
         except:
             return None
 
     def get_tags_by_user_id(self, user_id: str) -> List[TagModel]:
-        tag_names = [
-            chat_id_tag.tag_name
-            for chat_id_tag in (
-                Session.query(ChatIdTag)
-                .filter_by(user_id=user_id)
-                .order_by(ChatIdTag.timestamp.desc())
-                .all()
-            )
-        ]
-
-        return [
-            TagModel.model_validate(tag)
-            for tag in (
-                Session.query(Tag)
-                .filter_by(user_id=user_id)
-                .filter(Tag.name.in_(tag_names))
-                .all()
-            )
-        ]
+        with get_db() as db:
+            tag_names = [
+                chat_id_tag.tag_name
+                for chat_id_tag in (
+                    db.query(ChatIdTag)
+                    .filter_by(user_id=user_id)
+                    .order_by(ChatIdTag.timestamp.desc())
+                    .all()
+                )
+            ]
+
+            return [
+                TagModel.model_validate(tag)
+                for tag in (
+                    db.query(Tag)
+                    .filter_by(user_id=user_id)
+                    .filter(Tag.name.in_(tag_names))
+                    .all()
+                )
+            ]
 
     def get_tags_by_chat_id_and_user_id(
         self, chat_id: str, user_id: str
     ) -> List[TagModel]:
-        tag_names = [
-            chat_id_tag.tag_name
-            for chat_id_tag in (
-                Session.query(ChatIdTag)
-                .filter_by(user_id=user_id, chat_id=chat_id)
-                .order_by(ChatIdTag.timestamp.desc())
-                .all()
-            )
-        ]
-
-        return [
-            TagModel.model_validate(tag)
-            for tag in (
-                Session.query(Tag)
-                .filter_by(user_id=user_id)
-                .filter(Tag.name.in_(tag_names))
-                .all()
-            )
-        ]
+        with get_db() as db:
+
+            tag_names = [
+                chat_id_tag.tag_name
+                for chat_id_tag in (
+                    db.query(ChatIdTag)
+                    .filter_by(user_id=user_id, chat_id=chat_id)
+                    .order_by(ChatIdTag.timestamp.desc())
+                    .all()
+                )
+            ]
+
+            return [
+                TagModel.model_validate(tag)
+                for tag in (
+                    db.query(Tag)
+                    .filter_by(user_id=user_id)
+                    .filter(Tag.name.in_(tag_names))
+                    .all()
+                )
+            ]
 
     def get_chat_ids_by_tag_name_and_user_id(
         self, tag_name: str, user_id: str
     ) -> List[ChatIdTagModel]:
-        return [
-            ChatIdTagModel.model_validate(chat_id_tag)
-            for chat_id_tag in (
-                Session.query(ChatIdTag)
-                .filter_by(user_id=user_id, tag_name=tag_name)
-                .order_by(ChatIdTag.timestamp.desc())
-                .all()
-            )
-        ]
+        with get_db() as db:
+
+            return [
+                ChatIdTagModel.model_validate(chat_id_tag)
+                for chat_id_tag in (
+                    db.query(ChatIdTag)
+                    .filter_by(user_id=user_id, tag_name=tag_name)
+                    .order_by(ChatIdTag.timestamp.desc())
+                    .all()
+                )
+            ]
 
     def count_chat_ids_by_tag_name_and_user_id(
         self, tag_name: str, user_id: str
     ) -> int:
-        return (
-            Session.query(ChatIdTag)
-            .filter_by(tag_name=tag_name, user_id=user_id)
-            .count()
-        )
+        with get_db() as db:
 
-    def delete_tag_by_tag_name_and_user_id(self, tag_name: str, user_id: str) -> bool:
-        try:
-            res = (
-                Session.query(ChatIdTag)
+            return (
+                db.query(ChatIdTag)
                 .filter_by(tag_name=tag_name, user_id=user_id)
-                .delete()
+                .count()
             )
-            log.debug(f"res: {res}")
-            Session.commit()
-
-            tag_count = self.count_chat_ids_by_tag_name_and_user_id(tag_name, user_id)
-            if tag_count == 0:
-                # Remove tag item from Tag col as well
-                Session.query(Tag).filter_by(name=tag_name, user_id=user_id).delete()
-            return True
+
+    def delete_tag_by_tag_name_and_user_id(self, tag_name: str, user_id: str) -> bool:
+        try:
+            with get_db() as db:
+                res = (
+                    db.query(ChatIdTag)
+                    .filter_by(tag_name=tag_name, user_id=user_id)
+                    .delete()
+                )
+                log.debug(f"res: {res}")
+                db.commit()
+
+                tag_count = self.count_chat_ids_by_tag_name_and_user_id(
+                    tag_name, user_id
+                )
+                if tag_count == 0:
+                    # Remove tag item from Tag col as well
+                    db.query(Tag).filter_by(name=tag_name, user_id=user_id).delete()
+                return True
         except Exception as e:
             log.error(f"delete_tag: {e}")
             return False
@@ -220,20 +234,24 @@ class TagTable:
         self, tag_name: str, chat_id: str, user_id: str
     ) -> bool:
         try:
-            res = (
-                Session.query(ChatIdTag)
-                .filter_by(tag_name=tag_name, chat_id=chat_id, user_id=user_id)
-                .delete()
-            )
-            log.debug(f"res: {res}")
-            Session.commit()
-
-            tag_count = self.count_chat_ids_by_tag_name_and_user_id(tag_name, user_id)
-            if tag_count == 0:
-                # Remove tag item from Tag col as well
-                Session.query(Tag).filter_by(name=tag_name, user_id=user_id).delete()
-
-            return True
+            with get_db() as db:
+
+                res = (
+                    db.query(ChatIdTag)
+                    .filter_by(tag_name=tag_name, chat_id=chat_id, user_id=user_id)
+                    .delete()
+                )
+                log.debug(f"res: {res}")
+                db.commit()
+
+                tag_count = self.count_chat_ids_by_tag_name_and_user_id(
+                    tag_name, user_id
+                )
+                if tag_count == 0:
+                    # Remove tag item from Tag col as well
+                    db.query(Tag).filter_by(name=tag_name, user_id=user_id).delete()
+
+                return True
         except Exception as e:
             log.error(f"delete_tag: {e}")
             return False

+ 50 - 39
backend/apps/webui/models/tools.py

@@ -4,7 +4,7 @@ import time
 import logging
 from sqlalchemy import String, Column, BigInteger, Text
 
-from apps.webui.internal.db import Base, JSONField, Session
+from apps.webui.internal.db import Base, JSONField, get_db
 from apps.webui.models.users import Users
 
 import json
@@ -83,54 +83,63 @@ class ToolsTable:
     def insert_new_tool(
         self, user_id: str, form_data: ToolForm, specs: List[dict]
     ) -> Optional[ToolModel]:
-        tool = ToolModel(
-            **{
-                **form_data.model_dump(),
-                "specs": specs,
-                "user_id": user_id,
-                "updated_at": int(time.time()),
-                "created_at": int(time.time()),
-            }
-        )
 
-        try:
-            result = Tool(**tool.model_dump())
-            Session.add(result)
-            Session.commit()
-            Session.refresh(result)
-            if result:
-                return ToolModel.model_validate(result)
-            else:
+        with get_db() as db:
+
+            tool = ToolModel(
+                **{
+                    **form_data.model_dump(),
+                    "specs": specs,
+                    "user_id": user_id,
+                    "updated_at": int(time.time()),
+                    "created_at": int(time.time()),
+                }
+            )
+
+            try:
+                result = Tool(**tool.model_dump())
+                db.add(result)
+                db.commit()
+                db.refresh(result)
+                if result:
+                    return ToolModel.model_validate(result)
+                else:
+                    return None
+            except Exception as e:
+                print(f"Error creating tool: {e}")
                 return None
-        except Exception as e:
-            print(f"Error creating tool: {e}")
-            return None
 
     def get_tool_by_id(self, id: str) -> Optional[ToolModel]:
         try:
-            tool = Session.get(Tool, id)
-            return ToolModel.model_validate(tool)
+            with get_db() as db:
+
+                tool = db.get(Tool, id)
+                return ToolModel.model_validate(tool)
         except:
             return None
 
     def get_tools(self) -> List[ToolModel]:
-        return [ToolModel.model_validate(tool) for tool in Session.query(Tool).all()]
+        return [ToolModel.model_validate(tool) for tool in db.query(Tool).all()]
 
     def get_tool_valves_by_id(self, id: str) -> Optional[dict]:
         try:
-            tool = Session.get(Tool, id)
-            return tool.valves if tool.valves else {}
+            with get_db() as db:
+
+                tool = db.get(Tool, id)
+                return tool.valves if tool.valves else {}
         except Exception as e:
             print(f"An error occurred: {e}")
             return None
 
     def update_tool_valves_by_id(self, id: str, valves: dict) -> Optional[ToolValves]:
         try:
-            Session.query(Tool).filter_by(id=id).update(
-                {"valves": valves, "updated_at": int(time.time())}
-            )
-            Session.commit()
-            return self.get_tool_by_id(id)
+            with get_db() as db:
+
+                db.query(Tool).filter_by(id=id).update(
+                    {"valves": valves, "updated_at": int(time.time())}
+                )
+                db.commit()
+                return self.get_tool_by_id(id)
         except:
             return None
 
@@ -177,19 +186,21 @@ class ToolsTable:
 
     def update_tool_by_id(self, id: str, updated: dict) -> Optional[ToolModel]:
         try:
-            tool = Session.get(Tool, id)
-            tool.update(**updated)
-            tool.updated_at = int(time.time())
-            Session.commit()
-            Session.refresh(tool)
-            return ToolModel.model_validate(tool)
+            with get_db() as db:
+                tool = db.get(Tool, id)
+                tool.update(**updated)
+                tool.updated_at = int(time.time())
+                db.commit()
+                db.refresh(tool)
+                return ToolModel.model_validate(tool)
         except:
             return None
 
     def delete_tool_by_id(self, id: str) -> bool:
         try:
-            Session.query(Tool).filter_by(id=id).delete()
-            return True
+            with get_db() as db:
+                db.query(Tool).filter_by(id=id).delete()
+                return True
         except:
             return False
 

+ 92 - 73
backend/apps/webui/models/users.py

@@ -6,7 +6,7 @@ from sqlalchemy import String, Column, BigInteger, Text
 
 from utils.misc import get_gravatar_url
 
-from apps.webui.internal.db import Base, JSONField, Session
+from apps.webui.internal.db import Base, JSONField, Session, get_db
 from apps.webui.models.chats import Chats
 
 ####################
@@ -88,81 +88,92 @@ class UsersTable:
         role: str = "pending",
         oauth_sub: Optional[str] = None,
     ) -> Optional[UserModel]:
-        user = UserModel(
-            **{
-                "id": id,
-                "name": name,
-                "email": email,
-                "role": role,
-                "profile_image_url": profile_image_url,
-                "last_active_at": int(time.time()),
-                "created_at": int(time.time()),
-                "updated_at": int(time.time()),
-                "oauth_sub": oauth_sub,
-            }
-        )
-        result = User(**user.model_dump())
-        Session.add(result)
-        Session.commit()
-        Session.refresh(result)
-        if result:
-            return user
-        else:
-            return None
+        with get_db() as db:
+            user = UserModel(
+                **{
+                    "id": id,
+                    "name": name,
+                    "email": email,
+                    "role": role,
+                    "profile_image_url": profile_image_url,
+                    "last_active_at": int(time.time()),
+                    "created_at": int(time.time()),
+                    "updated_at": int(time.time()),
+                    "oauth_sub": oauth_sub,
+                }
+            )
+            result = User(**user.model_dump())
+            db.add(result)
+            db.commit()
+            db.refresh(result)
+            if result:
+                return user
+            else:
+                return None
 
     def get_user_by_id(self, id: str) -> Optional[UserModel]:
         try:
-            user = Session.query(User).filter_by(id=id).first()
-            return UserModel.model_validate(user)
+            with get_db() as db:
+                user = db.query(User).filter_by(id=id).first()
+                return UserModel.model_validate(user)
         except Exception as e:
             return None
 
     def get_user_by_api_key(self, api_key: str) -> Optional[UserModel]:
         try:
-            user = Session.query(User).filter_by(api_key=api_key).first()
-            return UserModel.model_validate(user)
+            with get_db() as db:
+
+                user = db.query(User).filter_by(api_key=api_key).first()
+                return UserModel.model_validate(user)
         except:
             return None
 
     def get_user_by_email(self, email: str) -> Optional[UserModel]:
         try:
-            user = Session.query(User).filter_by(email=email).first()
-            return UserModel.model_validate(user)
+            with get_db() as db:
+
+                user = db.query(User).filter_by(email=email).first()
+                return UserModel.model_validate(user)
         except:
             return None
 
     def get_user_by_oauth_sub(self, sub: str) -> Optional[UserModel]:
         try:
-            user = Session.query(User).filter_by(oauth_sub=sub).first()
-            return UserModel.model_validate(user)
+            with get_db() as db:
+
+                user = db.query(User).filter_by(oauth_sub=sub).first()
+                return UserModel.model_validate(user)
         except:
             return None
 
     def get_users(self, skip: int = 0, limit: int = 50) -> List[UserModel]:
-        users = (
-            Session.query(User)
-            # .offset(skip).limit(limit)
-            .all()
-        )
-        return [UserModel.model_validate(user) for user in users]
+        with get_db() as db:
+            users = (
+                db.query(User)
+                # .offset(skip).limit(limit)
+                .all()
+            )
+            return [UserModel.model_validate(user) for user in users]
 
     def get_num_users(self) -> Optional[int]:
-        return Session.query(User).count()
+        with get_db() as db:
+            return db.query(User).count()
 
     def get_first_user(self) -> UserModel:
         try:
-            user = Session.query(User).order_by(User.created_at).first()
-            return UserModel.model_validate(user)
+            with get_db() as db:
+                user = db.query(User).order_by(User.created_at).first()
+                return UserModel.model_validate(user)
         except:
             return None
 
     def update_user_role_by_id(self, id: str, role: str) -> Optional[UserModel]:
         try:
-            Session.query(User).filter_by(id=id).update({"role": role})
-            Session.commit()
-
-            user = Session.query(User).filter_by(id=id).first()
-            return UserModel.model_validate(user)
+            with get_db() as db:
+                db.query(User).filter_by(id=id).update({"role": role})
+                db.commit()
+                user = db.query(User).filter_by(id=id).first()
+                return UserModel.model_validate(user)
         except:
             return None
 
@@ -170,25 +181,28 @@ class UsersTable:
         self, id: str, profile_image_url: str
     ) -> Optional[UserModel]:
         try:
-            Session.query(User).filter_by(id=id).update(
-                {"profile_image_url": profile_image_url}
-            )
-            Session.commit()
-
-            user = Session.query(User).filter_by(id=id).first()
-            return UserModel.model_validate(user)
+            with get_db() as db:
+                db.query(User).filter_by(id=id).update(
+                    {"profile_image_url": profile_image_url}
+                )
+                db.commit()
+
+                user = db.query(User).filter_by(id=id).first()
+                return UserModel.model_validate(user)
         except:
             return None
 
     def update_user_last_active_by_id(self, id: str) -> Optional[UserModel]:
         try:
-            Session.query(User).filter_by(id=id).update(
-                {"last_active_at": int(time.time())}
-            )
-            Session.commit()
+            with get_db() as db:
+
+                db.query(User).filter_by(id=id).update(
+                    {"last_active_at": int(time.time())}
+                )
+                db.commit()
 
-            user = Session.query(User).filter_by(id=id).first()
-            return UserModel.model_validate(user)
+                user = db.query(User).filter_by(id=id).first()
+                return UserModel.model_validate(user)
         except:
             return None
 
@@ -196,21 +210,23 @@ class UsersTable:
         self, id: str, oauth_sub: str
     ) -> Optional[UserModel]:
         try:
-            Session.query(User).filter_by(id=id).update({"oauth_sub": oauth_sub})
+            with get_db() as db:
+                db.query(User).filter_by(id=id).update({"oauth_sub": oauth_sub})
 
-            user = Session.query(User).filter_by(id=id).first()
-            return UserModel.model_validate(user)
+                user = db.query(User).filter_by(id=id).first()
+                return UserModel.model_validate(user)
         except:
             return None
 
     def update_user_by_id(self, id: str, updated: dict) -> Optional[UserModel]:
         try:
-            Session.query(User).filter_by(id=id).update(updated)
-            Session.commit()
+            with get_db() as db:
+                db.query(User).filter_by(id=id).update(updated)
+                db.commit()
 
-            user = Session.query(User).filter_by(id=id).first()
-            return UserModel.model_validate(user)
-            # return UserModel(**user.dict())
+                user = db.query(User).filter_by(id=id).first()
+                return UserModel.model_validate(user)
+                # return UserModel(**user.dict())
         except Exception as e:
             return None
 
@@ -220,9 +236,10 @@ class UsersTable:
             result = Chats.delete_chats_by_user_id(id)
 
             if result:
-                # Delete User
-                Session.query(User).filter_by(id=id).delete()
-                Session.commit()
+                with get_db() as db:
+                    # Delete User
+                    db.query(User).filter_by(id=id).delete()
+                    db.commit()
 
                 return True
             else:
@@ -232,16 +249,18 @@ class UsersTable:
 
     def update_user_api_key_by_id(self, id: str, api_key: str) -> str:
         try:
-            result = Session.query(User).filter_by(id=id).update({"api_key": api_key})
-            Session.commit()
-            return True if result == 1 else False
+            with get_db() as db:
+                result = db.query(User).filter_by(id=id).update({"api_key": api_key})
+                db.commit()
+                return True if result == 1 else False
         except:
             return False
 
     def get_user_api_key_by_id(self, id: str) -> Optional[str]:
         try:
-            user = Session.query(User).filter_by(id=id).first()
-            return user.api_key
+            with get_db() as db:
+                user = db.query(User).filter_by(id=id).first()
+                return user.api_key
         except Exception as e:
             return None