Просмотр исходного кода

feat(sqlalchemy): format backend

Jonathan Rohde 10 месяцев назад
Родитель
Сommit
d88bd51e3c

+ 9 - 11
backend/apps/webui/models/chats.py

@@ -85,9 +85,7 @@ class ChatTable:
                 "id": id,
                 "user_id": user_id,
                 "title": (
-                    form_data.chat["title"]
-                    if "title" in form_data.chat
-                    else "New Chat"
+                    form_data.chat["title"] if "title" in form_data.chat else "New Chat"
                 ),
                 "chat": json.dumps(form_data.chat),
                 "created_at": int(time.time()),
@@ -197,14 +195,14 @@ class ChatTable:
     def get_archived_chat_list_by_user_id(
         self, user_id: str, skip: int = 0, limit: int = 50
     ) -> List[ChatModel]:
-            all_chats = (
-                Session.query(Chat)
-                .filter_by(user_id=user_id, archived=True)
-                .order_by(Chat.updated_at.desc())
-                # .limit(limit).offset(skip)
-                .all()
-            )
-            return [ChatModel.model_validate(chat) for chat in all_chats]
+        all_chats = (
+            Session.query(Chat)
+            .filter_by(user_id=user_id, archived=True)
+            .order_by(Chat.updated_at.desc())
+            # .limit(limit).offset(skip)
+            .all()
+        )
+        return [ChatModel.model_validate(chat) for chat in all_chats]
 
     def get_chat_list_by_user_id(
         self,

+ 1 - 3
backend/apps/webui/models/memories.py

@@ -115,9 +115,7 @@ class MemoriesTable:
         except:
             return False
 
-    def delete_memory_by_id_and_user_id(
-        self, id: str, user_id: str
-    ) -> bool:
+    def delete_memory_by_id_and_user_id(self, id: str, user_id: str) -> bool:
         try:
             Session.query(Memory).filter_by(id=id, user_id=user_id).delete()
             return True

+ 3 - 1
backend/apps/webui/models/models.py

@@ -140,7 +140,9 @@ class ModelsTable:
             return None
 
     def get_all_models(self) -> List[ModelModel]:
-        return [ModelModel.model_validate(model) for model in Session.query(Model).all()]
+        return [
+            ModelModel.model_validate(model) for model in Session.query(Model).all()
+        ]
 
     def get_model_by_id(self, id: str) -> Optional[ModelModel]:
         try:

+ 2 - 6
backend/apps/webui/models/tags.py

@@ -207,9 +207,7 @@ class TagTable:
             log.debug(f"res: {res}")
             Session.commit()
 
-            tag_count = self.count_chat_ids_by_tag_name_and_user_id(
-                tag_name, user_id
-            )
+            tag_count = self.count_chat_ids_by_tag_name_and_user_id(tag_name, user_id)
             if tag_count == 0:
                 # Remove tag item from Tag col as well
                 Session.query(Tag).filter_by(name=tag_name, user_id=user_id).delete()
@@ -230,9 +228,7 @@ class TagTable:
             log.debug(f"res: {res}")
             Session.commit()
 
-            tag_count = self.count_chat_ids_by_tag_name_and_user_id(
-                tag_name, user_id
-            )
+            tag_count = self.count_chat_ids_by_tag_name_and_user_id(tag_name, user_id)
             if tag_count == 0:
                 # Remove tag item from Tag col as well
                 Session.query(Tag).filter_by(name=tag_name, user_id=user_id).delete()

+ 1 - 0
backend/main.py

@@ -793,6 +793,7 @@ app.add_middleware(
     allow_headers=["*"],
 )
 
+
 @app.middleware("http")
 async def commit_session_after_request(request: Request, call_next):
     response = await call_next(request)

+ 134 - 121
backend/migrations/versions/7e5b5dc7342b_init.py

@@ -5,6 +5,7 @@ Revises:
 Create Date: 2024-06-24 13:15:33.808998
 
 """
+
 from typing import Sequence, Union
 
 from alembic import op
@@ -13,7 +14,7 @@ import apps.webui.internal.db
 from migrations.util import get_existing_tables
 
 # revision identifiers, used by Alembic.
-revision: str = '7e5b5dc7342b'
+revision: str = "7e5b5dc7342b"
 down_revision: Union[str, None] = None
 branch_labels: Union[str, Sequence[str], None] = None
 depends_on: Union[str, Sequence[str], None] = None
@@ -24,163 +25,175 @@ def upgrade() -> None:
 
     # ### commands auto generated by Alembic - please adjust! ###
     if "auth" not in existing_tables:
-        op.create_table('auth',
-        sa.Column('id', sa.String(), nullable=False),
-        sa.Column('email', sa.String(), nullable=True),
-        sa.Column('password', sa.Text(), nullable=True),
-        sa.Column('active', sa.Boolean(), nullable=True),
-        sa.PrimaryKeyConstraint('id')
+        op.create_table(
+            "auth",
+            sa.Column("id", sa.String(), nullable=False),
+            sa.Column("email", sa.String(), nullable=True),
+            sa.Column("password", sa.Text(), nullable=True),
+            sa.Column("active", sa.Boolean(), nullable=True),
+            sa.PrimaryKeyConstraint("id"),
         )
 
     if "chat" not in existing_tables:
-        op.create_table('chat',
-        sa.Column('id', sa.String(), nullable=False),
-        sa.Column('user_id', sa.String(), nullable=True),
-        sa.Column('title', sa.Text(), nullable=True),
-        sa.Column('chat', sa.Text(), nullable=True),
-        sa.Column('created_at', sa.BigInteger(), nullable=True),
-        sa.Column('updated_at', sa.BigInteger(), nullable=True),
-        sa.Column('share_id', sa.Text(), nullable=True),
-        sa.Column('archived', sa.Boolean(), nullable=True),
-        sa.PrimaryKeyConstraint('id'),
-        sa.UniqueConstraint('share_id')
+        op.create_table(
+            "chat",
+            sa.Column("id", sa.String(), nullable=False),
+            sa.Column("user_id", sa.String(), nullable=True),
+            sa.Column("title", sa.Text(), nullable=True),
+            sa.Column("chat", sa.Text(), nullable=True),
+            sa.Column("created_at", sa.BigInteger(), nullable=True),
+            sa.Column("updated_at", sa.BigInteger(), nullable=True),
+            sa.Column("share_id", sa.Text(), nullable=True),
+            sa.Column("archived", sa.Boolean(), nullable=True),
+            sa.PrimaryKeyConstraint("id"),
+            sa.UniqueConstraint("share_id"),
         )
 
     if "chatidtag" not in existing_tables:
-        op.create_table('chatidtag',
-        sa.Column('id', sa.String(), nullable=False),
-        sa.Column('tag_name', sa.String(), nullable=True),
-        sa.Column('chat_id', sa.String(), nullable=True),
-        sa.Column('user_id', sa.String(), nullable=True),
-        sa.Column('timestamp', sa.BigInteger(), nullable=True),
-        sa.PrimaryKeyConstraint('id')
+        op.create_table(
+            "chatidtag",
+            sa.Column("id", sa.String(), nullable=False),
+            sa.Column("tag_name", sa.String(), nullable=True),
+            sa.Column("chat_id", sa.String(), nullable=True),
+            sa.Column("user_id", sa.String(), nullable=True),
+            sa.Column("timestamp", sa.BigInteger(), nullable=True),
+            sa.PrimaryKeyConstraint("id"),
         )
 
     if "document" not in existing_tables:
-        op.create_table('document',
-        sa.Column('collection_name', sa.String(), nullable=False),
-        sa.Column('name', sa.String(), nullable=True),
-        sa.Column('title', sa.Text(), nullable=True),
-        sa.Column('filename', sa.Text(), nullable=True),
-        sa.Column('content', sa.Text(), nullable=True),
-        sa.Column('user_id', sa.String(), nullable=True),
-        sa.Column('timestamp', sa.BigInteger(), nullable=True),
-        sa.PrimaryKeyConstraint('collection_name'),
-        sa.UniqueConstraint('name')
+        op.create_table(
+            "document",
+            sa.Column("collection_name", sa.String(), nullable=False),
+            sa.Column("name", sa.String(), nullable=True),
+            sa.Column("title", sa.Text(), nullable=True),
+            sa.Column("filename", sa.Text(), nullable=True),
+            sa.Column("content", sa.Text(), nullable=True),
+            sa.Column("user_id", sa.String(), nullable=True),
+            sa.Column("timestamp", sa.BigInteger(), nullable=True),
+            sa.PrimaryKeyConstraint("collection_name"),
+            sa.UniqueConstraint("name"),
         )
 
     if "file" not in existing_tables:
-        op.create_table('file',
-        sa.Column('id', sa.String(), nullable=False),
-        sa.Column('user_id', sa.String(), nullable=True),
-        sa.Column('filename', sa.Text(), nullable=True),
-        sa.Column('meta', apps.webui.internal.db.JSONField(), nullable=True),
-        sa.Column('created_at', sa.BigInteger(), nullable=True),
-        sa.PrimaryKeyConstraint('id')
+        op.create_table(
+            "file",
+            sa.Column("id", sa.String(), nullable=False),
+            sa.Column("user_id", sa.String(), nullable=True),
+            sa.Column("filename", sa.Text(), nullable=True),
+            sa.Column("meta", apps.webui.internal.db.JSONField(), nullable=True),
+            sa.Column("created_at", sa.BigInteger(), nullable=True),
+            sa.PrimaryKeyConstraint("id"),
         )
 
     if "function" not in existing_tables:
-        op.create_table('function',
-        sa.Column('id', sa.String(), nullable=False),
-        sa.Column('user_id', sa.String(), nullable=True),
-        sa.Column('name', sa.Text(), nullable=True),
-        sa.Column('type', sa.Text(), nullable=True),
-        sa.Column('content', sa.Text(), nullable=True),
-        sa.Column('meta', apps.webui.internal.db.JSONField(), nullable=True),
-        sa.Column('valves', apps.webui.internal.db.JSONField(), nullable=True),
-        sa.Column('is_active', sa.Boolean(), nullable=True),
-        sa.Column('updated_at', sa.BigInteger(), nullable=True),
-        sa.Column('created_at', sa.BigInteger(), nullable=True),
-        sa.PrimaryKeyConstraint('id')
+        op.create_table(
+            "function",
+            sa.Column("id", sa.String(), nullable=False),
+            sa.Column("user_id", sa.String(), nullable=True),
+            sa.Column("name", sa.Text(), nullable=True),
+            sa.Column("type", sa.Text(), nullable=True),
+            sa.Column("content", sa.Text(), nullable=True),
+            sa.Column("meta", apps.webui.internal.db.JSONField(), nullable=True),
+            sa.Column("valves", apps.webui.internal.db.JSONField(), nullable=True),
+            sa.Column("is_active", sa.Boolean(), nullable=True),
+            sa.Column("updated_at", sa.BigInteger(), nullable=True),
+            sa.Column("created_at", sa.BigInteger(), nullable=True),
+            sa.PrimaryKeyConstraint("id"),
         )
 
     if "memory" not in existing_tables:
-        op.create_table('memory',
-        sa.Column('id', sa.String(), nullable=False),
-        sa.Column('user_id', sa.String(), nullable=True),
-        sa.Column('content', sa.Text(), nullable=True),
-        sa.Column('updated_at', sa.BigInteger(), nullable=True),
-        sa.Column('created_at', sa.BigInteger(), nullable=True),
-        sa.PrimaryKeyConstraint('id')
+        op.create_table(
+            "memory",
+            sa.Column("id", sa.String(), nullable=False),
+            sa.Column("user_id", sa.String(), nullable=True),
+            sa.Column("content", sa.Text(), nullable=True),
+            sa.Column("updated_at", sa.BigInteger(), nullable=True),
+            sa.Column("created_at", sa.BigInteger(), nullable=True),
+            sa.PrimaryKeyConstraint("id"),
         )
 
     if "model" not in existing_tables:
-        op.create_table('model',
-        sa.Column('id', sa.Text(), nullable=False),
-        sa.Column('user_id', sa.Text(), nullable=True),
-        sa.Column('base_model_id', sa.Text(), nullable=True),
-        sa.Column('name', sa.Text(), nullable=True),
-        sa.Column('params', apps.webui.internal.db.JSONField(), nullable=True),
-        sa.Column('meta', apps.webui.internal.db.JSONField(), nullable=True),
-        sa.Column('updated_at', sa.BigInteger(), nullable=True),
-        sa.Column('created_at', sa.BigInteger(), nullable=True),
-        sa.PrimaryKeyConstraint('id')
+        op.create_table(
+            "model",
+            sa.Column("id", sa.Text(), nullable=False),
+            sa.Column("user_id", sa.Text(), nullable=True),
+            sa.Column("base_model_id", sa.Text(), nullable=True),
+            sa.Column("name", sa.Text(), nullable=True),
+            sa.Column("params", apps.webui.internal.db.JSONField(), nullable=True),
+            sa.Column("meta", apps.webui.internal.db.JSONField(), nullable=True),
+            sa.Column("updated_at", sa.BigInteger(), nullable=True),
+            sa.Column("created_at", sa.BigInteger(), nullable=True),
+            sa.PrimaryKeyConstraint("id"),
         )
 
     if "prompt" not in existing_tables:
-        op.create_table('prompt',
-        sa.Column('command', sa.String(), nullable=False),
-        sa.Column('user_id', sa.String(), nullable=True),
-        sa.Column('title', sa.Text(), nullable=True),
-        sa.Column('content', sa.Text(), nullable=True),
-        sa.Column('timestamp', sa.BigInteger(), nullable=True),
-        sa.PrimaryKeyConstraint('command')
+        op.create_table(
+            "prompt",
+            sa.Column("command", sa.String(), nullable=False),
+            sa.Column("user_id", sa.String(), nullable=True),
+            sa.Column("title", sa.Text(), nullable=True),
+            sa.Column("content", sa.Text(), nullable=True),
+            sa.Column("timestamp", sa.BigInteger(), nullable=True),
+            sa.PrimaryKeyConstraint("command"),
         )
 
     if "tag" not in existing_tables:
-        op.create_table('tag',
-        sa.Column('id', sa.String(), nullable=False),
-        sa.Column('name', sa.String(), nullable=True),
-        sa.Column('user_id', sa.String(), nullable=True),
-        sa.Column('data', sa.Text(), nullable=True),
-        sa.PrimaryKeyConstraint('id')
+        op.create_table(
+            "tag",
+            sa.Column("id", sa.String(), nullable=False),
+            sa.Column("name", sa.String(), nullable=True),
+            sa.Column("user_id", sa.String(), nullable=True),
+            sa.Column("data", sa.Text(), nullable=True),
+            sa.PrimaryKeyConstraint("id"),
         )
 
     if "tool" not in existing_tables:
-        op.create_table('tool',
-        sa.Column('id', sa.String(), nullable=False),
-        sa.Column('user_id', sa.String(), nullable=True),
-        sa.Column('name', sa.Text(), nullable=True),
-        sa.Column('content', sa.Text(), nullable=True),
-        sa.Column('specs', apps.webui.internal.db.JSONField(), nullable=True),
-        sa.Column('meta', apps.webui.internal.db.JSONField(), nullable=True),
-        sa.Column('valves', apps.webui.internal.db.JSONField(), nullable=True),
-        sa.Column('updated_at', sa.BigInteger(), nullable=True),
-        sa.Column('created_at', sa.BigInteger(), nullable=True),
-        sa.PrimaryKeyConstraint('id')
+        op.create_table(
+            "tool",
+            sa.Column("id", sa.String(), nullable=False),
+            sa.Column("user_id", sa.String(), nullable=True),
+            sa.Column("name", sa.Text(), nullable=True),
+            sa.Column("content", sa.Text(), nullable=True),
+            sa.Column("specs", apps.webui.internal.db.JSONField(), nullable=True),
+            sa.Column("meta", apps.webui.internal.db.JSONField(), nullable=True),
+            sa.Column("valves", apps.webui.internal.db.JSONField(), nullable=True),
+            sa.Column("updated_at", sa.BigInteger(), nullable=True),
+            sa.Column("created_at", sa.BigInteger(), nullable=True),
+            sa.PrimaryKeyConstraint("id"),
         )
 
     if "user" not in existing_tables:
-        op.create_table('user',
-        sa.Column('id', sa.String(), nullable=False),
-        sa.Column('name', sa.String(), nullable=True),
-        sa.Column('email', sa.String(), nullable=True),
-        sa.Column('role', sa.String(), nullable=True),
-        sa.Column('profile_image_url', sa.Text(), nullable=True),
-        sa.Column('last_active_at', sa.BigInteger(), nullable=True),
-        sa.Column('updated_at', sa.BigInteger(), nullable=True),
-        sa.Column('created_at', sa.BigInteger(), nullable=True),
-        sa.Column('api_key', sa.String(), nullable=True),
-        sa.Column('settings', apps.webui.internal.db.JSONField(), nullable=True),
-        sa.Column('info', apps.webui.internal.db.JSONField(), nullable=True),
-        sa.PrimaryKeyConstraint('id'),
-        sa.UniqueConstraint('api_key')
+        op.create_table(
+            "user",
+            sa.Column("id", sa.String(), nullable=False),
+            sa.Column("name", sa.String(), nullable=True),
+            sa.Column("email", sa.String(), nullable=True),
+            sa.Column("role", sa.String(), nullable=True),
+            sa.Column("profile_image_url", sa.Text(), nullable=True),
+            sa.Column("last_active_at", sa.BigInteger(), nullable=True),
+            sa.Column("updated_at", sa.BigInteger(), nullable=True),
+            sa.Column("created_at", sa.BigInteger(), nullable=True),
+            sa.Column("api_key", sa.String(), nullable=True),
+            sa.Column("settings", apps.webui.internal.db.JSONField(), nullable=True),
+            sa.Column("info", apps.webui.internal.db.JSONField(), nullable=True),
+            sa.PrimaryKeyConstraint("id"),
+            sa.UniqueConstraint("api_key"),
         )
     # ### end Alembic commands ###
 
 
 def downgrade() -> None:
     # ### commands auto generated by Alembic - please adjust! ###
-    op.drop_table('user')
-    op.drop_table('tool')
-    op.drop_table('tag')
-    op.drop_table('prompt')
-    op.drop_table('model')
-    op.drop_table('memory')
-    op.drop_table('function')
-    op.drop_table('file')
-    op.drop_table('document')
-    op.drop_table('chatidtag')
-    op.drop_table('chat')
-    op.drop_table('auth')
+    op.drop_table("user")
+    op.drop_table("tool")
+    op.drop_table("tag")
+    op.drop_table("prompt")
+    op.drop_table("model")
+    op.drop_table("memory")
+    op.drop_table("function")
+    op.drop_table("file")
+    op.drop_table("document")
+    op.drop_table("chatidtag")
+    op.drop_table("chat")
+    op.drop_table("auth")
     # ### end Alembic commands ###

+ 1 - 0
backend/test/apps/webui/routers/test_chats.py

@@ -91,6 +91,7 @@ class TestChats(AbstractPostgresTest):
     def test_get_user_archived_chats(self):
         self.chats.archive_all_chats_by_user_id("2")
         from apps.webui.internal.db import Session
+
         Session.commit()
         with mock_webui_user(id="2"):
             response = self.fast_api_client.get(self.create_url("/all/archived"))

+ 2 - 0
backend/test/util/abstract_integration_test.py

@@ -110,6 +110,7 @@ class AbstractPostgresTest(AbstractIntegrationTest):
 
     def _check_db_connection(self):
         from apps.webui.internal.db import Session
+
         retries = 10
         while retries > 0:
             try:
@@ -133,6 +134,7 @@ class AbstractPostgresTest(AbstractIntegrationTest):
 
     def teardown_method(self):
         from apps.webui.internal.db import Session
+
         # rollback everything not yet committed
         Session.commit()