Ver Fonte

fix: multi-user tags issue

Timothy J. Baek há 6 meses atrás
pai
commit
8ae605ec4b

+ 13 - 5
backend/open_webui/apps/webui/models/tags.py

@@ -8,7 +8,7 @@ from open_webui.apps.webui.internal.db import Base, get_db
 
 
 from open_webui.env import SRC_LOG_LEVELS
 from open_webui.env import SRC_LOG_LEVELS
 from pydantic import BaseModel, ConfigDict
 from pydantic import BaseModel, ConfigDict
-from sqlalchemy import BigInteger, Column, String, JSON
+from sqlalchemy import BigInteger, Column, String, JSON, PrimaryKeyConstraint
 
 
 log = logging.getLogger(__name__)
 log = logging.getLogger(__name__)
 log.setLevel(SRC_LOG_LEVELS["MODELS"])
 log.setLevel(SRC_LOG_LEVELS["MODELS"])
@@ -19,11 +19,14 @@ log.setLevel(SRC_LOG_LEVELS["MODELS"])
 ####################
 ####################
 class Tag(Base):
 class Tag(Base):
     __tablename__ = "tag"
     __tablename__ = "tag"
-    id = Column(String, primary_key=True)
+    id = Column(String)
     name = Column(String)
     name = Column(String)
     user_id = Column(String)
     user_id = Column(String)
     meta = Column(JSON, nullable=True)
     meta = Column(JSON, nullable=True)
 
 
+    # Unique constraint ensuring (id, user_id) is unique, not just the `id` column
+    __table_args__ = (PrimaryKeyConstraint("id", "user_id", name="pk_id_user_id"),)
+
 
 
 class TagModel(BaseModel):
 class TagModel(BaseModel):
     id: str
     id: str
@@ -57,7 +60,8 @@ class TagTable:
                     return TagModel.model_validate(result)
                     return TagModel.model_validate(result)
                 else:
                 else:
                     return None
                     return None
-            except Exception:
+            except Exception as e:
+                print(e)
                 return None
                 return None
 
 
     def get_tag_by_name_and_user_id(
     def get_tag_by_name_and_user_id(
@@ -78,11 +82,15 @@ class TagTable:
                 for tag in (db.query(Tag).filter_by(user_id=user_id).all())
                 for tag in (db.query(Tag).filter_by(user_id=user_id).all())
             ]
             ]
 
 
-    def get_tags_by_ids(self, ids: list[str]) -> list[TagModel]:
+    def get_tags_by_ids_and_user_id(
+        self, ids: list[str], user_id: str
+    ) -> list[TagModel]:
         with get_db() as db:
         with get_db() as db:
             return [
             return [
                 TagModel.model_validate(tag)
                 TagModel.model_validate(tag)
-                for tag in (db.query(Tag).filter(Tag.id.in_(ids)).all())
+                for tag in (
+                    db.query(Tag).filter(Tag.id.in_(ids), Tag.user_id == user_id).all()
+                )
             ]
             ]
 
 
     def delete_tag_by_name_and_user_id(self, name: str, user_id: str) -> bool:
     def delete_tag_by_name_and_user_id(self, name: str, user_id: str) -> bool:

+ 4 - 4
backend/open_webui/apps/webui/routers/chats.py

@@ -465,7 +465,7 @@ async def get_chat_tags_by_id(id: str, user=Depends(get_verified_user)):
     chat = Chats.get_chat_by_id_and_user_id(id, user.id)
     chat = Chats.get_chat_by_id_and_user_id(id, user.id)
     if chat:
     if chat:
         tags = chat.meta.get("tags", [])
         tags = chat.meta.get("tags", [])
-        return Tags.get_tags_by_ids(tags)
+        return Tags.get_tags_by_ids_and_user_id(tags, user.id)
     else:
     else:
         raise HTTPException(
         raise HTTPException(
             status_code=status.HTTP_401_UNAUTHORIZED, detail=ERROR_MESSAGES.NOT_FOUND
             status_code=status.HTTP_401_UNAUTHORIZED, detail=ERROR_MESSAGES.NOT_FOUND
@@ -494,7 +494,7 @@ async def add_tag_by_id_and_tag_name(
 
 
         chat = Chats.get_chat_by_id_and_user_id(id, user.id)
         chat = Chats.get_chat_by_id_and_user_id(id, user.id)
         tags = chat.meta.get("tags", [])
         tags = chat.meta.get("tags", [])
-        return Tags.get_tags_by_ids(tags)
+        return Tags.get_tags_by_ids_and_user_id(tags, user.id)
     else:
     else:
         raise HTTPException(
         raise HTTPException(
             status_code=status.HTTP_401_UNAUTHORIZED, detail=ERROR_MESSAGES.DEFAULT()
             status_code=status.HTTP_401_UNAUTHORIZED, detail=ERROR_MESSAGES.DEFAULT()
@@ -519,7 +519,7 @@ async def delete_tag_by_id_and_tag_name(
 
 
         chat = Chats.get_chat_by_id_and_user_id(id, user.id)
         chat = Chats.get_chat_by_id_and_user_id(id, user.id)
         tags = chat.meta.get("tags", [])
         tags = chat.meta.get("tags", [])
-        return Tags.get_tags_by_ids(tags)
+        return Tags.get_tags_by_ids_and_user_id(tags, user.id)
     else:
     else:
         raise HTTPException(
         raise HTTPException(
             status_code=status.HTTP_401_UNAUTHORIZED, detail=ERROR_MESSAGES.NOT_FOUND
             status_code=status.HTTP_401_UNAUTHORIZED, detail=ERROR_MESSAGES.NOT_FOUND
@@ -543,7 +543,7 @@ async def delete_all_chat_tags_by_id(id: str, user=Depends(get_verified_user)):
 
 
         chat = Chats.get_chat_by_id_and_user_id(id, user.id)
         chat = Chats.get_chat_by_id_and_user_id(id, user.id)
         tags = chat.meta.get("tags", [])
         tags = chat.meta.get("tags", [])
-        return Tags.get_tags_by_ids(tags)
+        return Tags.get_tags_by_ids_and_user_id(tags, user.id)
     else:
     else:
         raise HTTPException(
         raise HTTPException(
             status_code=status.HTTP_401_UNAUTHORIZED, detail=ERROR_MESSAGES.NOT_FOUND
             status_code=status.HTTP_401_UNAUTHORIZED, detail=ERROR_MESSAGES.NOT_FOUND

+ 67 - 0
backend/open_webui/migrations/versions/3ab32c4b8f59_update_tags.py

@@ -0,0 +1,67 @@
+"""Update tags
+
+Revision ID: 3ab32c4b8f59
+Revises: 1af9b942657b
+Create Date: 2024-10-09 21:02:35.241684
+
+"""
+
+from alembic import op
+import sqlalchemy as sa
+from sqlalchemy.sql import table, select, update, column
+from sqlalchemy.engine.reflection import Inspector
+
+import json
+
+revision = "3ab32c4b8f59"
+down_revision = "1af9b942657b"
+branch_labels = None
+depends_on = None
+
+
+def upgrade():
+    conn = op.get_bind()
+    inspector = Inspector.from_engine(conn)
+
+    # Inspecting the 'tag' table constraints and structure
+    existing_pk = inspector.get_pk_constraint("tag")
+    unique_constraints = inspector.get_unique_constraints("tag")
+    existing_indexes = inspector.get_indexes("tag")
+
+    print(existing_pk, unique_constraints)
+
+    with op.batch_alter_table("tag", schema=None) as batch_op:
+        # Drop unique constraints that could conflict with new primary key
+        for constraint in unique_constraints:
+            if constraint["name"] == "uq_id_user_id":
+                batch_op.drop_constraint(constraint["name"], type_="unique")
+
+        for index in existing_indexes:
+            if index["unique"]:
+                # Drop the unique index
+                batch_op.drop_index(index["name"])
+
+        # Drop existing primary key constraint if it exists
+        if existing_pk and existing_pk.get("constrained_columns"):
+            batch_op.drop_constraint(existing_pk["name"], type_="primary")
+
+        # Immediately after dropping the old primary key, create the new one
+        batch_op.create_primary_key("pk_id_user_id", ["id", "user_id"])
+
+
+def downgrade():
+    conn = op.get_bind()
+    inspector = Inspector.from_engine(conn)
+
+    current_pk = inspector.get_pk_constraint("tag")
+
+    with op.batch_alter_table("tag", schema=None) as batch_op:
+        # Drop the current primary key first, if it matches the one we know we added in upgrade
+        if current_pk and "pk_id_user_id" == current_pk.get("name"):
+            batch_op.drop_constraint("pk_id_user_id", type_="primary")
+
+        # Restore the original primary key
+        batch_op.create_primary_key("pk_id", ["id"])
+
+        # Since primary key on just 'id' is restored, we now add back any unique constraints if necessary
+        batch_op.create_unique_constraint("uq_id_user_id", ["id", "user_id"])