浏览代码

refac: tags

Timothy J. Baek 6 月之前
父节点
当前提交
acb5dcf30a

+ 164 - 2
backend/open_webui/apps/webui/models/chats.py

@@ -4,10 +4,13 @@ import uuid
 from typing import Optional
 
 from open_webui.apps.webui.internal.db import Base, get_db
+from open_webui.apps.webui.models.tags import TagModel, Tag, Tags
+
+
 from pydantic import BaseModel, ConfigDict
 from sqlalchemy import BigInteger, Boolean, Column, String, Text, JSON
-from sqlalchemy import or_, func, select
-
+from sqlalchemy import or_, func, select, and_, text
+from sqlalchemy.sql import exists
 
 ####################
 # Chat DB Schema
@@ -27,6 +30,9 @@ class Chat(Base):
 
     share_id = Column(Text, unique=True, nullable=True)
     archived = Column(Boolean, default=False)
+    pinned = Column(Boolean, default=False, nullable=True)
+
+    meta = Column(JSON, server_default="{}")
 
 
 class ChatModel(BaseModel):
@@ -42,6 +48,9 @@ class ChatModel(BaseModel):
 
     share_id: Optional[str] = None
     archived: bool = False
+    pinned: Optional[bool] = False
+
+    meta: dict = {}
 
 
 ####################
@@ -66,6 +75,8 @@ class ChatResponse(BaseModel):
     created_at: int  # timestamp in epoch
     share_id: Optional[str] = None  # id of the chat to be shared
     archived: bool
+    pinned: Optional[bool] = False
+    meta: dict = {}
 
 
 class ChatTitleIdResponse(BaseModel):
@@ -184,11 +195,24 @@ class ChatTable:
         except Exception:
             return None
 
+    def toggle_chat_pinned_by_id(self, id: str) -> Optional[ChatModel]:
+        try:
+            with get_db() as db:
+                chat = db.get(Chat, id)
+                chat.pinned = not chat.pinned
+                chat.updated_at = int(time.time())
+                db.commit()
+                db.refresh(chat)
+                return ChatModel.model_validate(chat)
+        except Exception:
+            return None
+
     def toggle_chat_archive_by_id(self, id: str) -> Optional[ChatModel]:
         try:
             with get_db() as db:
                 chat = db.get(Chat, id)
                 chat.archived = not chat.archived
+                chat.updated_at = int(time.time())
                 db.commit()
                 db.refresh(chat)
                 return ChatModel.model_validate(chat)
@@ -330,6 +354,15 @@ class ChatTable:
             )
             return [ChatModel.model_validate(chat) for chat in all_chats]
 
+    def get_pinned_chats_by_user_id(self, user_id: str) -> list[ChatModel]:
+        with get_db() as db:
+            all_chats = (
+                db.query(Chat)
+                .filter_by(user_id=user_id, pinned=True)
+                .order_by(Chat.updated_at.desc())
+            )
+            return [ChatModel.model_validate(chat) for chat in all_chats]
+
     def get_archived_chats_by_user_id(self, user_id: str) -> list[ChatModel]:
         with get_db() as db:
             all_chats = (
@@ -383,6 +416,135 @@ class ChatTable:
         paginated_chats = filtered_chats[skip : skip + limit]
         return [ChatModel.model_validate(chat) for chat in paginated_chats]
 
+    def get_chat_tags_by_id_and_user_id(self, id: str, user_id: str) -> list[TagModel]:
+        with get_db() as db:
+            chat = db.get(Chat, id)
+            tags = chat.meta.get("tags", [])
+            return [Tags.get_tag_by_name_and_user_id(tag, user_id) for tag in tags]
+
+    def get_chat_list_by_user_id_and_tag_name(
+        self, user_id: str, tag_name: str, skip: int = 0, limit: int = 50
+    ) -> list[ChatModel]:
+        with get_db() as db:
+            query = db.query(Chat).filter_by(user_id=user_id)
+            tag_id = tag_name.replace(" ", "_").lower()
+
+            print(db.bind.dialect.name)
+            if db.bind.dialect.name == "sqlite":
+                # SQLite JSON1 querying for tags within the meta JSON field
+                query = query.filter(
+                    text(
+                        f"EXISTS (SELECT 1 FROM json_each(Chat.meta, '$.tags') WHERE json_each.value = :tag_id)"
+                    )
+                ).params(tag_id=tag_id)
+            elif db.bind.dialect.name == "postgresql":
+                # PostgreSQL JSON query for tags within the meta JSON field (for `json` type)
+                query = query.filter(
+                    text(
+                        "EXISTS (SELECT 1 FROM json_array_elements_text(Chat.meta->'tags') elem WHERE elem = :tag_id)"
+                    )
+                ).params(tag_id=tag_id)
+            else:
+                raise NotImplementedError(
+                    f"Unsupported dialect: {db.bind.dialect.name}"
+                )
+
+            all_chats = query.all()
+            print("all_chats", all_chats)
+            return [ChatModel.model_validate(chat) for chat in all_chats]
+
+    def add_chat_tag_by_id_and_user_id_and_tag_name(
+        self, id: str, user_id: str, tag_name: str
+    ) -> Optional[ChatModel]:
+        tag = Tags.get_tag_by_name_and_user_id(tag_name, user_id)
+        if tag is None:
+            tag = Tags.insert_new_tag(tag_name, user_id)
+        try:
+            with get_db() as db:
+                chat = db.get(Chat, id)
+
+                tag_id = tag.id
+                if tag_id not in chat.meta.get("tags", []):
+                    chat.meta = {
+                        **chat.meta,
+                        "tags": chat.meta.get("tags", []) + [tag_id],
+                    }
+
+                db.commit()
+                db.refresh(chat)
+                return ChatModel.model_validate(chat)
+        except Exception:
+            return None
+
+    def count_chats_by_tag_name_and_user_id(self, tag_name: str, user_id: str) -> int:
+        with get_db() as db:  # Assuming `get_db()` returns a session object
+            query = db.query(Chat).filter_by(user_id=user_id)
+
+            # Normalize the tag_name for consistency
+            tag_id = tag_name.replace(" ", "_").lower()
+
+            if db.bind.dialect.name == "sqlite":
+                # SQLite JSON1 support for querying the tags inside the `meta` JSON field
+                query = query.filter(
+                    text(
+                        f"EXISTS (SELECT 1 FROM json_each(Chat.meta, '$.tags') WHERE json_each.value = :tag_id)"
+                    )
+                ).params(tag_id=tag_id)
+
+            elif db.bind.dialect.name == "postgresql":
+                # PostgreSQL JSONB support for querying the tags inside the `meta` JSON field
+                query = query.filter(
+                    text(
+                        "EXISTS (SELECT 1 FROM json_array_elements_text(Chat.meta->'tags') elem WHERE elem = :tag_id)"
+                    )
+                ).params(tag_id=tag_id)
+
+            else:
+                raise NotImplementedError(
+                    f"Unsupported dialect: {db.bind.dialect.name}"
+                )
+
+            # Get the count of matching records
+            count = query.count()
+
+            # Debugging output for inspection
+            print(f"Count of chats for tag '{tag_name}':", count)
+
+            return count
+
+    def delete_tag_by_id_and_user_id_and_tag_name(
+        self, id: str, user_id: str, tag_name: str
+    ) -> bool:
+        try:
+            with get_db() as db:
+                chat = db.get(Chat, id)
+                tags = chat.meta.get("tags", [])
+                tag_id = tag_name.replace(" ", "_").lower()
+
+                tags = [tag for tag in tags if tag != tag_id]
+                chat.meta = {
+                    **chat.meta,
+                    "tags": tags,
+                }
+                db.commit()
+                return True
+        except Exception:
+            return False
+
+    def delete_all_tags_by_id_and_user_id(self, id: str, user_id: str) -> bool:
+        try:
+            with get_db() as db:
+                chat = db.get(Chat, id)
+                chat.meta = {
+                    **chat.meta,
+                    "tags": [],
+                }
+                db.commit()
+
+                return True
+        except Exception:
+            return False
+
     def delete_chat_by_id(self, id: str) -> bool:
         try:
             with get_db() as db:

+ 17 - 178
backend/open_webui/apps/webui/models/tags.py

@@ -4,53 +4,32 @@ import uuid
 from typing import Optional
 
 from open_webui.apps.webui.internal.db import Base, get_db
+
+
 from open_webui.env import SRC_LOG_LEVELS
 from pydantic import BaseModel, ConfigDict
-from sqlalchemy import BigInteger, Column, String, Text
+from sqlalchemy import BigInteger, Column, String, JSON
 
 log = logging.getLogger(__name__)
 log.setLevel(SRC_LOG_LEVELS["MODELS"])
 
+
 ####################
 # Tag DB Schema
 ####################
-
-
 class Tag(Base):
     __tablename__ = "tag"
-
     id = Column(String, primary_key=True)
     name = Column(String)
     user_id = Column(String)
-    data = Column(Text, nullable=True)
-
-
-class ChatIdTag(Base):
-    __tablename__ = "chatidtag"
-
-    id = Column(String, primary_key=True)
-    tag_name = Column(String)
-    chat_id = Column(String)
-    user_id = Column(String)
-    timestamp = Column(BigInteger)
+    meta = Column(JSON, nullable=True)
 
 
 class TagModel(BaseModel):
     id: str
     name: str
     user_id: str
-    data: Optional[str] = None
-
-    model_config = ConfigDict(from_attributes=True)
-
-
-class ChatIdTagModel(BaseModel):
-    id: str
-    tag_name: str
-    chat_id: str
-    user_id: str
-    timestamp: int
-
+    meta: Optional[dict] = None
     model_config = ConfigDict(from_attributes=True)
 
 
@@ -59,23 +38,15 @@ class ChatIdTagModel(BaseModel):
 ####################
 
 
-class ChatIdTagForm(BaseModel):
-    tag_name: str
+class TagChatIdForm(BaseModel):
+    name: str
     chat_id: str
 
 
-class TagChatIdsResponse(BaseModel):
-    chat_ids: list[str]
-
-
-class ChatTagsResponse(BaseModel):
-    tags: list[str]
-
-
 class TagTable:
     def insert_new_tag(self, name: str, user_id: str) -> Optional[TagModel]:
         with get_db() as db:
-            id = str(uuid.uuid4())
+            id = name.replace(" ", "_").lower()
             tag = TagModel(**{"id": id, "user_id": user_id, "name": name})
             try:
                 result = Tag(**tag.model_dump())
@@ -93,170 +64,38 @@ class TagTable:
         self, name: str, user_id: str
     ) -> Optional[TagModel]:
         try:
+            id = name.replace(" ", "_").lower()
             with get_db() as db:
-                tag = db.query(Tag).filter_by(name=name, user_id=user_id).first()
+                tag = db.query(Tag).filter_by(id=id, user_id=user_id).first()
                 return TagModel.model_validate(tag)
         except Exception:
             return None
 
-    def add_tag_to_chat(
-        self, user_id: str, form_data: ChatIdTagForm
-    ) -> Optional[ChatIdTagModel]:
-        tag = self.get_tag_by_name_and_user_id(form_data.tag_name, user_id)
-        if tag is None:
-            tag = self.insert_new_tag(form_data.tag_name, user_id)
-
-        id = str(uuid.uuid4())
-        chatIdTag = ChatIdTagModel(
-            **{
-                "id": id,
-                "user_id": user_id,
-                "chat_id": form_data.chat_id,
-                "tag_name": tag.name,
-                "timestamp": int(time.time()),
-            }
-        )
-        try:
-            with get_db() as db:
-                result = ChatIdTag(**chatIdTag.model_dump())
-                db.add(result)
-                db.commit()
-                db.refresh(result)
-                if result:
-                    return ChatIdTagModel.model_validate(result)
-                else:
-                    return None
-        except Exception:
-            return None
-
     def get_tags_by_user_id(self, user_id: str) -> list[TagModel]:
         with get_db() as db:
-            tag_names = [
-                chat_id_tag.tag_name
-                for chat_id_tag in (
-                    db.query(ChatIdTag)
-                    .filter_by(user_id=user_id)
-                    .order_by(ChatIdTag.timestamp.desc())
-                    .all()
-                )
-            ]
-
             return [
                 TagModel.model_validate(tag)
-                for tag in (
-                    db.query(Tag)
-                    .filter_by(user_id=user_id)
-                    .filter(Tag.name.in_(tag_names))
-                    .all()
-                )
+                for tag in (db.query(Tag).filter_by(user_id=user_id).all())
             ]
 
-    def get_tags_by_chat_id_and_user_id(
-        self, chat_id: str, user_id: str
-    ) -> list[TagModel]:
+    def get_tags_by_ids(self, ids: list[str]) -> list[TagModel]:
         with get_db() as db:
-            tag_names = [
-                chat_id_tag.tag_name
-                for chat_id_tag in (
-                    db.query(ChatIdTag)
-                    .filter_by(user_id=user_id, chat_id=chat_id)
-                    .order_by(ChatIdTag.timestamp.desc())
-                    .all()
-                )
-            ]
-
             return [
                 TagModel.model_validate(tag)
-                for tag in (
-                    db.query(Tag)
-                    .filter_by(user_id=user_id)
-                    .filter(Tag.name.in_(tag_names))
-                    .all()
-                )
+                for tag in (db.query(Tag).filter(Tag.id.in_(ids)).all())
             ]
 
-    def get_chat_ids_by_tag_name_and_user_id(
-        self, tag_name: str, user_id: str
-    ) -> list[ChatIdTagModel]:
-        with get_db() as db:
-            return [
-                ChatIdTagModel.model_validate(chat_id_tag)
-                for chat_id_tag in (
-                    db.query(ChatIdTag)
-                    .filter_by(user_id=user_id, tag_name=tag_name)
-                    .order_by(ChatIdTag.timestamp.desc())
-                    .all()
-                )
-            ]
-
-    def count_chat_ids_by_tag_name_and_user_id(
-        self, tag_name: str, user_id: str
-    ) -> int:
-        with get_db() as db:
-            return (
-                db.query(ChatIdTag)
-                .filter_by(tag_name=tag_name, user_id=user_id)
-                .count()
-            )
-
-    def delete_tag_by_tag_name_and_user_id(self, tag_name: str, user_id: str) -> bool:
+    def delete_tag_by_name_and_user_id(self, name: str, user_id: str) -> bool:
         try:
             with get_db() as db:
-                res = (
-                    db.query(ChatIdTag)
-                    .filter_by(tag_name=tag_name, user_id=user_id)
-                    .delete()
-                )
+                id = name.replace(" ", "_").lower()
+                res = db.query(Tag).filter_by(id=id, user_id=user_id).delete()
                 log.debug(f"res: {res}")
                 db.commit()
-
-                tag_count = self.count_chat_ids_by_tag_name_and_user_id(
-                    tag_name, user_id
-                )
-                if tag_count == 0:
-                    # Remove tag item from Tag col as well
-                    db.query(Tag).filter_by(name=tag_name, user_id=user_id).delete()
-                    db.commit()
                 return True
         except Exception as e:
             log.error(f"delete_tag: {e}")
             return False
 
-    def delete_tag_by_tag_name_and_chat_id_and_user_id(
-        self, tag_name: str, chat_id: str, user_id: str
-    ) -> bool:
-        try:
-            with get_db() as db:
-                res = (
-                    db.query(ChatIdTag)
-                    .filter_by(tag_name=tag_name, chat_id=chat_id, user_id=user_id)
-                    .delete()
-                )
-                log.debug(f"res: {res}")
-                db.commit()
-
-                tag_count = self.count_chat_ids_by_tag_name_and_user_id(
-                    tag_name, user_id
-                )
-                if tag_count == 0:
-                    # Remove tag item from Tag col as well
-                    db.query(Tag).filter_by(name=tag_name, user_id=user_id).delete()
-                    db.commit()
-
-                return True
-        except Exception as e:
-            log.error(f"delete_tag: {e}")
-            return False
-
-    def delete_tags_by_chat_id_and_user_id(self, chat_id: str, user_id: str) -> bool:
-        tags = self.get_tags_by_chat_id_and_user_id(chat_id, user_id)
-
-        for tag in tags:
-            self.delete_tag_by_tag_name_and_chat_id_and_user_id(
-                tag.tag_name, chat_id, user_id
-            )
-
-        return True
-
 
 Tags = TagTable()

+ 117 - 65
backend/open_webui/apps/webui/routers/chats.py

@@ -8,12 +8,8 @@ from open_webui.apps.webui.models.chats import (
     Chats,
     ChatTitleIdResponse,
 )
-from open_webui.apps.webui.models.tags import (
-    ChatIdTagForm,
-    ChatIdTagModel,
-    TagModel,
-    Tags,
-)
+from open_webui.apps.webui.models.tags import TagModel, Tags
+
 from open_webui.config import ENABLE_ADMIN_CHAT_ACCESS, ENABLE_ADMIN_EXPORT
 from open_webui.constants import ERROR_MESSAGES
 from open_webui.env import SRC_LOG_LEVELS
@@ -126,6 +122,19 @@ async def search_user_chats(
     ]
 
 
+############################
+# GetPinnedChats
+############################
+
+
+@router.get("/pinned", response_model=list[ChatResponse])
+async def get_user_pinned_chats(user=Depends(get_verified_user)):
+    return [
+        ChatResponse(**chat.model_dump())
+        for chat in Chats.get_pinned_chats_by_user_id(user.id)
+    ]
+
+
 ############################
 # GetChats
 ############################
@@ -152,6 +161,23 @@ async def get_user_archived_chats(user=Depends(get_verified_user)):
     ]
 
 
+############################
+# GetAllTags
+############################
+
+
+@router.get("/all/tags", response_model=list[TagModel])
+async def get_all_user_tags(user=Depends(get_verified_user)):
+    try:
+        tags = Tags.get_tags_by_user_id(user.id)
+        return tags
+    except Exception as e:
+        log.exception(e)
+        raise HTTPException(
+            status_code=status.HTTP_400_BAD_REQUEST, detail=ERROR_MESSAGES.DEFAULT()
+        )
+
+
 ############################
 # GetAllChatsInDB
 ############################
@@ -220,48 +246,28 @@ async def get_shared_chat_by_id(share_id: str, user=Depends(get_verified_user)):
 ############################
 
 
-class TagNameForm(BaseModel):
+class TagForm(BaseModel):
     name: str
+
+
+class TagFilterForm(TagForm):
     skip: Optional[int] = 0
     limit: Optional[int] = 50
 
 
 @router.post("/tags", response_model=list[ChatTitleIdResponse])
 async def get_user_chat_list_by_tag_name(
-    form_data: TagNameForm, user=Depends(get_verified_user)
+    form_data: TagFilterForm, user=Depends(get_verified_user)
 ):
-    chat_ids = [
-        chat_id_tag.chat_id
-        for chat_id_tag in Tags.get_chat_ids_by_tag_name_and_user_id(
-            form_data.name, user.id
-        )
-    ]
-
-    chats = Chats.get_chat_list_by_chat_ids(chat_ids, form_data.skip, form_data.limit)
-
+    chats = Chats.get_chat_list_by_user_id_and_tag_name(
+        user.id, form_data.name, form_data.skip, form_data.limit
+    )
     if len(chats) == 0:
-        Tags.delete_tag_by_tag_name_and_user_id(form_data.name, user.id)
+        Tags.delete_tag_by_name_and_user_id(form_data.name, user.id)
 
     return chats
 
 
-############################
-# GetAllTags
-############################
-
-
-@router.get("/tags/all", response_model=list[TagModel])
-async def get_all_tags(user=Depends(get_verified_user)):
-    try:
-        tags = Tags.get_tags_by_user_id(user.id)
-        return tags
-    except Exception as e:
-        log.exception(e)
-        raise HTTPException(
-            status_code=status.HTTP_400_BAD_REQUEST, detail=ERROR_MESSAGES.DEFAULT()
-        )
-
-
 ############################
 # GetChatById
 ############################
@@ -324,12 +330,45 @@ async def delete_chat_by_id(request: Request, id: str, user=Depends(get_verified
         return result
 
 
+############################
+# GetPinnedStatusById
+############################
+
+
+@router.get("/{id}/pinned", response_model=Optional[bool])
+async def get_pinned_status_by_id(id: str, user=Depends(get_verified_user)):
+    chat = Chats.get_chat_by_id_and_user_id(id, user.id)
+    if chat:
+        return chat.pinned
+    else:
+        raise HTTPException(
+            status_code=status.HTTP_401_UNAUTHORIZED, detail=ERROR_MESSAGES.DEFAULT()
+        )
+
+
+############################
+# PinChatById
+############################
+
+
+@router.post("/{id}/pin", response_model=Optional[ChatResponse])
+async def pin_chat_by_id(id: str, user=Depends(get_verified_user)):
+    chat = Chats.get_chat_by_id_and_user_id(id, user.id)
+    if chat:
+        chat = Chats.toggle_chat_pinned_by_id(id)
+        return chat
+    else:
+        raise HTTPException(
+            status_code=status.HTTP_401_UNAUTHORIZED, detail=ERROR_MESSAGES.DEFAULT()
+        )
+
+
 ############################
 # CloneChat
 ############################
 
 
-@router.get("/{id}/clone", response_model=Optional[ChatResponse])
+@router.post("/{id}/clone", response_model=Optional[ChatResponse])
 async def clone_chat_by_id(id: str, user=Depends(get_verified_user)):
     chat = Chats.get_chat_by_id_and_user_id(id, user.id)
     if chat:
@@ -353,7 +392,7 @@ async def clone_chat_by_id(id: str, user=Depends(get_verified_user)):
 ############################
 
 
-@router.get("/{id}/archive", response_model=Optional[ChatResponse])
+@router.post("/{id}/archive", response_model=Optional[ChatResponse])
 async def archive_chat_by_id(id: str, user=Depends(get_verified_user)):
     chat = Chats.get_chat_by_id_and_user_id(id, user.id)
     if chat:
@@ -423,10 +462,10 @@ async def delete_shared_chat_by_id(id: str, user=Depends(get_verified_user)):
 
 @router.get("/{id}/tags", response_model=list[TagModel])
 async def get_chat_tags_by_id(id: str, user=Depends(get_verified_user)):
-    tags = Tags.get_tags_by_chat_id_and_user_id(id, user.id)
-
-    if tags != None:
-        return tags
+    chat = Chats.get_chat_by_id_and_user_id(id, user.id)
+    if chat:
+        tags = chat.meta.get("tags", [])
+        return Tags.get_tags_by_ids(tags)
     else:
         raise HTTPException(
             status_code=status.HTTP_401_UNAUTHORIZED, detail=ERROR_MESSAGES.NOT_FOUND
@@ -438,22 +477,24 @@ async def get_chat_tags_by_id(id: str, user=Depends(get_verified_user)):
 ############################
 
 
-@router.post("/{id}/tags", response_model=Optional[ChatIdTagModel])
-async def add_chat_tag_by_id(
-    id: str, form_data: ChatIdTagForm, user=Depends(get_verified_user)
+@router.post("/{id}/tags", response_model=list[TagModel])
+async def add_tag_by_id_and_tag_name(
+    id: str, form_data: TagForm, user=Depends(get_verified_user)
 ):
-    tags = Tags.get_tags_by_chat_id_and_user_id(id, user.id)
-
-    if form_data.tag_name not in tags:
-        tag = Tags.add_tag_to_chat(user.id, form_data)
+    chat = Chats.get_chat_by_id_and_user_id(id, user.id)
+    if chat:
+        tags = chat.meta.get("tags", [])
+        tag_id = form_data.name.replace(" ", "_").lower()
 
-        if tag:
-            return tag
-        else:
-            raise HTTPException(
-                status_code=status.HTTP_401_UNAUTHORIZED,
-                detail=ERROR_MESSAGES.NOT_FOUND,
+        print(tags, tag_id)
+        if tag_id not in tags:
+            Chats.add_chat_tag_by_id_and_user_id_and_tag_name(
+                id, user.id, form_data.name
             )
+
+        chat = Chats.get_chat_by_id_and_user_id(id, user.id)
+        tags = chat.meta.get("tags", [])
+        return Tags.get_tags_by_ids(tags)
     else:
         raise HTTPException(
             status_code=status.HTTP_401_UNAUTHORIZED, detail=ERROR_MESSAGES.DEFAULT()
@@ -465,16 +506,20 @@ async def add_chat_tag_by_id(
 ############################
 
 
-@router.delete("/{id}/tags", response_model=Optional[bool])
-async def delete_chat_tag_by_id(
-    id: str, form_data: ChatIdTagForm, user=Depends(get_verified_user)
+@router.delete("/{id}/tags", response_model=list[TagModel])
+async def delete_tag_by_id_and_tag_name(
+    id: str, form_data: TagForm, user=Depends(get_verified_user)
 ):
-    result = Tags.delete_tag_by_tag_name_and_chat_id_and_user_id(
-        form_data.tag_name, id, user.id
-    )
+    chat = Chats.get_chat_by_id_and_user_id(id, user.id)
+    if chat:
+        Chats.delete_tag_by_id_and_user_id_and_tag_name(id, user.id, form_data.name)
 
-    if result:
-        return result
+        if Chats.count_chats_by_tag_name_and_user_id(form_data.name, user.id) == 0:
+            Tags.delete_tag_by_name_and_user_id(form_data.name, user.id)
+
+        chat = Chats.get_chat_by_id_and_user_id(id, user.id)
+        tags = chat.meta.get("tags", [])
+        return Tags.get_tags_by_ids(tags)
     else:
         raise HTTPException(
             status_code=status.HTTP_401_UNAUTHORIZED, detail=ERROR_MESSAGES.NOT_FOUND
@@ -488,10 +533,17 @@ async def delete_chat_tag_by_id(
 
 @router.delete("/{id}/tags/all", response_model=Optional[bool])
 async def delete_all_chat_tags_by_id(id: str, user=Depends(get_verified_user)):
-    result = Tags.delete_tags_by_chat_id_and_user_id(id, user.id)
+    chat = Chats.get_chat_by_id_and_user_id(id, user.id)
+    if chat:
+        Chats.delete_all_tags_by_id_and_user_id(id, user.id)
 
-    if result:
-        return result
+        for tag in chat.meta.get("tags", []):
+            if Chats.count_chats_by_tag_name_and_user_id(tag, user.id) == 0:
+                Tags.delete_tag_by_name_and_user_id(tag, user.id)
+
+        chat = Chats.get_chat_by_id_and_user_id(id, user.id)
+        tags = chat.meta.get("tags", [])
+        return Tags.get_tags_by_ids(tags)
     else:
         raise HTTPException(
             status_code=status.HTTP_401_UNAUTHORIZED, detail=ERROR_MESSAGES.NOT_FOUND

+ 109 - 0
backend/open_webui/migrations/versions/1af9b942657b_migrate_tags.py

@@ -0,0 +1,109 @@
+"""Migrate tags
+
+Revision ID: 1af9b942657b
+Revises: 242a2047eae0
+Create Date: 2024-10-09 21:02:35.241684
+
+"""
+
+from alembic import op
+import sqlalchemy as sa
+from sqlalchemy.sql import table, select, update, column
+
+import json
+
+revision = "1af9b942657b"
+down_revision = "242a2047eae0"
+branch_labels = None
+depends_on = None
+
+
+def upgrade():
+    # Step 1: Modify Tag table using batch mode for SQLite support
+    with op.batch_alter_table("tag", schema=None) as batch_op:
+        batch_op.create_unique_constraint(
+            "uq_id_user_id", ["id", "user_id"]
+        )  # Ensure unique (id, user_id)
+        batch_op.drop_column("data")
+        batch_op.add_column(sa.Column("meta", sa.JSON(), nullable=True))
+
+    tag = table(
+        "tag",
+        column("id", sa.String()),
+        column("name", sa.String()),
+        column("user_id", sa.String()),
+        column("meta", sa.JSON()),
+    )
+
+    # Step 2: Migrate tags
+    conn = op.get_bind()
+    result = conn.execute(sa.select(tag.c.id, tag.c.name, tag.c.user_id))
+
+    tag_updates = {}
+    for row in result:
+        new_id = row.name.replace(" ", "_").lower()
+        tag_updates[row.id] = new_id
+
+    for tag_id, new_tag_id in tag_updates.items():
+        print(f"Updating tag {tag_id} to {new_tag_id}")
+        if new_tag_id == "pinned":
+            # delete tag
+            delete_stmt = sa.delete(tag).where(tag.c.id == tag_id)
+            conn.execute(delete_stmt)
+        else:
+            update_stmt = sa.update(tag).where(tag.c.id == tag_id)
+            update_stmt = update_stmt.values(id=new_tag_id)
+            conn.execute(update_stmt)
+
+    # Add columns `pinned` and `meta` to 'chat'
+    op.add_column("chat", sa.Column("pinned", sa.Boolean(), nullable=True))
+    op.add_column(
+        "chat", sa.Column("meta", sa.JSON(), nullable=False, server_default="{}")
+    )
+
+    chatidtag = table(
+        "chatidtag", column("chat_id", sa.String()), column("tag_name", sa.String())
+    )
+    chat = table(
+        "chat",
+        column("id", sa.String()),
+        column("pinned", sa.Boolean()),
+        column("meta", sa.JSON()),
+    )
+
+    # Fetch existing tags
+    conn = op.get_bind()
+    result = conn.execute(sa.select(chatidtag.c.chat_id, chatidtag.c.tag_name))
+
+    chat_updates = {}
+    for row in result:
+        chat_id = row.chat_id
+        tag_name = row.tag_name.replace(" ", "_").lower()
+
+        if tag_name == "pinned":
+            # Specifically handle 'pinned' tag
+            if chat_id not in chat_updates:
+                chat_updates[chat_id] = {"pinned": True, "meta": {}}
+            else:
+                chat_updates[chat_id]["pinned"] = True
+        else:
+            if chat_id not in chat_updates:
+                chat_updates[chat_id] = {"pinned": False, "meta": {"tags": [tag_name]}}
+            else:
+                tags = chat_updates[chat_id]["meta"].get("tags", [])
+                tags.append(tag_name)
+
+                chat_updates[chat_id]["meta"]["tags"] = tags
+
+    # Update chats based on accumulated changes
+    for chat_id, updates in chat_updates.items():
+        update_stmt = sa.update(chat).where(chat.c.id == chat_id)
+        update_stmt = update_stmt.values(
+            meta=updates.get("meta", {}), pinned=updates.get("pinned", False)
+        )
+        conn.execute(update_stmt)
+    pass
+
+
+def downgrade():
+    pass

+ 115 - 7
src/lib/apis/chats/index.ts

@@ -267,7 +267,7 @@ export const getAllUserChats = async (token: string) => {
 export const getAllChatTags = async (token: string) => {
 	let error = null;
 
-	const res = await fetch(`${WEBUI_API_BASE_URL}/chats/tags/all`, {
+	const res = await fetch(`${WEBUI_API_BASE_URL}/chats/all/tags`, {
 		method: 'GET',
 		headers: {
 			Accept: 'application/json',
@@ -295,6 +295,40 @@ export const getAllChatTags = async (token: string) => {
 	return res;
 };
 
+export const getPinnedChatList = async (token: string = '') => {
+	let error = null;
+
+	const res = await fetch(`${WEBUI_API_BASE_URL}/chats/pinned`, {
+		method: 'GET',
+		headers: {
+			Accept: 'application/json',
+			'Content-Type': 'application/json',
+			...(token && { authorization: `Bearer ${token}` })
+		}
+	})
+		.then(async (res) => {
+			if (!res.ok) throw await res.json();
+			return res.json();
+		})
+		.then((json) => {
+			return json;
+		})
+		.catch((err) => {
+			error = err;
+			console.log(err);
+			return null;
+		});
+
+	if (error) {
+		throw error;
+	}
+
+	return res.map((chat) => ({
+		...chat,
+		time_range: getTimeRange(chat.updated_at)
+	}));
+};
+
 export const getChatListByTagName = async (token: string = '', tagName: string) => {
 	let error = null;
 
@@ -396,11 +430,87 @@ export const getChatByShareId = async (token: string, share_id: string) => {
 	return res;
 };
 
+export const getChatPinnedStatusById = async (token: string, id: string) => {
+	let error = null;
+
+	const res = await fetch(`${WEBUI_API_BASE_URL}/chats/${id}/pinned`, {
+		method: 'GET',
+		headers: {
+			Accept: 'application/json',
+			'Content-Type': 'application/json',
+			...(token && { authorization: `Bearer ${token}` })
+		}
+	})
+		.then(async (res) => {
+			if (!res.ok) throw await res.json();
+			return res.json();
+		})
+		.then((json) => {
+			return json;
+		})
+		.catch((err) => {
+			error = err;
+
+			if ('detail' in err) {
+				error = err.detail;
+			} else {
+				error = err;
+			}
+
+			console.log(err);
+			return null;
+		});
+
+	if (error) {
+		throw error;
+	}
+
+	return res;
+};
+
+export const toggleChatPinnedStatusById = async (token: string, id: string) => {
+	let error = null;
+
+	const res = await fetch(`${WEBUI_API_BASE_URL}/chats/${id}/pin`, {
+		method: 'POST',
+		headers: {
+			Accept: 'application/json',
+			'Content-Type': 'application/json',
+			...(token && { authorization: `Bearer ${token}` })
+		}
+	})
+		.then(async (res) => {
+			if (!res.ok) throw await res.json();
+			return res.json();
+		})
+		.then((json) => {
+			return json;
+		})
+		.catch((err) => {
+			error = err;
+
+			if ('detail' in err) {
+				error = err.detail;
+			} else {
+				error = err;
+			}
+
+			console.log(err);
+			return null;
+		});
+
+	if (error) {
+		throw error;
+	}
+
+	return res;
+};
+
 export const cloneChatById = async (token: string, id: string) => {
 	let error = null;
 
 	const res = await fetch(`${WEBUI_API_BASE_URL}/chats/${id}/clone`, {
-		method: 'GET',
+		method: 'POST',
 		headers: {
 			Accept: 'application/json',
 			'Content-Type': 'application/json',
@@ -470,7 +580,7 @@ export const archiveChatById = async (token: string, id: string) => {
 	let error = null;
 
 	const res = await fetch(`${WEBUI_API_BASE_URL}/chats/${id}/archive`, {
-		method: 'GET',
+		method: 'POST',
 		headers: {
 			Accept: 'application/json',
 			'Content-Type': 'application/json',
@@ -640,8 +750,7 @@ export const addTagById = async (token: string, id: string, tagName: string) =>
 			...(token && { authorization: `Bearer ${token}` })
 		},
 		body: JSON.stringify({
-			tag_name: tagName,
-			chat_id: id
+			name: tagName
 		})
 	})
 		.then(async (res) => {
@@ -676,8 +785,7 @@ export const deleteTagById = async (token: string, id: string, tagName: string)
 			...(token && { authorization: `Bearer ${token}` })
 		},
 		body: JSON.stringify({
-			tag_name: tagName,
-			chat_id: id
+			name: tagName
 		})
 	})
 		.then(async (res) => {

+ 4 - 15
src/lib/components/chat/Tags.svelte

@@ -25,40 +25,30 @@
 	let tags = [];
 
 	const getTags = async () => {
-		return (
-			await getTagsById(localStorage.token, chatId).catch(async (error) => {
-				return [];
-			})
-		).filter((tag) => tag.name !== 'pinned');
+		return await getTagsById(localStorage.token, chatId).catch(async (error) => {
+			return [];
+		});
 	};
 
 	const addTag = async (tagName) => {
 		const res = await addTagById(localStorage.token, chatId, tagName);
 		tags = await getTags();
-
 		await updateChatById(localStorage.token, chatId, {
 			tags: tags
 		});
-
 		_tags.set(await getAllChatTags(localStorage.token));
-		await pinnedChats.set(await getChatListByTagName(localStorage.token, 'pinned'));
 	};
 
 	const deleteTag = async (tagName) => {
 		const res = await deleteTagById(localStorage.token, chatId, tagName);
 		tags = await getTags();
-
 		await updateChatById(localStorage.token, chatId, {
 			tags: tags
 		});
 
 		await _tags.set(await getAllChatTags(localStorage.token));
 		if ($_tags.map((t) => t.name).includes(tagName)) {
-			if (tagName === 'pinned') {
-				await pinnedChats.set(await getChatListByTagName(localStorage.token, 'pinned'));
-			} else {
-				await chats.set(await getChatListByTagName(localStorage.token, tagName));
-			}
+			await chats.set(await getChatListByTagName(localStorage.token, tagName));
 
 			if ($chats.find((chat) => chat.id === chatId)) {
 				dispatch('close');
@@ -67,7 +57,6 @@
 			// if the tag we deleted is no longer a valid tag, return to main chat list view
 			currentChatPage.set(1);
 			await chats.set(await getChatList(localStorage.token, $currentChatPage));
-			await pinnedChats.set(await getChatListByTagName(localStorage.token, 'pinned'));
 			await scrollPaginationEnabled.set(true);
 		}
 	};

+ 4 - 0
src/lib/components/layout/Navbar/Menu.svelte

@@ -24,6 +24,7 @@
 	import Clipboard from '$lib/components/icons/Clipboard.svelte';
 	import AdjustmentsHorizontal from '$lib/components/icons/AdjustmentsHorizontal.svelte';
 	import Cube from '$lib/components/icons/Cube.svelte';
+	import { getChatById } from '$lib/apis/chats';
 
 	const i18n = getContext('i18n');
 
@@ -81,6 +82,9 @@
 	};
 
 	const downloadJSONExport = async () => {
+		if (chat.id) {
+			chat = await getChatById(localStorage.token, chat.id);
+		}
 		let blob = new Blob([JSON.stringify([chat])], {
 			type: 'application/json'
 		});

+ 11 - 9
src/lib/components/layout/Sidebar.svelte

@@ -34,7 +34,8 @@
 		archiveChatById,
 		cloneChatById,
 		getChatListBySearchText,
-		createNewChat
+		createNewChat,
+		getPinnedChatList
 	} from '$lib/apis/chats';
 	import { WEBUI_BASE_URL } from '$lib/constants';
 
@@ -135,7 +136,7 @@
 			currentChatPage.set(1);
 			await chats.set(await getChatList(localStorage.token, $currentChatPage));
 
-			await pinnedChats.set(await getChatListByTagName(localStorage.token, 'pinned'));
+			await pinnedChats.set(await getPinnedChatList(localStorage.token));
 		}
 	};
 
@@ -255,7 +256,7 @@
 			localStorage.sidebar = value;
 		});
 
-		await pinnedChats.set(await getChatListByTagName(localStorage.token, 'pinned'));
+		await pinnedChats.set(await getPinnedChatList(localStorage.token));
 		await initChatList();
 
 		window.addEventListener('keydown', onKeyDown);
@@ -495,7 +496,7 @@
 				</div>
 			</div>
 
-			{#if $tags.filter((t) => t.name !== 'pinned').length > 0}
+			{#if $tags.length > 0}
 				<div class="px-3.5 mb-1 flex gap-0.5 flex-wrap">
 					<button
 						class="px-2.5 py-[1px] text-xs transition {selectedTagName === null
@@ -508,7 +509,7 @@
 					>
 						{$i18n.t('all')}
 					</button>
-					{#each $tags.filter((t) => t.name !== 'pinned') as tag}
+					{#each $tags as tag}
 						<button
 							class="px-2.5 py-[1px] text-xs transition {selectedTagName === tag.name
 								? 'bg-gray-100 dark:bg-gray-900'
@@ -516,14 +517,15 @@
 							on:click={async () => {
 								selectedTagName = tag.name;
 								scrollPaginationEnabled.set(false);
-								let chatIds = await getChatListByTagName(localStorage.token, tag.name);
-								if (chatIds.length === 0) {
-									await tags.set(await getAllChatTags(localStorage.token));
 
+								let taggedChatList = await getChatListByTagName(localStorage.token, tag.name);
+								if (taggedChatList.length === 0) {
+									await tags.set(await getAllChatTags(localStorage.token));
 									// if the tag we deleted is no longer a valid tag, return to main chat list view
 									await initChatList();
+								} else {
+									await chats.set(taggedChatList);
 								}
-								await chats.set(chatIds);
 								chatListLoading = false;
 							}}
 						>

+ 5 - 4
src/lib/components/layout/Sidebar/ChatItem.svelte

@@ -12,6 +12,7 @@
 		deleteChatById,
 		getChatList,
 		getChatListByTagName,
+		getPinnedChatList,
 		updateChatById
 	} from '$lib/apis/chats';
 	import {
@@ -55,7 +56,7 @@
 
 			currentChatPage.set(1);
 			await chats.set(await getChatList(localStorage.token, $currentChatPage));
-			await pinnedChats.set(await getChatListByTagName(localStorage.token, 'pinned'));
+			await pinnedChats.set(await getPinnedChatList(localStorage.token));
 		}
 	};
 
@@ -70,7 +71,7 @@
 
 			currentChatPage.set(1);
 			await chats.set(await getChatList(localStorage.token, $currentChatPage));
-			await pinnedChats.set(await getChatListByTagName(localStorage.token, 'pinned'));
+			await pinnedChats.set(await getPinnedChatList(localStorage.token));
 		}
 	};
 
@@ -79,7 +80,7 @@
 
 		currentChatPage.set(1);
 		await chats.set(await getChatList(localStorage.token, $currentChatPage));
-		await pinnedChats.set(await getChatListByTagName(localStorage.token, 'pinned'));
+		await pinnedChats.set(await getPinnedChatList(localStorage.token));
 	};
 
 	const focusEdit = async (node: HTMLInputElement) => {
@@ -256,7 +257,7 @@
 						dispatch('unselect');
 					}}
 					on:change={async () => {
-						await pinnedChats.set(await getChatListByTagName(localStorage.token, 'pinned'));
+						await pinnedChats.set(await getPinnedChatList(localStorage.token));
 					}}
 				>
 					<button

+ 9 - 11
src/lib/components/layout/Sidebar/ChatMenu.svelte

@@ -15,7 +15,13 @@
 	import DocumentDuplicate from '$lib/components/icons/DocumentDuplicate.svelte';
 	import Bookmark from '$lib/components/icons/Bookmark.svelte';
 	import BookmarkSlash from '$lib/components/icons/BookmarkSlash.svelte';
-	import { addTagById, deleteTagById, getTagsById } from '$lib/apis/chats';
+	import {
+		addTagById,
+		deleteTagById,
+		getChatPinnedStatusById,
+		getTagsById,
+		toggleChatPinnedStatusById
+	} from '$lib/apis/chats';
 
 	const i18n = getContext('i18n');
 
@@ -32,20 +38,12 @@
 	let pinned = false;
 
 	const pinHandler = async () => {
-		if (pinned) {
-			await deleteTagById(localStorage.token, chatId, 'pinned');
-		} else {
-			await addTagById(localStorage.token, chatId, 'pinned');
-		}
+		await toggleChatPinnedStatusById(localStorage.token, chatId);
 		dispatch('change');
 	};
 
 	const checkPinned = async () => {
-		pinned = (
-			await getTagsById(localStorage.token, chatId).catch(async (error) => {
-				return [];
-			})
-		).find((tag) => tag.name === 'pinned');
+		pinned = await getChatPinnedStatusById(localStorage.token, chatId);
 	};
 
 	$: if (show) {