浏览代码

feat(sqlalchemy): remove session reference from router

Jonathan Rohde 10 月之前
父节点
当前提交
bee835cb65
共有 34 个文件被更改,包括 1223 次插入1203 次删除
  1. 2 5
      backend/apps/ollama/main.py
  2. 1 3
      backend/apps/openai/main.py
  3. 5 4
      backend/apps/webui/internal/db.py
  4. 2 2
      backend/apps/webui/main.py
  5. 77 71
      backend/apps/webui/models/auths.py
  6. 179 159
      backend/apps/webui/models/chats.py
  7. 46 39
      backend/apps/webui/models/documents.py
  8. 26 19
      backend/apps/webui/models/files.py
  9. 62 55
      backend/apps/webui/models/functions.py
  10. 35 28
      backend/apps/webui/models/memories.py
  11. 31 26
      backend/apps/webui/models/models.py
  12. 53 49
      backend/apps/webui/models/prompts.py
  13. 118 107
      backend/apps/webui/models/tags.py
  14. 37 33
      backend/apps/webui/models/tools.py
  15. 152 137
      backend/apps/webui/models/users.py
  16. 26 34
      backend/apps/webui/routers/auths.py
  17. 56 60
      backend/apps/webui/routers/chats.py
  18. 12 14
      backend/apps/webui/routers/documents.py
  19. 12 15
      backend/apps/webui/routers/files.py
  20. 13 14
      backend/apps/webui/routers/functions.py
  21. 10 13
      backend/apps/webui/routers/memories.py
  22. 10 13
      backend/apps/webui/routers/models.py
  23. 10 12
      backend/apps/webui/routers/prompts.py
  24. 10 13
      backend/apps/webui/routers/tools.py
  25. 23 26
      backend/apps/webui/routers/users.py
  26. 13 18
      backend/main.py
  27. 0 188
      backend/migrations/versions/22b5ab2667b8_init.py
  28. 161 0
      backend/migrations/versions/ba76b0bae648_init.py
  29. 6 13
      backend/test/apps/webui/routers/test_auths.py
  30. 17 21
      backend/test/apps/webui/routers/test_chats.py
  31. 5 5
      backend/test/apps/webui/routers/test_documents.py
  32. 10 0
      backend/test/apps/webui/routers/test_prompts.py
  33. 0 2
      backend/test/apps/webui/routers/test_users.py
  34. 3 5
      backend/utils/utils.py

+ 2 - 5
backend/apps/ollama/main.py

@@ -31,7 +31,6 @@ from typing import Optional, List, Union
 
 
 from starlette.background import BackgroundTask
 from starlette.background import BackgroundTask
 
 
-from apps.webui.internal.db import get_db
 from apps.webui.models.models import Models
 from apps.webui.models.models import Models
 from apps.webui.models.users import Users
 from apps.webui.models.users import Users
 from constants import ERROR_MESSAGES
 from constants import ERROR_MESSAGES
@@ -712,7 +711,6 @@ async def generate_chat_completion(
     form_data: GenerateChatCompletionForm,
     form_data: GenerateChatCompletionForm,
     url_idx: Optional[int] = None,
     url_idx: Optional[int] = None,
     user=Depends(get_verified_user),
     user=Depends(get_verified_user),
-    db=Depends(get_db),
 ):
 ):
 
 
     log.debug(
     log.debug(
@@ -726,7 +724,7 @@ async def generate_chat_completion(
     }
     }
 
 
     model_id = form_data.model
     model_id = form_data.model
-    model_info = Models.get_model_by_id(db, model_id)
+    model_info = Models.get_model_by_id(model_id)
 
 
     if model_info:
     if model_info:
         if model_info.base_model_id:
         if model_info.base_model_id:
@@ -885,7 +883,6 @@ async def generate_openai_chat_completion(
     form_data: dict,
     form_data: dict,
     url_idx: Optional[int] = None,
     url_idx: Optional[int] = None,
     user=Depends(get_verified_user),
     user=Depends(get_verified_user),
-    db=Depends(get_db),
 ):
 ):
     form_data = OpenAIChatCompletionForm(**form_data)
     form_data = OpenAIChatCompletionForm(**form_data)
 
 
@@ -894,7 +891,7 @@ async def generate_openai_chat_completion(
     }
     }
 
 
     model_id = form_data.model
     model_id = form_data.model
-    model_info = Models.get_model_by_id(db, model_id)
+    model_info = Models.get_model_by_id(model_id)
 
 
     if model_info:
     if model_info:
         if model_info.base_model_id:
         if model_info.base_model_id:

+ 1 - 3
backend/apps/openai/main.py

@@ -11,7 +11,6 @@ import logging
 from pydantic import BaseModel
 from pydantic import BaseModel
 from starlette.background import BackgroundTask
 from starlette.background import BackgroundTask
 
 
-from apps.webui.internal.db import get_db
 from apps.webui.models.models import Models
 from apps.webui.models.models import Models
 from apps.webui.models.users import Users
 from apps.webui.models.users import Users
 from constants import ERROR_MESSAGES
 from constants import ERROR_MESSAGES
@@ -354,13 +353,12 @@ async def generate_chat_completion(
     form_data: dict,
     form_data: dict,
     url_idx: Optional[int] = None,
     url_idx: Optional[int] = None,
     user=Depends(get_verified_user),
     user=Depends(get_verified_user),
-    db=Depends(get_db),
 ):
 ):
     idx = 0
     idx = 0
     payload = {**form_data}
     payload = {**form_data}
 
 
     model_id = form_data.get("model")
     model_id = form_data.get("model")
-    model_info = Models.get_model_by_id(db, model_id)
+    model_info = Models.get_model_by_id(model_id)
 
 
     if model_info:
     if model_info:
         if model_info.base_model_id:
         if model_info.base_model_id:

+ 5 - 4
backend/apps/webui/internal/db.py

@@ -1,6 +1,7 @@
 import os
 import os
 import logging
 import logging
 import json
 import json
+from contextlib import contextmanager
 from typing import Optional, Any
 from typing import Optional, Any
 from typing_extensions import Self
 from typing_extensions import Self
 
 
@@ -52,11 +53,12 @@ if "sqlite" in SQLALCHEMY_DATABASE_URL:
     )
     )
 else:
 else:
     engine = create_engine(SQLALCHEMY_DATABASE_URL, pool_pre_ping=True)
     engine = create_engine(SQLALCHEMY_DATABASE_URL, pool_pre_ping=True)
-SessionLocal = sessionmaker(autocommit=False, autoflush=False, bind=engine)
+SessionLocal = sessionmaker(autocommit=False, autoflush=False, bind=engine, expire_on_commit=False)
 Base = declarative_base()
 Base = declarative_base()
 
 
 
 
-def get_db():
+@contextmanager
+def get_session():
     db = SessionLocal()
     db = SessionLocal()
     try:
     try:
         yield db
         yield db
@@ -64,5 +66,4 @@ def get_db():
     except Exception as e:
     except Exception as e:
         db.rollback()
         db.rollback()
         raise e
         raise e
-    finally:
-        db.close()
+

+ 2 - 2
backend/apps/webui/main.py

@@ -114,8 +114,8 @@ async def get_status():
     }
     }
 
 
 
 
-async def get_pipe_models(db: Session):
-    pipes = Functions.get_functions_by_type(db, "pipe", active_only=True)
+async def get_pipe_models():
+    pipes = Functions.get_functions_by_type("pipe", active_only=True)
     pipe_models = []
     pipe_models = []
 
 
     for pipe in pipes:
     for pipe in pipes:

+ 77 - 71
backend/apps/webui/models/auths.py

@@ -8,7 +8,7 @@ from sqlalchemy.orm import Session
 from apps.webui.models.users import UserModel, Users
 from apps.webui.models.users import UserModel, Users
 from utils.utils import verify_password
 from utils.utils import verify_password
 
 
-from apps.webui.internal.db import Base
+from apps.webui.internal.db import Base, get_session
 
 
 from config import SRC_LOG_LEVELS
 from config import SRC_LOG_LEVELS
 
 
@@ -96,7 +96,6 @@ class AuthsTable:
 
 
     def insert_new_auth(
     def insert_new_auth(
         self,
         self,
-        db: Session,
         email: str,
         email: str,
         password: str,
         password: str,
         name: str,
         name: str,
@@ -104,100 +103,107 @@ class AuthsTable:
         role: str = "pending",
         role: str = "pending",
         oauth_sub: Optional[str] = None,
         oauth_sub: Optional[str] = None,
     ) -> Optional[UserModel]:
     ) -> Optional[UserModel]:
-        log.info("insert_new_auth")
+        with get_session() as db:
+            log.info("insert_new_auth")
 
 
-        id = str(uuid.uuid4())
+            id = str(uuid.uuid4())
 
 
-        auth = AuthModel(
-            **{"id": id, "email": email, "password": password, "active": True}
-        )
-        result = Auth(**auth.model_dump())
-        db.add(result)
+            auth = AuthModel(
+                **{"id": id, "email": email, "password": password, "active": True}
+            )
+            result = Auth(**auth.model_dump())
+            db.add(result)
 
 
-        user = Users.insert_new_user(
-            db, id, name, email, profile_image_url, role, oauth_sub
-        )
+            user = Users.insert_new_user(
+                id, name, email, profile_image_url, role, oauth_sub
+            )
 
 
-        db.commit()
-        db.refresh(result)
+            db.commit()
+            db.refresh(result)
 
 
-        if result and user:
-            return user
-        else:
-            return None
+            if result and user:
+                return user
+            else:
+                return None
 
 
     def authenticate_user(
     def authenticate_user(
-        self, db: Session, email: str, password: str
+        self, email: str, password: str
     ) -> Optional[UserModel]:
     ) -> Optional[UserModel]:
         log.info(f"authenticate_user: {email}")
         log.info(f"authenticate_user: {email}")
-        try:
-            auth = db.query(Auth).filter_by(email=email, active=True).first()
-            if auth:
-                if verify_password(password, auth.password):
-                    user = Users.get_user_by_id(db, auth.id)
-                    return user
+        with get_session() as db:
+            try:
+                auth = db.query(Auth).filter_by(email=email, active=True).first()
+                if auth:
+                    if verify_password(password, auth.password):
+                        user = Users.get_user_by_id(auth.id)
+                        return user
+                    else:
+                        return None
                 else:
                 else:
                     return None
                     return None
-            else:
+            except:
                 return None
                 return None
-        except:
-            return None
 
 
     def authenticate_user_by_api_key(
     def authenticate_user_by_api_key(
-        self, db: Session, api_key: str
+        self, api_key: str
     ) -> Optional[UserModel]:
     ) -> Optional[UserModel]:
         log.info(f"authenticate_user_by_api_key: {api_key}")
         log.info(f"authenticate_user_by_api_key: {api_key}")
-        # if no api_key, return None
-        if not api_key:
-            return None
+        with get_session() as db:
+            # if no api_key, return None
+            if not api_key:
+                return None
 
 
-        try:
-            user = Users.get_user_by_api_key(db, api_key)
-            return user if user else None
-        except:
-            return False
+            try:
+                user = Users.get_user_by_api_key(api_key)
+                return user if user else None
+            except:
+                return False
 
 
     def authenticate_user_by_trusted_header(
     def authenticate_user_by_trusted_header(
-        self, db: Session, email: str
+        self, email: str
     ) -> Optional[UserModel]:
     ) -> Optional[UserModel]:
         log.info(f"authenticate_user_by_trusted_header: {email}")
         log.info(f"authenticate_user_by_trusted_header: {email}")
-        try:
-            auth = db.query(Auth).filter(email=email, active=True).first()
-            if auth:
-                user = Users.get_user_by_id(auth.id)
-                return user
-        except:
-            return None
+        with get_session() as db:
+            try:
+                auth = db.query(Auth).filter(email=email, active=True).first()
+                if auth:
+                    user = Users.get_user_by_id(auth.id)
+                    return user
+            except:
+                return None
 
 
     def update_user_password_by_id(
     def update_user_password_by_id(
-        self, db: Session, id: str, new_password: str
+        self, id: str, new_password: str
     ) -> bool:
     ) -> bool:
-        try:
-            result = db.query(Auth).filter_by(id=id).update({"password": new_password})
-            return True if result == 1 else False
-        except:
-            return False
-
-    def update_email_by_id(self, db: Session, id: str, email: str) -> bool:
-        try:
-            result = db.query(Auth).filter_by(id=id).update({"email": email})
-            return True if result == 1 else False
-        except:
-            return False
-
-    def delete_auth_by_id(self, db: Session, id: str) -> bool:
-        try:
-            # Delete User
-            result = Users.delete_user_by_id(db, id)
-
-            if result:
-                db.query(Auth).filter_by(id=id).delete()
-
-                return True
-            else:
+        with get_session() as db:
+            try:
+                result = db.query(Auth).filter_by(id=id).update({"password": new_password})
+                return True if result == 1 else False
+            except:
+                return False
+
+    def update_email_by_id(self, id: str, email: str) -> bool:
+        with get_session() as db:
+            try:
+                result = db.query(Auth).filter_by(id=id).update({"email": email})
+                return True if result == 1 else False
+            except:
+                return False
+
+    def delete_auth_by_id(self, id: str) -> bool:
+        with get_session() as db:
+            try:
+                # Delete User
+                result = Users.delete_user_by_id(id)
+
+                if result:
+                    db.query(Auth).filter_by(id=id).delete()
+
+                    return True
+                else:
+                    return False
+            except:
                 return False
                 return False
-        except:
-            return False
 
 
 
 
 Auths = AuthsTable()
 Auths = AuthsTable()

+ 179 - 159
backend/apps/webui/models/chats.py

@@ -8,7 +8,7 @@ import time
 from sqlalchemy import Column, String, BigInteger, Boolean
 from sqlalchemy import Column, String, BigInteger, Boolean
 from sqlalchemy.orm import Session
 from sqlalchemy.orm import Session
 
 
-from apps.webui.internal.db import Base
+from apps.webui.internal.db import Base, get_session
 
 
 
 
 ####################
 ####################
@@ -80,249 +80,269 @@ class ChatTitleIdResponse(BaseModel):
 class ChatTable:
 class ChatTable:
 
 
     def insert_new_chat(
     def insert_new_chat(
-        self, db: Session, user_id: str, form_data: ChatForm
+        self, user_id: str, form_data: ChatForm
     ) -> Optional[ChatModel]:
     ) -> Optional[ChatModel]:
-        id = str(uuid.uuid4())
-        chat = ChatModel(
-            **{
-                "id": id,
-                "user_id": user_id,
-                "title": (
-                    form_data.chat["title"] if "title" in form_data.chat else "New Chat"
-                ),
-                "chat": json.dumps(form_data.chat),
-                "created_at": int(time.time()),
-                "updated_at": int(time.time()),
-            }
-        )
-
-        result = Chat(**chat.model_dump())
-        db.add(result)
-        db.commit()
-        db.refresh(result)
-        return ChatModel.model_validate(result) if result else None
-
-    def update_chat_by_id(
-        self, db: Session, id: str, chat: dict
-    ) -> Optional[ChatModel]:
-        try:
-            db.query(Chat).filter_by(id=id).update(
-                {
-                    "chat": json.dumps(chat),
-                    "title": chat["title"] if "title" in chat else "New Chat",
+        with get_session() as db:
+            id = str(uuid.uuid4())
+            chat = ChatModel(
+                **{
+                    "id": id,
+                    "user_id": user_id,
+                    "title": (
+                        form_data.chat["title"] if "title" in form_data.chat else "New Chat"
+                    ),
+                    "chat": json.dumps(form_data.chat),
+                    "created_at": int(time.time()),
                     "updated_at": int(time.time()),
                     "updated_at": int(time.time()),
                 }
                 }
             )
             )
 
 
-            return self.get_chat_by_id(db, id)
-        except:
-            return None
+            result = Chat(**chat.model_dump())
+            db.add(result)
+            db.commit()
+            db.refresh(result)
+            return ChatModel.model_validate(result) if result else None
 
 
-    def insert_shared_chat_by_chat_id(
-        self, db: Session, chat_id: str
+    def update_chat_by_id(
+        self, id: str, chat: dict
     ) -> Optional[ChatModel]:
     ) -> Optional[ChatModel]:
-        # Get the existing chat to share
-        chat = db.get(Chat, chat_id)
-        # Check if the chat is already shared
-        if chat.share_id:
-            return self.get_chat_by_id_and_user_id(db, chat.share_id, "shared")
-        # Create a new chat with the same data, but with a new ID
-        shared_chat = ChatModel(
-            **{
-                "id": str(uuid.uuid4()),
-                "user_id": f"shared-{chat_id}",
-                "title": chat.title,
-                "chat": chat.chat,
-                "created_at": chat.created_at,
-                "updated_at": int(time.time()),
-            }
-        )
-        shared_result = Chat(**shared_chat.model_dump())
-        db.add(shared_result)
-        db.commit()
-        db.refresh(shared_result)
-        # Update the original chat with the share_id
-        result = (
-            db.query(Chat).filter_by(id=chat_id).update({"share_id": shared_chat.id})
-        )
-
-        return shared_chat if (shared_result and result) else None
+        with get_session() as db:
+            try:
+                chat_obj = db.get(Chat, id)
+                chat_obj.chat = json.dumps(chat)
+                chat_obj.title = chat["title"] if "title" in chat else "New Chat"
+                chat_obj.updated_at = int(time.time())
+                db.commit()
+                db.refresh(chat_obj)
+
+                return ChatModel.model_validate(chat_obj)
+            except Exception as e:
+                return None
 
 
-    def update_shared_chat_by_chat_id(
-        self, db: Session, chat_id: str
+    def insert_shared_chat_by_chat_id(
+        self, chat_id: str
     ) -> Optional[ChatModel]:
     ) -> Optional[ChatModel]:
-        try:
-            print("update_shared_chat_by_id")
+        with get_session() as db:
+            # Get the existing chat to share
             chat = db.get(Chat, chat_id)
             chat = db.get(Chat, chat_id)
-            print(chat)
-
-            db.query(Chat).filter_by(id=chat.share_id).update(
-                {"title": chat.title, "chat": chat.chat}
+            # Check if the chat is already shared
+            if chat.share_id:
+                return self.get_chat_by_id_and_user_id(chat.share_id, "shared")
+            # Create a new chat with the same data, but with a new ID
+            shared_chat = ChatModel(
+                **{
+                    "id": str(uuid.uuid4()),
+                    "user_id": f"shared-{chat_id}",
+                    "title": chat.title,
+                    "chat": chat.chat,
+                    "created_at": chat.created_at,
+                    "updated_at": int(time.time()),
+                }
+            )
+            shared_result = Chat(**shared_chat.model_dump())
+            db.add(shared_result)
+            db.commit()
+            db.refresh(shared_result)
+            # Update the original chat with the share_id
+            result = (
+                db.query(Chat).filter_by(id=chat_id).update({"share_id": shared_chat.id})
             )
             )
 
 
-            return self.get_chat_by_id(db, chat.share_id)
-        except:
-            return None
+            return shared_chat if (shared_result and result) else None
+
+    def update_shared_chat_by_chat_id(
+        self, chat_id: str
+    ) -> Optional[ChatModel]:
+        with get_session() as db:
+            try:
+                print("update_shared_chat_by_id")
+                chat = db.get(Chat, chat_id)
+                print(chat)
+                chat.title = chat.title
+                chat.chat = chat.chat
+                db.commit()
+                db.refresh(chat)
+
+                return self.get_chat_by_id(chat.share_id)
+            except:
+                return None
 
 
-    def delete_shared_chat_by_chat_id(self, db: Session, chat_id: str) -> bool:
+    def delete_shared_chat_by_chat_id(self, chat_id: str) -> bool:
         try:
         try:
-            db.query(Chat).filter_by(user_id=f"shared-{chat_id}").delete()
+            with get_session() as db:
+                db.query(Chat).filter_by(user_id=f"shared-{chat_id}").delete()
             return True
             return True
         except:
         except:
             return False
             return False
 
 
     def update_chat_share_id_by_id(
     def update_chat_share_id_by_id(
-        self, db: Session, id: str, share_id: Optional[str]
+        self, id: str, share_id: Optional[str]
     ) -> Optional[ChatModel]:
     ) -> Optional[ChatModel]:
         try:
         try:
-            db.query(Chat).filter_by(id=id).update({"share_id": share_id})
-
-            return self.get_chat_by_id(db, id)
+            with get_session() as db:
+                chat = db.get(Chat, id)
+                chat.share_id = share_id
+                db.commit()
+                db.refresh(chat)
+                return chat
         except:
         except:
             return None
             return None
 
 
-    def toggle_chat_archive_by_id(self, db: Session, id: str) -> Optional[ChatModel]:
+    def toggle_chat_archive_by_id(self, id: str) -> Optional[ChatModel]:
         try:
         try:
-            chat = self.get_chat_by_id(db, id)
-            db.query(Chat).filter_by(id=id).update({"archived": not chat.archived})
+            with get_session() as db:
+                chat = self.get_chat_by_id(id)
+                db.query(Chat).filter_by(id=id).update({"archived": not chat.archived})
 
 
-            return self.get_chat_by_id(db, id)
+                return self.get_chat_by_id(id)
         except:
         except:
             return None
             return None
 
 
-    def archive_all_chats_by_user_id(self, db: Session, user_id: str) -> bool:
+    def archive_all_chats_by_user_id(self, user_id: str) -> bool:
         try:
         try:
-            db.query(Chat).filter_by(user_id=user_id).update({"archived": True})
+            with get_session() as db:
+                db.query(Chat).filter_by(user_id=user_id).update({"archived": True})
 
 
             return True
             return True
         except:
         except:
             return False
             return False
 
 
     def get_archived_chat_list_by_user_id(
     def get_archived_chat_list_by_user_id(
-        self, db: Session, user_id: str, skip: int = 0, limit: int = 50
+        self, user_id: str, skip: int = 0, limit: int = 50
     ) -> List[ChatModel]:
     ) -> List[ChatModel]:
-        all_chats = (
-            db.query(Chat)
-            .filter_by(user_id=user_id, archived=True)
-            .order_by(Chat.updated_at.desc())
-            # .limit(limit).offset(skip)
-            .all()
-        )
-        return [ChatModel.model_validate(chat) for chat in all_chats]
+        with get_session() as db:
+            all_chats = (
+                db.query(Chat)
+                .filter_by(user_id=user_id, archived=True)
+                .order_by(Chat.updated_at.desc())
+                # .limit(limit).offset(skip)
+                .all()
+            )
+            return [ChatModel.model_validate(chat) for chat in all_chats]
 
 
     def get_chat_list_by_user_id(
     def get_chat_list_by_user_id(
         self,
         self,
-        db: Session,
         user_id: str,
         user_id: str,
         include_archived: bool = False,
         include_archived: bool = False,
         skip: int = 0,
         skip: int = 0,
         limit: int = 50,
         limit: int = 50,
     ) -> List[ChatModel]:
     ) -> List[ChatModel]:
-        query = db.query(Chat).filter_by(user_id=user_id)
-        if not include_archived:
-            query = query.filter_by(archived=False)
-        all_chats = (
-            query.order_by(Chat.updated_at.desc())
-            # .limit(limit).offset(skip)
-            .all()
-        )
-        return [ChatModel.model_validate(chat) for chat in all_chats]
+        with get_session() as db:
+            query = db.query(Chat).filter_by(user_id=user_id)
+            if not include_archived:
+                query = query.filter_by(archived=False)
+            all_chats = (
+                query.order_by(Chat.updated_at.desc())
+                # .limit(limit).offset(skip)
+                .all()
+            )
+            return [ChatModel.model_validate(chat) for chat in all_chats]
 
 
     def get_chat_list_by_chat_ids(
     def get_chat_list_by_chat_ids(
-        self, db: Session, chat_ids: List[str], skip: int = 0, limit: int = 50
+        self, chat_ids: List[str], skip: int = 0, limit: int = 50
     ) -> List[ChatModel]:
     ) -> List[ChatModel]:
-        all_chats = (
-            db.query(Chat)
-            .filter(Chat.id.in_(chat_ids))
-            .filter_by(archived=False)
-            .order_by(Chat.updated_at.desc())
-            .all()
-        )
-        return [ChatModel.model_validate(chat) for chat in all_chats]
-
-    def get_chat_by_id(self, db: Session, id: str) -> Optional[ChatModel]:
+        with get_session() as db:
+            all_chats = (
+                db.query(Chat)
+                .filter(Chat.id.in_(chat_ids))
+                .filter_by(archived=False)
+                .order_by(Chat.updated_at.desc())
+                .all()
+            )
+            return [ChatModel.model_validate(chat) for chat in all_chats]
+
+    def get_chat_by_id(self, id: str) -> Optional[ChatModel]:
         try:
         try:
-            chat = db.get(Chat, id)
-            return ChatModel.model_validate(chat)
+            with get_session() as db:
+                chat = db.get(Chat, id)
+                return ChatModel.model_validate(chat)
         except:
         except:
             return None
             return None
 
 
-    def get_chat_by_share_id(self, db: Session, id: str) -> Optional[ChatModel]:
+    def get_chat_by_share_id(self, id: str) -> Optional[ChatModel]:
         try:
         try:
-            chat = db.query(Chat).filter_by(share_id=id).first()
+            with get_session() as db:
+                chat = db.query(Chat).filter_by(share_id=id).first()
 
 
-            if chat:
-                return self.get_chat_by_id(db, id)
-            else:
-                return None
+                if chat:
+                    return self.get_chat_by_id(id)
+                else:
+                    return None
         except Exception as e:
         except Exception as e:
             return None
             return None
 
 
     def get_chat_by_id_and_user_id(
     def get_chat_by_id_and_user_id(
-        self, db: Session, id: str, user_id: str
+        self, id: str, user_id: str
     ) -> Optional[ChatModel]:
     ) -> Optional[ChatModel]:
         try:
         try:
-            chat = db.query(Chat).filter_by(id=id, user_id=user_id).first()
-            return ChatModel.model_validate(chat)
+            with get_session() as db:
+                chat = db.query(Chat).filter_by(id=id, user_id=user_id).first()
+                return ChatModel.model_validate(chat)
         except:
         except:
             return None
             return None
 
 
-    def get_chats(self, db: Session, skip: int = 0, limit: int = 50) -> List[ChatModel]:
-        all_chats = (
-            db.query(Chat)
-            # .limit(limit).offset(skip)
-            .order_by(Chat.updated_at.desc())
-        )
-        return [ChatModel.model_validate(chat) for chat in all_chats]
+    def get_chats(self, skip: int = 0, limit: int = 50) -> List[ChatModel]:
+        with get_session() as db:
+            all_chats = (
+                db.query(Chat)
+                # .limit(limit).offset(skip)
+                .order_by(Chat.updated_at.desc())
+            )
+            return [ChatModel.model_validate(chat) for chat in all_chats]
 
 
-    def get_chats_by_user_id(self, db: Session, user_id: str) -> List[ChatModel]:
-        all_chats = (
-            db.query(Chat).filter_by(user_id=user_id).order_by(Chat.updated_at.desc())
-        )
-        return [ChatModel.model_validate(chat) for chat in all_chats]
+    def get_chats_by_user_id(self, user_id: str) -> List[ChatModel]:
+        with get_session() as db:
+            all_chats = (
+                db.query(Chat).filter_by(user_id=user_id).order_by(Chat.updated_at.desc())
+            )
+            return [ChatModel.model_validate(chat) for chat in all_chats]
 
 
     def get_archived_chats_by_user_id(
     def get_archived_chats_by_user_id(
-        self, db: Session, user_id: str
+        self, user_id: str
     ) -> List[ChatModel]:
     ) -> List[ChatModel]:
-        all_chats = (
-            db.query(Chat)
-            .filter_by(user_id=user_id, archived=True)
-            .order_by(Chat.updated_at.desc())
-        )
-        return [ChatModel.model_validate(chat) for chat in all_chats]
-
-    def delete_chat_by_id(self, db: Session, id: str) -> bool:
+        with get_session() as db:
+            all_chats = (
+                db.query(Chat)
+                .filter_by(user_id=user_id, archived=True)
+                .order_by(Chat.updated_at.desc())
+            )
+            return [ChatModel.model_validate(chat) for chat in all_chats]
+
+    def delete_chat_by_id(self, id: str) -> bool:
         try:
         try:
-            db.query(Chat).filter_by(id=id).delete()
+            with get_session() as db:
+                db.query(Chat).filter_by(id=id).delete()
 
 
-            return True and self.delete_shared_chat_by_chat_id(db, id)
+                return True and self.delete_shared_chat_by_chat_id(id)
         except:
         except:
             return False
             return False
 
 
-    def delete_chat_by_id_and_user_id(self, db: Session, id: str, user_id: str) -> bool:
+    def delete_chat_by_id_and_user_id(self, id: str, user_id: str) -> bool:
         try:
         try:
-            db.query(Chat).filter_by(id=id, user_id=user_id).delete()
+            with get_session() as db:
+                db.query(Chat).filter_by(id=id, user_id=user_id).delete()
 
 
-            return True and self.delete_shared_chat_by_chat_id(db, id)
+                return True and self.delete_shared_chat_by_chat_id(id)
         except:
         except:
             return False
             return False
 
 
-    def delete_chats_by_user_id(self, db: Session, user_id: str) -> bool:
+    def delete_chats_by_user_id(self, user_id: str) -> bool:
         try:
         try:
+            with get_session() as db:
+                self.delete_shared_chats_by_user_id(user_id)
 
 
-            self.delete_shared_chats_by_user_id(db, user_id)
-
-            db.query(Chat).filter_by(user_id=user_id).delete()
+                db.query(Chat).filter_by(user_id=user_id).delete()
             return True
             return True
         except:
         except:
             return False
             return False
 
 
-    def delete_shared_chats_by_user_id(self, db: Session, user_id: str) -> bool:
+    def delete_shared_chats_by_user_id(self, user_id: str) -> bool:
         try:
         try:
-            chats_by_user = db.query(Chat).filter_by(user_id=user_id).all()
-            shared_chat_ids = [f"shared-{chat.id}" for chat in chats_by_user]
+            with get_session() as db:
+                chats_by_user = db.query(Chat).filter_by(user_id=user_id).all()
+                shared_chat_ids = [f"shared-{chat.id}" for chat in chats_by_user]
 
 
-            db.query(Chat).filter(Chat.user_id.in_(shared_chat_ids)).delete()
+                db.query(Chat).filter(Chat.user_id.in_(shared_chat_ids)).delete()
 
 
             return True
             return True
         except:
         except:

+ 46 - 39
backend/apps/webui/models/documents.py

@@ -6,7 +6,7 @@ import logging
 from sqlalchemy import String, Column, BigInteger
 from sqlalchemy import String, Column, BigInteger
 from sqlalchemy.orm import Session
 from sqlalchemy.orm import Session
 
 
-from apps.webui.internal.db import Base
+from apps.webui.internal.db import Base, get_session
 
 
 import json
 import json
 
 
@@ -73,7 +73,7 @@ class DocumentForm(DocumentUpdateForm):
 class DocumentsTable:
 class DocumentsTable:
 
 
     def insert_new_doc(
     def insert_new_doc(
-        self, db: Session, user_id: str, form_data: DocumentForm
+        self, user_id: str, form_data: DocumentForm
     ) -> Optional[DocumentModel]:
     ) -> Optional[DocumentModel]:
         document = DocumentModel(
         document = DocumentModel(
             **{
             **{
@@ -84,66 +84,73 @@ class DocumentsTable:
         )
         )
 
 
         try:
         try:
-            result = Document(**document.model_dump())
-            db.add(result)
-            db.commit()
-            db.refresh(result)
-            if result:
-                return DocumentModel.model_validate(result)
-            else:
-                return None
+            with get_session() as db:
+                result = Document(**document.model_dump())
+                db.add(result)
+                db.commit()
+                db.refresh(result)
+                if result:
+                    return DocumentModel.model_validate(result)
+                else:
+                    return None
         except:
         except:
             return None
             return None
 
 
-    def get_doc_by_name(self, db: Session, name: str) -> Optional[DocumentModel]:
+    def get_doc_by_name(self, name: str) -> Optional[DocumentModel]:
         try:
         try:
-            document = db.query(Document).filter_by(name=name).first()
-            return DocumentModel.model_validate(document) if document else None
+            with get_session() as db:
+                document = db.query(Document).filter_by(name=name).first()
+                return DocumentModel.model_validate(document) if document else None
         except:
         except:
             return None
             return None
 
 
-    def get_docs(self, db: Session) -> List[DocumentModel]:
-        return [DocumentModel.model_validate(doc) for doc in db.query(Document).all()]
+    def get_docs(self) -> List[DocumentModel]:
+        with get_session() as db:
+            return [DocumentModel.model_validate(doc) for doc in db.query(Document).all()]
 
 
     def update_doc_by_name(
     def update_doc_by_name(
-        self, db: Session, name: str, form_data: DocumentUpdateForm
+        self, name: str, form_data: DocumentUpdateForm
     ) -> Optional[DocumentModel]:
     ) -> Optional[DocumentModel]:
         try:
         try:
-            db.query(Document).filter_by(name=name).update(
-                {
-                    "title": form_data.title,
-                    "name": form_data.name,
-                    "timestamp": int(time.time()),
-                }
-            )
-            return self.get_doc_by_name(db, form_data.name)
+            with get_session() as db:
+                db.query(Document).filter_by(name=name).update(
+                    {
+                        "title": form_data.title,
+                        "name": form_data.name,
+                        "timestamp": int(time.time()),
+                    }
+                )
+                db.commit()
+                return self.get_doc_by_name(form_data.name)
         except Exception as e:
         except Exception as e:
             log.exception(e)
             log.exception(e)
             return None
             return None
 
 
     def update_doc_content_by_name(
     def update_doc_content_by_name(
-        self, db: Session, name: str, updated: dict
+        self, name: str, updated: dict
     ) -> Optional[DocumentModel]:
     ) -> Optional[DocumentModel]:
         try:
         try:
-            doc = self.get_doc_by_name(db, name)
-            doc_content = json.loads(doc.content if doc.content else "{}")
-            doc_content = {**doc_content, **updated}
-
-            db.query(Document).filter_by(name=name).update(
-                {
-                    "content": json.dumps(doc_content),
-                    "timestamp": int(time.time()),
-                }
-            )
-
-            return self.get_doc_by_name(db, name)
+            with get_session() as db:
+                doc = self.get_doc_by_name(name)
+                doc_content = json.loads(doc.content if doc.content else "{}")
+                doc_content = {**doc_content, **updated}
+
+                db.query(Document).filter_by(name=name).update(
+                    {
+                        "content": json.dumps(doc_content),
+                        "timestamp": int(time.time()),
+                    }
+                )
+                db.commit()
+                return self.get_doc_by_name(name)
         except Exception as e:
         except Exception as e:
             log.exception(e)
             log.exception(e)
             return None
             return None
 
 
-    def delete_doc_by_name(self, db: Session, name: str) -> bool:
+    def delete_doc_by_name(self, name: str) -> bool:
         try:
         try:
-            db.query(Document).filter_by(name=name).delete()
+            with get_session() as db:
+                db.query(Document).filter_by(name=name).delete()
             return True
             return True
         except:
         except:
             return False
             return False

+ 26 - 19
backend/apps/webui/models/files.py

@@ -6,7 +6,7 @@ import logging
 from sqlalchemy import Column, String, BigInteger
 from sqlalchemy import Column, String, BigInteger
 from sqlalchemy.orm import Session
 from sqlalchemy.orm import Session
 
 
-from apps.webui.internal.db import JSONField, Base
+from apps.webui.internal.db import JSONField, Base, get_session
 
 
 import json
 import json
 
 
@@ -60,7 +60,7 @@ class FileForm(BaseModel):
 
 
 class FilesTable:
 class FilesTable:
 
 
-    def insert_new_file(self, db: Session, user_id: str, form_data: FileForm) -> Optional[FileModel]:
+    def insert_new_file(self, user_id: str, form_data: FileForm) -> Optional[FileModel]:
         file = FileModel(
         file = FileModel(
             **{
             **{
                 **form_data.model_dump(),
                 **form_data.model_dump(),
@@ -70,38 +70,45 @@ class FilesTable:
         )
         )
 
 
         try:
         try:
-            result = File(**file.model_dump())
-            db.add(result)
-            db.commit()
-            db.refresh(result)
-            if result:
-                return FileModel.model_validate(result)
-            else:
-                return None
+            with get_session() as db:
+                result = File(**file.model_dump())
+                db.add(result)
+                db.commit()
+                db.refresh(result)
+                if result:
+                    return FileModel.model_validate(result)
+                else:
+                    return None
         except Exception as e:
         except Exception as e:
             print(f"Error creating tool: {e}")
             print(f"Error creating tool: {e}")
             return None
             return None
 
 
-    def get_file_by_id(self, db: Session, id: str) -> Optional[FileModel]:
+    def get_file_by_id(self, id: str) -> Optional[FileModel]:
         try:
         try:
-            file = db.get(File, id)
-            return FileModel.model_validate(file)
+            with get_session() as db:
+                file = db.get(File, id)
+                return FileModel.model_validate(file)
         except:
         except:
             return None
             return None
 
 
-    def get_files(self, db: Session) -> List[FileModel]:
-        return [FileModel.model_validate(file) for file in db.query(File).all()]
+    def get_files(self) -> List[FileModel]:
+        with get_session() as db:
+            return [FileModel.model_validate(file) for file in db.query(File).all()]
 
 
-    def delete_file_by_id(self, db: Session, id: str) -> bool:
+    def delete_file_by_id(self, id: str) -> bool:
         try:
         try:
-            db.query(File).filter_by(id=id).delete()
+            with get_session() as db:
+                db.query(File).filter_by(id=id).delete()
+                db.commit()
             return True
             return True
         except:
         except:
             return False
             return False
 
 
-    def delete_all_files(self, db: Session) -> bool:
+    def delete_all_files(self) -> bool:
         try:
         try:
-            db.query(File).delete()
+            with get_session() as db:
+                db.query(File).delete()
+                db.commit()
             return True
             return True
         except:
         except:
             return False
             return False

+ 62 - 55
backend/apps/webui/models/functions.py

@@ -6,7 +6,7 @@ import logging
 from sqlalchemy import Column, String, Text, BigInteger, Boolean
 from sqlalchemy import Column, String, Text, BigInteger, Boolean
 from sqlalchemy.orm import Session
 from sqlalchemy.orm import Session
 
 
-from apps.webui.internal.db import JSONField, Base
+from apps.webui.internal.db import JSONField, Base, get_session
 from apps.webui.models.users import Users
 from apps.webui.models.users import Users
 
 
 import json
 import json
@@ -87,7 +87,7 @@ class FunctionValves(BaseModel):
 class FunctionsTable:
 class FunctionsTable:
 
 
     def insert_new_function(
     def insert_new_function(
-        self, db: Session, user_id: str, type: str, form_data: FunctionForm
+        self, user_id: str, type: str, form_data: FunctionForm
     ) -> Optional[FunctionModel]:
     ) -> Optional[FunctionModel]:
         function = FunctionModel(
         function = FunctionModel(
             **{
             **{
@@ -100,57 +100,64 @@ class FunctionsTable:
         )
         )
 
 
         try:
         try:
-            result = Function(**function.model_dump())
-            db.add(result)
-            db.commit()
-            db.refresh(result)
-            if result:
-                return FunctionModel.model_validate(result)
-            else:
-                return None
+            with get_session() as db:
+                result = Function(**function.model_dump())
+                db.add(result)
+                db.commit()
+                db.refresh(result)
+                if result:
+                    return FunctionModel.model_validate(result)
+                else:
+                    return None
         except Exception as e:
         except Exception as e:
             print(f"Error creating tool: {e}")
             print(f"Error creating tool: {e}")
             return None
             return None
 
 
-    def get_function_by_id(self, db: Session, id: str) -> Optional[FunctionModel]:
+    def get_function_by_id(self, id: str) -> Optional[FunctionModel]:
         try:
         try:
-            function = db.get(Function, id)
-            return FunctionModel.model_validate(function)
+            with get_session() as db:
+                function = db.get(Function, id)
+                return FunctionModel.model_validate(function)
         except:
         except:
             return None
             return None
 
 
     def get_functions(self, active_only=False) -> List[FunctionModel]:
     def get_functions(self, active_only=False) -> List[FunctionModel]:
         if active_only:
         if active_only:
-            return [
-                FunctionModel(**model_to_dict(function))
-                for function in Function.select().where(Function.is_active == True)
-            ]
+            with get_session() as db:
+                return [
+                    FunctionModel.model_validate(function)
+                    for function in db.query(Function).filter_by(is_active=True).all()
+                ]
         else:
         else:
-            return [
-                FunctionModel(**model_to_dict(function))
-                for function in Function.select()
-            ]
+            with get_session() as db:
+                return [
+                    FunctionModel.model_validate(function)
+                    for function in db.query(Function).all()
+                ]
 
 
     def get_functions_by_type(
     def get_functions_by_type(
         self, type: str, active_only=False
         self, type: str, active_only=False
     ) -> List[FunctionModel]:
     ) -> List[FunctionModel]:
         if active_only:
         if active_only:
-            return [
-                FunctionModel(**model_to_dict(function))
-                for function in Function.select().where(
-                    Function.type == type, Function.is_active == True
-                )
-            ]
+            with get_session() as db:
+                return [
+                    FunctionModel.model_validate(function)
+                    for function in db.query(Function).filter_by(
+                        type=type, is_active=True
+                    ).all()
+                ]
         else:
         else:
-            return [
-                FunctionModel(**model_to_dict(function))
-                for function in Function.select().where(Function.type == type)
-            ]
+            with get_session() as db:
+                return [
+                    FunctionModel.model_validate(function)
+                    for function in db.query(Function).filter_by(type=type).all()
+                ]
 
 
     def get_function_valves_by_id(self, id: str) -> Optional[dict]:
     def get_function_valves_by_id(self, id: str) -> Optional[dict]:
         try:
         try:
-            function = Function.get(Function.id == id)
-            return function.valves if function.valves else {}
+            with get_session() as db:
+                function = db.get(Function, id)
+                return function.valves if function.valves else {}
         except Exception as e:
         except Exception as e:
             print(f"An error occurred: {e}")
             print(f"An error occurred: {e}")
             return None
             return None
@@ -159,14 +166,12 @@ class FunctionsTable:
         self, id: str, valves: dict
         self, id: str, valves: dict
     ) -> Optional[FunctionValves]:
     ) -> Optional[FunctionValves]:
         try:
         try:
-            query = Function.update(
-                **{"valves": valves},
-                updated_at=int(time.time()),
-            ).where(Function.id == id)
-            query.execute()
-
-            function = Function.get(Function.id == id)
-            return FunctionValves(**model_to_dict(function))
+            with get_session() as db:
+                db.query(Function).filter_by(id=id).update(
+                    {"valves": valves, "updated_at": int(time.time())}
+                )
+                db.commit()
+                return self.get_function_by_id(id)
         except:
         except:
             return None
             return None
 
 
@@ -214,30 +219,32 @@ class FunctionsTable:
 
 
     def update_function_by_id(self, id: str, updated: dict) -> Optional[FunctionModel]:
     def update_function_by_id(self, id: str, updated: dict) -> Optional[FunctionModel]:
         try:
         try:
-            db.query(Function).filter_by(id=id).update({
-                **updated,
-                "updated_at": int(time.time()),
-            })
-            return self.get_function_by_id(db, id)
+            with get_session() as db:
+                db.query(Function).filter_by(id=id).update({
+                    **updated,
+                    "updated_at": int(time.time()),
+                })
+                db.commit()
+                return self.get_function_by_id(id)
         except:
         except:
             return None
             return None
 
 
     def deactivate_all_functions(self) -> Optional[bool]:
     def deactivate_all_functions(self) -> Optional[bool]:
         try:
         try:
-            query = Function.update(
-                **{"is_active": False},
-                updated_at=int(time.time()),
-            )
-
-            query.execute()
-
+            with get_session() as db:
+                db.query(Function).update({
+                    "is_active": False,
+                    "updated_at": int(time.time()),
+                })
+                db.commit()
             return True
             return True
         except:
         except:
             return None
             return None
 
 
-    def delete_function_by_id(self, db: Session, id: str) -> bool:
+    def delete_function_by_id(self, id: str) -> bool:
         try:
         try:
-            db.query(Function).filter_by(id=id).delete()
+            with get_session() as db:
+                db.query(Function).filter_by(id=id).delete()
             return True
             return True
         except:
         except:
             return False
             return False

+ 35 - 28
backend/apps/webui/models/memories.py

@@ -4,7 +4,7 @@ from typing import List, Union, Optional
 from sqlalchemy import Column, String, BigInteger
 from sqlalchemy import Column, String, BigInteger
 from sqlalchemy.orm import Session
 from sqlalchemy.orm import Session
 
 
-from apps.webui.internal.db import Base
+from apps.webui.internal.db import Base, get_session
 from apps.webui.models.chats import Chats
 from apps.webui.models.chats import Chats
 
 
 import time
 import time
@@ -44,7 +44,6 @@ class MemoriesTable:
 
 
     def insert_new_memory(
     def insert_new_memory(
         self,
         self,
-        db: Session,
         user_id: str,
         user_id: str,
         content: str,
         content: str,
     ) -> Optional[MemoryModel]:
     ) -> Optional[MemoryModel]:
@@ -59,53 +58,59 @@ class MemoriesTable:
                 "updated_at": int(time.time()),
                 "updated_at": int(time.time()),
             }
             }
         )
         )
-        result = Memory(**memory.dict())
-        db.add(result)
-        db.commit()
-        db.refresh(result)
-        if result:
-            return MemoryModel.model_validate(result)
-        else:
-            return None
+        with get_session() as db:
+            result = Memory(**memory.model_dump())
+            db.add(result)
+            db.commit()
+            db.refresh(result)
+            if result:
+                return MemoryModel.model_validate(result)
+            else:
+                return None
 
 
     def update_memory_by_id(
     def update_memory_by_id(
         self,
         self,
-        db: Session,
         id: str,
         id: str,
         content: str,
         content: str,
     ) -> Optional[MemoryModel]:
     ) -> Optional[MemoryModel]:
         try:
         try:
-            db.query(Memory).filter_by(id=id).update(
-                {"content": content, "updated_at": int(time.time())}
-            )
-            return self.get_memory_by_id(db, id)
+            with get_session() as db:
+                db.query(Memory).filter_by(id=id).update(
+                    {"content": content, "updated_at": int(time.time())}
+                )
+                db.commit()
+                return self.get_memory_by_id(id)
         except:
         except:
             return None
             return None
 
 
-    def get_memories(self, db: Session) -> List[MemoryModel]:
+    def get_memories(self) -> List[MemoryModel]:
         try:
         try:
-            memories = db.query(Memory).all()
-            return [MemoryModel.model_validate(memory) for memory in memories]
+            with get_session() as db:
+                memories = db.query(Memory).all()
+                return [MemoryModel.model_validate(memory) for memory in memories]
         except:
         except:
             return None
             return None
 
 
-    def get_memories_by_user_id(self, db: Session, user_id: str) -> List[MemoryModel]:
+    def get_memories_by_user_id(self, user_id: str) -> List[MemoryModel]:
         try:
         try:
-            memories = db.query(Memory).filter_by(user_id=user_id).all()
-            return [MemoryModel.model_validate(memory) for memory in memories]
+            with get_session() as db:
+                memories = db.query(Memory).filter_by(user_id=user_id).all()
+                return [MemoryModel.model_validate(memory) for memory in memories]
         except:
         except:
             return None
             return None
 
 
-    def get_memory_by_id(self, db: Session, id: str) -> Optional[MemoryModel]:
+    def get_memory_by_id(self, id: str) -> Optional[MemoryModel]:
         try:
         try:
-            memory = db.get(Memory, id)
-            return MemoryModel.model_validate(memory)
+            with get_session() as db:
+                memory = db.get(Memory, id)
+                return MemoryModel.model_validate(memory)
         except:
         except:
             return None
             return None
 
 
-    def delete_memory_by_id(self, db: Session, id: str) -> bool:
+    def delete_memory_by_id(self, id: str) -> bool:
         try:
         try:
-            db.query(Memory).filter_by(id=id).delete()
+            with get_session() as db:
+                db.query(Memory).filter_by(id=id).delete()
             return True
             return True
 
 
         except:
         except:
@@ -113,7 +118,8 @@ class MemoriesTable:
 
 
     def delete_memories_by_user_id(self, db: Session, user_id: str) -> bool:
     def delete_memories_by_user_id(self, db: Session, user_id: str) -> bool:
         try:
         try:
-            db.query(Memory).filter_by(user_id=user_id).delete()
+            with get_session() as db:
+                db.query(Memory).filter_by(user_id=user_id).delete()
             return True
             return True
         except:
         except:
             return False
             return False
@@ -122,7 +128,8 @@ class MemoriesTable:
         self, db: Session, id: str, user_id: str
         self, db: Session, id: str, user_id: str
     ) -> bool:
     ) -> bool:
         try:
         try:
-            db.query(Memory).filter_by(id=id, user_id=user_id).delete()
+            with get_session() as db:
+                db.query(Memory).filter_by(id=id, user_id=user_id).delete()
             return True
             return True
         except:
         except:
             return False
             return False

+ 31 - 26
backend/apps/webui/models/models.py

@@ -6,7 +6,7 @@ from pydantic import BaseModel, ConfigDict
 from sqlalchemy import String, Column, BigInteger
 from sqlalchemy import String, Column, BigInteger
 from sqlalchemy.orm import Session
 from sqlalchemy.orm import Session
 
 
-from apps.webui.internal.db import Base, JSONField
+from apps.webui.internal.db import Base, JSONField, get_session
 
 
 from typing import List, Union, Optional
 from typing import List, Union, Optional
 from config import SRC_LOG_LEVELS
 from config import SRC_LOG_LEVELS
@@ -78,8 +78,6 @@ class Model(Base):
 
 
 
 
 class ModelModel(BaseModel):
 class ModelModel(BaseModel):
-    model_config = ConfigDict(from_attributes=True)
-
     id: str
     id: str
     user_id: str
     user_id: str
     base_model_id: Optional[str] = None
     base_model_id: Optional[str] = None
@@ -91,6 +89,8 @@ class ModelModel(BaseModel):
     updated_at: int  # timestamp in epoch
     updated_at: int  # timestamp in epoch
     created_at: int  # timestamp in epoch
     created_at: int  # timestamp in epoch
 
 
+    model_config = ConfigDict(from_attributes=True)
+
 
 
 ####################
 ####################
 # Forms
 # Forms
@@ -116,7 +116,7 @@ class ModelForm(BaseModel):
 class ModelsTable:
 class ModelsTable:
 
 
     def insert_new_model(
     def insert_new_model(
-        self, db: Session, form_data: ModelForm, user_id: str
+        self, form_data: ModelForm, user_id: str
     ) -> Optional[ModelModel]:
     ) -> Optional[ModelModel]:
         model = ModelModel(
         model = ModelModel(
             **{
             **{
@@ -127,47 +127,52 @@ class ModelsTable:
             }
             }
         )
         )
         try:
         try:
-            result = Model(**model.dict())
-            db.add(result)
-            db.commit()
-            db.refresh(result)
-
-            if result:
-                return ModelModel.model_validate(result)
-            else:
-                return None
+            with get_session() as db:
+                result = Model(**model.model_dump())
+                db.add(result)
+                db.commit()
+                db.refresh(result)
+
+                if result:
+                    return ModelModel.model_validate(result)
+                else:
+                    return None
         except Exception as e:
         except Exception as e:
             print(e)
             print(e)
             return None
             return None
 
 
-    def get_all_models(self, db: Session) -> List[ModelModel]:
-        return [ModelModel.model_validate(model) for model in db.query(Model).all()]
+    def get_all_models(self) -> List[ModelModel]:
+        with get_session() as db:
+            return [ModelModel.model_validate(model) for model in db.query(Model).all()]
 
 
-    def get_model_by_id(self, db: Session, id: str) -> Optional[ModelModel]:
+    def get_model_by_id(self, id: str) -> Optional[ModelModel]:
         try:
         try:
-            model = db.get(Model, id)
-            return ModelModel.model_validate(model)
+            with get_session() as db:
+                model = db.get(Model, id)
+                return ModelModel.model_validate(model)
         except:
         except:
             return None
             return None
 
 
     def update_model_by_id(
     def update_model_by_id(
-        self, db: Session, id: str, model: ModelForm
+        self, id: str, model: ModelForm
     ) -> Optional[ModelModel]:
     ) -> Optional[ModelModel]:
         try:
         try:
             # update only the fields that are present in the model
             # update only the fields that are present in the model
-            model = db.query(Model).get(id)
-            model.update(**model.model_dump())
-            db.commit()
-            db.refresh(model)
-            return ModelModel.model_validate(model)
+            with get_session() as db:
+                model = db.query(Model).get(id)
+                model.update(**model.model_dump())
+                db.commit()
+                db.refresh(model)
+                return ModelModel.model_validate(model)
         except Exception as e:
         except Exception as e:
             print(e)
             print(e)
 
 
             return None
             return None
 
 
-    def delete_model_by_id(self, db: Session, id: str) -> bool:
+    def delete_model_by_id(self, id: str) -> bool:
         try:
         try:
-            db.query(Model).filter_by(id=id).delete()
+            with get_session() as db:
+                db.query(Model).filter_by(id=id).delete()
             return True
             return True
         except:
         except:
             return False
             return False

+ 53 - 49
backend/apps/webui/models/prompts.py

@@ -5,7 +5,7 @@ import time
 from sqlalchemy import String, Column, BigInteger
 from sqlalchemy import String, Column, BigInteger
 from sqlalchemy.orm import Session
 from sqlalchemy.orm import Session
 
 
-from apps.webui.internal.db import Base
+from apps.webui.internal.db import Base, get_session
 
 
 import json
 import json
 
 
@@ -48,61 +48,65 @@ class PromptForm(BaseModel):
 class PromptsTable:
 class PromptsTable:
 
 
     def insert_new_prompt(
     def insert_new_prompt(
-        self, db: Session, user_id: str, form_data: PromptForm
+        self, user_id: str, form_data: PromptForm
     ) -> Optional[PromptModel]:
     ) -> Optional[PromptModel]:
-        prompt = PromptModel(
-            **{
-                "user_id": user_id,
-                "command": form_data.command,
-                "title": form_data.title,
-                "content": form_data.content,
-                "timestamp": int(time.time()),
-            }
-        )
-
-        try:
-            result = Prompt(**prompt.dict())
-            db.add(result)
-            db.commit()
-            db.refresh(result)
-            if result:
-                return PromptModel.model_validate(result)
-            else:
+        with get_session() as db:
+            prompt = PromptModel(
+                **{
+                    "user_id": user_id,
+                    "command": form_data.command,
+                    "title": form_data.title,
+                    "content": form_data.content,
+                    "timestamp": int(time.time()),
+                }
+            )
+
+            try:
+                result = Prompt(**prompt.dict())
+                db.add(result)
+                db.commit()
+                db.refresh(result)
+                if result:
+                    return PromptModel.model_validate(result)
+                else:
+                    return None
+            except Exception as e:
                 return None
                 return None
-        except Exception as e:
-            return None
 
 
-    def get_prompt_by_command(self, db: Session, command: str) -> Optional[PromptModel]:
-        try:
-            prompt = db.query(Prompt).filter_by(command=command).first()
-            return PromptModel.model_validate(prompt)
-        except:
-            return None
+    def get_prompt_by_command(self, command: str) -> Optional[PromptModel]:
+        with get_session() as db:
+            try:
+                prompt = db.query(Prompt).filter_by(command=command).first()
+                return PromptModel.model_validate(prompt)
+            except:
+                return None
 
 
-    def get_prompts(self, db: Session) -> List[PromptModel]:
-        return [PromptModel.model_validate(prompt) for prompt in db.query(Prompt).all()]
+    def get_prompts(self) -> List[PromptModel]:
+        with get_session() as db:
+            return [PromptModel.model_validate(prompt) for prompt in db.query(Prompt).all()]
 
 
     def update_prompt_by_command(
     def update_prompt_by_command(
-        self, db: Session, command: str, form_data: PromptForm
+        self, command: str, form_data: PromptForm
     ) -> Optional[PromptModel]:
     ) -> Optional[PromptModel]:
-        try:
-            db.query(Prompt).filter_by(command=command).update(
-                {
-                    "title": form_data.title,
-                    "content": form_data.content,
-                    "timestamp": int(time.time()),
-                }
-            )
-            return self.get_prompt_by_command(db, command)
-        except:
-            return None
-
-    def delete_prompt_by_command(self, db: Session, command: str) -> bool:
-        try:
-            db.query(Prompt).filter_by(command=command).delete()
-            return True
-        except:
-            return False
+        with get_session() as db:
+            try:
+                prompt = db.query(Prompt).filter_by(command=command).first()
+                prompt.title = form_data.title
+                prompt.content = form_data.content
+                prompt.timestamp = int(time.time())
+                db.commit()
+                return prompt
+                # return self.get_prompt_by_command(command)
+            except:
+                return None
+
+    def delete_prompt_by_command(self, command: str) -> bool:
+        with get_session() as db:
+            try:
+                db.query(Prompt).filter_by(command=command).delete()
+                return True
+            except:
+                return False
 
 
 
 
 Prompts = PromptsTable()
 Prompts = PromptsTable()

+ 118 - 107
backend/apps/webui/models/tags.py

@@ -9,7 +9,7 @@ import logging
 from sqlalchemy import String, Column, BigInteger
 from sqlalchemy import String, Column, BigInteger
 from sqlalchemy.orm import Session
 from sqlalchemy.orm import Session
 
 
-from apps.webui.internal.db import Base
+from apps.webui.internal.db import Base, get_session
 
 
 from config import SRC_LOG_LEVELS
 from config import SRC_LOG_LEVELS
 
 
@@ -80,37 +80,39 @@ class ChatTagsResponse(BaseModel):
 class TagTable:
 class TagTable:
 
 
     def insert_new_tag(
     def insert_new_tag(
-        self, db: Session, name: str, user_id: str
+        self, name: str, user_id: str
     ) -> Optional[TagModel]:
     ) -> Optional[TagModel]:
         id = str(uuid.uuid4())
         id = str(uuid.uuid4())
         tag = TagModel(**{"id": id, "user_id": user_id, "name": name})
         tag = TagModel(**{"id": id, "user_id": user_id, "name": name})
         try:
         try:
-            result = Tag(**tag.dict())
-            db.add(result)
-            db.commit()
-            db.refresh(result)
-            if result:
-                return TagModel.model_validate(result)
-            else:
-                return None
+            with get_session() as db:
+                result = Tag(**tag.model_dump())
+                db.add(result)
+                db.commit()
+                db.refresh(result)
+                if result:
+                    return TagModel.model_validate(result)
+                else:
+                    return None
         except Exception as e:
         except Exception as e:
             return None
             return None
 
 
     def get_tag_by_name_and_user_id(
     def get_tag_by_name_and_user_id(
-        self, db: Session, name: str, user_id: str
+        self, name: str, user_id: str
     ) -> Optional[TagModel]:
     ) -> Optional[TagModel]:
         try:
         try:
-            tag = db.query(Tag).filter(name=name, user_id=user_id).first()
-            return TagModel.model_validate(tag)
+            with get_session() as db:
+                tag = db.query(Tag).filter(name=name, user_id=user_id).first()
+                return TagModel.model_validate(tag)
         except Exception as e:
         except Exception as e:
             return None
             return None
 
 
     def add_tag_to_chat(
     def add_tag_to_chat(
-        self, db: Session, user_id: str, form_data: ChatIdTagForm
+        self, user_id: str, form_data: ChatIdTagForm
     ) -> Optional[ChatIdTagModel]:
     ) -> Optional[ChatIdTagModel]:
-        tag = self.get_tag_by_name_and_user_id(db, form_data.tag_name, user_id)
+        tag = self.get_tag_by_name_and_user_id(form_data.tag_name, user_id)
         if tag == None:
         if tag == None:
-            tag = self.insert_new_tag(db, form_data.tag_name, user_id)
+            tag = self.insert_new_tag(form_data.tag_name, user_id)
 
 
         id = str(uuid.uuid4())
         id = str(uuid.uuid4())
         chatIdTag = ChatIdTagModel(
         chatIdTag = ChatIdTagModel(
@@ -123,118 +125,127 @@ class TagTable:
             }
             }
         )
         )
         try:
         try:
-            result = ChatIdTag(**chatIdTag.dict())
-            db.add(result)
-            db.commit()
-            db.refresh(result)
-            if result:
-                return ChatIdTagModel.model_validate(result)
-            else:
-                return None
+            with get_session() as db:
+                result = ChatIdTag(**chatIdTag.model_dump())
+                db.add(result)
+                db.commit()
+                db.refresh(result)
+                if result:
+                    return ChatIdTagModel.model_validate(result)
+                else:
+                    return None
         except:
         except:
             return None
             return None
 
 
-    def get_tags_by_user_id(self, db: Session, user_id: str) -> List[TagModel]:
-        tag_names = [
-            chat_id_tag.tag_name
-            for chat_id_tag in (
-                db.query(ChatIdTag)
-                .filter_by(user_id=user_id)
-                .order_by(ChatIdTag.timestamp.desc())
-                .all()
-            )
-        ]
-
-        return [
-            TagModel.model_validate(tag)
-            for tag in (
-                db.query(Tag)
-                .filter_by(user_id=user_id)
-                .filter(Tag.name.in_(tag_names))
-                .all()
-            )
-        ]
+    def get_tags_by_user_id(self, user_id: str) -> List[TagModel]:
+        with get_session() as db:
+            tag_names = [
+                chat_id_tag.tag_name
+                for chat_id_tag in (
+                    db.query(ChatIdTag)
+                    .filter_by(user_id=user_id)
+                    .order_by(ChatIdTag.timestamp.desc())
+                    .all()
+                )
+            ]
+
+            return [
+                TagModel.model_validate(tag)
+                for tag in (
+                    db.query(Tag)
+                    .filter_by(user_id=user_id)
+                    .filter(Tag.name.in_(tag_names))
+                    .all()
+                )
+            ]
 
 
     def get_tags_by_chat_id_and_user_id(
     def get_tags_by_chat_id_and_user_id(
-        self, db: Session, chat_id: str, user_id: str
+        self, chat_id: str, user_id: str
     ) -> List[TagModel]:
     ) -> List[TagModel]:
-        tag_names = [
-            chat_id_tag.tag_name
-            for chat_id_tag in (
-                db.query(ChatIdTag)
-                .filter_by(user_id=user_id, chat_id=chat_id)
-                .order_by(ChatIdTag.timestamp.desc())
-                .all()
-            )
-        ]
-
-        return [
-            TagModel.model_validate(tag)
-            for tag in (
-                db.query(Tag)
-                .filter_by(user_id=user_id)
-                .filter(Tag.name.in_(tag_names))
-                .all()
-            )
-        ]
+        with get_session() as db:
+            tag_names = [
+                chat_id_tag.tag_name
+                for chat_id_tag in (
+                    db.query(ChatIdTag)
+                    .filter_by(user_id=user_id, chat_id=chat_id)
+                    .order_by(ChatIdTag.timestamp.desc())
+                    .all()
+                )
+            ]
+
+            return [
+                TagModel.model_validate(tag)
+                for tag in (
+                    db.query(Tag)
+                    .filter_by(user_id=user_id)
+                    .filter(Tag.name.in_(tag_names))
+                    .all()
+                )
+            ]
 
 
     def get_chat_ids_by_tag_name_and_user_id(
     def get_chat_ids_by_tag_name_and_user_id(
-        self, db: Session, tag_name: str, user_id: str
+        self, tag_name: str, user_id: str
     ) -> List[ChatIdTagModel]:
     ) -> List[ChatIdTagModel]:
-        return [
-            ChatIdTagModel.model_validate(chat_id_tag)
-            for chat_id_tag in (
-                db.query(ChatIdTag)
-                .filter_by(user_id=user_id, tag_name=tag_name)
-                .order_by(ChatIdTag.timestamp.desc())
-                .all()
-            )
-        ]
+        with get_session() as db:
+            return [
+                ChatIdTagModel.model_validate(chat_id_tag)
+                for chat_id_tag in (
+                    db.query(ChatIdTag)
+                    .filter_by(user_id=user_id, tag_name=tag_name)
+                    .order_by(ChatIdTag.timestamp.desc())
+                    .all()
+                )
+            ]
 
 
     def count_chat_ids_by_tag_name_and_user_id(
     def count_chat_ids_by_tag_name_and_user_id(
-        self, db: Session, tag_name: str, user_id: str
+        self, tag_name: str, user_id: str
     ) -> int:
     ) -> int:
-        return db.query(ChatIdTag).filter_by(tag_name=tag_name, user_id=user_id).count()
+        with get_session() as db:
+            return db.query(ChatIdTag).filter_by(tag_name=tag_name, user_id=user_id).count()
 
 
     def delete_tag_by_tag_name_and_user_id(
     def delete_tag_by_tag_name_and_user_id(
-        self, db: Session, tag_name: str, user_id: str
+        self, tag_name: str, user_id: str
     ) -> bool:
     ) -> bool:
         try:
         try:
-            res = (
-                db.query(ChatIdTag)
-                .filter_by(tag_name=tag_name, user_id=user_id)
-                .delete()
-            )
-            log.debug(f"res: {res}")
-
-            tag_count = self.count_chat_ids_by_tag_name_and_user_id(
-                db, tag_name, user_id
-            )
-            if tag_count == 0:
-                # Remove tag item from Tag col as well
-                db.query(Tag).filter_by(name=tag_name, user_id=user_id).delete()
+            with get_session() as db:
+                res = (
+                    db.query(ChatIdTag)
+                    .filter_by(tag_name=tag_name, user_id=user_id)
+                    .delete()
+                )
+                log.debug(f"res: {res}")
+                db.commit()
+
+                tag_count = self.count_chat_ids_by_tag_name_and_user_id(
+                    tag_name, user_id
+                )
+                if tag_count == 0:
+                    # Remove tag item from Tag col as well
+                    db.query(Tag).filter_by(name=tag_name, user_id=user_id).delete()
             return True
             return True
         except Exception as e:
         except Exception as e:
             log.error(f"delete_tag: {e}")
             log.error(f"delete_tag: {e}")
             return False
             return False
 
 
     def delete_tag_by_tag_name_and_chat_id_and_user_id(
     def delete_tag_by_tag_name_and_chat_id_and_user_id(
-        self, db: Session, tag_name: str, chat_id: str, user_id: str
+        self, tag_name: str, chat_id: str, user_id: str
     ) -> bool:
     ) -> bool:
         try:
         try:
-            res = (
-                db.query(ChatIdTag)
-                .filter_by(tag_name=tag_name, chat_id=chat_id, user_id=user_id)
-                .delete()
-            )
-            log.debug(f"res: {res}")
-
-            tag_count = self.count_chat_ids_by_tag_name_and_user_id(
-                db, tag_name, user_id
-            )
-            if tag_count == 0:
-                # Remove tag item from Tag col as well
-                db.query(Tag).filter_by(name=tag_name, user_id=user_id).delete()
+            with get_session() as db:
+                res = (
+                    db.query(ChatIdTag)
+                    .filter_by(tag_name=tag_name, chat_id=chat_id, user_id=user_id)
+                    .delete()
+                )
+                log.debug(f"res: {res}")
+                db.commit()
+
+                tag_count = self.count_chat_ids_by_tag_name_and_user_id(
+                    tag_name, user_id
+                )
+                if tag_count == 0:
+                    # Remove tag item from Tag col as well
+                    db.query(Tag).filter_by(name=tag_name, user_id=user_id).delete()
 
 
             return True
             return True
         except Exception as e:
         except Exception as e:
@@ -242,13 +253,13 @@ class TagTable:
             return False
             return False
 
 
     def delete_tags_by_chat_id_and_user_id(
     def delete_tags_by_chat_id_and_user_id(
-        self, db: Session, chat_id: str, user_id: str
+        self, chat_id: str, user_id: str
     ) -> bool:
     ) -> bool:
-        tags = self.get_tags_by_chat_id_and_user_id(db, chat_id, user_id)
+        tags = self.get_tags_by_chat_id_and_user_id(chat_id, user_id)
 
 
         for tag in tags:
         for tag in tags:
             self.delete_tag_by_tag_name_and_chat_id_and_user_id(
             self.delete_tag_by_tag_name_and_chat_id_and_user_id(
-                db, tag.tag_name, chat_id, user_id
+                tag.tag_name, chat_id, user_id
             )
             )
 
 
         return True
         return True

+ 37 - 33
backend/apps/webui/models/tools.py

@@ -5,7 +5,7 @@ import logging
 from sqlalchemy import String, Column, BigInteger
 from sqlalchemy import String, Column, BigInteger
 from sqlalchemy.orm import Session
 from sqlalchemy.orm import Session
 
 
-from apps.webui.internal.db import Base, JSONField
+from apps.webui.internal.db import Base, JSONField, get_session
 from apps.webui.models.users import Users
 from apps.webui.models.users import Users
 
 
 import json
 import json
@@ -82,7 +82,7 @@ class ToolValves(BaseModel):
 class ToolsTable:
 class ToolsTable:
 
 
     def insert_new_tool(
     def insert_new_tool(
-        self, db: Session, user_id: str, form_data: ToolForm, specs: List[dict]
+        self, user_id: str, form_data: ToolForm, specs: List[dict]
     ) -> Optional[ToolModel]:
     ) -> Optional[ToolModel]:
         tool = ToolModel(
         tool = ToolModel(
             **{
             **{
@@ -95,46 +95,48 @@ class ToolsTable:
         )
         )
 
 
         try:
         try:
-            result = Tool(**tool.dict())
-            db.add(result)
-            db.commit()
-            db.refresh(result)
-            if result:
-                return ToolModel.model_validate(result)
-            else:
-                return None
+            with get_session() as db:
+                result = Tool(**tool.model_dump())
+                db.add(result)
+                db.commit()
+                db.refresh(result)
+                if result:
+                    return ToolModel.model_validate(result)
+                else:
+                    return None
         except Exception as e:
         except Exception as e:
             print(f"Error creating tool: {e}")
             print(f"Error creating tool: {e}")
             return None
             return None
 
 
-    def get_tool_by_id(self, db: Session, id: str) -> Optional[ToolModel]:
+    def get_tool_by_id(self, id: str) -> Optional[ToolModel]:
         try:
         try:
-            tool = db.get(Tool, id)
-            return ToolModel.model_validate(tool)
+            with get_session() as db:
+                tool = db.get(Tool, id)
+                return ToolModel.model_validate(tool)
         except:
         except:
             return None
             return None
 
 
-    def get_tools(self, db: Session) -> List[ToolModel]:
-        return [ToolModel.model_validate(tool) for tool in db.query(Tool).all()]
+    def get_tools(self) -> List[ToolModel]:
+        with get_session() as db:
+            return [ToolModel.model_validate(tool) for tool in db.query(Tool).all()]
 
 
     def get_tool_valves_by_id(self, id: str) -> Optional[dict]:
     def get_tool_valves_by_id(self, id: str) -> Optional[dict]:
         try:
         try:
-            tool = Tool.get(Tool.id == id)
-            return tool.valves if tool.valves else {}
+            with get_session() as db:
+                tool = db.get(Tool, id)
+                return tool.valves if tool.valves else {}
         except Exception as e:
         except Exception as e:
             print(f"An error occurred: {e}")
             print(f"An error occurred: {e}")
             return None
             return None
 
 
     def update_tool_valves_by_id(self, id: str, valves: dict) -> Optional[ToolValves]:
     def update_tool_valves_by_id(self, id: str, valves: dict) -> Optional[ToolValves]:
         try:
         try:
-            query = Tool.update(
-                **{"valves": valves},
-                updated_at=int(time.time()),
-            ).where(Tool.id == id)
-            query.execute()
-
-            tool = Tool.get(Tool.id == id)
-            return ToolValves(**model_to_dict(tool))
+            with get_session() as db:
+                db.query(Tool).filter_by(id=id).update(
+                    {"valves": valves, "updated_at": int(time.time())}
+                )
+                db.commit()
+                return self.get_tool_by_id(id)
         except:
         except:
             return None
             return None
 
 
@@ -172,8 +174,7 @@ class ToolsTable:
             user_settings["tools"]["valves"][id] = valves
             user_settings["tools"]["valves"][id] = valves
 
 
             # Update the user settings in the database
             # Update the user settings in the database
-            query = Users.update_user_by_id(user_id, {"settings": user_settings})
-            query.execute()
+            Users.update_user_by_id(user_id, {"settings": user_settings})
 
 
             return user_settings["tools"]["valves"][id]
             return user_settings["tools"]["valves"][id]
         except Exception as e:
         except Exception as e:
@@ -182,16 +183,19 @@ class ToolsTable:
 
 
     def update_tool_by_id(self, id: str, updated: dict) -> Optional[ToolModel]:
     def update_tool_by_id(self, id: str, updated: dict) -> Optional[ToolModel]:
         try:
         try:
-            db.query(Tool).filter_by(id=id).update(
-                {**updated, "updated_at": int(time.time())}
-            )
-            return self.get_tool_by_id(db, id)
+            with get_session() as db:
+                db.query(Tool).filter_by(id=id).update(
+                    {**updated, "updated_at": int(time.time())}
+                )
+                db.commit()
+                return self.get_tool_by_id(id)
         except:
         except:
             return None
             return None
 
 
-    def delete_tool_by_id(self, db: Session, id: str) -> bool:
+    def delete_tool_by_id(self, id: str) -> bool:
         try:
         try:
-            db.query(Tool).filter_by(id=id).delete()
+            with get_session() as db:
+                db.query(Tool).filter_by(id=id).delete()
             return True
             return True
         except:
         except:
             return False
             return False

+ 152 - 137
backend/apps/webui/models/users.py

@@ -7,7 +7,7 @@ from sqlalchemy.orm import Session
 
 
 from utils.misc import get_gravatar_url
 from utils.misc import get_gravatar_url
 
 
-from apps.webui.internal.db import Base, JSONField
+from apps.webui.internal.db import Base, JSONField, get_session
 from apps.webui.models.chats import Chats
 from apps.webui.models.chats import Chats
 
 
 ####################
 ####################
@@ -42,8 +42,6 @@ class UserSettings(BaseModel):
 
 
 
 
 class UserModel(BaseModel):
 class UserModel(BaseModel):
-    model_config = ConfigDict(from_attributes=True)
-
     id: str
     id: str
     name: str
     name: str
     email: str
     email: str
@@ -60,6 +58,8 @@ class UserModel(BaseModel):
 
 
     oauth_sub: Optional[str] = None
     oauth_sub: Optional[str] = None
 
 
+    model_config = ConfigDict(from_attributes=True)
+
 
 
 ####################
 ####################
 # Forms
 # Forms
@@ -82,7 +82,6 @@ class UsersTable:
 
 
     def insert_new_user(
     def insert_new_user(
         self,
         self,
-        db: Session,
         id: str,
         id: str,
         name: str,
         name: str,
         email: str,
         email: str,
@@ -90,165 +89,181 @@ class UsersTable:
         role: str = "pending",
         role: str = "pending",
         oauth_sub: Optional[str] = None,
         oauth_sub: Optional[str] = None,
     ) -> Optional[UserModel]:
     ) -> Optional[UserModel]:
-        user = UserModel(
-            **{
-                "id": id,
-                "name": name,
-                "email": email,
-                "role": role,
-                "profile_image_url": profile_image_url,
-                "last_active_at": int(time.time()),
-                "created_at": int(time.time()),
-                "updated_at": int(time.time()),
-                "oauth_sub": oauth_sub,
-            }
-        )
-        result = User(**user.model_dump())
-        db.add(result)
-        db.commit()
-        db.refresh(result)
-        if result:
-            return user
-        else:
-            return None
-
-    def get_user_by_id(self, db: Session, id: str) -> Optional[UserModel]:
-        try:
-            user = db.query(User).filter_by(id=id).first()
-            return UserModel.model_validate(user)
-        except Exception as e:
-            return None
-
-    def get_user_by_api_key(self, db: Session, api_key: str) -> Optional[UserModel]:
-        try:
-            user = db.query(User).filter_by(api_key=api_key).first()
-            return UserModel.model_validate(user)
-        except:
-            return None
-
-    def get_user_by_email(self, db: Session, email: str) -> Optional[UserModel]:
-        try:
-            user = db.query(User).filter_by(email=email).first()
-            return UserModel.model_validate(user)
-        except:
-            return None
+        with get_session() as db:
+            user = UserModel(
+                **{
+                    "id": id,
+                    "name": name,
+                    "email": email,
+                    "role": role,
+                    "profile_image_url": profile_image_url,
+                    "last_active_at": int(time.time()),
+                    "created_at": int(time.time()),
+                    "updated_at": int(time.time()),
+                    "oauth_sub": oauth_sub,
+                }
+            )
+            result = User(**user.model_dump())
+            db.add(result)
+            db.commit()
+            db.refresh(result)
+            if result:
+                return user
+            else:
+                return None
+
+    def get_user_by_id(self, id: str) -> Optional[UserModel]:
+        with get_session() as db:
+            try:
+                user = db.query(User).filter_by(id=id).first()
+                return UserModel.model_validate(user)
+            except Exception as e:
+                return None
+
+    def get_user_by_api_key(self, api_key: str) -> Optional[UserModel]:
+        with get_session() as db:
+            try:
+                user = db.query(User).filter_by(api_key=api_key).first()
+                return UserModel.model_validate(user)
+            except:
+                return None
+
+    def get_user_by_email(self, email: str) -> Optional[UserModel]:
+        with get_session() as db:
+            try:
+                user = db.query(User).filter_by(email=email).first()
+                return UserModel.model_validate(user)
+            except:
+                return None
 
 
     def get_user_by_oauth_sub(self, sub: str) -> Optional[UserModel]:
     def get_user_by_oauth_sub(self, sub: str) -> Optional[UserModel]:
-        try:
-            user = User.get(User.oauth_sub == sub)
-            return UserModel(**model_to_dict(user))
-        except:
-            return None
-
-    def get_users(self, db: Session, skip: int = 0, limit: int = 50) -> List[UserModel]:
-        users = (
-            db.query(User)
-            # .offset(skip).limit(limit)
-            .all()
-        )
-        return [UserModel.model_validate(user) for user in users]
-
-    def get_num_users(self, db: Session) -> Optional[int]:
-        return db.query(User).count()
-
-    def get_first_user(self, db: Session) -> UserModel:
-        try:
-            user = db.query(User).order_by(User.created_at).first()
-            return UserModel.model_validate(user)
-        except:
-            return None
+        with get_session() as db:
+            try:
+                user = db.query(User).filter_by(oauth_sub=sub).first()
+                return UserModel.model_validate(user)
+            except:
+                return None
+
+    def get_users(self, skip: int = 0, limit: int = 50) -> List[UserModel]:
+        with get_session() as db:
+            users = (
+                db.query(User)
+                # .offset(skip).limit(limit)
+                .all()
+            )
+            return [UserModel.model_validate(user) for user in users]
+
+    def get_num_users(self) -> Optional[int]:
+        with get_session() as db:
+            return db.query(User).count()
+
+    def get_first_user(self) -> UserModel:
+        with get_session() as db:
+            try:
+                user = db.query(User).order_by(User.created_at).first()
+                return UserModel.model_validate(user)
+            except:
+                return None
 
 
     def update_user_role_by_id(
     def update_user_role_by_id(
-        self, db: Session, id: str, role: str
+        self, id: str, role: str
     ) -> Optional[UserModel]:
     ) -> Optional[UserModel]:
-        try:
-            db.query(User).filter_by(id=id).update({"role": role})
-            db.commit()
+        with get_session() as db:
+            try:
+                db.query(User).filter_by(id=id).update({"role": role})
+                db.commit()
 
 
-            user = db.query(User).filter_by(id=id).first()
-            return UserModel.model_validate(user)
-        except:
-            return None
+                user = db.query(User).filter_by(id=id).first()
+                return UserModel.model_validate(user)
+            except:
+                return None
 
 
     def update_user_profile_image_url_by_id(
     def update_user_profile_image_url_by_id(
-        self, db: Session, id: str, profile_image_url: str
+        self, id: str, profile_image_url: str
     ) -> Optional[UserModel]:
     ) -> Optional[UserModel]:
-        try:
-            db.query(User).filter_by(id=id).update(
-                {"profile_image_url": profile_image_url}
-            )
-            db.commit()
+        with get_session() as db:
+            try:
+                db.query(User).filter_by(id=id).update(
+                    {"profile_image_url": profile_image_url}
+                )
+                db.commit()
 
 
-            user = db.query(User).filter_by(id=id).first()
-            return UserModel.model_validate(user)
-        except:
-            return None
+                user = db.query(User).filter_by(id=id).first()
+                return UserModel.model_validate(user)
+            except:
+                return None
 
 
     def update_user_last_active_by_id(
     def update_user_last_active_by_id(
-        self, db: Session, id: str
+        self, id: str
     ) -> Optional[UserModel]:
     ) -> Optional[UserModel]:
-        try:
-            db.query(User).filter_by(id=id).update({"last_active_at": int(time.time())})
+        with get_session() as db:
+            try:
+                db.query(User).filter_by(id=id).update({"last_active_at": int(time.time())})
 
 
-            user = db.query(User).filter_by(id=id).first()
-            return UserModel.model_validate(user)
-        except:
-            return None
+                user = db.query(User).filter_by(id=id).first()
+                return UserModel.model_validate(user)
+            except:
+                return None
 
 
     def update_user_oauth_sub_by_id(
     def update_user_oauth_sub_by_id(
-        self, db: Session, id: str, oauth_sub: str
+        self, id: str, oauth_sub: str
     ) -> Optional[UserModel]:
     ) -> Optional[UserModel]:
-        try:
-            db.query(User).filter_by(id=id).update({"oauth_sub": oauth_sub})
+        with get_session() as db:
+            try:
+                db.query(User).filter_by(id=id).update({"oauth_sub": oauth_sub})
 
 
-            user = db.query(User).filter_by(id=id).first()
-            return UserModel.model_validate(user)
-        except:
-            return None
+                user = db.query(User).filter_by(id=id).first()
+                return UserModel.model_validate(user)
+            except:
+                return None
 
 
     def update_user_by_id(
     def update_user_by_id(
-        self, db: Session, id: str, updated: dict
+        self, id: str, updated: dict
     ) -> Optional[UserModel]:
     ) -> Optional[UserModel]:
-        try:
-            db.query(User).filter_by(id=id).update(updated)
-            db.commit()
-
-            user = db.query(User).filter_by(id=id).first()
-            return UserModel.model_validate(user)
-            # return UserModel(**user.dict())
-        except Exception as e:
-            return None
+        with get_session() as db:
+            try:
+                db.query(User).filter_by(id=id).update(updated)
+                db.commit()
 
 
-    def delete_user_by_id(self, db: Session, id: str) -> bool:
-        try:
-            # Delete User Chats
-            result = Chats.delete_chats_by_user_id(db, id)
+                user = db.query(User).filter_by(id=id).first()
+                return UserModel.model_validate(user)
+                # return UserModel(**user.dict())
+            except Exception as e:
+                return None
+
+    def delete_user_by_id(self, id: str) -> bool:
+        with get_session() as db:
+            try:
+                # Delete User Chats
+                result = Chats.delete_chats_by_user_id(id)
+
+                if result:
+                    # Delete User
+                    db.query(User).filter_by(id=id).delete()
+                    db.commit()
+
+                    return True
+                else:
+                    return False
+            except:
+                return False
 
 
-            if result:
-                # Delete User
-                db.query(User).filter_by(id=id).delete()
+    def update_user_api_key_by_id(self, id: str, api_key: str) -> str:
+        with get_session() as db:
+            try:
+                result = db.query(User).filter_by(id=id).update({"api_key": api_key})
                 db.commit()
                 db.commit()
-
-                return True
-            else:
+                return True if result == 1 else False
+            except:
                 return False
                 return False
-        except:
-            return False
 
 
-    def update_user_api_key_by_id(self, db: Session, id: str, api_key: str) -> str:
-        try:
-            result = db.query(User).filter_by(id=id).update({"api_key": api_key})
-            db.commit()
-            return True if result == 1 else False
-        except:
-            return False
-
-    def get_user_api_key_by_id(self, db: Session, id: str) -> Optional[str]:
-        try:
-            user = db.query(User).filter_by(id=id).first()
-            return user.api_key
-        except Exception as e:
-            return None
+    def get_user_api_key_by_id(self, id: str) -> Optional[str]:
+        with get_session() as db:
+            try:
+                user = db.query(User).filter_by(id=id).first()
+                return user.api_key
+            except Exception as e:
+                return None
 
 
 
 
 Users = UsersTable()
 Users = UsersTable()

+ 26 - 34
backend/apps/webui/routers/auths.py

@@ -10,7 +10,6 @@ import re
 import uuid
 import uuid
 import csv
 import csv
 
 
-from apps.webui.internal.db import get_db
 from apps.webui.models.auths import (
 from apps.webui.models.auths import (
     SigninForm,
     SigninForm,
     SignupForm,
     SignupForm,
@@ -80,12 +79,10 @@ async def get_session_user(
 @router.post("/update/profile", response_model=UserResponse)
 @router.post("/update/profile", response_model=UserResponse)
 async def update_profile(
 async def update_profile(
     form_data: UpdateProfileForm,
     form_data: UpdateProfileForm,
-    session_user=Depends(get_current_user),
-    db=Depends(get_db),
+    session_user=Depends(get_current_user)
 ):
 ):
     if session_user:
     if session_user:
         user = Users.update_user_by_id(
         user = Users.update_user_by_id(
-            db,
             session_user.id,
             session_user.id,
             {"profile_image_url": form_data.profile_image_url, "name": form_data.name},
             {"profile_image_url": form_data.profile_image_url, "name": form_data.name},
         )
         )
@@ -105,17 +102,16 @@ async def update_profile(
 @router.post("/update/password", response_model=bool)
 @router.post("/update/password", response_model=bool)
 async def update_password(
 async def update_password(
     form_data: UpdatePasswordForm,
     form_data: UpdatePasswordForm,
-    session_user=Depends(get_current_user),
-    db=Depends(get_db),
+    session_user=Depends(get_current_user)
 ):
 ):
     if WEBUI_AUTH_TRUSTED_EMAIL_HEADER:
     if WEBUI_AUTH_TRUSTED_EMAIL_HEADER:
         raise HTTPException(400, detail=ERROR_MESSAGES.ACTION_PROHIBITED)
         raise HTTPException(400, detail=ERROR_MESSAGES.ACTION_PROHIBITED)
     if session_user:
     if session_user:
-        user = Auths.authenticate_user(db, session_user.email, form_data.password)
+        user = Auths.authenticate_user(session_user.email, form_data.password)
 
 
         if user:
         if user:
             hashed = get_password_hash(form_data.new_password)
             hashed = get_password_hash(form_data.new_password)
-            return Auths.update_user_password_by_id(db, user.id, hashed)
+            return Auths.update_user_password_by_id(user.id, hashed)
         else:
         else:
             raise HTTPException(400, detail=ERROR_MESSAGES.INVALID_PASSWORD)
             raise HTTPException(400, detail=ERROR_MESSAGES.INVALID_PASSWORD)
     else:
     else:
@@ -128,7 +124,7 @@ async def update_password(
 
 
 
 
 @router.post("/signin", response_model=SigninResponse)
 @router.post("/signin", response_model=SigninResponse)
-async def signin(request: Request, response: Response, form_data: SigninForm, db=Depends(get_db)):
+async def signin(request: Request, response: Response, form_data: SigninForm):
     if WEBUI_AUTH_TRUSTED_EMAIL_HEADER:
     if WEBUI_AUTH_TRUSTED_EMAIL_HEADER:
         if WEBUI_AUTH_TRUSTED_EMAIL_HEADER not in request.headers:
         if WEBUI_AUTH_TRUSTED_EMAIL_HEADER not in request.headers:
             raise HTTPException(400, detail=ERROR_MESSAGES.INVALID_TRUSTED_HEADER)
             raise HTTPException(400, detail=ERROR_MESSAGES.INVALID_TRUSTED_HEADER)
@@ -139,34 +135,32 @@ async def signin(request: Request, response: Response, form_data: SigninForm, db
             trusted_name = request.headers.get(
             trusted_name = request.headers.get(
                 WEBUI_AUTH_TRUSTED_NAME_HEADER, trusted_email
                 WEBUI_AUTH_TRUSTED_NAME_HEADER, trusted_email
             )
             )
-        if not Users.get_user_by_email(db, trusted_email.lower()):
+        if not Users.get_user_by_email(trusted_email.lower()):
             await signup(
             await signup(
                 request,
                 request,
                 SignupForm(
                 SignupForm(
                     email=trusted_email, password=str(uuid.uuid4()), name=trusted_name
                     email=trusted_email, password=str(uuid.uuid4()), name=trusted_name
                 ),
                 ),
-                db,
             )
             )
-        user = Auths.authenticate_user_by_trusted_header(db, trusted_email)
+        user = Auths.authenticate_user_by_trusted_header(trusted_email)
     elif WEBUI_AUTH == False:
     elif WEBUI_AUTH == False:
         admin_email = "admin@localhost"
         admin_email = "admin@localhost"
         admin_password = "admin"
         admin_password = "admin"
 
 
-        if Users.get_user_by_email(db, admin_email.lower()):
-            user = Auths.authenticate_user(db, admin_email.lower(), admin_password)
+        if Users.get_user_by_email(admin_email.lower()):
+            user = Auths.authenticate_user(admin_email.lower(), admin_password)
         else:
         else:
-            if Users.get_num_users(db) != 0:
+            if Users.get_num_users() != 0:
                 raise HTTPException(400, detail=ERROR_MESSAGES.EXISTING_USERS)
                 raise HTTPException(400, detail=ERROR_MESSAGES.EXISTING_USERS)
 
 
             await signup(
             await signup(
                 request,
                 request,
                 SignupForm(email=admin_email, password=admin_password, name="User"),
                 SignupForm(email=admin_email, password=admin_password, name="User"),
-                db,
             )
             )
 
 
-            user = Auths.authenticate_user(db, admin_email.lower(), admin_password)
+            user = Auths.authenticate_user(admin_email.lower(), admin_password)
     else:
     else:
-        user = Auths.authenticate_user(db, form_data.email.lower(), form_data.password)
+        user = Auths.authenticate_user(form_data.email.lower(), form_data.password)
 
 
     if user:
     if user:
         token = create_token(
         token = create_token(
@@ -200,7 +194,7 @@ async def signin(request: Request, response: Response, form_data: SigninForm, db
 
 
 
 
 @router.post("/signup", response_model=SigninResponse)
 @router.post("/signup", response_model=SigninResponse)
-async def signup(request: Request, response: Response, form_data: SignupForm, db=Depends(get_db)):
+async def signup(request: Request, response: Response, form_data: SignupForm):
     if not request.app.state.config.ENABLE_SIGNUP and WEBUI_AUTH:
     if not request.app.state.config.ENABLE_SIGNUP and WEBUI_AUTH:
         raise HTTPException(
         raise HTTPException(
             status.HTTP_403_FORBIDDEN, detail=ERROR_MESSAGES.ACCESS_PROHIBITED
             status.HTTP_403_FORBIDDEN, detail=ERROR_MESSAGES.ACCESS_PROHIBITED
@@ -211,18 +205,17 @@ async def signup(request: Request, response: Response, form_data: SignupForm, db
             status.HTTP_400_BAD_REQUEST, detail=ERROR_MESSAGES.INVALID_EMAIL_FORMAT
             status.HTTP_400_BAD_REQUEST, detail=ERROR_MESSAGES.INVALID_EMAIL_FORMAT
         )
         )
 
 
-    if Users.get_user_by_email(db, form_data.email.lower()):
+    if Users.get_user_by_email(form_data.email.lower()):
         raise HTTPException(400, detail=ERROR_MESSAGES.EMAIL_TAKEN)
         raise HTTPException(400, detail=ERROR_MESSAGES.EMAIL_TAKEN)
 
 
     try:
     try:
         role = (
         role = (
             "admin"
             "admin"
-            if Users.get_num_users(db) == 0
+            if Users.get_num_users() == 0
             else request.app.state.config.DEFAULT_USER_ROLE
             else request.app.state.config.DEFAULT_USER_ROLE
         )
         )
         hashed = get_password_hash(form_data.password)
         hashed = get_password_hash(form_data.password)
         user = Auths.insert_new_auth(
         user = Auths.insert_new_auth(
-            db,
             form_data.email.lower(),
             form_data.email.lower(),
             hashed,
             hashed,
             form_data.name,
             form_data.name,
@@ -277,7 +270,7 @@ async def signup(request: Request, response: Response, form_data: SignupForm, db
 
 
 @router.post("/add", response_model=SigninResponse)
 @router.post("/add", response_model=SigninResponse)
 async def add_user(
 async def add_user(
-    form_data: AddUserForm, user=Depends(get_admin_user), db=Depends(get_db)
+    form_data: AddUserForm, user=Depends(get_admin_user)
 ):
 ):
 
 
     if not validate_email_format(form_data.email.lower()):
     if not validate_email_format(form_data.email.lower()):
@@ -285,7 +278,7 @@ async def add_user(
             status.HTTP_400_BAD_REQUEST, detail=ERROR_MESSAGES.INVALID_EMAIL_FORMAT
             status.HTTP_400_BAD_REQUEST, detail=ERROR_MESSAGES.INVALID_EMAIL_FORMAT
         )
         )
 
 
-    if Users.get_user_by_email(db, form_data.email.lower()):
+    if Users.get_user_by_email(form_data.email.lower()):
         raise HTTPException(400, detail=ERROR_MESSAGES.EMAIL_TAKEN)
         raise HTTPException(400, detail=ERROR_MESSAGES.EMAIL_TAKEN)
 
 
     try:
     try:
@@ -293,7 +286,6 @@ async def add_user(
         print(form_data)
         print(form_data)
         hashed = get_password_hash(form_data.password)
         hashed = get_password_hash(form_data.password)
         user = Auths.insert_new_auth(
         user = Auths.insert_new_auth(
-            db,
             form_data.email.lower(),
             form_data.email.lower(),
             hashed,
             hashed,
             form_data.name,
             form_data.name,
@@ -325,7 +317,7 @@ async def add_user(
 
 
 @router.get("/admin/details")
 @router.get("/admin/details")
 async def get_admin_details(
 async def get_admin_details(
-    request: Request, user=Depends(get_current_user), db=Depends(get_db)
+    request: Request, user=Depends(get_current_user)
 ):
 ):
     if request.app.state.config.SHOW_ADMIN_DETAILS:
     if request.app.state.config.SHOW_ADMIN_DETAILS:
         admin_email = request.app.state.config.ADMIN_EMAIL
         admin_email = request.app.state.config.ADMIN_EMAIL
@@ -334,11 +326,11 @@ async def get_admin_details(
         print(admin_email, admin_name)
         print(admin_email, admin_name)
 
 
         if admin_email:
         if admin_email:
-            admin = Users.get_user_by_email(db, admin_email)
+            admin = Users.get_user_by_email(admin_email)
             if admin:
             if admin:
                 admin_name = admin.name
                 admin_name = admin.name
         else:
         else:
-            admin = Users.get_first_user(db)
+            admin = Users.get_first_user()
             if admin:
             if admin:
                 admin_email = admin.email
                 admin_email = admin.email
                 admin_name = admin.name
                 admin_name = admin.name
@@ -411,9 +403,9 @@ async def update_admin_config(
 
 
 # create api key
 # create api key
 @router.post("/api_key", response_model=ApiKey)
 @router.post("/api_key", response_model=ApiKey)
-async def create_api_key_(user=Depends(get_current_user), db=Depends(get_db)):
+async def create_api_key_(user=Depends(get_current_user)):
     api_key = create_api_key()
     api_key = create_api_key()
-    success = Users.update_user_api_key_by_id(db, user.id, api_key)
+    success = Users.update_user_api_key_by_id(user.id, api_key)
     if success:
     if success:
         return {
         return {
             "api_key": api_key,
             "api_key": api_key,
@@ -424,15 +416,15 @@ async def create_api_key_(user=Depends(get_current_user), db=Depends(get_db)):
 
 
 # delete api key
 # delete api key
 @router.delete("/api_key", response_model=bool)
 @router.delete("/api_key", response_model=bool)
-async def delete_api_key(user=Depends(get_current_user), db=Depends(get_db)):
-    success = Users.update_user_api_key_by_id(db, user.id, None)
+async def delete_api_key(user=Depends(get_current_user)):
+    success = Users.update_user_api_key_by_id(user.id, None)
     return success
     return success
 
 
 
 
 # get api key
 # get api key
 @router.get("/api_key", response_model=ApiKey)
 @router.get("/api_key", response_model=ApiKey)
-async def get_api_key(user=Depends(get_current_user), db=Depends(get_db)):
-    api_key = Users.get_user_api_key_by_id(db, user.id)
+async def get_api_key(user=Depends(get_current_user)):
+    api_key = Users.get_user_api_key_by_id(user.id)
     if api_key:
     if api_key:
         return {
         return {
             "api_key": api_key,
             "api_key": api_key,

+ 56 - 60
backend/apps/webui/routers/chats.py

@@ -2,7 +2,6 @@ from fastapi import Depends, Request, HTTPException, status
 from datetime import datetime, timedelta
 from datetime import datetime, timedelta
 from typing import List, Union, Optional
 from typing import List, Union, Optional
 
 
-from apps.webui.internal.db import get_db
 from utils.utils import get_current_user, get_admin_user
 from utils.utils import get_current_user, get_admin_user
 from fastapi import APIRouter
 from fastapi import APIRouter
 from pydantic import BaseModel
 from pydantic import BaseModel
@@ -45,9 +44,9 @@ router = APIRouter()
 @router.get("/", response_model=List[ChatTitleIdResponse])
 @router.get("/", response_model=List[ChatTitleIdResponse])
 @router.get("/list", response_model=List[ChatTitleIdResponse])
 @router.get("/list", response_model=List[ChatTitleIdResponse])
 async def get_session_user_chat_list(
 async def get_session_user_chat_list(
-    user=Depends(get_current_user), skip: int = 0, limit: int = 50, db=Depends(get_db)
+    user=Depends(get_current_user), skip: int = 0, limit: int = 50
 ):
 ):
-    return Chats.get_chat_list_by_user_id(db, user.id, skip, limit)
+    return Chats.get_chat_list_by_user_id(user.id, skip, limit)
 
 
 
 
 ############################
 ############################
@@ -57,7 +56,7 @@ async def get_session_user_chat_list(
 
 
 @router.delete("/", response_model=bool)
 @router.delete("/", response_model=bool)
 async def delete_all_user_chats(
 async def delete_all_user_chats(
-    request: Request, user=Depends(get_current_user), db=Depends(get_db)
+    request: Request, user=Depends(get_current_user)
 ):
 ):
 
 
     if (
     if (
@@ -69,7 +68,7 @@ async def delete_all_user_chats(
             detail=ERROR_MESSAGES.ACCESS_PROHIBITED,
             detail=ERROR_MESSAGES.ACCESS_PROHIBITED,
         )
         )
 
 
-    result = Chats.delete_chats_by_user_id(db, user.id)
+    result = Chats.delete_chats_by_user_id(user.id)
     return result
     return result
 
 
 
 
@@ -84,10 +83,9 @@ async def get_user_chat_list_by_user_id(
     user=Depends(get_admin_user),
     user=Depends(get_admin_user),
     skip: int = 0,
     skip: int = 0,
     limit: int = 50,
     limit: int = 50,
-    db=Depends(get_db),
 ):
 ):
     return Chats.get_chat_list_by_user_id(
     return Chats.get_chat_list_by_user_id(
-        db, user_id, include_archived=True, skip=skip, limit=limit
+        user_id, include_archived=True, skip=skip, limit=limit
     )
     )
 
 
 
 
@@ -98,10 +96,10 @@ async def get_user_chat_list_by_user_id(
 
 
 @router.post("/new", response_model=Optional[ChatResponse])
 @router.post("/new", response_model=Optional[ChatResponse])
 async def create_new_chat(
 async def create_new_chat(
-    form_data: ChatForm, user=Depends(get_current_user), db=Depends(get_db)
+    form_data: ChatForm, user=Depends(get_current_user)
 ):
 ):
     try:
     try:
-        chat = Chats.insert_new_chat(db, user.id, form_data)
+        chat = Chats.insert_new_chat(user.id, form_data)
         return ChatResponse(**{**chat.model_dump(), "chat": json.loads(chat.chat)})
         return ChatResponse(**{**chat.model_dump(), "chat": json.loads(chat.chat)})
     except Exception as e:
     except Exception as e:
         log.exception(e)
         log.exception(e)
@@ -116,10 +114,10 @@ async def create_new_chat(
 
 
 
 
 @router.get("/all", response_model=List[ChatResponse])
 @router.get("/all", response_model=List[ChatResponse])
-async def get_user_chats(user=Depends(get_current_user), db=Depends(get_db)):
+async def get_user_chats(user=Depends(get_current_user)):
     return [
     return [
         ChatResponse(**{**chat.model_dump(), "chat": json.loads(chat.chat)})
         ChatResponse(**{**chat.model_dump(), "chat": json.loads(chat.chat)})
-        for chat in Chats.get_chats_by_user_id(db, user.id)
+        for chat in Chats.get_chats_by_user_id(user.id)
     ]
     ]
 
 
 
 
@@ -129,10 +127,10 @@ async def get_user_chats(user=Depends(get_current_user), db=Depends(get_db)):
 
 
 
 
 @router.get("/all/archived", response_model=List[ChatResponse])
 @router.get("/all/archived", response_model=List[ChatResponse])
-async def get_user_archived_chats(user=Depends(get_current_user), db=Depends(get_db)):
+async def get_user_archived_chats(user=Depends(get_current_user)):
     return [
     return [
         ChatResponse(**{**chat.model_dump(), "chat": json.loads(chat.chat)})
         ChatResponse(**{**chat.model_dump(), "chat": json.loads(chat.chat)})
-        for chat in Chats.get_archived_chats_by_user_id(db, user.id)
+        for chat in Chats.get_archived_chats_by_user_id(user.id)
     ]
     ]
 
 
 
 
@@ -142,7 +140,7 @@ async def get_user_archived_chats(user=Depends(get_current_user), db=Depends(get
 
 
 
 
 @router.get("/all/db", response_model=List[ChatResponse])
 @router.get("/all/db", response_model=List[ChatResponse])
-async def get_all_user_chats_in_db(user=Depends(get_admin_user), db=Depends(get_db)):
+async def get_all_user_chats_in_db(user=Depends(get_admin_user)):
     if not ENABLE_ADMIN_EXPORT:
     if not ENABLE_ADMIN_EXPORT:
         raise HTTPException(
         raise HTTPException(
             status_code=status.HTTP_401_UNAUTHORIZED,
             status_code=status.HTTP_401_UNAUTHORIZED,
@@ -150,7 +148,7 @@ async def get_all_user_chats_in_db(user=Depends(get_admin_user), db=Depends(get_
         )
         )
     return [
     return [
         ChatResponse(**{**chat.model_dump(), "chat": json.loads(chat.chat)})
         ChatResponse(**{**chat.model_dump(), "chat": json.loads(chat.chat)})
-        for chat in Chats.get_chats(db)
+        for chat in Chats.get_chats()
     ]
     ]
 
 
 
 
@@ -161,9 +159,9 @@ async def get_all_user_chats_in_db(user=Depends(get_admin_user), db=Depends(get_
 
 
 @router.get("/archived", response_model=List[ChatTitleIdResponse])
 @router.get("/archived", response_model=List[ChatTitleIdResponse])
 async def get_archived_session_user_chat_list(
 async def get_archived_session_user_chat_list(
-    user=Depends(get_current_user), skip: int = 0, limit: int = 50, db=Depends(get_db)
+    user=Depends(get_current_user), skip: int = 0, limit: int = 50
 ):
 ):
-    return Chats.get_archived_chat_list_by_user_id(db, user.id, skip, limit)
+    return Chats.get_archived_chat_list_by_user_id(user.id, skip, limit)
 
 
 
 
 ############################
 ############################
@@ -172,8 +170,8 @@ async def get_archived_session_user_chat_list(
 
 
 
 
 @router.post("/archive/all", response_model=bool)
 @router.post("/archive/all", response_model=bool)
-async def archive_all_chats(user=Depends(get_current_user), db=Depends(get_db)):
-    return Chats.archive_all_chats_by_user_id(db, user.id)
+async def archive_all_chats(user=Depends(get_current_user)):
+    return Chats.archive_all_chats_by_user_id(user.id)
 
 
 
 
 ############################
 ############################
@@ -183,7 +181,7 @@ async def archive_all_chats(user=Depends(get_current_user), db=Depends(get_db)):
 
 
 @router.get("/share/{share_id}", response_model=Optional[ChatResponse])
 @router.get("/share/{share_id}", response_model=Optional[ChatResponse])
 async def get_shared_chat_by_id(
 async def get_shared_chat_by_id(
-    share_id: str, user=Depends(get_current_user), db=Depends(get_db)
+    share_id: str, user=Depends(get_current_user)
 ):
 ):
     if user.role == "pending":
     if user.role == "pending":
         raise HTTPException(
         raise HTTPException(
@@ -191,9 +189,9 @@ async def get_shared_chat_by_id(
         )
         )
 
 
     if user.role == "user":
     if user.role == "user":
-        chat = Chats.get_chat_by_share_id(db, share_id)
+        chat = Chats.get_chat_by_share_id(share_id)
     elif user.role == "admin":
     elif user.role == "admin":
-        chat = Chats.get_chat_by_id(db, share_id)
+        chat = Chats.get_chat_by_id(share_id)
 
 
     if chat:
     if chat:
         return ChatResponse(**{**chat.model_dump(), "chat": json.loads(chat.chat)})
         return ChatResponse(**{**chat.model_dump(), "chat": json.loads(chat.chat)})
@@ -216,23 +214,23 @@ class TagNameForm(BaseModel):
 
 
 @router.post("/tags", response_model=List[ChatTitleIdResponse])
 @router.post("/tags", response_model=List[ChatTitleIdResponse])
 async def get_user_chat_list_by_tag_name(
 async def get_user_chat_list_by_tag_name(
-    form_data: TagNameForm, user=Depends(get_current_user), db=Depends(get_db)
+    form_data: TagNameForm, user=Depends(get_current_user)
 ):
 ):
 
 
     print(form_data)
     print(form_data)
     chat_ids = [
     chat_ids = [
         chat_id_tag.chat_id
         chat_id_tag.chat_id
         for chat_id_tag in Tags.get_chat_ids_by_tag_name_and_user_id(
         for chat_id_tag in Tags.get_chat_ids_by_tag_name_and_user_id(
-            db, form_data.name, user.id
+            form_data.name, user.id
         )
         )
     ]
     ]
 
 
     chats = Chats.get_chat_list_by_chat_ids(
     chats = Chats.get_chat_list_by_chat_ids(
-        db, chat_ids, form_data.skip, form_data.limit
+        chat_ids, form_data.skip, form_data.limit
     )
     )
 
 
     if len(chats) == 0:
     if len(chats) == 0:
-        Tags.delete_tag_by_tag_name_and_user_id(db, form_data.name, user.id)
+        Tags.delete_tag_by_tag_name_and_user_id(form_data.name, user.id)
 
 
     return chats
     return chats
 
 
@@ -243,9 +241,9 @@ async def get_user_chat_list_by_tag_name(
 
 
 
 
 @router.get("/tags/all", response_model=List[TagModel])
 @router.get("/tags/all", response_model=List[TagModel])
-async def get_all_tags(user=Depends(get_current_user), db=Depends(get_db)):
+async def get_all_tags(user=Depends(get_current_user)):
     try:
     try:
-        tags = Tags.get_tags_by_user_id(db, user.id)
+        tags = Tags.get_tags_by_user_id(user.id)
         return tags
         return tags
     except Exception as e:
     except Exception as e:
         log.exception(e)
         log.exception(e)
@@ -260,8 +258,8 @@ async def get_all_tags(user=Depends(get_current_user), db=Depends(get_db)):
 
 
 
 
 @router.get("/{id}", response_model=Optional[ChatResponse])
 @router.get("/{id}", response_model=Optional[ChatResponse])
-async def get_chat_by_id(id: str, user=Depends(get_current_user), db=Depends(get_db)):
-    chat = Chats.get_chat_by_id_and_user_id(db, id, user.id)
+async def get_chat_by_id(id: str, user=Depends(get_current_user)):
+    chat = Chats.get_chat_by_id_and_user_id(id, user.id)
 
 
     if chat:
     if chat:
         return ChatResponse(**{**chat.model_dump(), "chat": json.loads(chat.chat)})
         return ChatResponse(**{**chat.model_dump(), "chat": json.loads(chat.chat)})
@@ -278,13 +276,13 @@ async def get_chat_by_id(id: str, user=Depends(get_current_user), db=Depends(get
 
 
 @router.post("/{id}", response_model=Optional[ChatResponse])
 @router.post("/{id}", response_model=Optional[ChatResponse])
 async def update_chat_by_id(
 async def update_chat_by_id(
-    id: str, form_data: ChatForm, user=Depends(get_current_user), db=Depends(get_db)
+    id: str, form_data: ChatForm, user=Depends(get_current_user)
 ):
 ):
-    chat = Chats.get_chat_by_id_and_user_id(db, id, user.id)
+    chat = Chats.get_chat_by_id_and_user_id(id, user.id)
     if chat:
     if chat:
         updated_chat = {**json.loads(chat.chat), **form_data.chat}
         updated_chat = {**json.loads(chat.chat), **form_data.chat}
 
 
-        chat = Chats.update_chat_by_id(db, id, updated_chat)
+        chat = Chats.update_chat_by_id(id, updated_chat)
         return ChatResponse(**{**chat.model_dump(), "chat": json.loads(chat.chat)})
         return ChatResponse(**{**chat.model_dump(), "chat": json.loads(chat.chat)})
     else:
     else:
         raise HTTPException(
         raise HTTPException(
@@ -300,11 +298,11 @@ async def update_chat_by_id(
 
 
 @router.delete("/{id}", response_model=bool)
 @router.delete("/{id}", response_model=bool)
 async def delete_chat_by_id(
 async def delete_chat_by_id(
-    request: Request, id: str, user=Depends(get_current_user), db=Depends(get_db)
+    request: Request, id: str, user=Depends(get_current_user)
 ):
 ):
 
 
     if user.role == "admin":
     if user.role == "admin":
-        result = Chats.delete_chat_by_id(db, id)
+        result = Chats.delete_chat_by_id(id)
         return result
         return result
     else:
     else:
         if not request.app.state.config.USER_PERMISSIONS["chat"]["deletion"]:
         if not request.app.state.config.USER_PERMISSIONS["chat"]["deletion"]:
@@ -313,7 +311,7 @@ async def delete_chat_by_id(
                 detail=ERROR_MESSAGES.ACCESS_PROHIBITED,
                 detail=ERROR_MESSAGES.ACCESS_PROHIBITED,
             )
             )
 
 
-        result = Chats.delete_chat_by_id_and_user_id(db, id, user.id)
+        result = Chats.delete_chat_by_id_and_user_id(id, user.id)
         return result
         return result
 
 
 
 
@@ -323,8 +321,8 @@ async def delete_chat_by_id(
 
 
 
 
 @router.get("/{id}/clone", response_model=Optional[ChatResponse])
 @router.get("/{id}/clone", response_model=Optional[ChatResponse])
-async def clone_chat_by_id(id: str, user=Depends(get_current_user), db=Depends(get_db)):
-    chat = Chats.get_chat_by_id_and_user_id(db, id, user.id)
+async def clone_chat_by_id(id: str, user=Depends(get_current_user)):
+    chat = Chats.get_chat_by_id_and_user_id(id, user.id)
     if chat:
     if chat:
 
 
         chat_body = json.loads(chat.chat)
         chat_body = json.loads(chat.chat)
@@ -335,7 +333,7 @@ async def clone_chat_by_id(id: str, user=Depends(get_current_user), db=Depends(g
             "title": f"Clone of {chat.title}",
             "title": f"Clone of {chat.title}",
         }
         }
 
 
-        chat = Chats.insert_new_chat(db, user.id, ChatForm(**{"chat": updated_chat}))
+        chat = Chats.insert_new_chat(user.id, ChatForm(**{"chat": updated_chat}))
         return ChatResponse(**{**chat.model_dump(), "chat": json.loads(chat.chat)})
         return ChatResponse(**{**chat.model_dump(), "chat": json.loads(chat.chat)})
     else:
     else:
         raise HTTPException(
         raise HTTPException(
@@ -350,11 +348,11 @@ async def clone_chat_by_id(id: str, user=Depends(get_current_user), db=Depends(g
 
 
 @router.get("/{id}/archive", response_model=Optional[ChatResponse])
 @router.get("/{id}/archive", response_model=Optional[ChatResponse])
 async def archive_chat_by_id(
 async def archive_chat_by_id(
-    id: str, user=Depends(get_current_user), db=Depends(get_db)
+    id: str, user=Depends(get_current_user)
 ):
 ):
-    chat = Chats.get_chat_by_id_and_user_id(db, id, user.id)
+    chat = Chats.get_chat_by_id_and_user_id(id, user.id)
     if chat:
     if chat:
-        chat = Chats.toggle_chat_archive_by_id(db, id)
+        chat = Chats.toggle_chat_archive_by_id(id)
         return ChatResponse(**{**chat.model_dump(), "chat": json.loads(chat.chat)})
         return ChatResponse(**{**chat.model_dump(), "chat": json.loads(chat.chat)})
     else:
     else:
         raise HTTPException(
         raise HTTPException(
@@ -368,16 +366,16 @@ async def archive_chat_by_id(
 
 
 
 
 @router.post("/{id}/share", response_model=Optional[ChatResponse])
 @router.post("/{id}/share", response_model=Optional[ChatResponse])
-async def share_chat_by_id(id: str, user=Depends(get_current_user), db=Depends(get_db)):
-    chat = Chats.get_chat_by_id_and_user_id(db, id, user.id)
+async def share_chat_by_id(id: str, user=Depends(get_current_user)):
+    chat = Chats.get_chat_by_id_and_user_id(id, user.id)
     if chat:
     if chat:
         if chat.share_id:
         if chat.share_id:
-            shared_chat = Chats.update_shared_chat_by_chat_id(db, chat.id)
+            shared_chat = Chats.update_shared_chat_by_chat_id(chat.id)
             return ChatResponse(
             return ChatResponse(
                 **{**shared_chat.model_dump(), "chat": json.loads(shared_chat.chat)}
                 **{**shared_chat.model_dump(), "chat": json.loads(shared_chat.chat)}
             )
             )
 
 
-        shared_chat = Chats.insert_shared_chat_by_chat_id(db, chat.id)
+        shared_chat = Chats.insert_shared_chat_by_chat_id(chat.id)
         if not shared_chat:
         if not shared_chat:
             raise HTTPException(
             raise HTTPException(
                 status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
                 status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
@@ -401,15 +399,15 @@ async def share_chat_by_id(id: str, user=Depends(get_current_user), db=Depends(g
 
 
 @router.delete("/{id}/share", response_model=Optional[bool])
 @router.delete("/{id}/share", response_model=Optional[bool])
 async def delete_shared_chat_by_id(
 async def delete_shared_chat_by_id(
-    id: str, user=Depends(get_current_user), db=Depends(get_db)
+    id: str, user=Depends(get_current_user)
 ):
 ):
-    chat = Chats.get_chat_by_id_and_user_id(db, id, user.id)
+    chat = Chats.get_chat_by_id_and_user_id(id, user.id)
     if chat:
     if chat:
         if not chat.share_id:
         if not chat.share_id:
             return False
             return False
 
 
-        result = Chats.delete_shared_chat_by_chat_id(db, id)
-        update_result = Chats.update_chat_share_id_by_id(db, id, None)
+        result = Chats.delete_shared_chat_by_chat_id(id)
+        update_result = Chats.update_chat_share_id_by_id(id, None)
 
 
         return result and update_result != None
         return result and update_result != None
     else:
     else:
@@ -426,9 +424,9 @@ async def delete_shared_chat_by_id(
 
 
 @router.get("/{id}/tags", response_model=List[TagModel])
 @router.get("/{id}/tags", response_model=List[TagModel])
 async def get_chat_tags_by_id(
 async def get_chat_tags_by_id(
-    id: str, user=Depends(get_current_user), db=Depends(get_db)
+    id: str, user=Depends(get_current_user)
 ):
 ):
-    tags = Tags.get_tags_by_chat_id_and_user_id(db, id, user.id)
+    tags = Tags.get_tags_by_chat_id_and_user_id(id, user.id)
 
 
     if tags != None:
     if tags != None:
         return tags
         return tags
@@ -447,13 +445,12 @@ async def get_chat_tags_by_id(
 async def add_chat_tag_by_id(
 async def add_chat_tag_by_id(
     id: str,
     id: str,
     form_data: ChatIdTagForm,
     form_data: ChatIdTagForm,
-    user=Depends(get_current_user),
-    db=Depends(get_db),
+    user=Depends(get_current_user)
 ):
 ):
-    tags = Tags.get_tags_by_chat_id_and_user_id(db, id, user.id)
+    tags = Tags.get_tags_by_chat_id_and_user_id(id, user.id)
 
 
     if form_data.tag_name not in tags:
     if form_data.tag_name not in tags:
-        tag = Tags.add_tag_to_chat(db, user.id, form_data)
+        tag = Tags.add_tag_to_chat(user.id, form_data)
 
 
         if tag:
         if tag:
             return tag
             return tag
@@ -478,10 +475,9 @@ async def delete_chat_tag_by_id(
     id: str,
     id: str,
     form_data: ChatIdTagForm,
     form_data: ChatIdTagForm,
     user=Depends(get_current_user),
     user=Depends(get_current_user),
-    db=Depends(get_db),
 ):
 ):
     result = Tags.delete_tag_by_tag_name_and_chat_id_and_user_id(
     result = Tags.delete_tag_by_tag_name_and_chat_id_and_user_id(
-        db, form_data.tag_name, id, user.id
+        form_data.tag_name, id, user.id
     )
     )
 
 
     if result:
     if result:
@@ -499,9 +495,9 @@ async def delete_chat_tag_by_id(
 
 
 @router.delete("/{id}/tags/all", response_model=Optional[bool])
 @router.delete("/{id}/tags/all", response_model=Optional[bool])
 async def delete_all_chat_tags_by_id(
 async def delete_all_chat_tags_by_id(
-    id: str, user=Depends(get_current_user), db=Depends(get_db)
+    id: str, user=Depends(get_current_user)
 ):
 ):
-    result = Tags.delete_tags_by_chat_id_and_user_id(db, id, user.id)
+    result = Tags.delete_tags_by_chat_id_and_user_id(id, user.id)
 
 
     if result:
     if result:
         return result
         return result

+ 12 - 14
backend/apps/webui/routers/documents.py

@@ -6,7 +6,6 @@ from fastapi import APIRouter
 from pydantic import BaseModel
 from pydantic import BaseModel
 import json
 import json
 
 
-from apps.webui.internal.db import get_db
 from apps.webui.models.documents import (
 from apps.webui.models.documents import (
     Documents,
     Documents,
     DocumentForm,
     DocumentForm,
@@ -26,7 +25,7 @@ router = APIRouter()
 
 
 
 
 @router.get("/", response_model=List[DocumentResponse])
 @router.get("/", response_model=List[DocumentResponse])
-async def get_documents(user=Depends(get_current_user), db=Depends(get_db)):
+async def get_documents(user=Depends(get_current_user)):
     docs = [
     docs = [
         DocumentResponse(
         DocumentResponse(
             **{
             **{
@@ -34,7 +33,7 @@ async def get_documents(user=Depends(get_current_user), db=Depends(get_db)):
                 "content": json.loads(doc.content if doc.content else "{}"),
                 "content": json.loads(doc.content if doc.content else "{}"),
             }
             }
         )
         )
-        for doc in Documents.get_docs(db)
+        for doc in Documents.get_docs()
     ]
     ]
     return docs
     return docs
 
 
@@ -46,11 +45,11 @@ async def get_documents(user=Depends(get_current_user), db=Depends(get_db)):
 
 
 @router.post("/create", response_model=Optional[DocumentResponse])
 @router.post("/create", response_model=Optional[DocumentResponse])
 async def create_new_doc(
 async def create_new_doc(
-    form_data: DocumentForm, user=Depends(get_admin_user), db=Depends(get_db)
+    form_data: DocumentForm, user=Depends(get_admin_user)
 ):
 ):
-    doc = Documents.get_doc_by_name(db, form_data.name)
+    doc = Documents.get_doc_by_name(form_data.name)
     if doc == None:
     if doc == None:
-        doc = Documents.insert_new_doc(db, user.id, form_data)
+        doc = Documents.insert_new_doc(user.id, form_data)
 
 
         if doc:
         if doc:
             return DocumentResponse(
             return DocumentResponse(
@@ -78,9 +77,9 @@ async def create_new_doc(
 
 
 @router.get("/doc", response_model=Optional[DocumentResponse])
 @router.get("/doc", response_model=Optional[DocumentResponse])
 async def get_doc_by_name(
 async def get_doc_by_name(
-    name: str, user=Depends(get_current_user), db=Depends(get_db)
+    name: str, user=Depends(get_current_user)
 ):
 ):
-    doc = Documents.get_doc_by_name(db, name)
+    doc = Documents.get_doc_by_name(name)
 
 
     if doc:
     if doc:
         return DocumentResponse(
         return DocumentResponse(
@@ -112,10 +111,10 @@ class TagDocumentForm(BaseModel):
 
 
 @router.post("/doc/tags", response_model=Optional[DocumentResponse])
 @router.post("/doc/tags", response_model=Optional[DocumentResponse])
 async def tag_doc_by_name(
 async def tag_doc_by_name(
-    form_data: TagDocumentForm, user=Depends(get_current_user), db=Depends(get_db)
+    form_data: TagDocumentForm, user=Depends(get_current_user)
 ):
 ):
     doc = Documents.update_doc_content_by_name(
     doc = Documents.update_doc_content_by_name(
-        db, form_data.name, {"tags": form_data.tags}
+        form_data.name, {"tags": form_data.tags}
     )
     )
 
 
     if doc:
     if doc:
@@ -142,9 +141,8 @@ async def update_doc_by_name(
     name: str,
     name: str,
     form_data: DocumentUpdateForm,
     form_data: DocumentUpdateForm,
     user=Depends(get_admin_user),
     user=Depends(get_admin_user),
-    db=Depends(get_db),
 ):
 ):
-    doc = Documents.update_doc_by_name(db, name, form_data)
+    doc = Documents.update_doc_by_name(name, form_data)
     if doc:
     if doc:
         return DocumentResponse(
         return DocumentResponse(
             **{
             **{
@@ -166,7 +164,7 @@ async def update_doc_by_name(
 
 
 @router.delete("/doc/delete", response_model=bool)
 @router.delete("/doc/delete", response_model=bool)
 async def delete_doc_by_name(
 async def delete_doc_by_name(
-    name: str, user=Depends(get_admin_user), db=Depends(get_db)
+    name: str, user=Depends(get_admin_user)
 ):
 ):
-    result = Documents.delete_doc_by_name(db, name)
+    result = Documents.delete_doc_by_name(name)
     return result
     return result

+ 12 - 15
backend/apps/webui/routers/files.py

@@ -20,7 +20,6 @@ from fastapi.responses import StreamingResponse, JSONResponse, FileResponse
 from pydantic import BaseModel
 from pydantic import BaseModel
 import json
 import json
 
 
-from apps.webui.internal.db import get_db
 from apps.webui.models.files import (
 from apps.webui.models.files import (
     Files,
     Files,
     FileForm,
     FileForm,
@@ -53,8 +52,7 @@ router = APIRouter()
 @router.post("/")
 @router.post("/")
 def upload_file(
 def upload_file(
     file: UploadFile = File(...),
     file: UploadFile = File(...),
-    user=Depends(get_verified_user),
-    db=Depends(get_db)
+    user=Depends(get_verified_user)
 ):
 ):
     log.info(f"file.content_type: {file.content_type}")
     log.info(f"file.content_type: {file.content_type}")
     try:
     try:
@@ -72,7 +70,6 @@ def upload_file(
             f.close()
             f.close()
 
 
         file = Files.insert_new_file(
         file = Files.insert_new_file(
-            db,
             user.id,
             user.id,
             FileForm(
             FileForm(
                 **{
                 **{
@@ -109,8 +106,8 @@ def upload_file(
 
 
 
 
 @router.get("/", response_model=List[FileModel])
 @router.get("/", response_model=List[FileModel])
-async def list_files(user=Depends(get_verified_user), db=Depends(get_db)):
-    files = Files.get_files(db)
+async def list_files(user=Depends(get_verified_user)):
+    files = Files.get_files()
     return files
     return files
 
 
 
 
@@ -120,8 +117,8 @@ async def list_files(user=Depends(get_verified_user), db=Depends(get_db)):
 
 
 
 
 @router.delete("/all")
 @router.delete("/all")
-async def delete_all_files(user=Depends(get_admin_user), db=Depends(get_db)):
-    result = Files.delete_all_files(db)
+async def delete_all_files(user=Depends(get_admin_user)):
+    result = Files.delete_all_files()
 
 
     if result:
     if result:
         folder = f"{UPLOAD_DIR}"
         folder = f"{UPLOAD_DIR}"
@@ -157,8 +154,8 @@ async def delete_all_files(user=Depends(get_admin_user), db=Depends(get_db)):
 
 
 
 
 @router.get("/{id}", response_model=Optional[FileModel])
 @router.get("/{id}", response_model=Optional[FileModel])
-async def get_file_by_id(id: str, user=Depends(get_verified_user), db=Depends(get_db)):
-    file = Files.get_file_by_id(db, id)
+async def get_file_by_id(id: str, user=Depends(get_verified_user)):
+    file = Files.get_file_by_id(id)
 
 
     if file:
     if file:
         return file
         return file
@@ -175,8 +172,8 @@ async def get_file_by_id(id: str, user=Depends(get_verified_user), db=Depends(ge
 
 
 
 
 @router.get("/{id}/content", response_model=Optional[FileModel])
 @router.get("/{id}/content", response_model=Optional[FileModel])
-async def get_file_content_by_id(id: str, user=Depends(get_verified_user), db=Depends(get_db)):
-    file = Files.get_file_by_id(db, id)
+async def get_file_content_by_id(id: str, user=Depends(get_verified_user)):
+    file = Files.get_file_by_id(id)
 
 
     if file:
     if file:
         file_path = Path(file.meta["path"])
         file_path = Path(file.meta["path"])
@@ -226,11 +223,11 @@ async def get_file_content_by_id(id: str, user=Depends(get_verified_user)):
 
 
 
 
 @router.delete("/{id}")
 @router.delete("/{id}")
-async def delete_file_by_id(id: str, user=Depends(get_verified_user), db=Depends(get_db)):
-    file = Files.get_file_by_id(db, id)
+async def delete_file_by_id(id: str, user=Depends(get_verified_user)):
+    file = Files.get_file_by_id(id)
 
 
     if file:
     if file:
-        result = Files.delete_file_by_id(db, id)
+        result = Files.delete_file_by_id(id)
         if result:
         if result:
             return {"message": "File deleted successfully"}
             return {"message": "File deleted successfully"}
         else:
         else:

+ 13 - 14
backend/apps/webui/routers/functions.py

@@ -6,7 +6,6 @@ from fastapi import APIRouter
 from pydantic import BaseModel
 from pydantic import BaseModel
 import json
 import json
 
 
-from apps.webui.internal.db import get_db
 from apps.webui.models.functions import (
 from apps.webui.models.functions import (
     Functions,
     Functions,
     FunctionForm,
     FunctionForm,
@@ -32,8 +31,8 @@ router = APIRouter()
 
 
 
 
 @router.get("/", response_model=List[FunctionResponse])
 @router.get("/", response_model=List[FunctionResponse])
-async def get_functions(user=Depends(get_verified_user), db=Depends(get_db)):
-    return Functions.get_functions(db)
+async def get_functions(user=Depends(get_verified_user)):
+    return Functions.get_functions()
 
 
 
 
 ############################
 ############################
@@ -42,8 +41,8 @@ async def get_functions(user=Depends(get_verified_user), db=Depends(get_db)):
 
 
 
 
 @router.get("/export", response_model=List[FunctionModel])
 @router.get("/export", response_model=List[FunctionModel])
-async def get_functions(user=Depends(get_admin_user), db=Depends(get_db)):
-    return Functions.get_functions(db)
+async def get_functions(user=Depends(get_admin_user)):
+    return Functions.get_functions()
 
 
 
 
 ############################
 ############################
@@ -53,7 +52,7 @@ async def get_functions(user=Depends(get_admin_user), db=Depends(get_db)):
 
 
 @router.post("/create", response_model=Optional[FunctionResponse])
 @router.post("/create", response_model=Optional[FunctionResponse])
 async def create_new_function(
 async def create_new_function(
-    request: Request, form_data: FunctionForm, user=Depends(get_admin_user), db=Depends(get_db)
+    request: Request, form_data: FunctionForm, user=Depends(get_admin_user)
 ):
 ):
     if not form_data.id.isidentifier():
     if not form_data.id.isidentifier():
         raise HTTPException(
         raise HTTPException(
@@ -63,7 +62,7 @@ async def create_new_function(
 
 
     form_data.id = form_data.id.lower()
     form_data.id = form_data.id.lower()
 
 
-    function = Functions.get_function_by_id(db, form_data.id)
+    function = Functions.get_function_by_id(form_data.id)
     if function == None:
     if function == None:
         function_path = os.path.join(FUNCTIONS_DIR, f"{form_data.id}.py")
         function_path = os.path.join(FUNCTIONS_DIR, f"{form_data.id}.py")
         try:
         try:
@@ -78,7 +77,7 @@ async def create_new_function(
             FUNCTIONS = request.app.state.FUNCTIONS
             FUNCTIONS = request.app.state.FUNCTIONS
             FUNCTIONS[form_data.id] = function_module
             FUNCTIONS[form_data.id] = function_module
 
 
-            function = Functions.insert_new_function(db, user.id, function_type, form_data)
+            function = Functions.insert_new_function(user.id, function_type, form_data)
 
 
             function_cache_dir = Path(CACHE_DIR) / "functions" / form_data.id
             function_cache_dir = Path(CACHE_DIR) / "functions" / form_data.id
             function_cache_dir.mkdir(parents=True, exist_ok=True)
             function_cache_dir.mkdir(parents=True, exist_ok=True)
@@ -109,8 +108,8 @@ async def create_new_function(
 
 
 
 
 @router.get("/id/{id}", response_model=Optional[FunctionModel])
 @router.get("/id/{id}", response_model=Optional[FunctionModel])
-async def get_function_by_id(id: str, user=Depends(get_admin_user), db=Depends(get_db)):
-    function = Functions.get_function_by_id(db, id)
+async def get_function_by_id(id: str, user=Depends(get_admin_user)):
+    function = Functions.get_function_by_id(id)
 
 
     if function:
     if function:
         return function
         return function
@@ -155,7 +154,7 @@ async def toggle_function_by_id(id: str, user=Depends(get_admin_user)):
 
 
 @router.post("/id/{id}/update", response_model=Optional[FunctionModel])
 @router.post("/id/{id}/update", response_model=Optional[FunctionModel])
 async def update_function_by_id(
 async def update_function_by_id(
-    request: Request, id: str, form_data: FunctionForm, user=Depends(get_admin_user), db=Depends(get_db)
+    request: Request, id: str, form_data: FunctionForm, user=Depends(get_admin_user)
 ):
 ):
     function_path = os.path.join(FUNCTIONS_DIR, f"{id}.py")
     function_path = os.path.join(FUNCTIONS_DIR, f"{id}.py")
 
 
@@ -172,7 +171,7 @@ async def update_function_by_id(
         updated = {**form_data.model_dump(exclude={"id"}), "type": function_type}
         updated = {**form_data.model_dump(exclude={"id"}), "type": function_type}
         print(updated)
         print(updated)
 
 
-        function = Functions.update_function_by_id(db, id, updated)
+        function = Functions.update_function_by_id(id, updated)
 
 
         if function:
         if function:
             return function
             return function
@@ -196,9 +195,9 @@ async def update_function_by_id(
 
 
 @router.delete("/id/{id}/delete", response_model=bool)
 @router.delete("/id/{id}/delete", response_model=bool)
 async def delete_function_by_id(
 async def delete_function_by_id(
-    request: Request, id: str, user=Depends(get_admin_user), db=Depends(get_db)
+    request: Request, id: str, user=Depends(get_admin_user)
 ):
 ):
-    result = Functions.delete_function_by_id(db, id)
+    result = Functions.delete_function_by_id(id)
 
 
     if result:
     if result:
         FUNCTIONS = request.app.state.FUNCTIONS
         FUNCTIONS = request.app.state.FUNCTIONS

+ 10 - 13
backend/apps/webui/routers/memories.py

@@ -7,7 +7,6 @@ from fastapi import APIRouter
 from pydantic import BaseModel
 from pydantic import BaseModel
 import logging
 import logging
 
 
-from apps.webui.internal.db import get_db
 from apps.webui.models.memories import Memories, MemoryModel
 from apps.webui.models.memories import Memories, MemoryModel
 
 
 from utils.utils import get_verified_user
 from utils.utils import get_verified_user
@@ -32,8 +31,8 @@ async def get_embeddings(request: Request):
 
 
 
 
 @router.get("/", response_model=List[MemoryModel])
 @router.get("/", response_model=List[MemoryModel])
-async def get_memories(user=Depends(get_verified_user), db=Depends(get_db)):
-    return Memories.get_memories_by_user_id(db, user.id)
+async def get_memories(user=Depends(get_verified_user)):
+    return Memories.get_memories_by_user_id(user.id)
 
 
 
 
 ############################
 ############################
@@ -54,9 +53,8 @@ async def add_memory(
     request: Request,
     request: Request,
     form_data: AddMemoryForm,
     form_data: AddMemoryForm,
     user=Depends(get_verified_user),
     user=Depends(get_verified_user),
-    db=Depends(get_db),
 ):
 ):
-    memory = Memories.insert_new_memory(db, user.id, form_data.content)
+    memory = Memories.insert_new_memory(user.id, form_data.content)
     memory_embedding = request.app.state.EMBEDDING_FUNCTION(memory.content)
     memory_embedding = request.app.state.EMBEDDING_FUNCTION(memory.content)
 
 
     collection = CHROMA_CLIENT.get_or_create_collection(name=f"user-memory-{user.id}")
     collection = CHROMA_CLIENT.get_or_create_collection(name=f"user-memory-{user.id}")
@@ -76,9 +74,8 @@ async def update_memory_by_id(
     request: Request,
     request: Request,
     form_data: MemoryUpdateModel,
     form_data: MemoryUpdateModel,
     user=Depends(get_verified_user),
     user=Depends(get_verified_user),
-    db=Depends(get_db),
 ):
 ):
-    memory = Memories.update_memory_by_id(db, memory_id, form_data.content)
+    memory = Memories.update_memory_by_id(memory_id, form_data.content)
     if memory is None:
     if memory is None:
         raise HTTPException(status_code=404, detail="Memory not found")
         raise HTTPException(status_code=404, detail="Memory not found")
 
 
@@ -129,12 +126,12 @@ async def query_memory(
 ############################
 ############################
 @router.get("/reset", response_model=bool)
 @router.get("/reset", response_model=bool)
 async def reset_memory_from_vector_db(
 async def reset_memory_from_vector_db(
-    request: Request, user=Depends(get_verified_user), db=Depends(get_db)
+    request: Request, user=Depends(get_verified_user)
 ):
 ):
     CHROMA_CLIENT.delete_collection(f"user-memory-{user.id}")
     CHROMA_CLIENT.delete_collection(f"user-memory-{user.id}")
     collection = CHROMA_CLIENT.get_or_create_collection(name=f"user-memory-{user.id}")
     collection = CHROMA_CLIENT.get_or_create_collection(name=f"user-memory-{user.id}")
 
 
-    memories = Memories.get_memories_by_user_id(db, user.id)
+    memories = Memories.get_memories_by_user_id(user.id)
     for memory in memories:
     for memory in memories:
         memory_embedding = request.app.state.EMBEDDING_FUNCTION(memory.content)
         memory_embedding = request.app.state.EMBEDDING_FUNCTION(memory.content)
         collection.upsert(
         collection.upsert(
@@ -151,8 +148,8 @@ async def reset_memory_from_vector_db(
 
 
 
 
 @router.delete("/user", response_model=bool)
 @router.delete("/user", response_model=bool)
-async def delete_memory_by_user_id(user=Depends(get_verified_user), db=Depends(get_db)):
-    result = Memories.delete_memories_by_user_id(db, user.id)
+async def delete_memory_by_user_id(user=Depends(get_verified_user)):
+    result = Memories.delete_memories_by_user_id(user.id)
 
 
     if result:
     if result:
         try:
         try:
@@ -171,9 +168,9 @@ async def delete_memory_by_user_id(user=Depends(get_verified_user), db=Depends(g
 
 
 @router.delete("/{memory_id}", response_model=bool)
 @router.delete("/{memory_id}", response_model=bool)
 async def delete_memory_by_id(
 async def delete_memory_by_id(
-    memory_id: str, user=Depends(get_verified_user), db=Depends(get_db)
+    memory_id: str, user=Depends(get_verified_user)
 ):
 ):
-    result = Memories.delete_memory_by_id_and_user_id(db, memory_id, user.id)
+    result = Memories.delete_memory_by_id_and_user_id(memory_id, user.id)
 
 
     if result:
     if result:
         collection = CHROMA_CLIENT.get_or_create_collection(
         collection = CHROMA_CLIENT.get_or_create_collection(

+ 10 - 13
backend/apps/webui/routers/models.py

@@ -6,7 +6,6 @@ from fastapi import APIRouter
 from pydantic import BaseModel
 from pydantic import BaseModel
 import json
 import json
 
 
-from apps.webui.internal.db import get_db
 from apps.webui.models.models import Models, ModelModel, ModelForm, ModelResponse
 from apps.webui.models.models import Models, ModelModel, ModelForm, ModelResponse
 
 
 from utils.utils import get_verified_user, get_admin_user
 from utils.utils import get_verified_user, get_admin_user
@@ -20,8 +19,8 @@ router = APIRouter()
 
 
 
 
 @router.get("/", response_model=List[ModelResponse])
 @router.get("/", response_model=List[ModelResponse])
-async def get_models(user=Depends(get_verified_user), db=Depends(get_db)):
-    return Models.get_all_models(db)
+async def get_models(user=Depends(get_verified_user)):
+    return Models.get_all_models()
 
 
 
 
 ############################
 ############################
@@ -34,7 +33,6 @@ async def add_new_model(
     request: Request,
     request: Request,
     form_data: ModelForm,
     form_data: ModelForm,
     user=Depends(get_admin_user),
     user=Depends(get_admin_user),
-    db=Depends(get_db),
 ):
 ):
     if form_data.id in request.app.state.MODELS:
     if form_data.id in request.app.state.MODELS:
         raise HTTPException(
         raise HTTPException(
@@ -42,7 +40,7 @@ async def add_new_model(
             detail=ERROR_MESSAGES.MODEL_ID_TAKEN,
             detail=ERROR_MESSAGES.MODEL_ID_TAKEN,
         )
         )
     else:
     else:
-        model = Models.insert_new_model(db, form_data, user.id)
+        model = Models.insert_new_model(form_data, user.id)
 
 
         if model:
         if model:
             return model
             return model
@@ -59,8 +57,8 @@ async def add_new_model(
 
 
 
 
 @router.get("/{id}", response_model=Optional[ModelModel])
 @router.get("/{id}", response_model=Optional[ModelModel])
-async def get_model_by_id(id: str, user=Depends(get_verified_user), db=Depends(get_db)):
-    model = Models.get_model_by_id(db, id)
+async def get_model_by_id(id: str, user=Depends(get_verified_user)):
+    model = Models.get_model_by_id(id)
 
 
     if model:
     if model:
         return model
         return model
@@ -82,15 +80,14 @@ async def update_model_by_id(
     id: str,
     id: str,
     form_data: ModelForm,
     form_data: ModelForm,
     user=Depends(get_admin_user),
     user=Depends(get_admin_user),
-    db=Depends(get_db),
 ):
 ):
-    model = Models.get_model_by_id(db, id)
+    model = Models.get_model_by_id(id)
     if model:
     if model:
-        model = Models.update_model_by_id(db, id, form_data)
+        model = Models.update_model_by_id(id, form_data)
         return model
         return model
     else:
     else:
         if form_data.id in request.app.state.MODELS:
         if form_data.id in request.app.state.MODELS:
-            model = Models.insert_new_model(db, form_data, user.id)
+            model = Models.insert_new_model(form_data, user.id)
             if model:
             if model:
                 return model
                 return model
             else:
             else:
@@ -111,6 +108,6 @@ async def update_model_by_id(
 
 
 
 
 @router.delete("/delete", response_model=bool)
 @router.delete("/delete", response_model=bool)
-async def delete_model_by_id(id: str, user=Depends(get_admin_user), db=Depends(get_db)):
-    result = Models.delete_model_by_id(db, id)
+async def delete_model_by_id(id: str, user=Depends(get_admin_user)):
+    result = Models.delete_model_by_id(id)
     return result
     return result

+ 10 - 12
backend/apps/webui/routers/prompts.py

@@ -6,7 +6,6 @@ from fastapi import APIRouter
 from pydantic import BaseModel
 from pydantic import BaseModel
 import json
 import json
 
 
-from apps.webui.internal.db import get_db
 from apps.webui.models.prompts import Prompts, PromptForm, PromptModel
 from apps.webui.models.prompts import Prompts, PromptForm, PromptModel
 
 
 from utils.utils import get_current_user, get_admin_user
 from utils.utils import get_current_user, get_admin_user
@@ -20,8 +19,8 @@ router = APIRouter()
 
 
 
 
 @router.get("/", response_model=List[PromptModel])
 @router.get("/", response_model=List[PromptModel])
-async def get_prompts(user=Depends(get_current_user), db=Depends(get_db)):
-    return Prompts.get_prompts(db)
+async def get_prompts(user=Depends(get_current_user)):
+    return Prompts.get_prompts()
 
 
 
 
 ############################
 ############################
@@ -31,11 +30,11 @@ async def get_prompts(user=Depends(get_current_user), db=Depends(get_db)):
 
 
 @router.post("/create", response_model=Optional[PromptModel])
 @router.post("/create", response_model=Optional[PromptModel])
 async def create_new_prompt(
 async def create_new_prompt(
-    form_data: PromptForm, user=Depends(get_admin_user), db=Depends(get_db)
+    form_data: PromptForm, user=Depends(get_admin_user)
 ):
 ):
-    prompt = Prompts.get_prompt_by_command(db, form_data.command)
+    prompt = Prompts.get_prompt_by_command(form_data.command)
     if prompt == None:
     if prompt == None:
-        prompt = Prompts.insert_new_prompt(db, user.id, form_data)
+        prompt = Prompts.insert_new_prompt(user.id, form_data)
 
 
         if prompt:
         if prompt:
             return prompt
             return prompt
@@ -56,9 +55,9 @@ async def create_new_prompt(
 
 
 @router.get("/command/{command}", response_model=Optional[PromptModel])
 @router.get("/command/{command}", response_model=Optional[PromptModel])
 async def get_prompt_by_command(
 async def get_prompt_by_command(
-    command: str, user=Depends(get_current_user), db=Depends(get_db)
+    command: str, user=Depends(get_current_user)
 ):
 ):
-    prompt = Prompts.get_prompt_by_command(db, f"/{command}")
+    prompt = Prompts.get_prompt_by_command(f"/{command}")
 
 
     if prompt:
     if prompt:
         return prompt
         return prompt
@@ -79,9 +78,8 @@ async def update_prompt_by_command(
     command: str,
     command: str,
     form_data: PromptForm,
     form_data: PromptForm,
     user=Depends(get_admin_user),
     user=Depends(get_admin_user),
-    db=Depends(get_db),
 ):
 ):
-    prompt = Prompts.update_prompt_by_command(db, f"/{command}", form_data)
+    prompt = Prompts.update_prompt_by_command(f"/{command}", form_data)
     if prompt:
     if prompt:
         return prompt
         return prompt
     else:
     else:
@@ -98,7 +96,7 @@ async def update_prompt_by_command(
 
 
 @router.delete("/command/{command}/delete", response_model=bool)
 @router.delete("/command/{command}/delete", response_model=bool)
 async def delete_prompt_by_command(
 async def delete_prompt_by_command(
-    command: str, user=Depends(get_admin_user), db=Depends(get_db)
+    command: str, user=Depends(get_admin_user)
 ):
 ):
-    result = Prompts.delete_prompt_by_command(db, f"/{command}")
+    result = Prompts.delete_prompt_by_command(f"/{command}")
     return result
     return result

+ 10 - 13
backend/apps/webui/routers/tools.py

@@ -6,7 +6,6 @@ from fastapi import APIRouter
 from pydantic import BaseModel
 from pydantic import BaseModel
 import json
 import json
 
 
-from apps.webui.internal.db import get_db
 from apps.webui.models.users import Users
 from apps.webui.models.users import Users
 from apps.webui.models.tools import Tools, ToolForm, ToolModel, ToolResponse
 from apps.webui.models.tools import Tools, ToolForm, ToolModel, ToolResponse
 from apps.webui.utils import load_toolkit_module_by_id
 from apps.webui.utils import load_toolkit_module_by_id
@@ -34,7 +33,7 @@ router = APIRouter()
 
 
 
 
 @router.get("/", response_model=List[ToolResponse])
 @router.get("/", response_model=List[ToolResponse])
-async def get_toolkits(user=Depends(get_verified_user), db=Depends(get_db)):
+async def get_toolkits(user=Depends(get_verified_user)):
     toolkits = [toolkit for toolkit in Tools.get_tools()]
     toolkits = [toolkit for toolkit in Tools.get_tools()]
     return toolkits
     return toolkits
 
 
@@ -45,8 +44,8 @@ async def get_toolkits(user=Depends(get_verified_user), db=Depends(get_db)):
 
 
 
 
 @router.get("/export", response_model=List[ToolModel])
 @router.get("/export", response_model=List[ToolModel])
-async def get_toolkits(user=Depends(get_admin_user), db=Depends(get_db)):
-    toolkits = [toolkit for toolkit in Tools.get_tools(db)]
+async def get_toolkits(user=Depends(get_admin_user)):
+    toolkits = [toolkit for toolkit in Tools.get_tools()]
     return toolkits
     return toolkits
 
 
 
 
@@ -60,7 +59,6 @@ async def create_new_toolkit(
     request: Request,
     request: Request,
     form_data: ToolForm,
     form_data: ToolForm,
     user=Depends(get_admin_user),
     user=Depends(get_admin_user),
-    db=Depends(get_db),
 ):
 ):
     if not form_data.id.isidentifier():
     if not form_data.id.isidentifier():
         raise HTTPException(
         raise HTTPException(
@@ -70,7 +68,7 @@ async def create_new_toolkit(
 
 
     form_data.id = form_data.id.lower()
     form_data.id = form_data.id.lower()
 
 
-    toolkit = Tools.get_tool_by_id(db, form_data.id)
+    toolkit = Tools.get_tool_by_id(form_data.id)
     if toolkit == None:
     if toolkit == None:
         toolkit_path = os.path.join(TOOLS_DIR, f"{form_data.id}.py")
         toolkit_path = os.path.join(TOOLS_DIR, f"{form_data.id}.py")
         try:
         try:
@@ -84,7 +82,7 @@ async def create_new_toolkit(
             TOOLS[form_data.id] = toolkit_module
             TOOLS[form_data.id] = toolkit_module
 
 
             specs = get_tools_specs(TOOLS[form_data.id])
             specs = get_tools_specs(TOOLS[form_data.id])
-            toolkit = Tools.insert_new_tool(db, user.id, form_data, specs)
+            toolkit = Tools.insert_new_tool(user.id, form_data, specs)
 
 
             tool_cache_dir = Path(CACHE_DIR) / "tools" / form_data.id
             tool_cache_dir = Path(CACHE_DIR) / "tools" / form_data.id
             tool_cache_dir.mkdir(parents=True, exist_ok=True)
             tool_cache_dir.mkdir(parents=True, exist_ok=True)
@@ -115,8 +113,8 @@ async def create_new_toolkit(
 
 
 
 
 @router.get("/id/{id}", response_model=Optional[ToolModel])
 @router.get("/id/{id}", response_model=Optional[ToolModel])
-async def get_toolkit_by_id(id: str, user=Depends(get_admin_user), db=Depends(get_db)):
-    toolkit = Tools.get_tool_by_id(db, id)
+async def get_toolkit_by_id(id: str, user=Depends(get_admin_user)):
+    toolkit = Tools.get_tool_by_id(id)
 
 
     if toolkit:
     if toolkit:
         return toolkit
         return toolkit
@@ -138,7 +136,6 @@ async def update_toolkit_by_id(
     id: str,
     id: str,
     form_data: ToolForm,
     form_data: ToolForm,
     user=Depends(get_admin_user),
     user=Depends(get_admin_user),
-    db=Depends(get_db),
 ):
 ):
     toolkit_path = os.path.join(TOOLS_DIR, f"{id}.py")
     toolkit_path = os.path.join(TOOLS_DIR, f"{id}.py")
 
 
@@ -160,7 +157,7 @@ async def update_toolkit_by_id(
         }
         }
 
 
         print(updated)
         print(updated)
-        toolkit = Tools.update_tool_by_id(db, id, updated)
+        toolkit = Tools.update_tool_by_id(id, updated)
 
 
         if toolkit:
         if toolkit:
             return toolkit
             return toolkit
@@ -184,9 +181,9 @@ async def update_toolkit_by_id(
 
 
 @router.delete("/id/{id}/delete", response_model=bool)
 @router.delete("/id/{id}/delete", response_model=bool)
 async def delete_toolkit_by_id(
 async def delete_toolkit_by_id(
-    request: Request, id: str, user=Depends(get_admin_user), db=Depends(get_db)
+    request: Request, id: str, user=Depends(get_admin_user)
 ):
 ):
-    result = Tools.delete_tool_by_id(db, id)
+    result = Tools.delete_tool_by_id(id)
 
 
     if result:
     if result:
         TOOLS = request.app.state.TOOLS
         TOOLS = request.app.state.TOOLS

+ 23 - 26
backend/apps/webui/routers/users.py

@@ -9,7 +9,6 @@ import time
 import uuid
 import uuid
 import logging
 import logging
 
 
-from apps.webui.internal.db import get_db
 from apps.webui.models.users import (
 from apps.webui.models.users import (
     UserModel,
     UserModel,
     UserUpdateForm,
     UserUpdateForm,
@@ -42,9 +41,9 @@ router = APIRouter()
 
 
 @router.get("/", response_model=List[UserModel])
 @router.get("/", response_model=List[UserModel])
 async def get_users(
 async def get_users(
-    skip: int = 0, limit: int = 50, user=Depends(get_admin_user), db=Depends(get_db)
+    skip: int = 0, limit: int = 50, user=Depends(get_admin_user)
 ):
 ):
-    return Users.get_users(db, skip, limit)
+    return Users.get_users(skip, limit)
 
 
 
 
 ############################
 ############################
@@ -72,11 +71,11 @@ async def update_user_permissions(
 
 
 @router.post("/update/role", response_model=Optional[UserModel])
 @router.post("/update/role", response_model=Optional[UserModel])
 async def update_user_role(
 async def update_user_role(
-    form_data: UserRoleUpdateForm, user=Depends(get_admin_user), db=Depends(get_db)
+    form_data: UserRoleUpdateForm, user=Depends(get_admin_user)
 ):
 ):
 
 
-    if user.id != form_data.id and form_data.id != Users.get_first_user(db).id:
-        return Users.update_user_role_by_id(db, form_data.id, form_data.role)
+    if user.id != form_data.id and form_data.id != Users.get_first_user().id:
+        return Users.update_user_role_by_id(form_data.id, form_data.role)
 
 
     raise HTTPException(
     raise HTTPException(
         status_code=status.HTTP_403_FORBIDDEN,
         status_code=status.HTTP_403_FORBIDDEN,
@@ -91,9 +90,9 @@ async def update_user_role(
 
 
 @router.get("/user/settings", response_model=Optional[UserSettings])
 @router.get("/user/settings", response_model=Optional[UserSettings])
 async def get_user_settings_by_session_user(
 async def get_user_settings_by_session_user(
-    user=Depends(get_verified_user), db=Depends(get_db)
+    user=Depends(get_verified_user)
 ):
 ):
-    user = Users.get_user_by_id(db, user.id)
+    user = Users.get_user_by_id(user.id)
     if user:
     if user:
         return user.settings
         return user.settings
     else:
     else:
@@ -110,9 +109,9 @@ async def get_user_settings_by_session_user(
 
 
 @router.post("/user/settings/update", response_model=UserSettings)
 @router.post("/user/settings/update", response_model=UserSettings)
 async def update_user_settings_by_session_user(
 async def update_user_settings_by_session_user(
-    form_data: UserSettings, user=Depends(get_verified_user), db=Depends(get_db)
+    form_data: UserSettings, user=Depends(get_verified_user)
 ):
 ):
-    user = Users.update_user_by_id(db, user.id, {"settings": form_data.model_dump()})
+    user = Users.update_user_by_id(user.id, {"settings": form_data.model_dump()})
     if user:
     if user:
         return user.settings
         return user.settings
     else:
     else:
@@ -129,9 +128,9 @@ async def update_user_settings_by_session_user(
 
 
 @router.get("/user/info", response_model=Optional[dict])
 @router.get("/user/info", response_model=Optional[dict])
 async def get_user_info_by_session_user(
 async def get_user_info_by_session_user(
-    user=Depends(get_verified_user), db=Depends(get_db)
+    user=Depends(get_verified_user)
 ):
 ):
-    user = Users.get_user_by_id(db, user.id)
+    user = Users.get_user_by_id(user.id)
     if user:
     if user:
         return user.info
         return user.info
     else:
     else:
@@ -148,15 +147,15 @@ async def get_user_info_by_session_user(
 
 
 @router.post("/user/info/update", response_model=Optional[dict])
 @router.post("/user/info/update", response_model=Optional[dict])
 async def update_user_info_by_session_user(
 async def update_user_info_by_session_user(
-    form_data: dict, user=Depends(get_verified_user), db=Depends(get_db)
+    form_data: dict, user=Depends(get_verified_user)
 ):
 ):
-    user = Users.get_user_by_id(db, user.id)
+    user = Users.get_user_by_id(user.id)
     if user:
     if user:
         if user.info is None:
         if user.info is None:
             user.info = {}
             user.info = {}
 
 
         user = Users.update_user_by_id(
         user = Users.update_user_by_id(
-            db, user.id, {"info": {**user.info, **form_data}}
+            user.id, {"info": {**user.info, **form_data}}
         )
         )
         if user:
         if user:
             return user.info
             return user.info
@@ -184,14 +183,14 @@ class UserResponse(BaseModel):
 
 
 @router.get("/{user_id}", response_model=UserResponse)
 @router.get("/{user_id}", response_model=UserResponse)
 async def get_user_by_id(
 async def get_user_by_id(
-    user_id: str, user=Depends(get_verified_user), db=Depends(get_db)
+    user_id: str, user=Depends(get_verified_user)
 ):
 ):
 
 
     # Check if user_id is a shared chat
     # Check if user_id is a shared chat
     # If it is, get the user_id from the chat
     # If it is, get the user_id from the chat
     if user_id.startswith("shared-"):
     if user_id.startswith("shared-"):
         chat_id = user_id.replace("shared-", "")
         chat_id = user_id.replace("shared-", "")
-        chat = Chats.get_chat_by_id(db, chat_id)
+        chat = Chats.get_chat_by_id(chat_id)
         if chat:
         if chat:
             user_id = chat.user_id
             user_id = chat.user_id
         else:
         else:
@@ -200,7 +199,7 @@ async def get_user_by_id(
                 detail=ERROR_MESSAGES.USER_NOT_FOUND,
                 detail=ERROR_MESSAGES.USER_NOT_FOUND,
             )
             )
 
 
-    user = Users.get_user_by_id(db, user_id)
+    user = Users.get_user_by_id(user_id)
 
 
     if user:
     if user:
         return UserResponse(name=user.name, profile_image_url=user.profile_image_url)
         return UserResponse(name=user.name, profile_image_url=user.profile_image_url)
@@ -221,13 +220,12 @@ async def update_user_by_id(
     user_id: str,
     user_id: str,
     form_data: UserUpdateForm,
     form_data: UserUpdateForm,
     session_user=Depends(get_admin_user),
     session_user=Depends(get_admin_user),
-    db=Depends(get_db),
 ):
 ):
-    user = Users.get_user_by_id(db, user_id)
+    user = Users.get_user_by_id(user_id)
 
 
     if user:
     if user:
         if form_data.email.lower() != user.email:
         if form_data.email.lower() != user.email:
-            email_user = Users.get_user_by_email(db, form_data.email.lower())
+            email_user = Users.get_user_by_email(form_data.email.lower())
             if email_user:
             if email_user:
                 raise HTTPException(
                 raise HTTPException(
                     status_code=status.HTTP_400_BAD_REQUEST,
                     status_code=status.HTTP_400_BAD_REQUEST,
@@ -237,11 +235,10 @@ async def update_user_by_id(
         if form_data.password:
         if form_data.password:
             hashed = get_password_hash(form_data.password)
             hashed = get_password_hash(form_data.password)
             log.debug(f"hashed: {hashed}")
             log.debug(f"hashed: {hashed}")
-            Auths.update_user_password_by_id(db, user_id, hashed)
+            Auths.update_user_password_by_id(user_id, hashed)
 
 
-        Auths.update_email_by_id(db, user_id, form_data.email.lower())
+        Auths.update_email_by_id(user_id, form_data.email.lower())
         updated_user = Users.update_user_by_id(
         updated_user = Users.update_user_by_id(
-            db,
             user_id,
             user_id,
             {
             {
                 "name": form_data.name,
                 "name": form_data.name,
@@ -271,10 +268,10 @@ async def update_user_by_id(
 
 
 @router.delete("/{user_id}", response_model=bool)
 @router.delete("/{user_id}", response_model=bool)
 async def delete_user_by_id(
 async def delete_user_by_id(
-    user_id: str, user=Depends(get_admin_user), db=Depends(get_db)
+    user_id: str, user=Depends(get_admin_user)
 ):
 ):
     if user.id != user_id:
     if user.id != user_id:
-        result = Auths.delete_auth_by_id(db, user_id)
+        result = Auths.delete_auth_by_id(user_id)
 
 
         if result:
         if result:
             return True
             return True

+ 13 - 18
backend/main.py

@@ -57,7 +57,7 @@ from apps.webui.main import (
     get_pipe_models,
     get_pipe_models,
     generate_function_chat_completion,
     generate_function_chat_completion,
 )
 )
-from apps.webui.internal.db import get_db, SessionLocal
+from apps.webui.internal.db import get_session, SessionLocal
 
 
 
 
 from pydantic import BaseModel
 from pydantic import BaseModel
@@ -410,7 +410,6 @@ class ChatCompletionMiddleware(BaseHTTPMiddleware):
             user = get_current_user(
             user = get_current_user(
                 request,
                 request,
                 get_http_authorization_cred(request.headers.get("Authorization")),
                 get_http_authorization_cred(request.headers.get("Authorization")),
-                SessionLocal(),
             )
             )
             # Flag to skip RAG completions if file_handler is present in tools/functions
             # Flag to skip RAG completions if file_handler is present in tools/functions
             skip_files = False
             skip_files = False
@@ -800,9 +799,7 @@ app.add_middleware(
 @app.middleware("http")
 @app.middleware("http")
 async def check_url(request: Request, call_next):
 async def check_url(request: Request, call_next):
     if len(app.state.MODELS) == 0:
     if len(app.state.MODELS) == 0:
-        db = SessionLocal()
-        await get_all_models(db)
-        db.commit()
+        await get_all_models()
     else:
     else:
         pass
         pass
 
 
@@ -836,12 +833,12 @@ app.mount("/api/v1", webui_app)
 webui_app.state.EMBEDDING_FUNCTION = rag_app.state.EMBEDDING_FUNCTION
 webui_app.state.EMBEDDING_FUNCTION = rag_app.state.EMBEDDING_FUNCTION
 
 
 
 
-async def get_all_models(db: Session):
+async def get_all_models():
     pipe_models = []
     pipe_models = []
     openai_models = []
     openai_models = []
     ollama_models = []
     ollama_models = []
 
 
-    pipe_models = await get_pipe_models(db)
+    pipe_models = await get_pipe_models()
 
 
     if app.state.config.ENABLE_OPENAI_API:
     if app.state.config.ENABLE_OPENAI_API:
         openai_models = await get_openai_models()
         openai_models = await get_openai_models()
@@ -863,7 +860,7 @@ async def get_all_models(db: Session):
 
 
     models = pipe_models + openai_models + ollama_models
     models = pipe_models + openai_models + ollama_models
 
 
-    custom_models = Models.get_all_models(db)
+    custom_models = Models.get_all_models()
     for custom_model in custom_models:
     for custom_model in custom_models:
         if custom_model.base_model_id == None:
         if custom_model.base_model_id == None:
             for model in models:
             for model in models:
@@ -903,8 +900,8 @@ async def get_all_models(db: Session):
 
 
 
 
 @app.get("/api/models")
 @app.get("/api/models")
-async def get_models(user=Depends(get_verified_user), db=Depends(get_db)):
-    models = await get_all_models(db)
+async def get_models(user=Depends(get_verified_user)):
+    models = await get_all_models()
 
 
     # Filter out filter pipelines
     # Filter out filter pipelines
     models = [
     models = [
@@ -1608,9 +1605,8 @@ async def get_pipeline_valves(
     urlIdx: Optional[int],
     urlIdx: Optional[int],
     pipeline_id: str,
     pipeline_id: str,
     user=Depends(get_admin_user),
     user=Depends(get_admin_user),
-    db=Depends(get_db),
 ):
 ):
-    models = await get_all_models(db)
+    models = await get_all_models()
     r = None
     r = None
     try:
     try:
 
 
@@ -1649,9 +1645,8 @@ async def get_pipeline_valves_spec(
     urlIdx: Optional[int],
     urlIdx: Optional[int],
     pipeline_id: str,
     pipeline_id: str,
     user=Depends(get_admin_user),
     user=Depends(get_admin_user),
-    db=Depends(get_db),
 ):
 ):
-    models = await get_all_models(db)
+    models = await get_all_models()
 
 
     r = None
     r = None
     try:
     try:
@@ -1690,9 +1685,8 @@ async def update_pipeline_valves(
     pipeline_id: str,
     pipeline_id: str,
     form_data: dict,
     form_data: dict,
     user=Depends(get_admin_user),
     user=Depends(get_admin_user),
-    db=Depends(get_db),
 ):
 ):
-    models = await get_all_models(db)
+    models = await get_all_models()
 
 
     r = None
     r = None
     try:
     try:
@@ -2040,8 +2034,9 @@ async def healthcheck():
 
 
 
 
 @app.get("/health/db")
 @app.get("/health/db")
-async def healthcheck_with_db(db: Session = Depends(get_db)):
-    result = db.execute(text("SELECT 1;")).all()
+async def healthcheck_with_db():
+    with get_session() as db:
+        result = db.execute(text("SELECT 1;")).all()
     return {"status": True}
     return {"status": True}
 
 
 
 

+ 0 - 188
backend/migrations/versions/22b5ab2667b8_init.py

@@ -1,188 +0,0 @@
-"""init
-
-Revision ID: 22b5ab2667b8
-Revises: 
-Create Date: 2024-06-20 13:22:40.397002
-
-"""
-
-from typing import Sequence, Union
-
-from alembic import op
-import sqlalchemy as sa
-from sqlalchemy.engine.reflection import Inspector
-
-import apps.webui.internal.db
-
-
-# revision identifiers, used by Alembic.
-revision: str = "22b5ab2667b8"
-down_revision: Union[str, None] = None
-branch_labels: Union[str, Sequence[str], None] = None
-depends_on: Union[str, Sequence[str], None] = None
-
-
-def upgrade() -> None:
-    con = op.get_bind()
-    inspector = Inspector.from_engine(con)
-    tables = set(inspector.get_table_names())
-
-    # ### commands auto generated by Alembic - please adjust! ###
-    if not "auth" in tables:
-        op.create_table(
-            "auth",
-            sa.Column("id", sa.String(), nullable=False),
-            sa.Column("email", sa.String(), nullable=True),
-            sa.Column("password", sa.String(), nullable=True),
-            sa.Column("active", sa.Boolean(), nullable=True),
-            sa.PrimaryKeyConstraint("id"),
-        )
-
-    if not "chat" in tables:
-        op.create_table(
-            "chat",
-            sa.Column("id", sa.String(), nullable=False),
-            sa.Column("user_id", sa.String(), nullable=True),
-            sa.Column("title", sa.String(), nullable=True),
-            sa.Column("chat", sa.String(), nullable=True),
-            sa.Column("created_at", sa.BigInteger(), nullable=True),
-            sa.Column("updated_at", sa.BigInteger(), nullable=True),
-            sa.Column("share_id", sa.String(), nullable=True),
-            sa.Column("archived", sa.Boolean(), nullable=True),
-            sa.PrimaryKeyConstraint("id"),
-            sa.UniqueConstraint("share_id"),
-        )
-
-    if not "chatidtag" in tables:
-        op.create_table(
-            "chatidtag",
-            sa.Column("id", sa.String(), nullable=False),
-            sa.Column("tag_name", sa.String(), nullable=True),
-            sa.Column("chat_id", sa.String(), nullable=True),
-            sa.Column("user_id", sa.String(), nullable=True),
-            sa.Column("timestamp", sa.BigInteger(), nullable=True),
-            sa.PrimaryKeyConstraint("id"),
-        )
-
-    if not "document" in tables:
-        op.create_table(
-            "document",
-            sa.Column("collection_name", sa.String(), nullable=False),
-            sa.Column("name", sa.String(), nullable=True),
-            sa.Column("title", sa.String(), nullable=True),
-            sa.Column("filename", sa.String(), nullable=True),
-            sa.Column("content", sa.String(), nullable=True),
-            sa.Column("user_id", sa.String(), nullable=True),
-            sa.Column("timestamp", sa.BigInteger(), nullable=True),
-            sa.PrimaryKeyConstraint("collection_name"),
-            sa.UniqueConstraint("name"),
-        )
-
-    if not "memory" in tables:
-        op.create_table(
-            "memory",
-            sa.Column("id", sa.String(), nullable=False),
-            sa.Column("user_id", sa.String(), nullable=True),
-            sa.Column("content", sa.String(), nullable=True),
-            sa.Column("updated_at", sa.BigInteger(), nullable=True),
-            sa.Column("created_at", sa.BigInteger(), nullable=True),
-            sa.PrimaryKeyConstraint("id"),
-        )
-
-    if not "model" in tables:
-        op.create_table(
-            "model",
-            sa.Column("id", sa.String(), nullable=False),
-            sa.Column("user_id", sa.String(), nullable=True),
-            sa.Column("base_model_id", sa.String(), nullable=True),
-            sa.Column("name", sa.String(), nullable=True),
-            sa.Column("params", apps.webui.internal.db.JSONField(), nullable=True),
-            sa.Column("meta", apps.webui.internal.db.JSONField(), nullable=True),
-            sa.Column("updated_at", sa.BigInteger(), nullable=True),
-            sa.Column("created_at", sa.BigInteger(), nullable=True),
-            sa.PrimaryKeyConstraint("id"),
-        )
-
-    if not "prompt" in tables:
-        op.create_table(
-            "prompt",
-            sa.Column("command", sa.String(), nullable=False),
-            sa.Column("user_id", sa.String(), nullable=True),
-            sa.Column("title", sa.String(), nullable=True),
-            sa.Column("content", sa.String(), nullable=True),
-            sa.Column("timestamp", sa.BigInteger(), nullable=True),
-            sa.PrimaryKeyConstraint("command"),
-        )
-
-    if not "tag" in tables:
-        op.create_table(
-            "tag",
-            sa.Column("id", sa.String(), nullable=False),
-            sa.Column("name", sa.String(), nullable=True),
-            sa.Column("user_id", sa.String(), nullable=True),
-            sa.Column("data", sa.String(), nullable=True),
-            sa.PrimaryKeyConstraint("id"),
-        )
-
-    if not "tool" in tables:
-        op.create_table(
-            "tool",
-            sa.Column("id", sa.String(), nullable=False),
-            sa.Column("user_id", sa.String(), nullable=True),
-            sa.Column("name", sa.String(), nullable=True),
-            sa.Column("content", sa.String(), nullable=True),
-            sa.Column("specs", apps.webui.internal.db.JSONField(), nullable=True),
-            sa.Column("meta", apps.webui.internal.db.JSONField(), nullable=True),
-            sa.Column("updated_at", sa.BigInteger(), nullable=True),
-            sa.Column("created_at", sa.BigInteger(), nullable=True),
-            sa.PrimaryKeyConstraint("id"),
-        )
-
-    if not "user" in tables:
-        op.create_table(
-            "user",
-            sa.Column("id", sa.String(), nullable=False),
-            sa.Column("name", sa.String(), nullable=True),
-            sa.Column("email", sa.String(), nullable=True),
-            sa.Column("role", sa.String(), nullable=True),
-            sa.Column("profile_image_url", sa.String(), nullable=True),
-            sa.Column("last_active_at", sa.BigInteger(), nullable=True),
-            sa.Column("updated_at", sa.BigInteger(), nullable=True),
-            sa.Column("created_at", sa.BigInteger(), nullable=True),
-            sa.Column("api_key", sa.String(), nullable=True),
-            sa.Column("settings", apps.webui.internal.db.JSONField(), nullable=True),
-            sa.Column("info", apps.webui.internal.db.JSONField(), nullable=True),
-            sa.PrimaryKeyConstraint("id"),
-            sa.UniqueConstraint("api_key"),
-        )
-
-    if not "file" in tables:
-        op.create_table('file',
-                        sa.Column('id', sa.String(), nullable=False),
-                        sa.Column('user_id', sa.String(), nullable=True),
-                        sa.Column('filename', sa.String(), nullable=True),
-                        sa.Column('meta', apps.webui.internal.db.JSONField(), nullable=True),
-                        sa.Column('created_at', sa.BigInteger(), nullable=True),
-                        sa.PrimaryKeyConstraint('id')
-                        )
-
-    if not "function" in tables:
-        op.create_table('function',
-                        sa.Column('id', sa.String(), nullable=False),
-                        sa.Column('user_id', sa.String(), nullable=True),
-                        sa.Column('name', sa.Text(), nullable=True),
-                        sa.Column('type', sa.Text(), nullable=True),
-                        sa.Column('content', sa.Text(), nullable=True),
-                        sa.Column('meta', apps.webui.internal.db.JSONField(), nullable=True),
-                        sa.Column('updated_at', sa.BigInteger(), nullable=True),
-                        sa.Column('created_at', sa.BigInteger(), nullable=True),
-                        sa.PrimaryKeyConstraint('id')
-                        )
-    # ### end Alembic commands ###
-
-
-def downgrade() -> None:
-    # ### commands auto generated by Alembic - please adjust! ###
-    # do nothing as we assume we had previous migrations from peewee-migrate
-    pass
-    # ### end Alembic commands ###

+ 161 - 0
backend/migrations/versions/ba76b0bae648_init.py

@@ -0,0 +1,161 @@
+"""init
+
+Revision ID: ba76b0bae648
+Revises: 
+Create Date: 2024-06-24 09:09:11.636336
+
+"""
+from typing import Sequence, Union
+
+from alembic import op
+import sqlalchemy as sa
+import apps.webui.internal.db
+
+
+# revision identifiers, used by Alembic.
+revision: str = 'ba76b0bae648'
+down_revision: Union[str, None] = None
+branch_labels: Union[str, Sequence[str], None] = None
+depends_on: Union[str, Sequence[str], None] = None
+
+
+def upgrade() -> None:
+    # ### commands auto generated by Alembic - please adjust! ###
+    op.create_table('auth',
+    sa.Column('id', sa.String(), nullable=False),
+    sa.Column('email', sa.String(), nullable=True),
+    sa.Column('password', sa.String(), nullable=True),
+    sa.Column('active', sa.Boolean(), nullable=True),
+    sa.PrimaryKeyConstraint('id')
+    )
+    op.create_table('chat',
+    sa.Column('id', sa.String(), nullable=False),
+    sa.Column('user_id', sa.String(), nullable=True),
+    sa.Column('title', sa.String(), nullable=True),
+    sa.Column('chat', sa.String(), nullable=True),
+    sa.Column('created_at', sa.BigInteger(), nullable=True),
+    sa.Column('updated_at', sa.BigInteger(), nullable=True),
+    sa.Column('share_id', sa.String(), nullable=True),
+    sa.Column('archived', sa.Boolean(), nullable=True),
+    sa.PrimaryKeyConstraint('id'),
+    sa.UniqueConstraint('share_id')
+    )
+    op.create_table('chatidtag',
+    sa.Column('id', sa.String(), nullable=False),
+    sa.Column('tag_name', sa.String(), nullable=True),
+    sa.Column('chat_id', sa.String(), nullable=True),
+    sa.Column('user_id', sa.String(), nullable=True),
+    sa.Column('timestamp', sa.BigInteger(), nullable=True),
+    sa.PrimaryKeyConstraint('id')
+    )
+    op.create_table('document',
+    sa.Column('collection_name', sa.String(), nullable=False),
+    sa.Column('name', sa.String(), nullable=True),
+    sa.Column('title', sa.String(), nullable=True),
+    sa.Column('filename', sa.String(), nullable=True),
+    sa.Column('content', sa.String(), nullable=True),
+    sa.Column('user_id', sa.String(), nullable=True),
+    sa.Column('timestamp', sa.BigInteger(), nullable=True),
+    sa.PrimaryKeyConstraint('collection_name'),
+    sa.UniqueConstraint('name')
+    )
+    op.create_table('file',
+    sa.Column('id', sa.String(), nullable=False),
+    sa.Column('user_id', sa.String(), nullable=True),
+    sa.Column('filename', sa.String(), nullable=True),
+    sa.Column('meta', apps.webui.internal.db.JSONField(), nullable=True),
+    sa.Column('created_at', sa.BigInteger(), nullable=True),
+    sa.PrimaryKeyConstraint('id')
+    )
+    op.create_table('function',
+    sa.Column('id', sa.String(), nullable=False),
+    sa.Column('user_id', sa.String(), nullable=True),
+    sa.Column('name', sa.Text(), nullable=True),
+    sa.Column('type', sa.Text(), nullable=True),
+    sa.Column('content', sa.Text(), nullable=True),
+    sa.Column('meta', apps.webui.internal.db.JSONField(), nullable=True),
+    sa.Column('valves', apps.webui.internal.db.JSONField(), nullable=True),
+    sa.Column('is_active', sa.Boolean(), nullable=True),
+    sa.Column('updated_at', sa.BigInteger(), nullable=True),
+    sa.Column('created_at', sa.BigInteger(), nullable=True),
+    sa.PrimaryKeyConstraint('id')
+    )
+    op.create_table('memory',
+    sa.Column('id', sa.String(), nullable=False),
+    sa.Column('user_id', sa.String(), nullable=True),
+    sa.Column('content', sa.String(), nullable=True),
+    sa.Column('updated_at', sa.BigInteger(), nullable=True),
+    sa.Column('created_at', sa.BigInteger(), nullable=True),
+    sa.PrimaryKeyConstraint('id')
+    )
+    op.create_table('model',
+    sa.Column('id', sa.String(), nullable=False),
+    sa.Column('user_id', sa.String(), nullable=True),
+    sa.Column('base_model_id', sa.String(), nullable=True),
+    sa.Column('name', sa.String(), nullable=True),
+    sa.Column('params', apps.webui.internal.db.JSONField(), nullable=True),
+    sa.Column('meta', apps.webui.internal.db.JSONField(), nullable=True),
+    sa.Column('updated_at', sa.BigInteger(), nullable=True),
+    sa.Column('created_at', sa.BigInteger(), nullable=True),
+    sa.PrimaryKeyConstraint('id')
+    )
+    op.create_table('prompt',
+    sa.Column('command', sa.String(), nullable=False),
+    sa.Column('user_id', sa.String(), nullable=True),
+    sa.Column('title', sa.String(), nullable=True),
+    sa.Column('content', sa.String(), nullable=True),
+    sa.Column('timestamp', sa.BigInteger(), nullable=True),
+    sa.PrimaryKeyConstraint('command')
+    )
+    op.create_table('tag',
+    sa.Column('id', sa.String(), nullable=False),
+    sa.Column('name', sa.String(), nullable=True),
+    sa.Column('user_id', sa.String(), nullable=True),
+    sa.Column('data', sa.String(), nullable=True),
+    sa.PrimaryKeyConstraint('id')
+    )
+    op.create_table('tool',
+    sa.Column('id', sa.String(), nullable=False),
+    sa.Column('user_id', sa.String(), nullable=True),
+    sa.Column('name', sa.String(), nullable=True),
+    sa.Column('content', sa.String(), nullable=True),
+    sa.Column('specs', apps.webui.internal.db.JSONField(), nullable=True),
+    sa.Column('meta', apps.webui.internal.db.JSONField(), nullable=True),
+    sa.Column('valves', apps.webui.internal.db.JSONField(), nullable=True),
+    sa.Column('updated_at', sa.BigInteger(), nullable=True),
+    sa.Column('created_at', sa.BigInteger(), nullable=True),
+    sa.PrimaryKeyConstraint('id')
+    )
+    op.create_table('user',
+    sa.Column('id', sa.String(), nullable=False),
+    sa.Column('name', sa.String(), nullable=True),
+    sa.Column('email', sa.String(), nullable=True),
+    sa.Column('role', sa.String(), nullable=True),
+    sa.Column('profile_image_url', sa.String(), nullable=True),
+    sa.Column('last_active_at', sa.BigInteger(), nullable=True),
+    sa.Column('updated_at', sa.BigInteger(), nullable=True),
+    sa.Column('created_at', sa.BigInteger(), nullable=True),
+    sa.Column('api_key', sa.String(), nullable=True),
+    sa.Column('settings', apps.webui.internal.db.JSONField(), nullable=True),
+    sa.Column('info', apps.webui.internal.db.JSONField(), nullable=True),
+    sa.PrimaryKeyConstraint('id'),
+    sa.UniqueConstraint('api_key')
+    )
+    # ### end Alembic commands ###
+
+
+def downgrade() -> None:
+    # ### commands auto generated by Alembic - please adjust! ###
+    op.drop_table('user')
+    op.drop_table('tool')
+    op.drop_table('tag')
+    op.drop_table('prompt')
+    op.drop_table('model')
+    op.drop_table('memory')
+    op.drop_table('function')
+    op.drop_table('file')
+    op.drop_table('document')
+    op.drop_table('chatidtag')
+    op.drop_table('chat')
+    op.drop_table('auth')
+    # ### end Alembic commands ###

+ 6 - 13
backend/test/apps/webui/routers/test_auths.py

@@ -31,7 +31,6 @@ class TestAuths(AbstractPostgresTest):
         from utils.utils import get_password_hash
         from utils.utils import get_password_hash
 
 
         user = self.auths.insert_new_auth(
         user = self.auths.insert_new_auth(
-            self.db_session,
             email="john.doe@openwebui.com",
             email="john.doe@openwebui.com",
             password=get_password_hash("old_password"),
             password=get_password_hash("old_password"),
             name="John Doe",
             name="John Doe",
@@ -45,7 +44,7 @@ class TestAuths(AbstractPostgresTest):
                 json={"name": "John Doe 2", "profile_image_url": "/user2.png"},
                 json={"name": "John Doe 2", "profile_image_url": "/user2.png"},
             )
             )
         assert response.status_code == 200
         assert response.status_code == 200
-        db_user = self.users.get_user_by_id(self.db_session, user.id)
+        db_user = self.users.get_user_by_id(user.id)
         assert db_user.name == "John Doe 2"
         assert db_user.name == "John Doe 2"
         assert db_user.profile_image_url == "/user2.png"
         assert db_user.profile_image_url == "/user2.png"
 
 
@@ -53,7 +52,6 @@ class TestAuths(AbstractPostgresTest):
         from utils.utils import get_password_hash
         from utils.utils import get_password_hash
 
 
         user = self.auths.insert_new_auth(
         user = self.auths.insert_new_auth(
-            self.db_session,
             email="john.doe@openwebui.com",
             email="john.doe@openwebui.com",
             password=get_password_hash("old_password"),
             password=get_password_hash("old_password"),
             name="John Doe",
             name="John Doe",
@@ -69,11 +67,11 @@ class TestAuths(AbstractPostgresTest):
         assert response.status_code == 200
         assert response.status_code == 200
 
 
         old_auth = self.auths.authenticate_user(
         old_auth = self.auths.authenticate_user(
-            self.db_session, "john.doe@openwebui.com", "old_password"
+            "john.doe@openwebui.com", "old_password"
         )
         )
         assert old_auth is None
         assert old_auth is None
         new_auth = self.auths.authenticate_user(
         new_auth = self.auths.authenticate_user(
-            self.db_session, "john.doe@openwebui.com", "new_password"
+            "john.doe@openwebui.com", "new_password"
         )
         )
         assert new_auth is not None
         assert new_auth is not None
 
 
@@ -81,7 +79,6 @@ class TestAuths(AbstractPostgresTest):
         from utils.utils import get_password_hash
         from utils.utils import get_password_hash
 
 
         user = self.auths.insert_new_auth(
         user = self.auths.insert_new_auth(
-            self.db_session,
             email="john.doe@openwebui.com",
             email="john.doe@openwebui.com",
             password=get_password_hash("password"),
             password=get_password_hash("password"),
             name="John Doe",
             name="John Doe",
@@ -144,7 +141,6 @@ class TestAuths(AbstractPostgresTest):
 
 
     def test_get_admin_details(self):
     def test_get_admin_details(self):
         self.auths.insert_new_auth(
         self.auths.insert_new_auth(
-            self.db_session,
             email="john.doe@openwebui.com",
             email="john.doe@openwebui.com",
             password="password",
             password="password",
             name="John Doe",
             name="John Doe",
@@ -162,7 +158,6 @@ class TestAuths(AbstractPostgresTest):
 
 
     def test_create_api_key_(self):
     def test_create_api_key_(self):
         user = self.auths.insert_new_auth(
         user = self.auths.insert_new_auth(
-            self.db_session,
             email="john.doe@openwebui.com",
             email="john.doe@openwebui.com",
             password="password",
             password="password",
             name="John Doe",
             name="John Doe",
@@ -178,31 +173,29 @@ class TestAuths(AbstractPostgresTest):
 
 
     def test_delete_api_key(self):
     def test_delete_api_key(self):
         user = self.auths.insert_new_auth(
         user = self.auths.insert_new_auth(
-            self.db_session,
             email="john.doe@openwebui.com",
             email="john.doe@openwebui.com",
             password="password",
             password="password",
             name="John Doe",
             name="John Doe",
             profile_image_url="/user.png",
             profile_image_url="/user.png",
             role="admin",
             role="admin",
         )
         )
-        self.users.update_user_api_key_by_id(self.db_session, user.id, "abc")
+        self.users.update_user_api_key_by_id(user.id, "abc")
         with mock_webui_user(id=user.id):
         with mock_webui_user(id=user.id):
             response = self.fast_api_client.delete(self.create_url("/api_key"))
             response = self.fast_api_client.delete(self.create_url("/api_key"))
         assert response.status_code == 200
         assert response.status_code == 200
         assert response.json() == True
         assert response.json() == True
-        db_user = self.users.get_user_by_id(self.db_session, user.id)
+        db_user = self.users.get_user_by_id(user.id)
         assert db_user.api_key is None
         assert db_user.api_key is None
 
 
     def test_get_api_key(self):
     def test_get_api_key(self):
         user = self.auths.insert_new_auth(
         user = self.auths.insert_new_auth(
-            self.db_session,
             email="john.doe@openwebui.com",
             email="john.doe@openwebui.com",
             password="password",
             password="password",
             name="John Doe",
             name="John Doe",
             profile_image_url="/user.png",
             profile_image_url="/user.png",
             role="admin",
             role="admin",
         )
         )
-        self.users.update_user_api_key_by_id(self.db_session, user.id, "abc")
+        self.users.update_user_api_key_by_id(user.id, "abc")
         with mock_webui_user(id=user.id):
         with mock_webui_user(id=user.id):
             response = self.fast_api_client.get(self.create_url("/api_key"))
             response = self.fast_api_client.get(self.create_url("/api_key"))
         assert response.status_code == 200
         assert response.status_code == 200

+ 17 - 21
backend/test/apps/webui/routers/test_chats.py

@@ -18,7 +18,6 @@ class TestChats(AbstractPostgresTest):
 
 
         self.chats = Chats
         self.chats = Chats
         self.chats.insert_new_chat(
         self.chats.insert_new_chat(
-            self.db_session,
             "2",
             "2",
             ChatForm(
             ChatForm(
                 **{
                 **{
@@ -46,7 +45,7 @@ class TestChats(AbstractPostgresTest):
         with mock_webui_user(id="2"):
         with mock_webui_user(id="2"):
             response = self.fast_api_client.delete(self.create_url("/"))
             response = self.fast_api_client.delete(self.create_url("/"))
         assert response.status_code == 200
         assert response.status_code == 200
-        assert len(self.chats.get_chats(self.db_session)) == 0
+        assert len(self.chats.get_chats()) == 0
 
 
     def test_get_user_chat_list_by_user_id(self):
     def test_get_user_chat_list_by_user_id(self):
         with mock_webui_user(id="3"):
         with mock_webui_user(id="3"):
@@ -84,14 +83,13 @@ class TestChats(AbstractPostgresTest):
         assert data["title"] == "New Chat"
         assert data["title"] == "New Chat"
         assert data["updated_at"] is not None
         assert data["updated_at"] is not None
         assert data["created_at"] is not None
         assert data["created_at"] is not None
-        assert len(self.chats.get_chats(self.db_session)) == 2
+        assert len(self.chats.get_chats()) == 2
 
 
     def test_get_user_chats(self):
     def test_get_user_chats(self):
         self.test_get_session_user_chat_list()
         self.test_get_session_user_chat_list()
 
 
     def test_get_user_archived_chats(self):
     def test_get_user_archived_chats(self):
-        self.chats.archive_all_chats_by_user_id(self.db_session, "2")
-        self.db_session.commit()
+        self.chats.archive_all_chats_by_user_id("2")
         with mock_webui_user(id="2"):
         with mock_webui_user(id="2"):
             response = self.fast_api_client.get(self.create_url("/all/archived"))
             response = self.fast_api_client.get(self.create_url("/all/archived"))
         assert response.status_code == 200
         assert response.status_code == 200
@@ -114,12 +112,11 @@ class TestChats(AbstractPostgresTest):
         with mock_webui_user(id="2"):
         with mock_webui_user(id="2"):
             response = self.fast_api_client.post(self.create_url("/archive/all"))
             response = self.fast_api_client.post(self.create_url("/archive/all"))
         assert response.status_code == 200
         assert response.status_code == 200
-        assert len(self.chats.get_archived_chats_by_user_id(self.db_session, "2")) == 1
+        assert len(self.chats.get_archived_chats_by_user_id("2")) == 1
 
 
     def test_get_shared_chat_by_id(self):
     def test_get_shared_chat_by_id(self):
-        chat_id = self.chats.get_chats(self.db_session)[0].id
-        self.chats.update_chat_share_id_by_id(self.db_session, chat_id, chat_id)
-        self.db_session.commit()
+        chat_id = self.chats.get_chats()[0].id
+        self.chats.update_chat_share_id_by_id(chat_id, chat_id)
         with mock_webui_user(id="2"):
         with mock_webui_user(id="2"):
             response = self.fast_api_client.get(self.create_url(f"/share/{chat_id}"))
             response = self.fast_api_client.get(self.create_url(f"/share/{chat_id}"))
         assert response.status_code == 200
         assert response.status_code == 200
@@ -136,7 +133,7 @@ class TestChats(AbstractPostgresTest):
         assert data["title"] == "New Chat"
         assert data["title"] == "New Chat"
 
 
     def test_get_chat_by_id(self):
     def test_get_chat_by_id(self):
-        chat_id = self.chats.get_chats(self.db_session)[0].id
+        chat_id = self.chats.get_chats()[0].id
         with mock_webui_user(id="2"):
         with mock_webui_user(id="2"):
             response = self.fast_api_client.get(self.create_url(f"/{chat_id}"))
             response = self.fast_api_client.get(self.create_url(f"/{chat_id}"))
         assert response.status_code == 200
         assert response.status_code == 200
@@ -153,7 +150,7 @@ class TestChats(AbstractPostgresTest):
         assert data["user_id"] == "2"
         assert data["user_id"] == "2"
 
 
     def test_update_chat_by_id(self):
     def test_update_chat_by_id(self):
-        chat_id = self.chats.get_chats(self.db_session)[0].id
+        chat_id = self.chats.get_chats()[0].id
         with mock_webui_user(id="2"):
         with mock_webui_user(id="2"):
             response = self.fast_api_client.post(
             response = self.fast_api_client.post(
                 self.create_url(f"/{chat_id}"),
                 self.create_url(f"/{chat_id}"),
@@ -181,14 +178,14 @@ class TestChats(AbstractPostgresTest):
         assert data["user_id"] == "2"
         assert data["user_id"] == "2"
 
 
     def test_delete_chat_by_id(self):
     def test_delete_chat_by_id(self):
-        chat_id = self.chats.get_chats(self.db_session)[0].id
+        chat_id = self.chats.get_chats()[0].id
         with mock_webui_user(id="2"):
         with mock_webui_user(id="2"):
             response = self.fast_api_client.delete(self.create_url(f"/{chat_id}"))
             response = self.fast_api_client.delete(self.create_url(f"/{chat_id}"))
         assert response.status_code == 200
         assert response.status_code == 200
         assert response.json() is True
         assert response.json() is True
 
 
     def test_clone_chat_by_id(self):
     def test_clone_chat_by_id(self):
-        chat_id = self.chats.get_chats(self.db_session)[0].id
+        chat_id = self.chats.get_chats()[0].id
         with mock_webui_user(id="2"):
         with mock_webui_user(id="2"):
             response = self.fast_api_client.get(self.create_url(f"/{chat_id}/clone"))
             response = self.fast_api_client.get(self.create_url(f"/{chat_id}/clone"))
 
 
@@ -209,31 +206,30 @@ class TestChats(AbstractPostgresTest):
         assert data["user_id"] == "2"
         assert data["user_id"] == "2"
 
 
     def test_archive_chat_by_id(self):
     def test_archive_chat_by_id(self):
-        chat_id = self.chats.get_chats(self.db_session)[0].id
+        chat_id = self.chats.get_chats()[0].id
         with mock_webui_user(id="2"):
         with mock_webui_user(id="2"):
             response = self.fast_api_client.get(self.create_url(f"/{chat_id}/archive"))
             response = self.fast_api_client.get(self.create_url(f"/{chat_id}/archive"))
         assert response.status_code == 200
         assert response.status_code == 200
 
 
-        chat = self.chats.get_chat_by_id(self.db_session, chat_id)
+        chat = self.chats.get_chat_by_id(chat_id)
         assert chat.archived is True
         assert chat.archived is True
 
 
     def test_share_chat_by_id(self):
     def test_share_chat_by_id(self):
-        chat_id = self.chats.get_chats(self.db_session)[0].id
+        chat_id = self.chats.get_chats()[0].id
         with mock_webui_user(id="2"):
         with mock_webui_user(id="2"):
             response = self.fast_api_client.post(self.create_url(f"/{chat_id}/share"))
             response = self.fast_api_client.post(self.create_url(f"/{chat_id}/share"))
         assert response.status_code == 200
         assert response.status_code == 200
 
 
-        chat = self.chats.get_chat_by_id(self.db_session, chat_id)
+        chat = self.chats.get_chat_by_id(chat_id)
         assert chat.share_id is not None
         assert chat.share_id is not None
 
 
     def test_delete_shared_chat_by_id(self):
     def test_delete_shared_chat_by_id(self):
-        chat_id = self.chats.get_chats(self.db_session)[0].id
+        chat_id = self.chats.get_chats()[0].id
         share_id = str(uuid.uuid4())
         share_id = str(uuid.uuid4())
-        self.chats.update_chat_share_id_by_id(self.db_session, chat_id, share_id)
-        self.db_session.commit()
+        self.chats.update_chat_share_id_by_id(chat_id, share_id)
         with mock_webui_user(id="2"):
         with mock_webui_user(id="2"):
             response = self.fast_api_client.delete(self.create_url(f"/{chat_id}/share"))
             response = self.fast_api_client.delete(self.create_url(f"/{chat_id}/share"))
         assert response.status_code
         assert response.status_code
 
 
-        chat = self.chats.get_chat_by_id(self.db_session, chat_id)
+        chat = self.chats.get_chat_by_id(chat_id)
         assert chat.share_id is None
         assert chat.share_id is None

+ 5 - 5
backend/test/apps/webui/routers/test_documents.py

@@ -14,7 +14,7 @@ class TestDocuments(AbstractPostgresTest):
 
 
     def test_documents(self):
     def test_documents(self):
         # Empty database
         # Empty database
-        assert len(self.documents.get_docs(self.db_session)) == 0
+        assert len(self.documents.get_docs()) == 0
         with mock_webui_user(id="2"):
         with mock_webui_user(id="2"):
             response = self.fast_api_client.get(self.create_url("/"))
             response = self.fast_api_client.get(self.create_url("/"))
         assert response.status_code == 200
         assert response.status_code == 200
@@ -34,7 +34,7 @@ class TestDocuments(AbstractPostgresTest):
             )
             )
         assert response.status_code == 200
         assert response.status_code == 200
         assert response.json()["name"] == "doc_name"
         assert response.json()["name"] == "doc_name"
-        assert len(self.documents.get_docs(self.db_session)) == 1
+        assert len(self.documents.get_docs()) == 1
 
 
         # Get the document
         # Get the document
         with mock_webui_user(id="2"):
         with mock_webui_user(id="2"):
@@ -61,7 +61,7 @@ class TestDocuments(AbstractPostgresTest):
             )
             )
         assert response.status_code == 200
         assert response.status_code == 200
         assert response.json()["name"] == "doc_name 2"
         assert response.json()["name"] == "doc_name 2"
-        assert len(self.documents.get_docs(self.db_session)) == 2
+        assert len(self.documents.get_docs()) == 2
 
 
         # Get all documents
         # Get all documents
         with mock_webui_user(id="2"):
         with mock_webui_user(id="2"):
@@ -95,7 +95,7 @@ class TestDocuments(AbstractPostgresTest):
         assert data["content"] == {
         assert data["content"] == {
             "tags": [{"name": "testing-tag"}, {"name": "another-tag"}]
             "tags": [{"name": "testing-tag"}, {"name": "another-tag"}]
         }
         }
-        assert len(self.documents.get_docs(self.db_session)) == 2
+        assert len(self.documents.get_docs()) == 2
 
 
         # Delete the first document
         # Delete the first document
         with mock_webui_user(id="2"):
         with mock_webui_user(id="2"):
@@ -103,4 +103,4 @@ class TestDocuments(AbstractPostgresTest):
                 self.create_url("/doc/delete?name=doc_name rework")
                 self.create_url("/doc/delete?name=doc_name rework")
             )
             )
         assert response.status_code == 200
         assert response.status_code == 200
-        assert len(self.documents.get_docs(self.db_session)) == 1
+        assert len(self.documents.get_docs()) == 1

+ 10 - 0
backend/test/apps/webui/routers/test_prompts.py

@@ -68,6 +68,16 @@ class TestPrompts(AbstractPostgresTest):
         assert data["content"] == "description Updated"
         assert data["content"] == "description Updated"
         assert data["user_id"] == "3"
         assert data["user_id"] == "3"
 
 
+        # Get prompt by command
+        with mock_webui_user(id="2"):
+            response = self.fast_api_client.get(self.create_url("/command/my-command2"))
+        assert response.status_code == 200
+        data = response.json()
+        assert data["command"] == "/my-command2"
+        assert data["title"] == "Hello World Updated"
+        assert data["content"] == "description Updated"
+        assert data["user_id"] == "3"
+
         # Delete prompt
         # Delete prompt
         with mock_webui_user(id="2"):
         with mock_webui_user(id="2"):
             response = self.fast_api_client.delete(
             response = self.fast_api_client.delete(

+ 0 - 2
backend/test/apps/webui/routers/test_users.py

@@ -33,7 +33,6 @@ class TestUsers(AbstractPostgresTest):
     def setup_method(self):
     def setup_method(self):
         super().setup_method()
         super().setup_method()
         self.users.insert_new_user(
         self.users.insert_new_user(
-            self.db_session,
             id="1",
             id="1",
             name="user 1",
             name="user 1",
             email="user1@openwebui.com",
             email="user1@openwebui.com",
@@ -41,7 +40,6 @@ class TestUsers(AbstractPostgresTest):
             role="user",
             role="user",
         )
         )
         self.users.insert_new_user(
         self.users.insert_new_user(
-            self.db_session,
             id="2",
             id="2",
             name="user 2",
             name="user 2",
             email="user2@openwebui.com",
             email="user2@openwebui.com",

+ 3 - 5
backend/utils/utils.py

@@ -2,7 +2,6 @@ from fastapi.security import HTTPBearer, HTTPAuthorizationCredentials
 from fastapi import HTTPException, status, Depends, Request
 from fastapi import HTTPException, status, Depends, Request
 from sqlalchemy.orm import Session
 from sqlalchemy.orm import Session
 
 
-from apps.webui.internal.db import get_db
 from apps.webui.models.users import Users
 from apps.webui.models.users import Users
 
 
 from pydantic import BaseModel
 from pydantic import BaseModel
@@ -79,7 +78,6 @@ def get_http_authorization_cred(auth_header: str):
 def get_current_user(
 def get_current_user(
     request: Request,
     request: Request,
     auth_token: HTTPAuthorizationCredentials = Depends(bearer_security),
     auth_token: HTTPAuthorizationCredentials = Depends(bearer_security),
-    db=Depends(get_db),
 ):
 ):
     token = None
     token = None
 
 
@@ -94,19 +92,19 @@ def get_current_user(
 
 
     # auth by api key
     # auth by api key
     if token.startswith("sk-"):
     if token.startswith("sk-"):
-        return get_current_user_by_api_key(db, token)
+        return get_current_user_by_api_key(token)
 
 
     # auth by jwt token
     # auth by jwt token
     data = decode_token(token)
     data = decode_token(token)
     if data != None and "id" in data:
     if data != None and "id" in data:
-        user = Users.get_user_by_id(db, data["id"])
+        user = Users.get_user_by_id(data["id"])
         if user is None:
         if user is None:
             raise HTTPException(
             raise HTTPException(
                 status_code=status.HTTP_401_UNAUTHORIZED,
                 status_code=status.HTTP_401_UNAUTHORIZED,
                 detail=ERROR_MESSAGES.INVALID_TOKEN,
                 detail=ERROR_MESSAGES.INVALID_TOKEN,
             )
             )
         else:
         else:
-            Users.update_user_last_active_by_id(db, user.id)
+            Users.update_user_last_active_by_id(user.id)
         return user
         return user
     else:
     else:
         raise HTTPException(
         raise HTTPException(