Browse Source

feat: `/messages` chat endpoint support

Timothy Jaeryang Baek 1 month ago
parent
commit
1be6ad1250

+ 58 - 0
backend/open_webui/routers/chats.py

@@ -2,6 +2,8 @@ import json
 import logging
 from typing import Optional
 
+
+from open_webui.socket.main import get_event_emitter
 from open_webui.models.chats import (
     ChatForm,
     ChatImportForm,
@@ -372,6 +374,62 @@ async def update_chat_by_id(
         )
 
 
+############################
+# UpdateChatMessageById
+############################
+class MessageForm(BaseModel):
+    content: str
+
+
+@router.post("/{id}/messages/{message_id}", response_model=Optional[ChatResponse])
+async def update_chat_message_by_id(
+    id: str, message_id: str, form_data: MessageForm, user=Depends(get_verified_user)
+):
+    chat = Chats.get_chat_by_id(id)
+
+    if not chat:
+        raise HTTPException(
+            status_code=status.HTTP_401_UNAUTHORIZED,
+            detail=ERROR_MESSAGES.ACCESS_PROHIBITED,
+        )
+
+    if chat.user_id != user.id and user.role != "admin":
+        raise HTTPException(
+            status_code=status.HTTP_401_UNAUTHORIZED,
+            detail=ERROR_MESSAGES.ACCESS_PROHIBITED,
+        )
+
+    chat = Chats.upsert_message_to_chat_by_id_and_message_id(
+        id,
+        message_id,
+        {
+            "content": form_data.content,
+        },
+    )
+
+    event_emitter = get_event_emitter(
+        {
+            "user_id": user.id,
+            "chat_id": id,
+            "message_id": message_id,
+        }
+    )
+
+    if event_emitter:
+        event_emitter(
+            {
+                "type": "chat:message",
+                "data": {
+                    "chat_id": id,
+                    "message_id": message_id,
+                    "content": form_data.content,
+                },
+            }
+        )
+
+    return ChatResponse(**chat.model_dump())
+
+
 ############################
 # DeleteChatById
 ############################

+ 27 - 25
backend/open_webui/socket/main.py

@@ -269,7 +269,7 @@ async def disconnect(sid):
         # print(f"Unknown session ID {sid} disconnected")
 
 
-def get_event_emitter(request_info):
+def get_event_emitter(request_info, update_db=True):
     async def __event_emitter__(event_data):
         user_id = request_info["user_id"]
         session_ids = list(
@@ -287,31 +287,33 @@ def get_event_emitter(request_info):
                 to=session_id,
             )
 
-        if "type" in event_data and event_data["type"] == "status":
-            Chats.add_message_status_to_chat_by_id_and_message_id(
-                request_info["chat_id"],
-                request_info["message_id"],
-                event_data.get("data", {}),
-            )
-
-        if "type" in event_data and event_data["type"] == "message":
-            message = Chats.get_message_by_id_and_message_id(
-                request_info["chat_id"],
-                request_info["message_id"],
-            )
-
-            content = message.get("content", "")
-            content += event_data.get("data", {}).get("content", "")
-
-            Chats.upsert_message_to_chat_by_id_and_message_id(
-                request_info["chat_id"],
-                request_info["message_id"],
-                {
-                    "content": content,
-                },
-            )
 
-        if "type" in event_data and event_data["type"] == "replace":
+        if update_db:
+            if "type" in event_data and event_data["type"] == "status":
+                Chats.add_message_status_to_chat_by_id_and_message_id(
+                    request_info["chat_id"],
+                    request_info["message_id"],
+                    event_data.get("data", {}),
+                )
+
+            if "type" in event_data and event_data["type"] == "message":
+                message = Chats.get_message_by_id_and_message_id(
+                    request_info["chat_id"],
+                    request_info["message_id"],
+                )
+
+                content = message.get("content", "")
+                content += event_data.get("data", {}).get("content", "")
+
+                Chats.upsert_message_to_chat_by_id_and_message_id(
+                    request_info["chat_id"],
+                    request_info["message_id"],
+                    {
+                        "content": content,
+                    },
+                )
+
+            if "type" in event_data and event_data["type"] == "replace":
             content = event_data.get("data", {}).get("content", "")
 
             Chats.upsert_message_to_chat_by_id_and_message_id(

+ 2 - 10
src/lib/components/chat/Chat.svelte

@@ -293,18 +293,10 @@
 				} else if (type === 'chat:tags') {
 					chat = await getChatById(localStorage.token, $chatId);
 					allTags.set(await getAllTags(localStorage.token));
-				} else if (type === 'message') {
+				} else if (type === 'chat:message:delta' || type === 'message') {
 					message.content += data.content;
-				} else if (type === 'replace') {
+				} else if (type === 'chat:message' || type === 'replace') {
 					message.content = data.content;
-				} else if (type === 'action') {
-					if (data.action === 'continue') {
-						const continueButton = document.getElementById('continue-response-button');
-
-						if (continueButton) {
-							continueButton.click();
-						}
-					}
 				} else if (type === 'confirmation') {
 					eventCallback = cb;