浏览代码

fix: non-stream chat completion

Timothy Jaeryang Baek 4 月之前
父节点
当前提交
da7fa09053
共有 2 个文件被更改,包括 190 次插入127 次删除
  1. 159 100
      backend/open_webui/utils/middleware.py
  2. 31 27
      src/lib/components/chat/Chat.svelte

+ 159 - 100
backend/open_webui/utils/middleware.py

@@ -748,14 +748,92 @@ async def process_chat_payload(request, form_data, metadata, user, model):
 async def process_chat_response(
     request, response, form_data, user, events, metadata, tasks
 ):
-    if not isinstance(response, StreamingResponse):
-        return response
+    async def background_tasks_handler():
+        message_map = Chats.get_messages_by_chat_id(metadata["chat_id"])
+        message = message_map.get(metadata["message_id"])
+
+        if message:
+            messages = get_message_list(message_map, message.get("id"))
+
+            if tasks:
+                if TASKS.TITLE_GENERATION in tasks:
+                    if tasks[TASKS.TITLE_GENERATION]:
+                        res = await generate_title(
+                            request,
+                            {
+                                "model": message["model"],
+                                "messages": messages,
+                                "chat_id": metadata["chat_id"],
+                            },
+                            user,
+                        )
 
-    if not any(
-        content_type in response.headers["Content-Type"]
-        for content_type in ["text/event-stream", "application/x-ndjson"]
-    ):
-        return response
+                        if res and isinstance(res, dict):
+                            title = (
+                                res.get("choices", [])[0]
+                                .get("message", {})
+                                .get(
+                                    "content",
+                                    message.get("content", "New Chat"),
+                                )
+                            )
+
+                            Chats.update_chat_title_by_id(metadata["chat_id"], title)
+
+                            await event_emitter(
+                                {
+                                    "type": "chat:title",
+                                    "data": title,
+                                }
+                            )
+                    elif len(messages) == 2:
+                        title = messages[0].get("content", "New Chat")
+
+                        Chats.update_chat_title_by_id(metadata["chat_id"], title)
+
+                        await event_emitter(
+                            {
+                                "type": "chat:title",
+                                "data": message.get("content", "New Chat"),
+                            }
+                        )
+
+                if TASKS.TAGS_GENERATION in tasks and tasks[TASKS.TAGS_GENERATION]:
+                    res = await generate_chat_tags(
+                        request,
+                        {
+                            "model": message["model"],
+                            "messages": messages,
+                            "chat_id": metadata["chat_id"],
+                        },
+                        user,
+                    )
+
+                    if res and isinstance(res, dict):
+                        tags_string = (
+                            res.get("choices", [])[0]
+                            .get("message", {})
+                            .get("content", "")
+                        )
+
+                        tags_string = tags_string[
+                            tags_string.find("{") : tags_string.rfind("}") + 1
+                        ]
+
+                        try:
+                            tags = json.loads(tags_string).get("tags", [])
+                            Chats.update_chat_tags_by_id(
+                                metadata["chat_id"], tags, user
+                            )
+
+                            await event_emitter(
+                                {
+                                    "type": "chat:tags",
+                                    "data": tags,
+                                }
+                            )
+                        except Exception as e:
+                            print(f"Error: {e}")
 
     event_emitter = None
     if (
@@ -768,6 +846,79 @@ async def process_chat_response(
     ):
         event_emitter = get_event_emitter(metadata)
 
+    if not isinstance(response, StreamingResponse):
+        if event_emitter:
+
+            if "selected_model_id" in response:
+                Chats.upsert_message_to_chat_by_id_and_message_id(
+                    metadata["chat_id"],
+                    metadata["message_id"],
+                    {
+                        "selectedModelId": response["selected_model_id"],
+                    },
+                )
+
+            if response.get("choices", [])[0].get("message", {}).get("content"):
+                content = response["choices"][0]["message"]["content"]
+
+                if content:
+
+                    await event_emitter(
+                        {
+                            "type": "chat:completion",
+                            "data": response,
+                        }
+                    )
+
+                    title = Chats.get_chat_title_by_id(metadata["chat_id"])
+
+                    await event_emitter(
+                        {
+                            "type": "chat:completion",
+                            "data": {
+                                "done": True,
+                                "content": content,
+                                "title": title,
+                            },
+                        }
+                    )
+
+                    # Save message in the database
+                    Chats.upsert_message_to_chat_by_id_and_message_id(
+                        metadata["chat_id"],
+                        metadata["message_id"],
+                        {
+                            "content": content,
+                        },
+                    )
+
+                    # Send a webhook notification if the user is not active
+                    if get_user_id_from_session_pool(metadata["session_id"]) is None:
+                        webhook_url = Users.get_user_webhook_url_by_id(user.id)
+                        if webhook_url:
+                            post_webhook(
+                                webhook_url,
+                                f"{title} - {request.app.state.config.WEBUI_URL}/c/{metadata['chat_id']}\n\n{content}",
+                                {
+                                    "action": "chat",
+                                    "message": content,
+                                    "title": title,
+                                    "url": f"{request.app.state.config.WEBUI_URL}/c/{metadata['chat_id']}",
+                                },
+                            )
+
+                    await background_tasks_handler()
+
+            return response
+        else:
+            return response
+
+    if not any(
+        content_type in response.headers["Content-Type"]
+        for content_type in ["text/event-stream", "application/x-ndjson"]
+    ):
+        return response
+
     if event_emitter:
 
         task_id = str(uuid4())  # Create a unique task ID.
@@ -877,99 +1028,7 @@ async def process_chat_response(
                         }
                     )
 
-                message_map = Chats.get_messages_by_chat_id(metadata["chat_id"])
-                message = message_map.get(metadata["message_id"])
-
-                if message:
-                    messages = get_message_list(message_map, message.get("id"))
-
-                    if tasks:
-                        if TASKS.TITLE_GENERATION in tasks:
-                            if tasks[TASKS.TITLE_GENERATION]:
-                                res = await generate_title(
-                                    request,
-                                    {
-                                        "model": message["model"],
-                                        "messages": messages,
-                                        "chat_id": metadata["chat_id"],
-                                    },
-                                    user,
-                                )
-
-                                if res and isinstance(res, dict):
-                                    title = (
-                                        res.get("choices", [])[0]
-                                        .get("message", {})
-                                        .get(
-                                            "content",
-                                            message.get("content", "New Chat"),
-                                        )
-                                    )
-
-                                    Chats.update_chat_title_by_id(
-                                        metadata["chat_id"], title
-                                    )
-
-                                    await event_emitter(
-                                        {
-                                            "type": "chat:title",
-                                            "data": title,
-                                        }
-                                    )
-                            elif len(messages) == 2:
-                                title = messages[0].get("content", "New Chat")
-
-                                Chats.update_chat_title_by_id(
-                                    metadata["chat_id"], title
-                                )
-
-                                await event_emitter(
-                                    {
-                                        "type": "chat:title",
-                                        "data": message.get("content", "New Chat"),
-                                    }
-                                )
-
-                        if (
-                            TASKS.TAGS_GENERATION in tasks
-                            and tasks[TASKS.TAGS_GENERATION]
-                        ):
-                            res = await generate_chat_tags(
-                                request,
-                                {
-                                    "model": message["model"],
-                                    "messages": messages,
-                                    "chat_id": metadata["chat_id"],
-                                },
-                                user,
-                            )
-
-                            if res and isinstance(res, dict):
-                                tags_string = (
-                                    res.get("choices", [])[0]
-                                    .get("message", {})
-                                    .get("content", "")
-                                )
-
-                                tags_string = tags_string[
-                                    tags_string.find("{") : tags_string.rfind("}") + 1
-                                ]
-
-                                try:
-                                    tags = json.loads(tags_string).get("tags", [])
-                                    Chats.update_chat_tags_by_id(
-                                        metadata["chat_id"], tags, user
-                                    )
-
-                                    await event_emitter(
-                                        {
-                                            "type": "chat:tags",
-                                            "data": tags,
-                                        }
-                                    )
-                                except Exception as e:
-                                    print(f"Error: {e}")
-
+                await background_tasks_handler()
             except asyncio.CancelledError:
                 print("Task was cancelled!")
                 await event_emitter({"type": "task-cancelled"})

+ 31 - 27
src/lib/components/chat/Chat.svelte

@@ -1064,37 +1064,41 @@
 		}
 
 		if (choices) {
-			let value = choices[0]?.delta?.content ?? '';
-			if (message.content == '' && value == '\n') {
-				console.log('Empty response');
+			if (choices[0]?.message?.content) {
+				message.content += choices[0]?.message?.content;
 			} else {
-				message.content += value;
+				let value = choices[0]?.delta?.content ?? '';
+				if (message.content == '' && value == '\n') {
+					console.log('Empty response');
+				} else {
+					message.content += value;
 
-				if (navigator.vibrate && ($settings?.hapticFeedback ?? false)) {
-					navigator.vibrate(5);
-				}
+					if (navigator.vibrate && ($settings?.hapticFeedback ?? false)) {
+						navigator.vibrate(5);
+					}
 
-				// Emit chat event for TTS
-				const messageContentParts = getMessageContentParts(
-					message.content,
-					$config?.audio?.tts?.split_on ?? 'punctuation'
-				);
-				messageContentParts.pop();
-
-				// dispatch only last sentence and make sure it hasn't been dispatched before
-				if (
-					messageContentParts.length > 0 &&
-					messageContentParts[messageContentParts.length - 1] !== message.lastSentence
-				) {
-					message.lastSentence = messageContentParts[messageContentParts.length - 1];
-					eventTarget.dispatchEvent(
-						new CustomEvent('chat', {
-							detail: {
-								id: message.id,
-								content: messageContentParts[messageContentParts.length - 1]
-							}
-						})
+					// Emit chat event for TTS
+					const messageContentParts = getMessageContentParts(
+						message.content,
+						$config?.audio?.tts?.split_on ?? 'punctuation'
 					);
+					messageContentParts.pop();
+
+					// dispatch only last sentence and make sure it hasn't been dispatched before
+					if (
+						messageContentParts.length > 0 &&
+						messageContentParts[messageContentParts.length - 1] !== message.lastSentence
+					) {
+						message.lastSentence = messageContentParts[messageContentParts.length - 1];
+						eventTarget.dispatchEvent(
+							new CustomEvent('chat', {
+								detail: {
+									id: message.id,
+									content: messageContentParts[messageContentParts.length - 1]
+								}
+							})
+						);
+					}
 				}
 			}
 		}