Explorar o código

feat(sqlalchemy): format backend

Jonathan Rohde hai 10 meses
pai
achega
c134eab27a

+ 3 - 2
backend/apps/webui/internal/db.py

@@ -53,7 +53,9 @@ if "sqlite" in SQLALCHEMY_DATABASE_URL:
     )
     )
 else:
 else:
     engine = create_engine(SQLALCHEMY_DATABASE_URL, pool_pre_ping=True)
     engine = create_engine(SQLALCHEMY_DATABASE_URL, pool_pre_ping=True)
-SessionLocal = sessionmaker(autocommit=False, autoflush=False, bind=engine, expire_on_commit=False)
+SessionLocal = sessionmaker(
+    autocommit=False, autoflush=False, bind=engine, expire_on_commit=False
+)
 Base = declarative_base()
 Base = declarative_base()
 
 
 
 
@@ -66,4 +68,3 @@ def get_session():
     except Exception as e:
     except Exception as e:
         db.rollback()
         db.rollback()
         raise e
         raise e
-

+ 7 - 13
backend/apps/webui/models/auths.py

@@ -126,9 +126,7 @@ class AuthsTable:
             else:
             else:
                 return None
                 return None
 
 
-    def authenticate_user(
-        self, email: str, password: str
-    ) -> Optional[UserModel]:
+    def authenticate_user(self, email: str, password: str) -> Optional[UserModel]:
         log.info(f"authenticate_user: {email}")
         log.info(f"authenticate_user: {email}")
         with get_session() as db:
         with get_session() as db:
             try:
             try:
@@ -144,9 +142,7 @@ class AuthsTable:
             except:
             except:
                 return None
                 return None
 
 
-    def authenticate_user_by_api_key(
-        self, api_key: str
-    ) -> Optional[UserModel]:
+    def authenticate_user_by_api_key(self, api_key: str) -> Optional[UserModel]:
         log.info(f"authenticate_user_by_api_key: {api_key}")
         log.info(f"authenticate_user_by_api_key: {api_key}")
         with get_session() as db:
         with get_session() as db:
             # if no api_key, return None
             # if no api_key, return None
@@ -159,9 +155,7 @@ class AuthsTable:
             except:
             except:
                 return False
                 return False
 
 
-    def authenticate_user_by_trusted_header(
-        self, email: str
-    ) -> Optional[UserModel]:
+    def authenticate_user_by_trusted_header(self, email: str) -> Optional[UserModel]:
         log.info(f"authenticate_user_by_trusted_header: {email}")
         log.info(f"authenticate_user_by_trusted_header: {email}")
         with get_session() as db:
         with get_session() as db:
             try:
             try:
@@ -172,12 +166,12 @@ class AuthsTable:
             except:
             except:
                 return None
                 return None
 
 
-    def update_user_password_by_id(
-        self, id: str, new_password: str
-    ) -> bool:
+    def update_user_password_by_id(self, id: str, new_password: str) -> bool:
         with get_session() as db:
         with get_session() as db:
             try:
             try:
-                result = db.query(Auth).filter_by(id=id).update({"password": new_password})
+                result = (
+                    db.query(Auth).filter_by(id=id).update({"password": new_password})
+                )
                 return True if result == 1 else False
                 return True if result == 1 else False
             except:
             except:
                 return False
                 return False

+ 15 - 21
backend/apps/webui/models/chats.py

@@ -79,9 +79,7 @@ class ChatTitleIdResponse(BaseModel):
 
 
 class ChatTable:
 class ChatTable:
 
 
-    def insert_new_chat(
-        self, user_id: str, form_data: ChatForm
-    ) -> Optional[ChatModel]:
+    def insert_new_chat(self, user_id: str, form_data: ChatForm) -> Optional[ChatModel]:
         with get_session() as db:
         with get_session() as db:
             id = str(uuid.uuid4())
             id = str(uuid.uuid4())
             chat = ChatModel(
             chat = ChatModel(
@@ -89,7 +87,9 @@ class ChatTable:
                     "id": id,
                     "id": id,
                     "user_id": user_id,
                     "user_id": user_id,
                     "title": (
                     "title": (
-                        form_data.chat["title"] if "title" in form_data.chat else "New Chat"
+                        form_data.chat["title"]
+                        if "title" in form_data.chat
+                        else "New Chat"
                     ),
                     ),
                     "chat": json.dumps(form_data.chat),
                     "chat": json.dumps(form_data.chat),
                     "created_at": int(time.time()),
                     "created_at": int(time.time()),
@@ -103,9 +103,7 @@ class ChatTable:
             db.refresh(result)
             db.refresh(result)
             return ChatModel.model_validate(result) if result else None
             return ChatModel.model_validate(result) if result else None
 
 
-    def update_chat_by_id(
-        self, id: str, chat: dict
-    ) -> Optional[ChatModel]:
+    def update_chat_by_id(self, id: str, chat: dict) -> Optional[ChatModel]:
         with get_session() as db:
         with get_session() as db:
             try:
             try:
                 chat_obj = db.get(Chat, id)
                 chat_obj = db.get(Chat, id)
@@ -119,9 +117,7 @@ class ChatTable:
             except Exception as e:
             except Exception as e:
                 return None
                 return None
 
 
-    def insert_shared_chat_by_chat_id(
-        self, chat_id: str
-    ) -> Optional[ChatModel]:
+    def insert_shared_chat_by_chat_id(self, chat_id: str) -> Optional[ChatModel]:
         with get_session() as db:
         with get_session() as db:
             # Get the existing chat to share
             # Get the existing chat to share
             chat = db.get(Chat, chat_id)
             chat = db.get(Chat, chat_id)
@@ -145,14 +141,14 @@ class ChatTable:
             db.refresh(shared_result)
             db.refresh(shared_result)
             # Update the original chat with the share_id
             # Update the original chat with the share_id
             result = (
             result = (
-                db.query(Chat).filter_by(id=chat_id).update({"share_id": shared_chat.id})
+                db.query(Chat)
+                .filter_by(id=chat_id)
+                .update({"share_id": shared_chat.id})
             )
             )
 
 
             return shared_chat if (shared_result and result) else None
             return shared_chat if (shared_result and result) else None
 
 
-    def update_shared_chat_by_chat_id(
-        self, chat_id: str
-    ) -> Optional[ChatModel]:
+    def update_shared_chat_by_chat_id(self, chat_id: str) -> Optional[ChatModel]:
         with get_session() as db:
         with get_session() as db:
             try:
             try:
                 print("update_shared_chat_by_id")
                 print("update_shared_chat_by_id")
@@ -271,9 +267,7 @@ class ChatTable:
         except Exception as e:
         except Exception as e:
             return None
             return None
 
 
-    def get_chat_by_id_and_user_id(
-        self, id: str, user_id: str
-    ) -> Optional[ChatModel]:
+    def get_chat_by_id_and_user_id(self, id: str, user_id: str) -> Optional[ChatModel]:
         try:
         try:
             with get_session() as db:
             with get_session() as db:
                 chat = db.query(Chat).filter_by(id=id, user_id=user_id).first()
                 chat = db.query(Chat).filter_by(id=id, user_id=user_id).first()
@@ -293,13 +287,13 @@ class ChatTable:
     def get_chats_by_user_id(self, user_id: str) -> List[ChatModel]:
     def get_chats_by_user_id(self, user_id: str) -> List[ChatModel]:
         with get_session() as db:
         with get_session() as db:
             all_chats = (
             all_chats = (
-                db.query(Chat).filter_by(user_id=user_id).order_by(Chat.updated_at.desc())
+                db.query(Chat)
+                .filter_by(user_id=user_id)
+                .order_by(Chat.updated_at.desc())
             )
             )
             return [ChatModel.model_validate(chat) for chat in all_chats]
             return [ChatModel.model_validate(chat) for chat in all_chats]
 
 
-    def get_archived_chats_by_user_id(
-        self, user_id: str
-    ) -> List[ChatModel]:
+    def get_archived_chats_by_user_id(self, user_id: str) -> List[ChatModel]:
         with get_session() as db:
         with get_session() as db:
             all_chats = (
             all_chats = (
                 db.query(Chat)
                 db.query(Chat)

+ 3 - 1
backend/apps/webui/models/documents.py

@@ -106,7 +106,9 @@ class DocumentsTable:
 
 
     def get_docs(self) -> List[DocumentModel]:
     def get_docs(self) -> List[DocumentModel]:
         with get_session() as db:
         with get_session() as db:
-            return [DocumentModel.model_validate(doc) for doc in db.query(Document).all()]
+            return [
+                DocumentModel.model_validate(doc) for doc in db.query(Document).all()
+            ]
 
 
     def update_doc_by_name(
     def update_doc_by_name(
         self, name: str, form_data: DocumentUpdateForm
         self, name: str, form_data: DocumentUpdateForm

+ 1 - 0
backend/apps/webui/models/files.py

@@ -39,6 +39,7 @@ class FileModel(BaseModel):
 
 
     model_config = ConfigDict(from_attributes=True)
     model_config = ConfigDict(from_attributes=True)
 
 
+
 ####################
 ####################
 # Forms
 # Forms
 ####################
 ####################

+ 15 - 11
backend/apps/webui/models/functions.py

@@ -142,9 +142,9 @@ class FunctionsTable:
             with get_session() as db:
             with get_session() as db:
                 return [
                 return [
                     FunctionModel.model_validate(function)
                     FunctionModel.model_validate(function)
-                    for function in db.query(Function).filter_by(
-                        type=type, is_active=True
-                    ).all()
+                    for function in db.query(Function)
+                    .filter_by(type=type, is_active=True)
+                    .all()
                 ]
                 ]
         else:
         else:
             with get_session() as db:
             with get_session() as db:
@@ -220,10 +220,12 @@ class FunctionsTable:
     def update_function_by_id(self, id: str, updated: dict) -> Optional[FunctionModel]:
     def update_function_by_id(self, id: str, updated: dict) -> Optional[FunctionModel]:
         try:
         try:
             with get_session() as db:
             with get_session() as db:
-                db.query(Function).filter_by(id=id).update({
-                    **updated,
-                    "updated_at": int(time.time()),
-                })
+                db.query(Function).filter_by(id=id).update(
+                    {
+                        **updated,
+                        "updated_at": int(time.time()),
+                    }
+                )
                 db.commit()
                 db.commit()
                 return self.get_function_by_id(id)
                 return self.get_function_by_id(id)
         except:
         except:
@@ -232,10 +234,12 @@ class FunctionsTable:
     def deactivate_all_functions(self) -> Optional[bool]:
     def deactivate_all_functions(self) -> Optional[bool]:
         try:
         try:
             with get_session() as db:
             with get_session() as db:
-                db.query(Function).update({
-                    "is_active": False,
-                    "updated_at": int(time.time()),
-                })
+                db.query(Function).update(
+                    {
+                        "is_active": False,
+                        "updated_at": int(time.time()),
+                    }
+                )
                 db.commit()
                 db.commit()
             return True
             return True
         except:
         except:

+ 1 - 3
backend/apps/webui/models/models.py

@@ -153,9 +153,7 @@ class ModelsTable:
         except:
         except:
             return None
             return None
 
 
-    def update_model_by_id(
-        self, id: str, model: ModelForm
-    ) -> Optional[ModelModel]:
+    def update_model_by_id(self, id: str, model: ModelForm) -> Optional[ModelModel]:
         try:
         try:
             # update only the fields that are present in the model
             # update only the fields that are present in the model
             with get_session() as db:
             with get_session() as db:

+ 3 - 1
backend/apps/webui/models/prompts.py

@@ -83,7 +83,9 @@ class PromptsTable:
 
 
     def get_prompts(self) -> List[PromptModel]:
     def get_prompts(self) -> List[PromptModel]:
         with get_session() as db:
         with get_session() as db:
-            return [PromptModel.model_validate(prompt) for prompt in db.query(Prompt).all()]
+            return [
+                PromptModel.model_validate(prompt) for prompt in db.query(Prompt).all()
+            ]
 
 
     def update_prompt_by_command(
     def update_prompt_by_command(
         self, command: str, form_data: PromptForm
         self, command: str, form_data: PromptForm

+ 8 - 10
backend/apps/webui/models/tags.py

@@ -79,9 +79,7 @@ class ChatTagsResponse(BaseModel):
 
 
 class TagTable:
 class TagTable:
 
 
-    def insert_new_tag(
-        self, name: str, user_id: str
-    ) -> Optional[TagModel]:
+    def insert_new_tag(self, name: str, user_id: str) -> Optional[TagModel]:
         id = str(uuid.uuid4())
         id = str(uuid.uuid4())
         tag = TagModel(**{"id": id, "user_id": user_id, "name": name})
         tag = TagModel(**{"id": id, "user_id": user_id, "name": name})
         try:
         try:
@@ -201,11 +199,13 @@ class TagTable:
         self, tag_name: str, user_id: str
         self, tag_name: str, user_id: str
     ) -> int:
     ) -> int:
         with get_session() as db:
         with get_session() as db:
-            return db.query(ChatIdTag).filter_by(tag_name=tag_name, user_id=user_id).count()
+            return (
+                db.query(ChatIdTag)
+                .filter_by(tag_name=tag_name, user_id=user_id)
+                .count()
+            )
 
 
-    def delete_tag_by_tag_name_and_user_id(
-        self, tag_name: str, user_id: str
-    ) -> bool:
+    def delete_tag_by_tag_name_and_user_id(self, tag_name: str, user_id: str) -> bool:
         try:
         try:
             with get_session() as db:
             with get_session() as db:
                 res = (
                 res = (
@@ -252,9 +252,7 @@ class TagTable:
             log.error(f"delete_tag: {e}")
             log.error(f"delete_tag: {e}")
             return False
             return False
 
 
-    def delete_tags_by_chat_id_and_user_id(
-        self, chat_id: str, user_id: str
-    ) -> bool:
+    def delete_tags_by_chat_id_and_user_id(self, chat_id: str, user_id: str) -> bool:
         tags = self.get_tags_by_chat_id_and_user_id(chat_id, user_id)
         tags = self.get_tags_by_chat_id_and_user_id(chat_id, user_id)
 
 
         for tag in tags:
         for tag in tags:

+ 6 - 10
backend/apps/webui/models/users.py

@@ -165,9 +165,7 @@ class UsersTable:
             except:
             except:
                 return None
                 return None
 
 
-    def update_user_role_by_id(
-        self, id: str, role: str
-    ) -> Optional[UserModel]:
+    def update_user_role_by_id(self, id: str, role: str) -> Optional[UserModel]:
         with get_session() as db:
         with get_session() as db:
             try:
             try:
                 db.query(User).filter_by(id=id).update({"role": role})
                 db.query(User).filter_by(id=id).update({"role": role})
@@ -193,12 +191,12 @@ class UsersTable:
             except:
             except:
                 return None
                 return None
 
 
-    def update_user_last_active_by_id(
-        self, id: str
-    ) -> Optional[UserModel]:
+    def update_user_last_active_by_id(self, id: str) -> Optional[UserModel]:
         with get_session() as db:
         with get_session() as db:
             try:
             try:
-                db.query(User).filter_by(id=id).update({"last_active_at": int(time.time())})
+                db.query(User).filter_by(id=id).update(
+                    {"last_active_at": int(time.time())}
+                )
 
 
                 user = db.query(User).filter_by(id=id).first()
                 user = db.query(User).filter_by(id=id).first()
                 return UserModel.model_validate(user)
                 return UserModel.model_validate(user)
@@ -217,9 +215,7 @@ class UsersTable:
             except:
             except:
                 return None
                 return None
 
 
-    def update_user_by_id(
-        self, id: str, updated: dict
-    ) -> Optional[UserModel]:
+    def update_user_by_id(self, id: str, updated: dict) -> Optional[UserModel]:
         with get_session() as db:
         with get_session() as db:
             try:
             try:
                 db.query(User).filter_by(id=id).update(updated)
                 db.query(User).filter_by(id=id).update(updated)

+ 4 - 10
backend/apps/webui/routers/auths.py

@@ -78,8 +78,7 @@ async def get_session_user(
 
 
 @router.post("/update/profile", response_model=UserResponse)
 @router.post("/update/profile", response_model=UserResponse)
 async def update_profile(
 async def update_profile(
-    form_data: UpdateProfileForm,
-    session_user=Depends(get_current_user)
+    form_data: UpdateProfileForm, session_user=Depends(get_current_user)
 ):
 ):
     if session_user:
     if session_user:
         user = Users.update_user_by_id(
         user = Users.update_user_by_id(
@@ -101,8 +100,7 @@ async def update_profile(
 
 
 @router.post("/update/password", response_model=bool)
 @router.post("/update/password", response_model=bool)
 async def update_password(
 async def update_password(
-    form_data: UpdatePasswordForm,
-    session_user=Depends(get_current_user)
+    form_data: UpdatePasswordForm, session_user=Depends(get_current_user)
 ):
 ):
     if WEBUI_AUTH_TRUSTED_EMAIL_HEADER:
     if WEBUI_AUTH_TRUSTED_EMAIL_HEADER:
         raise HTTPException(400, detail=ERROR_MESSAGES.ACTION_PROHIBITED)
         raise HTTPException(400, detail=ERROR_MESSAGES.ACTION_PROHIBITED)
@@ -269,9 +267,7 @@ async def signup(request: Request, response: Response, form_data: SignupForm):
 
 
 
 
 @router.post("/add", response_model=SigninResponse)
 @router.post("/add", response_model=SigninResponse)
-async def add_user(
-    form_data: AddUserForm, user=Depends(get_admin_user)
-):
+async def add_user(form_data: AddUserForm, user=Depends(get_admin_user)):
 
 
     if not validate_email_format(form_data.email.lower()):
     if not validate_email_format(form_data.email.lower()):
         raise HTTPException(
         raise HTTPException(
@@ -316,9 +312,7 @@ async def add_user(
 
 
 
 
 @router.get("/admin/details")
 @router.get("/admin/details")
-async def get_admin_details(
-    request: Request, user=Depends(get_current_user)
-):
+async def get_admin_details(request: Request, user=Depends(get_current_user)):
     if request.app.state.config.SHOW_ADMIN_DETAILS:
     if request.app.state.config.SHOW_ADMIN_DETAILS:
         admin_email = request.app.state.config.ADMIN_EMAIL
         admin_email = request.app.state.config.ADMIN_EMAIL
         admin_name = None
         admin_name = None

+ 10 - 30
backend/apps/webui/routers/chats.py

@@ -55,9 +55,7 @@ async def get_session_user_chat_list(
 
 
 
 
 @router.delete("/", response_model=bool)
 @router.delete("/", response_model=bool)
-async def delete_all_user_chats(
-    request: Request, user=Depends(get_current_user)
-):
+async def delete_all_user_chats(request: Request, user=Depends(get_current_user)):
 
 
     if (
     if (
         user.role == "user"
         user.role == "user"
@@ -95,9 +93,7 @@ async def get_user_chat_list_by_user_id(
 
 
 
 
 @router.post("/new", response_model=Optional[ChatResponse])
 @router.post("/new", response_model=Optional[ChatResponse])
-async def create_new_chat(
-    form_data: ChatForm, user=Depends(get_current_user)
-):
+async def create_new_chat(form_data: ChatForm, user=Depends(get_current_user)):
     try:
     try:
         chat = Chats.insert_new_chat(user.id, form_data)
         chat = Chats.insert_new_chat(user.id, form_data)
         return ChatResponse(**{**chat.model_dump(), "chat": json.loads(chat.chat)})
         return ChatResponse(**{**chat.model_dump(), "chat": json.loads(chat.chat)})
@@ -180,9 +176,7 @@ async def archive_all_chats(user=Depends(get_current_user)):
 
 
 
 
 @router.get("/share/{share_id}", response_model=Optional[ChatResponse])
 @router.get("/share/{share_id}", response_model=Optional[ChatResponse])
-async def get_shared_chat_by_id(
-    share_id: str, user=Depends(get_current_user)
-):
+async def get_shared_chat_by_id(share_id: str, user=Depends(get_current_user)):
     if user.role == "pending":
     if user.role == "pending":
         raise HTTPException(
         raise HTTPException(
             status_code=status.HTTP_401_UNAUTHORIZED, detail=ERROR_MESSAGES.NOT_FOUND
             status_code=status.HTTP_401_UNAUTHORIZED, detail=ERROR_MESSAGES.NOT_FOUND
@@ -225,9 +219,7 @@ async def get_user_chat_list_by_tag_name(
         )
         )
     ]
     ]
 
 
-    chats = Chats.get_chat_list_by_chat_ids(
-        chat_ids, form_data.skip, form_data.limit
-    )
+    chats = Chats.get_chat_list_by_chat_ids(chat_ids, form_data.skip, form_data.limit)
 
 
     if len(chats) == 0:
     if len(chats) == 0:
         Tags.delete_tag_by_tag_name_and_user_id(form_data.name, user.id)
         Tags.delete_tag_by_tag_name_and_user_id(form_data.name, user.id)
@@ -297,9 +289,7 @@ async def update_chat_by_id(
 
 
 
 
 @router.delete("/{id}", response_model=bool)
 @router.delete("/{id}", response_model=bool)
-async def delete_chat_by_id(
-    request: Request, id: str, user=Depends(get_current_user)
-):
+async def delete_chat_by_id(request: Request, id: str, user=Depends(get_current_user)):
 
 
     if user.role == "admin":
     if user.role == "admin":
         result = Chats.delete_chat_by_id(id)
         result = Chats.delete_chat_by_id(id)
@@ -347,9 +337,7 @@ async def clone_chat_by_id(id: str, user=Depends(get_current_user)):
 
 
 
 
 @router.get("/{id}/archive", response_model=Optional[ChatResponse])
 @router.get("/{id}/archive", response_model=Optional[ChatResponse])
-async def archive_chat_by_id(
-    id: str, user=Depends(get_current_user)
-):
+async def archive_chat_by_id(id: str, user=Depends(get_current_user)):
     chat = Chats.get_chat_by_id_and_user_id(id, user.id)
     chat = Chats.get_chat_by_id_and_user_id(id, user.id)
     if chat:
     if chat:
         chat = Chats.toggle_chat_archive_by_id(id)
         chat = Chats.toggle_chat_archive_by_id(id)
@@ -398,9 +386,7 @@ async def share_chat_by_id(id: str, user=Depends(get_current_user)):
 
 
 
 
 @router.delete("/{id}/share", response_model=Optional[bool])
 @router.delete("/{id}/share", response_model=Optional[bool])
-async def delete_shared_chat_by_id(
-    id: str, user=Depends(get_current_user)
-):
+async def delete_shared_chat_by_id(id: str, user=Depends(get_current_user)):
     chat = Chats.get_chat_by_id_and_user_id(id, user.id)
     chat = Chats.get_chat_by_id_and_user_id(id, user.id)
     if chat:
     if chat:
         if not chat.share_id:
         if not chat.share_id:
@@ -423,9 +409,7 @@ async def delete_shared_chat_by_id(
 
 
 
 
 @router.get("/{id}/tags", response_model=List[TagModel])
 @router.get("/{id}/tags", response_model=List[TagModel])
-async def get_chat_tags_by_id(
-    id: str, user=Depends(get_current_user)
-):
+async def get_chat_tags_by_id(id: str, user=Depends(get_current_user)):
     tags = Tags.get_tags_by_chat_id_and_user_id(id, user.id)
     tags = Tags.get_tags_by_chat_id_and_user_id(id, user.id)
 
 
     if tags != None:
     if tags != None:
@@ -443,9 +427,7 @@ async def get_chat_tags_by_id(
 
 
 @router.post("/{id}/tags", response_model=Optional[ChatIdTagModel])
 @router.post("/{id}/tags", response_model=Optional[ChatIdTagModel])
 async def add_chat_tag_by_id(
 async def add_chat_tag_by_id(
-    id: str,
-    form_data: ChatIdTagForm,
-    user=Depends(get_current_user)
+    id: str, form_data: ChatIdTagForm, user=Depends(get_current_user)
 ):
 ):
     tags = Tags.get_tags_by_chat_id_and_user_id(id, user.id)
     tags = Tags.get_tags_by_chat_id_and_user_id(id, user.id)
 
 
@@ -494,9 +476,7 @@ async def delete_chat_tag_by_id(
 
 
 
 
 @router.delete("/{id}/tags/all", response_model=Optional[bool])
 @router.delete("/{id}/tags/all", response_model=Optional[bool])
-async def delete_all_chat_tags_by_id(
-    id: str, user=Depends(get_current_user)
-):
+async def delete_all_chat_tags_by_id(id: str, user=Depends(get_current_user)):
     result = Tags.delete_tags_by_chat_id_and_user_id(id, user.id)
     result = Tags.delete_tags_by_chat_id_and_user_id(id, user.id)
 
 
     if result:
     if result:

+ 5 - 15
backend/apps/webui/routers/documents.py

@@ -44,9 +44,7 @@ async def get_documents(user=Depends(get_current_user)):
 
 
 
 
 @router.post("/create", response_model=Optional[DocumentResponse])
 @router.post("/create", response_model=Optional[DocumentResponse])
-async def create_new_doc(
-    form_data: DocumentForm, user=Depends(get_admin_user)
-):
+async def create_new_doc(form_data: DocumentForm, user=Depends(get_admin_user)):
     doc = Documents.get_doc_by_name(form_data.name)
     doc = Documents.get_doc_by_name(form_data.name)
     if doc == None:
     if doc == None:
         doc = Documents.insert_new_doc(user.id, form_data)
         doc = Documents.insert_new_doc(user.id, form_data)
@@ -76,9 +74,7 @@ async def create_new_doc(
 
 
 
 
 @router.get("/doc", response_model=Optional[DocumentResponse])
 @router.get("/doc", response_model=Optional[DocumentResponse])
-async def get_doc_by_name(
-    name: str, user=Depends(get_current_user)
-):
+async def get_doc_by_name(name: str, user=Depends(get_current_user)):
     doc = Documents.get_doc_by_name(name)
     doc = Documents.get_doc_by_name(name)
 
 
     if doc:
     if doc:
@@ -110,12 +106,8 @@ class TagDocumentForm(BaseModel):
 
 
 
 
 @router.post("/doc/tags", response_model=Optional[DocumentResponse])
 @router.post("/doc/tags", response_model=Optional[DocumentResponse])
-async def tag_doc_by_name(
-    form_data: TagDocumentForm, user=Depends(get_current_user)
-):
-    doc = Documents.update_doc_content_by_name(
-        form_data.name, {"tags": form_data.tags}
-    )
+async def tag_doc_by_name(form_data: TagDocumentForm, user=Depends(get_current_user)):
+    doc = Documents.update_doc_content_by_name(form_data.name, {"tags": form_data.tags})
 
 
     if doc:
     if doc:
         return DocumentResponse(
         return DocumentResponse(
@@ -163,8 +155,6 @@ async def update_doc_by_name(
 
 
 
 
 @router.delete("/doc/delete", response_model=bool)
 @router.delete("/doc/delete", response_model=bool)
-async def delete_doc_by_name(
-    name: str, user=Depends(get_admin_user)
-):
+async def delete_doc_by_name(name: str, user=Depends(get_admin_user)):
     result = Documents.delete_doc_by_name(name)
     result = Documents.delete_doc_by_name(name)
     return result
     return result

+ 1 - 4
backend/apps/webui/routers/files.py

@@ -50,10 +50,7 @@ router = APIRouter()
 
 
 
 
 @router.post("/")
 @router.post("/")
-def upload_file(
-    file: UploadFile = File(...),
-    user=Depends(get_verified_user)
-):
+def upload_file(file: UploadFile = File(...), user=Depends(get_verified_user)):
     log.info(f"file.content_type: {file.content_type}")
     log.info(f"file.content_type: {file.content_type}")
     try:
     try:
         unsanitized_filename = file.filename
         unsanitized_filename = file.filename

+ 1 - 3
backend/apps/webui/routers/memories.py

@@ -167,9 +167,7 @@ async def delete_memory_by_user_id(user=Depends(get_verified_user)):
 
 
 
 
 @router.delete("/{memory_id}", response_model=bool)
 @router.delete("/{memory_id}", response_model=bool)
-async def delete_memory_by_id(
-    memory_id: str, user=Depends(get_verified_user)
-):
+async def delete_memory_by_id(memory_id: str, user=Depends(get_verified_user)):
     result = Memories.delete_memory_by_id_and_user_id(memory_id, user.id)
     result = Memories.delete_memory_by_id_and_user_id(memory_id, user.id)
 
 
     if result:
     if result:

+ 3 - 9
backend/apps/webui/routers/prompts.py

@@ -29,9 +29,7 @@ async def get_prompts(user=Depends(get_current_user)):
 
 
 
 
 @router.post("/create", response_model=Optional[PromptModel])
 @router.post("/create", response_model=Optional[PromptModel])
-async def create_new_prompt(
-    form_data: PromptForm, user=Depends(get_admin_user)
-):
+async def create_new_prompt(form_data: PromptForm, user=Depends(get_admin_user)):
     prompt = Prompts.get_prompt_by_command(form_data.command)
     prompt = Prompts.get_prompt_by_command(form_data.command)
     if prompt == None:
     if prompt == None:
         prompt = Prompts.insert_new_prompt(user.id, form_data)
         prompt = Prompts.insert_new_prompt(user.id, form_data)
@@ -54,9 +52,7 @@ async def create_new_prompt(
 
 
 
 
 @router.get("/command/{command}", response_model=Optional[PromptModel])
 @router.get("/command/{command}", response_model=Optional[PromptModel])
-async def get_prompt_by_command(
-    command: str, user=Depends(get_current_user)
-):
+async def get_prompt_by_command(command: str, user=Depends(get_current_user)):
     prompt = Prompts.get_prompt_by_command(f"/{command}")
     prompt = Prompts.get_prompt_by_command(f"/{command}")
 
 
     if prompt:
     if prompt:
@@ -95,8 +91,6 @@ async def update_prompt_by_command(
 
 
 
 
 @router.delete("/command/{command}/delete", response_model=bool)
 @router.delete("/command/{command}/delete", response_model=bool)
-async def delete_prompt_by_command(
-    command: str, user=Depends(get_admin_user)
-):
+async def delete_prompt_by_command(command: str, user=Depends(get_admin_user)):
     result = Prompts.delete_prompt_by_command(f"/{command}")
     result = Prompts.delete_prompt_by_command(f"/{command}")
     return result
     return result

+ 1 - 3
backend/apps/webui/routers/tools.py

@@ -180,9 +180,7 @@ async def update_toolkit_by_id(
 
 
 
 
 @router.delete("/id/{id}/delete", response_model=bool)
 @router.delete("/id/{id}/delete", response_model=bool)
-async def delete_toolkit_by_id(
-    request: Request, id: str, user=Depends(get_admin_user)
-):
+async def delete_toolkit_by_id(request: Request, id: str, user=Depends(get_admin_user)):
     result = Tools.delete_tool_by_id(id)
     result = Tools.delete_tool_by_id(id)
 
 
     if result:
     if result:

+ 7 - 21
backend/apps/webui/routers/users.py

@@ -40,9 +40,7 @@ router = APIRouter()
 
 
 
 
 @router.get("/", response_model=List[UserModel])
 @router.get("/", response_model=List[UserModel])
-async def get_users(
-    skip: int = 0, limit: int = 50, user=Depends(get_admin_user)
-):
+async def get_users(skip: int = 0, limit: int = 50, user=Depends(get_admin_user)):
     return Users.get_users(skip, limit)
     return Users.get_users(skip, limit)
 
 
 
 
@@ -70,9 +68,7 @@ async def update_user_permissions(
 
 
 
 
 @router.post("/update/role", response_model=Optional[UserModel])
 @router.post("/update/role", response_model=Optional[UserModel])
-async def update_user_role(
-    form_data: UserRoleUpdateForm, user=Depends(get_admin_user)
-):
+async def update_user_role(form_data: UserRoleUpdateForm, user=Depends(get_admin_user)):
 
 
     if user.id != form_data.id and form_data.id != Users.get_first_user().id:
     if user.id != form_data.id and form_data.id != Users.get_first_user().id:
         return Users.update_user_role_by_id(form_data.id, form_data.role)
         return Users.update_user_role_by_id(form_data.id, form_data.role)
@@ -89,9 +85,7 @@ async def update_user_role(
 
 
 
 
 @router.get("/user/settings", response_model=Optional[UserSettings])
 @router.get("/user/settings", response_model=Optional[UserSettings])
-async def get_user_settings_by_session_user(
-    user=Depends(get_verified_user)
-):
+async def get_user_settings_by_session_user(user=Depends(get_verified_user)):
     user = Users.get_user_by_id(user.id)
     user = Users.get_user_by_id(user.id)
     if user:
     if user:
         return user.settings
         return user.settings
@@ -127,9 +121,7 @@ async def update_user_settings_by_session_user(
 
 
 
 
 @router.get("/user/info", response_model=Optional[dict])
 @router.get("/user/info", response_model=Optional[dict])
-async def get_user_info_by_session_user(
-    user=Depends(get_verified_user)
-):
+async def get_user_info_by_session_user(user=Depends(get_verified_user)):
     user = Users.get_user_by_id(user.id)
     user = Users.get_user_by_id(user.id)
     if user:
     if user:
         return user.info
         return user.info
@@ -154,9 +146,7 @@ async def update_user_info_by_session_user(
         if user.info is None:
         if user.info is None:
             user.info = {}
             user.info = {}
 
 
-        user = Users.update_user_by_id(
-            user.id, {"info": {**user.info, **form_data}}
-        )
+        user = Users.update_user_by_id(user.id, {"info": {**user.info, **form_data}})
         if user:
         if user:
             return user.info
             return user.info
         else:
         else:
@@ -182,9 +172,7 @@ class UserResponse(BaseModel):
 
 
 
 
 @router.get("/{user_id}", response_model=UserResponse)
 @router.get("/{user_id}", response_model=UserResponse)
-async def get_user_by_id(
-    user_id: str, user=Depends(get_verified_user)
-):
+async def get_user_by_id(user_id: str, user=Depends(get_verified_user)):
 
 
     # Check if user_id is a shared chat
     # Check if user_id is a shared chat
     # If it is, get the user_id from the chat
     # If it is, get the user_id from the chat
@@ -267,9 +255,7 @@ async def update_user_by_id(
 
 
 
 
 @router.delete("/{user_id}", response_model=bool)
 @router.delete("/{user_id}", response_model=bool)
-async def delete_user_by_id(
-    user_id: str, user=Depends(get_admin_user)
-):
+async def delete_user_by_id(user_id: str, user=Depends(get_admin_user)):
     if user.id != user_id:
     if user.id != user_id:
         result = Auths.delete_auth_by_id(user_id)
         result = Auths.delete_auth_by_id(user_id)
 
 

+ 3 - 1
backend/main.py

@@ -175,7 +175,9 @@ https://github.com/open-webui/open-webui
 def run_migrations():
 def run_migrations():
     env = os.environ.copy()
     env = os.environ.copy()
     env["DATABASE_URL"] = DATABASE_URL
     env["DATABASE_URL"] = DATABASE_URL
-    migration_task = subprocess.run(["alembic", f"-c{BACKEND_DIR}/alembic.ini", "upgrade", "head"], env=env)
+    migration_task = subprocess.run(
+        ["alembic", f"-c{BACKEND_DIR}/alembic.ini", "upgrade", "head"], env=env
+    )
     if migration_task.returncode > 0:
     if migration_task.returncode > 0:
         raise ValueError("Error running migrations")
         raise ValueError("Error running migrations")
 
 

+ 134 - 121
backend/migrations/versions/ba76b0bae648_init.py

@@ -5,6 +5,7 @@ Revises:
 Create Date: 2024-06-24 09:09:11.636336
 Create Date: 2024-06-24 09:09:11.636336
 
 
 """
 """
+
 from typing import Sequence, Union
 from typing import Sequence, Union
 
 
 from alembic import op
 from alembic import op
@@ -13,7 +14,7 @@ import apps.webui.internal.db
 
 
 
 
 # revision identifiers, used by Alembic.
 # revision identifiers, used by Alembic.
-revision: str = 'ba76b0bae648'
+revision: str = "ba76b0bae648"
 down_revision: Union[str, None] = None
 down_revision: Union[str, None] = None
 branch_labels: Union[str, Sequence[str], None] = None
 branch_labels: Union[str, Sequence[str], None] = None
 depends_on: Union[str, Sequence[str], None] = None
 depends_on: Union[str, Sequence[str], None] = None
@@ -21,141 +22,153 @@ depends_on: Union[str, Sequence[str], None] = None
 
 
 def upgrade() -> None:
 def upgrade() -> None:
     # ### commands auto generated by Alembic - please adjust! ###
     # ### commands auto generated by Alembic - please adjust! ###
-    op.create_table('auth',
-    sa.Column('id', sa.String(), nullable=False),
-    sa.Column('email', sa.String(), nullable=True),
-    sa.Column('password', sa.String(), nullable=True),
-    sa.Column('active', sa.Boolean(), nullable=True),
-    sa.PrimaryKeyConstraint('id')
+    op.create_table(
+        "auth",
+        sa.Column("id", sa.String(), nullable=False),
+        sa.Column("email", sa.String(), nullable=True),
+        sa.Column("password", sa.String(), nullable=True),
+        sa.Column("active", sa.Boolean(), nullable=True),
+        sa.PrimaryKeyConstraint("id"),
     )
     )
-    op.create_table('chat',
-    sa.Column('id', sa.String(), nullable=False),
-    sa.Column('user_id', sa.String(), nullable=True),
-    sa.Column('title', sa.String(), nullable=True),
-    sa.Column('chat', sa.String(), nullable=True),
-    sa.Column('created_at', sa.BigInteger(), nullable=True),
-    sa.Column('updated_at', sa.BigInteger(), nullable=True),
-    sa.Column('share_id', sa.String(), nullable=True),
-    sa.Column('archived', sa.Boolean(), nullable=True),
-    sa.PrimaryKeyConstraint('id'),
-    sa.UniqueConstraint('share_id')
+    op.create_table(
+        "chat",
+        sa.Column("id", sa.String(), nullable=False),
+        sa.Column("user_id", sa.String(), nullable=True),
+        sa.Column("title", sa.String(), nullable=True),
+        sa.Column("chat", sa.String(), nullable=True),
+        sa.Column("created_at", sa.BigInteger(), nullable=True),
+        sa.Column("updated_at", sa.BigInteger(), nullable=True),
+        sa.Column("share_id", sa.String(), nullable=True),
+        sa.Column("archived", sa.Boolean(), nullable=True),
+        sa.PrimaryKeyConstraint("id"),
+        sa.UniqueConstraint("share_id"),
     )
     )
-    op.create_table('chatidtag',
-    sa.Column('id', sa.String(), nullable=False),
-    sa.Column('tag_name', sa.String(), nullable=True),
-    sa.Column('chat_id', sa.String(), nullable=True),
-    sa.Column('user_id', sa.String(), nullable=True),
-    sa.Column('timestamp', sa.BigInteger(), nullable=True),
-    sa.PrimaryKeyConstraint('id')
+    op.create_table(
+        "chatidtag",
+        sa.Column("id", sa.String(), nullable=False),
+        sa.Column("tag_name", sa.String(), nullable=True),
+        sa.Column("chat_id", sa.String(), nullable=True),
+        sa.Column("user_id", sa.String(), nullable=True),
+        sa.Column("timestamp", sa.BigInteger(), nullable=True),
+        sa.PrimaryKeyConstraint("id"),
     )
     )
-    op.create_table('document',
-    sa.Column('collection_name', sa.String(), nullable=False),
-    sa.Column('name', sa.String(), nullable=True),
-    sa.Column('title', sa.String(), nullable=True),
-    sa.Column('filename', sa.String(), nullable=True),
-    sa.Column('content', sa.String(), nullable=True),
-    sa.Column('user_id', sa.String(), nullable=True),
-    sa.Column('timestamp', sa.BigInteger(), nullable=True),
-    sa.PrimaryKeyConstraint('collection_name'),
-    sa.UniqueConstraint('name')
+    op.create_table(
+        "document",
+        sa.Column("collection_name", sa.String(), nullable=False),
+        sa.Column("name", sa.String(), nullable=True),
+        sa.Column("title", sa.String(), nullable=True),
+        sa.Column("filename", sa.String(), nullable=True),
+        sa.Column("content", sa.String(), nullable=True),
+        sa.Column("user_id", sa.String(), nullable=True),
+        sa.Column("timestamp", sa.BigInteger(), nullable=True),
+        sa.PrimaryKeyConstraint("collection_name"),
+        sa.UniqueConstraint("name"),
     )
     )
-    op.create_table('file',
-    sa.Column('id', sa.String(), nullable=False),
-    sa.Column('user_id', sa.String(), nullable=True),
-    sa.Column('filename', sa.String(), nullable=True),
-    sa.Column('meta', apps.webui.internal.db.JSONField(), nullable=True),
-    sa.Column('created_at', sa.BigInteger(), nullable=True),
-    sa.PrimaryKeyConstraint('id')
+    op.create_table(
+        "file",
+        sa.Column("id", sa.String(), nullable=False),
+        sa.Column("user_id", sa.String(), nullable=True),
+        sa.Column("filename", sa.String(), nullable=True),
+        sa.Column("meta", apps.webui.internal.db.JSONField(), nullable=True),
+        sa.Column("created_at", sa.BigInteger(), nullable=True),
+        sa.PrimaryKeyConstraint("id"),
     )
     )
-    op.create_table('function',
-    sa.Column('id', sa.String(), nullable=False),
-    sa.Column('user_id', sa.String(), nullable=True),
-    sa.Column('name', sa.Text(), nullable=True),
-    sa.Column('type', sa.Text(), nullable=True),
-    sa.Column('content', sa.Text(), nullable=True),
-    sa.Column('meta', apps.webui.internal.db.JSONField(), nullable=True),
-    sa.Column('valves', apps.webui.internal.db.JSONField(), nullable=True),
-    sa.Column('is_active', sa.Boolean(), nullable=True),
-    sa.Column('updated_at', sa.BigInteger(), nullable=True),
-    sa.Column('created_at', sa.BigInteger(), nullable=True),
-    sa.PrimaryKeyConstraint('id')
+    op.create_table(
+        "function",
+        sa.Column("id", sa.String(), nullable=False),
+        sa.Column("user_id", sa.String(), nullable=True),
+        sa.Column("name", sa.Text(), nullable=True),
+        sa.Column("type", sa.Text(), nullable=True),
+        sa.Column("content", sa.Text(), nullable=True),
+        sa.Column("meta", apps.webui.internal.db.JSONField(), nullable=True),
+        sa.Column("valves", apps.webui.internal.db.JSONField(), nullable=True),
+        sa.Column("is_active", sa.Boolean(), nullable=True),
+        sa.Column("updated_at", sa.BigInteger(), nullable=True),
+        sa.Column("created_at", sa.BigInteger(), nullable=True),
+        sa.PrimaryKeyConstraint("id"),
     )
     )
-    op.create_table('memory',
-    sa.Column('id', sa.String(), nullable=False),
-    sa.Column('user_id', sa.String(), nullable=True),
-    sa.Column('content', sa.String(), nullable=True),
-    sa.Column('updated_at', sa.BigInteger(), nullable=True),
-    sa.Column('created_at', sa.BigInteger(), nullable=True),
-    sa.PrimaryKeyConstraint('id')
+    op.create_table(
+        "memory",
+        sa.Column("id", sa.String(), nullable=False),
+        sa.Column("user_id", sa.String(), nullable=True),
+        sa.Column("content", sa.String(), nullable=True),
+        sa.Column("updated_at", sa.BigInteger(), nullable=True),
+        sa.Column("created_at", sa.BigInteger(), nullable=True),
+        sa.PrimaryKeyConstraint("id"),
     )
     )
-    op.create_table('model',
-    sa.Column('id', sa.String(), nullable=False),
-    sa.Column('user_id', sa.String(), nullable=True),
-    sa.Column('base_model_id', sa.String(), nullable=True),
-    sa.Column('name', sa.String(), nullable=True),
-    sa.Column('params', apps.webui.internal.db.JSONField(), nullable=True),
-    sa.Column('meta', apps.webui.internal.db.JSONField(), nullable=True),
-    sa.Column('updated_at', sa.BigInteger(), nullable=True),
-    sa.Column('created_at', sa.BigInteger(), nullable=True),
-    sa.PrimaryKeyConstraint('id')
+    op.create_table(
+        "model",
+        sa.Column("id", sa.String(), nullable=False),
+        sa.Column("user_id", sa.String(), nullable=True),
+        sa.Column("base_model_id", sa.String(), nullable=True),
+        sa.Column("name", sa.String(), nullable=True),
+        sa.Column("params", apps.webui.internal.db.JSONField(), nullable=True),
+        sa.Column("meta", apps.webui.internal.db.JSONField(), nullable=True),
+        sa.Column("updated_at", sa.BigInteger(), nullable=True),
+        sa.Column("created_at", sa.BigInteger(), nullable=True),
+        sa.PrimaryKeyConstraint("id"),
     )
     )
-    op.create_table('prompt',
-    sa.Column('command', sa.String(), nullable=False),
-    sa.Column('user_id', sa.String(), nullable=True),
-    sa.Column('title', sa.String(), nullable=True),
-    sa.Column('content', sa.String(), nullable=True),
-    sa.Column('timestamp', sa.BigInteger(), nullable=True),
-    sa.PrimaryKeyConstraint('command')
+    op.create_table(
+        "prompt",
+        sa.Column("command", sa.String(), nullable=False),
+        sa.Column("user_id", sa.String(), nullable=True),
+        sa.Column("title", sa.String(), nullable=True),
+        sa.Column("content", sa.String(), nullable=True),
+        sa.Column("timestamp", sa.BigInteger(), nullable=True),
+        sa.PrimaryKeyConstraint("command"),
     )
     )
-    op.create_table('tag',
-    sa.Column('id', sa.String(), nullable=False),
-    sa.Column('name', sa.String(), nullable=True),
-    sa.Column('user_id', sa.String(), nullable=True),
-    sa.Column('data', sa.String(), nullable=True),
-    sa.PrimaryKeyConstraint('id')
+    op.create_table(
+        "tag",
+        sa.Column("id", sa.String(), nullable=False),
+        sa.Column("name", sa.String(), nullable=True),
+        sa.Column("user_id", sa.String(), nullable=True),
+        sa.Column("data", sa.String(), nullable=True),
+        sa.PrimaryKeyConstraint("id"),
     )
     )
-    op.create_table('tool',
-    sa.Column('id', sa.String(), nullable=False),
-    sa.Column('user_id', sa.String(), nullable=True),
-    sa.Column('name', sa.String(), nullable=True),
-    sa.Column('content', sa.String(), nullable=True),
-    sa.Column('specs', apps.webui.internal.db.JSONField(), nullable=True),
-    sa.Column('meta', apps.webui.internal.db.JSONField(), nullable=True),
-    sa.Column('valves', apps.webui.internal.db.JSONField(), nullable=True),
-    sa.Column('updated_at', sa.BigInteger(), nullable=True),
-    sa.Column('created_at', sa.BigInteger(), nullable=True),
-    sa.PrimaryKeyConstraint('id')
+    op.create_table(
+        "tool",
+        sa.Column("id", sa.String(), nullable=False),
+        sa.Column("user_id", sa.String(), nullable=True),
+        sa.Column("name", sa.String(), nullable=True),
+        sa.Column("content", sa.String(), nullable=True),
+        sa.Column("specs", apps.webui.internal.db.JSONField(), nullable=True),
+        sa.Column("meta", apps.webui.internal.db.JSONField(), nullable=True),
+        sa.Column("valves", apps.webui.internal.db.JSONField(), nullable=True),
+        sa.Column("updated_at", sa.BigInteger(), nullable=True),
+        sa.Column("created_at", sa.BigInteger(), nullable=True),
+        sa.PrimaryKeyConstraint("id"),
     )
     )
-    op.create_table('user',
-    sa.Column('id', sa.String(), nullable=False),
-    sa.Column('name', sa.String(), nullable=True),
-    sa.Column('email', sa.String(), nullable=True),
-    sa.Column('role', sa.String(), nullable=True),
-    sa.Column('profile_image_url', sa.String(), nullable=True),
-    sa.Column('last_active_at', sa.BigInteger(), nullable=True),
-    sa.Column('updated_at', sa.BigInteger(), nullable=True),
-    sa.Column('created_at', sa.BigInteger(), nullable=True),
-    sa.Column('api_key', sa.String(), nullable=True),
-    sa.Column('settings', apps.webui.internal.db.JSONField(), nullable=True),
-    sa.Column('info', apps.webui.internal.db.JSONField(), nullable=True),
-    sa.PrimaryKeyConstraint('id'),
-    sa.UniqueConstraint('api_key')
+    op.create_table(
+        "user",
+        sa.Column("id", sa.String(), nullable=False),
+        sa.Column("name", sa.String(), nullable=True),
+        sa.Column("email", sa.String(), nullable=True),
+        sa.Column("role", sa.String(), nullable=True),
+        sa.Column("profile_image_url", sa.String(), nullable=True),
+        sa.Column("last_active_at", sa.BigInteger(), nullable=True),
+        sa.Column("updated_at", sa.BigInteger(), nullable=True),
+        sa.Column("created_at", sa.BigInteger(), nullable=True),
+        sa.Column("api_key", sa.String(), nullable=True),
+        sa.Column("settings", apps.webui.internal.db.JSONField(), nullable=True),
+        sa.Column("info", apps.webui.internal.db.JSONField(), nullable=True),
+        sa.PrimaryKeyConstraint("id"),
+        sa.UniqueConstraint("api_key"),
     )
     )
     # ### end Alembic commands ###
     # ### end Alembic commands ###
 
 
 
 
 def downgrade() -> None:
 def downgrade() -> None:
     # ### commands auto generated by Alembic - please adjust! ###
     # ### commands auto generated by Alembic - please adjust! ###
-    op.drop_table('user')
-    op.drop_table('tool')
-    op.drop_table('tag')
-    op.drop_table('prompt')
-    op.drop_table('model')
-    op.drop_table('memory')
-    op.drop_table('function')
-    op.drop_table('file')
-    op.drop_table('document')
-    op.drop_table('chatidtag')
-    op.drop_table('chat')
-    op.drop_table('auth')
+    op.drop_table("user")
+    op.drop_table("tool")
+    op.drop_table("tag")
+    op.drop_table("prompt")
+    op.drop_table("model")
+    op.drop_table("memory")
+    op.drop_table("function")
+    op.drop_table("file")
+    op.drop_table("document")
+    op.drop_table("chatidtag")
+    op.drop_table("chat")
+    op.drop_table("auth")
     # ### end Alembic commands ###
     # ### end Alembic commands ###

+ 1 - 0
backend/test/util/abstract_integration_test.py

@@ -91,6 +91,7 @@ class AbstractPostgresTest(AbstractIntegrationTest):
             while retries > 0:
             while retries > 0:
                 try:
                 try:
                     from config import BACKEND_DIR
                     from config import BACKEND_DIR
+
                     db = create_engine(database_url, pool_pre_ping=True)
                     db = create_engine(database_url, pool_pre_ping=True)
                     db = db.connect()
                     db = db.connect()
                     log.info("postgres is ready!")
                     log.info("postgres is ready!")