فهرست منبع

refac: convert chat.chat to json data type

Timothy J. Baek 6 ماه پیش
والد
کامیت
d7a00af576

+ 10 - 10
backend/open_webui/apps/webui/models/chats.py

@@ -5,7 +5,7 @@ from typing import Optional
 
 from open_webui.apps.webui.internal.db import Base, get_db
 from pydantic import BaseModel, ConfigDict
-from sqlalchemy import BigInteger, Boolean, Column, String, Text
+from sqlalchemy import BigInteger, Boolean, Column, String, Text, JSON
 
 ####################
 # Chat DB Schema
@@ -18,7 +18,7 @@ class Chat(Base):
     id = Column(String, primary_key=True)
     user_id = Column(String)
     title = Column(Text)
-    chat = Column(Text)  # Save Chat JSON as Text
+    chat = Column(JSON)
 
     created_at = Column(BigInteger)
     updated_at = Column(BigInteger)
@@ -33,7 +33,7 @@ class ChatModel(BaseModel):
     id: str
     user_id: str
     title: str
-    chat: str
+    chat: dict
 
     created_at: int  # timestamp in epoch
     updated_at: int  # timestamp in epoch
@@ -86,7 +86,7 @@ class ChatTable:
                         if "title" in form_data.chat
                         else "New Chat"
                     ),
-                    "chat": json.dumps(form_data.chat),
+                    "chat": form_data.chat,
                     "created_at": int(time.time()),
                     "updated_at": int(time.time()),
                 }
@@ -101,14 +101,14 @@ class ChatTable:
     def update_chat_by_id(self, id: str, chat: dict) -> Optional[ChatModel]:
         try:
             with get_db() as db:
-                chat_obj = db.get(Chat, id)
-                chat_obj.chat = json.dumps(chat)
-                chat_obj.title = chat["title"] if "title" in chat else "New Chat"
-                chat_obj.updated_at = int(time.time())
+                chat_item = db.get(Chat, id)
+                chat_item.chat = chat
+                chat_item.title = chat["title"] if "title" in chat else "New Chat"
+                chat_item.updated_at = int(time.time())
                 db.commit()
-                db.refresh(chat_obj)
+                db.refresh(chat_item)
 
-                return ChatModel.model_validate(chat_obj)
+                return ChatModel.model_validate(chat_item)
         except Exception:
             return None
 

+ 16 - 23
backend/open_webui/apps/webui/routers/chats.py

@@ -95,7 +95,7 @@ async def get_user_chat_list_by_user_id(
 async def create_new_chat(form_data: ChatForm, user=Depends(get_verified_user)):
     try:
         chat = Chats.insert_new_chat(user.id, form_data)
-        return ChatResponse(**{**chat.model_dump(), "chat": json.loads(chat.chat)})
+        return ChatResponse(**chat.model_dump())
     except Exception as e:
         log.exception(e)
         raise HTTPException(
@@ -111,7 +111,7 @@ async def create_new_chat(form_data: ChatForm, user=Depends(get_verified_user)):
 @router.get("/all", response_model=list[ChatResponse])
 async def get_user_chats(user=Depends(get_verified_user)):
     return [
-        ChatResponse(**{**chat.model_dump(), "chat": json.loads(chat.chat)})
+        ChatResponse(**chat.model_dump())
         for chat in Chats.get_chats_by_user_id(user.id)
     ]
 
@@ -124,7 +124,7 @@ async def get_user_chats(user=Depends(get_verified_user)):
 @router.get("/all/archived", response_model=list[ChatResponse])
 async def get_user_archived_chats(user=Depends(get_verified_user)):
     return [
-        ChatResponse(**{**chat.model_dump(), "chat": json.loads(chat.chat)})
+        ChatResponse(**chat.model_dump())
         for chat in Chats.get_archived_chats_by_user_id(user.id)
     ]
 
@@ -141,10 +141,7 @@ async def get_all_user_chats_in_db(user=Depends(get_admin_user)):
             status_code=status.HTTP_401_UNAUTHORIZED,
             detail=ERROR_MESSAGES.ACCESS_PROHIBITED,
         )
-    return [
-        ChatResponse(**{**chat.model_dump(), "chat": json.loads(chat.chat)})
-        for chat in Chats.get_chats()
-    ]
+    return [ChatResponse(**chat.model_dump()) for chat in Chats.get_chats()]
 
 
 ############################
@@ -187,7 +184,8 @@ async def get_shared_chat_by_id(share_id: str, user=Depends(get_verified_user)):
         chat = Chats.get_chat_by_id(share_id)
 
     if chat:
-        return ChatResponse(**{**chat.model_dump(), "chat": json.loads(chat.chat)})
+        return ChatResponse(**chat.model_dump())
+
     else:
         raise HTTPException(
             status_code=status.HTTP_401_UNAUTHORIZED, detail=ERROR_MESSAGES.NOT_FOUND
@@ -251,7 +249,8 @@ async def get_chat_by_id(id: str, user=Depends(get_verified_user)):
     chat = Chats.get_chat_by_id_and_user_id(id, user.id)
 
     if chat:
-        return ChatResponse(**{**chat.model_dump(), "chat": json.loads(chat.chat)})
+        return ChatResponse(**chat.model_dump())
+
     else:
         raise HTTPException(
             status_code=status.HTTP_401_UNAUTHORIZED, detail=ERROR_MESSAGES.NOT_FOUND
@@ -269,10 +268,9 @@ async def update_chat_by_id(
 ):
     chat = Chats.get_chat_by_id_and_user_id(id, user.id)
     if chat:
-        updated_chat = {**json.loads(chat.chat), **form_data.chat}
-
+        updated_chat = {**chat.chat, **form_data.chat}
         chat = Chats.update_chat_by_id(id, updated_chat)
-        return ChatResponse(**{**chat.model_dump(), "chat": json.loads(chat.chat)})
+        return ChatResponse(**chat.model_dump())
     else:
         raise HTTPException(
             status_code=status.HTTP_401_UNAUTHORIZED,
@@ -312,16 +310,15 @@ async def delete_chat_by_id(request: Request, id: str, user=Depends(get_verified
 async def clone_chat_by_id(id: str, user=Depends(get_verified_user)):
     chat = Chats.get_chat_by_id_and_user_id(id, user.id)
     if chat:
-        chat_body = json.loads(chat.chat)
         updated_chat = {
-            **chat_body,
+            **chat.chat,
             "originalChatId": chat.id,
-            "branchPointMessageId": chat_body["history"]["currentId"],
+            "branchPointMessageId": chat.chat["history"]["currentId"],
             "title": f"Clone of {chat.title}",
         }
 
         chat = Chats.insert_new_chat(user.id, ChatForm(**{"chat": updated_chat}))
-        return ChatResponse(**{**chat.model_dump(), "chat": json.loads(chat.chat)})
+        return ChatResponse(**chat.model_dump())
     else:
         raise HTTPException(
             status_code=status.HTTP_401_UNAUTHORIZED, detail=ERROR_MESSAGES.DEFAULT()
@@ -338,7 +335,7 @@ async def archive_chat_by_id(id: str, user=Depends(get_verified_user)):
     chat = Chats.get_chat_by_id_and_user_id(id, user.id)
     if chat:
         chat = Chats.toggle_chat_archive_by_id(id)
-        return ChatResponse(**{**chat.model_dump(), "chat": json.loads(chat.chat)})
+        return ChatResponse(**chat.model_dump())
     else:
         raise HTTPException(
             status_code=status.HTTP_401_UNAUTHORIZED, detail=ERROR_MESSAGES.DEFAULT()
@@ -356,9 +353,7 @@ async def share_chat_by_id(id: str, user=Depends(get_verified_user)):
     if chat:
         if chat.share_id:
             shared_chat = Chats.update_shared_chat_by_chat_id(chat.id)
-            return ChatResponse(
-                **{**shared_chat.model_dump(), "chat": json.loads(shared_chat.chat)}
-            )
+            return ChatResponse(**shared_chat.model_dump())
 
         shared_chat = Chats.insert_shared_chat_by_chat_id(chat.id)
         if not shared_chat:
@@ -366,10 +361,8 @@ async def share_chat_by_id(id: str, user=Depends(get_verified_user)):
                 status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
                 detail=ERROR_MESSAGES.DEFAULT(),
             )
+        return ChatResponse(**shared_chat.model_dump())
 
-        return ChatResponse(
-            **{**shared_chat.model_dump(), "chat": json.loads(shared_chat.chat)}
-        )
     else:
         raise HTTPException(
             status_code=status.HTTP_401_UNAUTHORIZED,

+ 82 - 0
backend/open_webui/migrations/versions/242a2047eae0_update_chat_table.py

@@ -0,0 +1,82 @@
+"""Update chat table
+
+Revision ID: 242a2047eae0
+Revises: 6a39f3d8e55c
+Create Date: 2024-10-09 21:02:35.241684
+
+"""
+
+from alembic import op
+import sqlalchemy as sa
+from sqlalchemy.sql import table, select, update
+
+import json
+
+revision = "242a2047eae0"
+down_revision = "6a39f3d8e55c"
+branch_labels = None
+depends_on = None
+
+
+def upgrade():
+    # Step 1: Rename current 'chat' column to 'old_chat'
+    op.alter_column("chat", "chat", new_column_name="old_chat", existing_type=sa.Text)
+
+    # Step 2: Add new 'chat' column of type JSON
+    op.add_column("chat", sa.Column("chat", sa.JSON(), nullable=True))
+
+    # Step 3: Migrate data from 'old_chat' to 'chat'
+    chat_table = table(
+        "chat",
+        sa.Column("id", sa.String, primary_key=True),
+        sa.Column("old_chat", sa.Text),
+        sa.Column("chat", sa.JSON()),
+    )
+
+    # - Selecting all data from the table
+    connection = op.get_bind()
+    results = connection.execute(select(chat_table.c.id, chat_table.c.old_chat))
+    for row in results:
+        try:
+            # Convert text JSON to actual JSON object, assuming the text is in JSON format
+            json_data = json.loads(row.old_chat)
+        except json.JSONDecodeError:
+            json_data = None  # Handle cases where the text cannot be converted to JSON
+
+        connection.execute(
+            sa.update(chat_table)
+            .where(chat_table.c.id == row.id)
+            .values(chat=json_data)
+        )
+
+    # Step 4: Drop 'old_chat' column
+    op.drop_column("chat", "old_chat")
+
+
+def downgrade():
+    # Step 1: Add 'old_chat' column back as Text
+    op.add_column("chat", sa.Column("old_chat", sa.Text(), nullable=True))
+
+    # Step 2: Convert 'chat' JSON data back to text and store in 'old_chat'
+    chat_table = table(
+        "chat",
+        sa.Column("id", sa.String, primary_key=True),
+        sa.Column("chat", sa.JSON()),
+        sa.Column("old_chat", sa.Text()),
+    )
+
+    connection = op.get_bind()
+    results = connection.execute(select(chat_table.c.id, chat_table.c.chat))
+    for row in results:
+        text_data = json.dumps(row.chat) if row.chat is not None else None
+        connection.execute(
+            sa.update(chat_table)
+            .where(chat_table.c.id == row.id)
+            .values(old_chat=text_data)
+        )
+
+    # Step 3: Remove the new 'chat' JSON column
+    op.drop_column("chat", "chat")
+
+    # Step 4: Rename 'old_chat' back to 'chat'
+    op.alter_column("chat", "old_chat", new_column_name="chat", existing_type=sa.Text)