Browse Source

refac: chat requests

Timothy Jaeryang Baek 4 tháng trước cách đây
mục cha
commit
2be9e55545

+ 23 - 5
backend/open_webui/main.py

@@ -30,7 +30,9 @@ from fastapi import (
     UploadFile,
     status,
     applications,
+    BackgroundTasks,
 )
+
 from fastapi.openapi.docs import get_swagger_ui_html
 
 from fastapi.middleware.cors import CORSMiddleware
@@ -295,6 +297,7 @@ from open_webui.utils.auth import (
 from open_webui.utils.oauth import oauth_manager
 from open_webui.utils.security_headers import SecurityHeadersMiddleware
 
+from open_webui.tasks import stop_task, list_tasks  # Import from tasks.py
 
 if SAFE_MODE:
     print("SAFE MODE ENABLED")
@@ -822,11 +825,11 @@ async def chat_completion(
     request: Request,
     form_data: dict,
     user=Depends(get_verified_user),
-    bypass_filter: bool = False,
 ):
     if not request.app.state.MODELS:
         await get_all_models(request)
 
+    tasks = form_data.pop("background_tasks", None)
     try:
         model_id = form_data.get("model", None)
         if model_id not in request.app.state.MODELS:
@@ -834,13 +837,14 @@ async def chat_completion(
         model = request.app.state.MODELS[model_id]
 
         # Check if user has access to the model
-        if not bypass_filter and user.role == "user":
+        if not BYPASS_MODEL_ACCESS_CONTROL and user.role == "user":
             try:
                 check_model_access(user, model)
             except Exception as e:
                 raise e
 
         metadata = {
+            "user_id": user.id,
             "chat_id": form_data.pop("chat_id", None),
             "message_id": form_data.pop("id", None),
             "session_id": form_data.pop("session_id", None),
@@ -859,10 +863,10 @@ async def chat_completion(
         )
 
     try:
-        response = await chat_completion_handler(
-            request, form_data, user, bypass_filter
+        response = await chat_completion_handler(request, form_data, user)
+        return await process_chat_response(
+            request, response, user, events, metadata, tasks
         )
-        return await process_chat_response(response, events, metadata)
     except Exception as e:
         raise HTTPException(
             status_code=status.HTTP_400_BAD_REQUEST,
@@ -901,6 +905,20 @@ async def chat_action(
         )
 
 
+@app.post("/api/tasks/stop/{task_id}")
+async def stop_task_endpoint(task_id: str, user=Depends(get_verified_user)):
+    try:
+        result = await stop_task(task_id)  # Use the function from tasks.py
+        return result
+    except ValueError as e:
+        raise HTTPException(status_code=status.HTTP_404_NOT_FOUND, detail=str(e))
+
+
+@app.get("/api/tasks")
+async def list_tasks_endpoint(user=Depends(get_verified_user)):
+    return {"tasks": list_tasks()}  # Use the function from tasks.py
+
+
 ##################################
 #
 # Config Endpoints

+ 60 - 0
backend/open_webui/models/chats.py

@@ -168,6 +168,66 @@ class ChatTable:
         except Exception:
             return None
 
+    def update_chat_title_by_id(self, id: str, title: str) -> Optional[ChatModel]:
+        chat = self.get_chat_by_id(id)
+        if chat is None:
+            return None
+
+        chat = chat.chat
+        chat["title"] = title
+
+        return self.update_chat_by_id(id, chat)
+
+    def update_chat_tags_by_id(
+        self, id: str, tags: list[str], user
+    ) -> Optional[ChatModel]:
+        chat = self.get_chat_by_id(id)
+        if chat is None:
+            return None
+
+        self.delete_all_tags_by_id_and_user_id(id, user.id)
+
+        for tag in chat.meta.get("tags", []):
+            if self.count_chats_by_tag_name_and_user_id(tag, user.id) == 0:
+                Tags.delete_tag_by_name_and_user_id(tag, user.id)
+
+        for tag_name in tags:
+            if tag_name.lower() == "none":
+                continue
+
+            self.add_chat_tag_by_id_and_user_id_and_tag_name(id, user.id, tag_name)
+        return self.get_chat_by_id(id)
+
+    def get_messages_by_chat_id(self, id: str) -> Optional[dict]:
+        chat = self.get_chat_by_id(id)
+        if chat is None:
+            return None
+
+        return chat.chat.get("history", {}).get("messages", {}) or {}
+
+    def upsert_message_to_chat_by_id_and_message_id(
+        self, id: str, message_id: str, message: dict
+    ) -> Optional[ChatModel]:
+        chat = self.get_chat_by_id(id)
+        if chat is None:
+            return None
+
+        chat = chat.chat
+        history = chat.get("history", {})
+
+        if message_id in history.get("messages", {}):
+            history["messages"][message_id] = {
+                **history["messages"][message_id],
+                **message,
+            }
+        else:
+            history["messages"][message_id] = message
+
+        history["currentId"] = message_id
+
+        chat["history"] = history
+        return self.update_chat_by_id(id, chat)
+
     def insert_shared_chat_by_chat_id(self, chat_id: str) -> Optional[ChatModel]:
         with get_db() as db:
             # Get the existing chat to share

+ 10 - 8
backend/open_webui/routers/ollama.py

@@ -82,6 +82,16 @@ async def send_get_request(url, key=None):
         return None
 
 
+async def cleanup_response(
+    response: Optional[aiohttp.ClientResponse],
+    session: Optional[aiohttp.ClientSession],
+):
+    if response:
+        response.close()
+    if session:
+        await session.close()
+
+
 async def send_post_request(
     url: str,
     payload: Union[str, bytes],
@@ -89,14 +99,6 @@ async def send_post_request(
     key: Optional[str] = None,
     content_type: Optional[str] = None,
 ):
-    async def cleanup_response(
-        response: Optional[aiohttp.ClientResponse],
-        session: Optional[aiohttp.ClientSession],
-    ):
-        if response:
-            response.close()
-        if session:
-            await session.close()
 
     r = None
     try:

+ 13 - 9
backend/open_webui/socket/main.py

@@ -217,15 +217,19 @@ async def disconnect(sid):
 
 def get_event_emitter(request_info):
     async def __event_emitter__(event_data):
-        await sio.emit(
-            "chat-events",
-            {
-                "chat_id": request_info["chat_id"],
-                "message_id": request_info["message_id"],
-                "data": event_data,
-            },
-            to=request_info["session_id"],
-        )
+        user_id = request_info["user_id"]
+        session_ids = USER_POOL.get(user_id, [])
+
+        for session_id in session_ids:
+            await sio.emit(
+                "chat-events",
+                {
+                    "chat_id": request_info["chat_id"],
+                    "message_id": request_info["message_id"],
+                    "data": event_data,
+                },
+                to=session_id,
+            )
 
     return __event_emitter__
 

+ 61 - 0
backend/open_webui/tasks.py

@@ -0,0 +1,61 @@
+# tasks.py
+import asyncio
+from typing import Dict
+from uuid import uuid4
+
+# A dictionary to keep track of active tasks
+tasks: Dict[str, asyncio.Task] = {}
+
+
+def cleanup_task(task_id: str):
+    """
+    Remove a completed or canceled task from the global `tasks` dictionary.
+    """
+    tasks.pop(task_id, None)  # Remove the task if it exists
+
+
+def create_task(coroutine):
+    """
+    Create a new asyncio task and add it to the global task dictionary.
+    """
+    task_id = str(uuid4())  # Generate a unique ID for the task
+    task = asyncio.create_task(coroutine)  # Create the task
+
+    # Add a done callback for cleanup
+    task.add_done_callback(lambda t: cleanup_task(task_id))
+
+    tasks[task_id] = task
+    return task_id, task
+
+
+def get_task(task_id: str):
+    """
+    Retrieve a task by its task ID.
+    """
+    return tasks.get(task_id)
+
+
+def list_tasks():
+    """
+    List all currently active task IDs.
+    """
+    return list(tasks.keys())
+
+
+async def stop_task(task_id: str):
+    """
+    Cancel a running task and remove it from the global task list.
+    """
+    task = tasks.get(task_id)
+    if not task:
+        raise ValueError(f"Task with ID {task_id} not found.")
+
+    task.cancel()  # Request task cancellation
+    try:
+        await task  # Wait for the task to handle the cancellation
+    except asyncio.CancelledError:
+        # Task successfully canceled
+        tasks.pop(task_id, None)  # Remove it from the dictionary
+        return {"status": True, "message": f"Task {task_id} successfully stopped."}
+
+    return {"status": False, "message": f"Failed to stop task {task_id}."}

+ 4 - 1
backend/open_webui/utils/chat.py

@@ -117,7 +117,9 @@ async def generate_chat_completion(
                 form_data, user, bypass_filter=True
             )
             return StreamingResponse(
-                stream_wrapper(response.body_iterator), media_type="text/event-stream"
+                stream_wrapper(response.body_iterator),
+                media_type="text/event-stream",
+                background=response.background,
             )
         else:
             return {
@@ -141,6 +143,7 @@ async def generate_chat_completion(
             return StreamingResponse(
                 convert_streaming_response_ollama_to_openai(response),
                 headers=dict(response.headers),
+                background=response.background,
             )
         else:
             return convert_response_ollama_to_openai(response)

+ 181 - 18
backend/open_webui/utils/middleware.py

@@ -2,21 +2,31 @@ import time
 import logging
 import sys
 
+import asyncio
 from aiocache import cached
 from typing import Any, Optional
 import random
 import json
 import inspect
+from uuid import uuid4
+
 
 from fastapi import Request
+from fastapi import BackgroundTasks
+
 from starlette.responses import Response, StreamingResponse
 
 
+from open_webui.models.chats import Chats
 from open_webui.socket.main import (
     get_event_call,
     get_event_emitter,
 )
-from open_webui.routers.tasks import generate_queries
+from open_webui.routers.tasks import (
+    generate_queries,
+    generate_title,
+    generate_chat_tags,
+)
 
 
 from open_webui.models.users import UserModel
@@ -33,6 +43,7 @@ from open_webui.utils.task import (
     tools_function_calling_generation_template,
 )
 from open_webui.utils.misc import (
+    get_message_list,
     add_or_update_system_message,
     get_last_user_message,
     prepend_to_first_user_message_content,
@@ -41,6 +52,8 @@ from open_webui.utils.tools import get_tools
 from open_webui.utils.plugin import load_function_module_by_id
 
 
+from open_webui.tasks import create_task
+
 from open_webui.config import DEFAULT_TOOLS_FUNCTION_CALLING_PROMPT_TEMPLATE
 from open_webui.env import SRC_LOG_LEVELS, GLOBAL_LOG_LEVEL, BYPASS_MODEL_ACCESS_CONTROL
 from open_webui.constants import TASKS
@@ -504,28 +517,178 @@ async def process_chat_payload(request, form_data, metadata, user, model):
     return form_data, events
 
 
-async def process_chat_response(response, events, metadata):
+async def process_chat_response(request, response, user, events, metadata, tasks):
     if not isinstance(response, StreamingResponse):
         return response
 
-    content_type = response.headers["Content-Type"]
-    is_openai = "text/event-stream" in content_type
-    is_ollama = "application/x-ndjson" in content_type
-
-    if not is_openai and not is_ollama:
+    if not any(
+        content_type in response.headers["Content-Type"]
+        for content_type in ["text/event-stream", "application/x-ndjson"]
+    ):
         return response
 
-    async def stream_wrapper(original_generator, events):
-        def wrap_item(item):
-            return f"data: {item}\n\n" if is_openai else f"{item}\n"
+    event_emitter = None
+    if "session_id" in metadata:
+        event_emitter = get_event_emitter(metadata)
 
-        for event in events:
-            yield wrap_item(json.dumps(event))
+    if event_emitter:
 
-        async for data in original_generator:
-            yield data
+        task_id = str(uuid4())  # Create a unique task ID.
 
-    return StreamingResponse(
-        stream_wrapper(response.body_iterator, events),
-        headers=dict(response.headers),
-    )
+        # Handle as a background task
+        async def post_response_handler(response, events):
+            try:
+                for event in events:
+                    await event_emitter(
+                        {
+                            "type": "chat-completion",
+                            "data": event,
+                        }
+                    )
+
+                content = ""
+                async for line in response.body_iterator:
+                    line = line.decode("utf-8") if isinstance(line, bytes) else line
+                    data = line
+
+                    # Skip empty lines
+                    if not data.strip():
+                        continue
+
+                    # "data: " is the prefix for each event
+                    if not data.startswith("data: "):
+                        continue
+
+                    # Remove the prefix
+                    data = data[len("data: ") :]
+
+                    try:
+                        data = json.loads(data)
+                        value = (
+                            data.get("choices", [])[0].get("delta", {}).get("content")
+                        )
+
+                        if value:
+                            content = f"{content}{value}"
+
+                            # Save message in the database
+                            Chats.upsert_message_to_chat_by_id_and_message_id(
+                                metadata["chat_id"],
+                                metadata["message_id"],
+                                {
+                                    "content": content,
+                                },
+                            )
+
+                    except Exception as e:
+                        done = "data: [DONE]" in line
+
+                        if done:
+                            data = {"done": True}
+                        else:
+                            continue
+
+                    await event_emitter(
+                        {
+                            "type": "chat-completion",
+                            "data": data,
+                        }
+                    )
+
+                message_map = Chats.get_messages_by_chat_id(metadata["chat_id"])
+                message = message_map.get(metadata["message_id"])
+
+                if message:
+                    messages = get_message_list(message_map, message.get("id"))
+
+                    if TASKS.TITLE_GENERATION in tasks:
+                        res = await generate_title(
+                            request,
+                            {
+                                "model": message["model"],
+                                "messages": messages,
+                                "chat_id": metadata["chat_id"],
+                            },
+                            user,
+                        )
+
+                        if res:
+                            title = (
+                                res.get("choices", [])[0]
+                                .get("message", {})
+                                .get("content", message.get("content", "New Chat"))
+                            )
+
+                            Chats.update_chat_title_by_id(metadata["chat_id"], title)
+
+                            await event_emitter(
+                                {
+                                    "type": "chat-title",
+                                    "data": title,
+                                }
+                            )
+
+                    if TASKS.TAGS_GENERATION in tasks:
+                        res = await generate_chat_tags(
+                            request,
+                            {
+                                "model": message["model"],
+                                "messages": messages,
+                                "chat_id": metadata["chat_id"],
+                            },
+                            user,
+                        )
+
+                        if res:
+                            tags_string = (
+                                res.get("choices", [])[0]
+                                .get("message", {})
+                                .get("content", "")
+                            )
+
+                            tags_string = tags_string[
+                                tags_string.find("{") : tags_string.rfind("}") + 1
+                            ]
+
+                            try:
+                                tags = json.loads(tags_string).get("tags", [])
+                                Chats.update_chat_tags_by_id(
+                                    metadata["chat_id"], tags, user
+                                )
+
+                                await event_emitter(
+                                    {
+                                        "type": "chat-tags",
+                                        "data": tags,
+                                    }
+                                )
+                            except Exception as e:
+                                print(f"Error: {e}")
+
+            except asyncio.CancelledError:
+                print("Task was cancelled!")
+                await event_emitter({"type": "task-cancelled"})
+
+            if response.background is not None:
+                await response.background()
+
+        # background_tasks.add_task(post_response_handler, response, events)
+        task_id, _ = create_task(post_response_handler(response, events))
+        return {"status": True, "task_id": task_id}
+
+    else:
+        # Fallback to the original response
+        async def stream_wrapper(original_generator, events):
+            def wrap_item(item):
+                return f"data: {item}\n\n"
+
+            for event in events:
+                yield wrap_item(json.dumps(event))
+
+            async for data in original_generator:
+                yield data
+
+        return StreamingResponse(
+            stream_wrapper(response.body_iterator, events),
+            headers=dict(response.headers),
+        )

+ 28 - 0
backend/open_webui/utils/misc.py

@@ -7,6 +7,34 @@ from pathlib import Path
 from typing import Callable, Optional
 
 
+def get_message_list(messages, message_id):
+    """
+    Reconstructs a list of messages in order up to the specified message_id.
+
+    :param message_id: ID of the message to reconstruct the chain
+    :param messages: Message history dict containing all messages
+    :return: List of ordered messages starting from the root to the given message
+    """
+
+    # Find the message by its id
+    current_message = messages.get(message_id)
+
+    if not current_message:
+        return f"Message ID {message_id} not found in the history."
+
+    # Reconstruct the chain by following the parentId links
+    message_list = []
+
+    while current_message:
+        message_list.insert(
+            0, current_message
+        )  # Insert the message at the beginning of the list
+        parent_id = current_message["parentId"]
+        current_message = messages.get(parent_id) if parent_id else None
+
+    return message_list
+
+
 def get_messages_content(messages: list[dict]) -> str:
     return "\n".join(
         [

+ 36 - 0
src/lib/apis/index.ts

@@ -107,6 +107,42 @@ export const chatAction = async (token: string, action_id: string, body: ChatAct
 	return res;
 };
 
+
+
+export const stopTask = async (token: string, id: string) => {
+	let error = null;
+
+	const res = await fetch(`${WEBUI_BASE_URL}/api/tasks/stop/${id}`, {
+		method: 'POST',
+		headers: {
+			Accept: 'application/json',
+			'Content-Type': 'application/json',
+			...(token && { authorization: `Bearer ${token}` })
+		}
+	})
+		.then(async (res) => {
+			if (!res.ok) throw await res.json();
+			return res.json();
+		})
+		.catch((err) => {
+			console.log(err);
+			if ('detail' in err) {
+				error = err.detail;
+			} else {
+				error = err;
+			}
+			return null;
+		});
+
+	if (error) {
+		throw error;
+	}
+
+	return res;
+};
+
+
+
 export const getTaskConfig = async (token: string = '') => {
 	let error = null;
 

+ 9 - 8
src/lib/apis/openai/index.ts

@@ -277,29 +277,30 @@ export const generateOpenAIChatCompletion = async (
 	token: string = '',
 	body: object,
 	url: string = OPENAI_API_BASE_URL
-): Promise<[Response | null, AbortController]> => {
-	const controller = new AbortController();
+) => {
 	let error = null;
 
 	const res = await fetch(`${url}/chat/completions`, {
-		signal: controller.signal,
 		method: 'POST',
 		headers: {
 			Authorization: `Bearer ${token}`,
 			'Content-Type': 'application/json'
 		},
 		body: JSON.stringify(body)
-	}).catch((err) => {
-		console.log(err);
-		error = err;
-		return null;
+	}).then(async (res) => {
+		if (!res.ok) throw await res.json();
+		return res.json();
+	})
+	.catch((err) => {
+		error = `OpenAI: ${err?.detail ?? 'Network Problem'}`;
+		return null; 
 	});
 
 	if (error) {
 		throw error;
 	}
 
-	return [res, controller];
+	return res;
 };
 
 export const synthesizeOpenAISpeech = async (

+ 322 - 370
src/lib/components/chat/Chat.svelte

@@ -69,7 +69,8 @@
 		generateQueries,
 		chatAction,
 		generateMoACompletion,
-		generateTags
+		generateTags,
+		stopTask
 	} from '$lib/apis';
 
 	import Banner from '../common/Banner.svelte';
@@ -88,7 +89,6 @@
 	let controlPane;
 	let controlPaneComponent;
 
-	let stopResponseFlag = false;
 	let autoScroll = true;
 	let processing = '';
 	let messagesContainerElement: HTMLDivElement;
@@ -121,6 +121,8 @@
 		currentId: null
 	};
 
+	let taskId = null;
+
 	// Chat Input
 	let prompt = '';
 	let chatFiles = [];
@@ -202,95 +204,107 @@
 	};
 
 	const chatEventHandler = async (event, cb) => {
+		console.log(event);
+
 		if (event.chat_id === $chatId) {
 			await tick();
-			console.log(event);
 			let message = history.messages[event.message_id];
 
-			const type = event?.data?.type ?? null;
-			const data = event?.data?.data ?? null;
+			if (message) {
+				const type = event?.data?.type ?? null;
+				const data = event?.data?.data ?? null;
 
-			if (type === 'status') {
-				if (message?.statusHistory) {
-					message.statusHistory.push(data);
-				} else {
-					message.statusHistory = [data];
-				}
-			} else if (type === 'source' || type === 'citation') {
-				if (data?.type === 'code_execution') {
-					// Code execution; update existing code execution by ID, or add new one.
-					if (!message?.code_executions) {
-						message.code_executions = [];
+				if (type === 'status') {
+					if (message?.statusHistory) {
+						message.statusHistory.push(data);
+					} else {
+						message.statusHistory = [data];
 					}
+				} else if (type === 'source' || type === 'citation') {
+					if (data?.type === 'code_execution') {
+						// Code execution; update existing code execution by ID, or add new one.
+						if (!message?.code_executions) {
+							message.code_executions = [];
+						}
 
-					const existingCodeExecutionIndex = message.code_executions.findIndex(
-						(execution) => execution.id === data.id
-					);
+						const existingCodeExecutionIndex = message.code_executions.findIndex(
+							(execution) => execution.id === data.id
+						);
 
-					if (existingCodeExecutionIndex !== -1) {
-						message.code_executions[existingCodeExecutionIndex] = data;
-					} else {
-						message.code_executions.push(data);
-					}
+						if (existingCodeExecutionIndex !== -1) {
+							message.code_executions[existingCodeExecutionIndex] = data;
+						} else {
+							message.code_executions.push(data);
+						}
 
-					message.code_executions = message.code_executions;
-				} else {
-					// Regular source.
-					if (message?.sources) {
-						message.sources.push(data);
+						message.code_executions = message.code_executions;
 					} else {
-						message.sources = [data];
+						// Regular source.
+						if (message?.sources) {
+							message.sources.push(data);
+						} else {
+							message.sources = [data];
+						}
 					}
-				}
-			} else if (type === 'message') {
-				message.content += data.content;
-			} else if (type === 'replace') {
-				message.content = data.content;
-			} else if (type === 'action') {
-				if (data.action === 'continue') {
-					const continueButton = document.getElementById('continue-response-button');
-
-					if (continueButton) {
-						continueButton.click();
+				} else if (type === 'chat-completion') {
+					chatCompletionEventHandler(data, message, event.chat_id);
+				} else if (type === 'chat-title') {
+					chatTitle.set(data);
+					currentChatPage.set(1);
+					await chats.set(await getChatList(localStorage.token, $currentChatPage));
+				} else if (type === 'chat-tags') {
+					chat = await getChatById(localStorage.token, $chatId);
+					allTags.set(await getAllTags(localStorage.token));
+				} else if (type === 'message') {
+					message.content += data.content;
+				} else if (type === 'replace') {
+					message.content = data.content;
+				} else if (type === 'action') {
+					if (data.action === 'continue') {
+						const continueButton = document.getElementById('continue-response-button');
+
+						if (continueButton) {
+							continueButton.click();
+						}
 					}
-				}
-			} else if (type === 'confirmation') {
-				eventCallback = cb;
+				} else if (type === 'confirmation') {
+					eventCallback = cb;
 
-				eventConfirmationInput = false;
-				showEventConfirmation = true;
+					eventConfirmationInput = false;
+					showEventConfirmation = true;
 
-				eventConfirmationTitle = data.title;
-				eventConfirmationMessage = data.message;
-			} else if (type === 'execute') {
-				eventCallback = cb;
+					eventConfirmationTitle = data.title;
+					eventConfirmationMessage = data.message;
+				} else if (type === 'execute') {
+					eventCallback = cb;
 
-				try {
-					// Use Function constructor to evaluate code in a safer way
-					const asyncFunction = new Function(`return (async () => { ${data.code} })()`);
-					const result = await asyncFunction(); // Await the result of the async function
+					try {
+						// Use Function constructor to evaluate code in a safer way
+						const asyncFunction = new Function(`return (async () => { ${data.code} })()`);
+						const result = await asyncFunction(); // Await the result of the async function
 
-					if (cb) {
-						cb(result);
+						if (cb) {
+							cb(result);
+						}
+					} catch (error) {
+						console.error('Error executing code:', error);
 					}
-				} catch (error) {
-					console.error('Error executing code:', error);
-				}
-			} else if (type === 'input') {
-				eventCallback = cb;
+				} else if (type === 'input') {
+					eventCallback = cb;
 
-				eventConfirmationInput = true;
-				showEventConfirmation = true;
+					eventConfirmationInput = true;
+					showEventConfirmation = true;
 
-				eventConfirmationTitle = data.title;
-				eventConfirmationMessage = data.message;
-				eventConfirmationInputPlaceholder = data.placeholder;
-				eventConfirmationInputValue = data?.value ?? '';
-			} else {
-				console.log('Unknown message type', data);
-			}
+					eventConfirmationTitle = data.title;
+					eventConfirmationMessage = data.message;
+					eventConfirmationInputPlaceholder = data.placeholder;
+					eventConfirmationInputValue = data?.value ?? '';
+				} else {
+					console.log('Unknown message type', data);
+				}
 
-			history.messages[event.message_id] = message;
+				history.messages[event.message_id] = message;
+			}
 		}
 	};
 
@@ -956,6 +970,119 @@
 		}
 	};
 
+	const chatCompletionEventHandler = async (data, message, chatId) => {
+		const { id, done, choices, sources, selectedModelId, error, usage } = data;
+
+		if (error) {
+			await handleOpenAIError(error, message);
+		}
+
+		if (sources) {
+			message.sources = sources;
+			// Only remove status if it was initially set
+			if (model?.info?.meta?.knowledge ?? false) {
+				message.statusHistory = message.statusHistory.filter(
+					(status) => status.action !== 'knowledge_search'
+				);
+			}
+		}
+
+		if (choices) {
+			const value = choices[0]?.delta?.content ?? '';
+			if (message.content == '' && value == '\n') {
+				console.log('Empty response');
+			} else {
+				message.content += value;
+
+				if (navigator.vibrate && ($settings?.hapticFeedback ?? false)) {
+					navigator.vibrate(5);
+				}
+
+				// Emit chat event for TTS
+				const messageContentParts = getMessageContentParts(
+					message.content,
+					$config?.audio?.tts?.split_on ?? 'punctuation'
+				);
+				messageContentParts.pop();
+				// dispatch only last sentence and make sure it hasn't been dispatched before
+				if (
+					messageContentParts.length > 0 &&
+					messageContentParts[messageContentParts.length - 1] !== message.lastSentence
+				) {
+					message.lastSentence = messageContentParts[messageContentParts.length - 1];
+					eventTarget.dispatchEvent(
+						new CustomEvent('chat', {
+							detail: {
+								id: message.id,
+								content: messageContentParts[messageContentParts.length - 1]
+							}
+						})
+					);
+				}
+			}
+		}
+
+		if (selectedModelId) {
+			message.selectedModelId = selectedModelId;
+			message.arena = true;
+		}
+
+		if (usage) {
+			message.usage = usage;
+		}
+
+		if (done) {
+			message.done = true;
+
+			if ($settings.notificationEnabled && !document.hasFocus()) {
+				new Notification(`${message.model}`, {
+					body: message.content,
+					icon: `${WEBUI_BASE_URL}/static/favicon.png`
+				});
+			}
+
+			if ($settings.responseAutoCopy) {
+				copyToClipboard(message.content);
+			}
+
+			if ($settings.responseAutoPlayback && !$showCallOverlay) {
+				await tick();
+				document.getElementById(`speak-button-${message.id}`)?.click();
+			}
+
+			// Emit chat event for TTS
+			let lastMessageContentPart =
+				getMessageContentParts(message.content, $config?.audio?.tts?.split_on ?? 'punctuation')?.at(
+					-1
+				) ?? '';
+			if (lastMessageContentPart) {
+				eventTarget.dispatchEvent(
+					new CustomEvent('chat', {
+						detail: { id: message.id, content: lastMessageContentPart }
+					})
+				);
+			}
+			eventTarget.dispatchEvent(
+				new CustomEvent('chat:finish', {
+					detail: {
+						id: message.id,
+						content: message.content
+					}
+				})
+			);
+
+			history.messages[message.id] = message;
+			await chatCompletedHandler(chatId, message.model, message.id, createMessagesList(message.id));
+		}
+
+		history.messages[message.id] = message;
+
+		console.log(data);
+		if (autoScroll) {
+			scrollToBottom();
+		}
+	};
+
 	//////////////////////////
 	// Chat functions
 	//////////////////////////
@@ -1061,6 +1188,7 @@
 		chatInput?.focus();
 
 		saveSessionSelectedModels();
+
 		await sendPrompt(userPrompt, userMessageId, { newChat: true });
 	};
 
@@ -1076,6 +1204,8 @@
 			history.messages[history.currentId].role === 'user'
 		) {
 			await initChatHandler();
+		} else {
+			await saveChatHandler($chatId);
 		}
 
 		// If modelId is provided, use it, else use selected model
@@ -1122,6 +1252,9 @@
 		}
 		await tick();
 
+		// Save chat after all messages have been created
+		await saveChatHandler($chatId);
+
 		const _chatId = JSON.parse(JSON.stringify($chatId));
 		await Promise.all(
 			selectedModelIds.map(async (modelId, _modelIdx) => {
@@ -1178,7 +1311,7 @@
 						await getWebSearchResults(model.id, parentId, responseMessageId);
 					}
 
-					await sendPromptOpenAI(model, prompt, responseMessageId, _chatId);
+					await sendPromptSocket(model, responseMessageId, _chatId);
 					if (chatEventEmitter) clearInterval(chatEventEmitter);
 				} else {
 					toast.error($i18n.t(`Model {{modelId}} not found`, { modelId }));
@@ -1190,9 +1323,7 @@
 		chats.set(await getChatList(localStorage.token, $currentChatPage));
 	};
 
-	const sendPromptOpenAI = async (model, userPrompt, responseMessageId, _chatId) => {
-		let _response = null;
-
+	const sendPromptSocket = async (model, responseMessageId, _chatId) => {
 		const responseMessage = history.messages[responseMessageId];
 		const userMessage = history.messages[responseMessage.parentId];
 
@@ -1243,7 +1374,6 @@
 		);
 
 		scrollToBottom();
-
 		eventTarget.dispatchEvent(
 			new CustomEvent('chat:start', {
 				detail: {
@@ -1253,278 +1383,133 @@
 		);
 		await tick();
 
-		try {
-			const stream =
-				model?.info?.params?.stream_response ??
-				$settings?.params?.stream_response ??
-				params?.stream_response ??
-				true;
-
-			const [res, controller] = await generateOpenAIChatCompletion(
-				localStorage.token,
-				{
-					stream: stream,
-					model: model.id,
-					messages: [
-						params?.system || $settings.system || (responseMessage?.userContext ?? null)
-							? {
-									role: 'system',
-									content: `${promptTemplate(
-										params?.system ?? $settings?.system ?? '',
-										$user.name,
-										$settings?.userLocation
-											? await getAndUpdateUserLocation(localStorage.token)
-											: undefined
-									)}${
-										(responseMessage?.userContext ?? null)
-											? `\n\nUser Context:\n${responseMessage?.userContext ?? ''}`
-											: ''
-									}`
-								}
-							: undefined,
-						...createMessagesList(responseMessageId)
-					]
-						.filter((message) => message?.content?.trim())
-						.map((message, idx, arr) => ({
-							role: message.role,
-							...((message.files?.filter((file) => file.type === 'image').length > 0 ?? false) &&
-							message.role === 'user'
-								? {
-										content: [
-											{
-												type: 'text',
-												text: message?.merged?.content ?? message.content
-											},
-											...message.files
-												.filter((file) => file.type === 'image')
-												.map((file) => ({
-													type: 'image_url',
-													image_url: {
-														url: file.url
-													}
-												}))
-										]
-									}
-								: {
-										content: message?.merged?.content ?? message.content
-									})
-						})),
-
-					params: {
-						...$settings?.params,
-						...params,
-
-						format: $settings.requestFormat ?? undefined,
-						keep_alive: $settings.keepAlive ?? undefined,
-						stop:
-							(params?.stop ?? $settings?.params?.stop ?? undefined)
-								? (
-										params?.stop.split(',').map((token) => token.trim()) ?? $settings.params.stop
-									).map((str) =>
-										decodeURIComponent(JSON.parse('"' + str.replace(/\"/g, '\\"') + '"'))
-									)
+		const stream =
+			model?.info?.params?.stream_response ??
+			$settings?.params?.stream_response ??
+			params?.stream_response ??
+			true;
+
+		const messages = [
+			params?.system || $settings.system || (responseMessage?.userContext ?? null)
+				? {
+						role: 'system',
+						content: `${promptTemplate(
+							params?.system ?? $settings?.system ?? '',
+							$user.name,
+							$settings?.userLocation
+								? await getAndUpdateUserLocation(localStorage.token)
 								: undefined
-					},
-
-					tool_ids: selectedToolIds.length > 0 ? selectedToolIds : undefined,
-					files: files.length > 0 ? files : undefined,
-					session_id: $socket?.id,
-					chat_id: $chatId,
-					id: responseMessageId,
-
-					...(stream && (model.info?.meta?.capabilities?.usage ?? false)
-						? {
-								stream_options: {
-									include_usage: true
-								}
-							}
-						: {})
-				},
-				`${WEBUI_BASE_URL}/api`
-			);
-
-			// Wait until history/message have been updated
-			await tick();
-
-			scrollToBottom();
-
-			if (res && res.ok && res.body) {
-				if (!stream) {
-					const response = await res.json();
-					console.log(response);
-
-					responseMessage.content = response.choices[0].message.content;
-					responseMessage.info = { ...response.usage, openai: true };
-					responseMessage.done = true;
-				} else {
-					const textStream = await createOpenAITextStream(res.body, $settings.splitLargeChunks);
-
-					for await (const update of textStream) {
-						const { value, done, sources, selectedModelId, error, usage } = update;
-						if (error) {
-							await handleOpenAIError(error, null, model, responseMessage);
-							break;
-						}
-
-						if (done || stopResponseFlag || _chatId !== $chatId) {
-							responseMessage.done = true;
-							history.messages[responseMessageId] = responseMessage;
-
-							if (stopResponseFlag) {
-								controller.abort('User: Stop Response');
-							}
-							_response = responseMessage.content;
-							break;
+						)}${
+							(responseMessage?.userContext ?? null)
+								? `\n\nUser Context:\n${responseMessage?.userContext ?? ''}`
+								: ''
+						}`
+					}
+				: undefined,
+			...createMessagesList(responseMessageId)
+		]
+			.filter((message) => message?.content?.trim())
+			.map((message, idx, arr) => ({
+				role: message.role,
+				...((message.files?.filter((file) => file.type === 'image').length > 0 ?? false) &&
+				message.role === 'user'
+					? {
+							content: [
+								{
+									type: 'text',
+									text: message?.merged?.content ?? message.content
+								},
+								...message.files
+									.filter((file) => file.type === 'image')
+									.map((file) => ({
+										type: 'image_url',
+										image_url: {
+											url: file.url
+										}
+									}))
+							]
 						}
+					: {
+							content: message?.merged?.content ?? message.content
+						})
+			}));
 
-						if (usage) {
-							responseMessage.usage = usage;
-						}
+		const res = await generateOpenAIChatCompletion(
+			localStorage.token,
+			{
+				stream: stream,
+				model: model.id,
+				messages: messages,
+				params: {
+					...$settings?.params,
+					...params,
+
+					format: $settings.requestFormat ?? undefined,
+					keep_alive: $settings.keepAlive ?? undefined,
+					stop:
+						(params?.stop ?? $settings?.params?.stop ?? undefined)
+							? (params?.stop.split(',').map((token) => token.trim()) ?? $settings.params.stop).map(
+									(str) => decodeURIComponent(JSON.parse('"' + str.replace(/\"/g, '\\"') + '"'))
+								)
+							: undefined
+				},
 
-						if (selectedModelId) {
-							responseMessage.selectedModelId = selectedModelId;
-							responseMessage.arena = true;
-							continue;
-						}
+				tool_ids: selectedToolIds.length > 0 ? selectedToolIds : undefined,
+				files: files.length > 0 ? files : undefined,
+				session_id: $socket?.id,
+				chat_id: $chatId,
+				id: responseMessageId,
 
-						if (sources) {
-							responseMessage.sources = sources;
-							// Only remove status if it was initially set
-							if (model?.info?.meta?.knowledge ?? false) {
-								responseMessage.statusHistory = responseMessage.statusHistory.filter(
-									(status) => status.action !== 'knowledge_search'
-								);
+				...(!$temporaryChatEnabled && messages.length == 1 && selectedModels[0] === model.id
+					? {
+							background_tasks: {
+								title_generation: $settings?.title?.auto ?? true,
+								tags_generation: $settings?.autoTags ?? true
 							}
-							continue;
 						}
+					: {}),
 
-						if (responseMessage.content == '' && value == '\n') {
-							continue;
-						} else {
-							responseMessage.content += value;
-
-							if (navigator.vibrate && ($settings?.hapticFeedback ?? false)) {
-								navigator.vibrate(5);
-							}
-
-							const messageContentParts = getMessageContentParts(
-								responseMessage.content,
-								$config?.audio?.tts?.split_on ?? 'punctuation'
-							);
-							messageContentParts.pop();
-
-							// dispatch only last sentence and make sure it hasn't been dispatched before
-							if (
-								messageContentParts.length > 0 &&
-								messageContentParts[messageContentParts.length - 1] !== responseMessage.lastSentence
-							) {
-								responseMessage.lastSentence = messageContentParts[messageContentParts.length - 1];
-								eventTarget.dispatchEvent(
-									new CustomEvent('chat', {
-										detail: {
-											id: responseMessageId,
-											content: messageContentParts[messageContentParts.length - 1]
-										}
-									})
-								);
+				...(stream && (model.info?.meta?.capabilities?.usage ?? false)
+					? {
+							stream_options: {
+								include_usage: true
 							}
-
-							history.messages[responseMessageId] = responseMessage;
 						}
+					: {})
+			},
+			`${WEBUI_BASE_URL}/api`
+		).catch((error) => {
+			responseMessage.error = {
+				content: error
+			};
+			responseMessage.done = true;
+			return null;
+		});
 
-						if (autoScroll) {
-							scrollToBottom();
-						}
-					}
-				}
-
-				if ($settings.notificationEnabled && !document.hasFocus()) {
-					const notification = new Notification(`${model.id}`, {
-						body: responseMessage.content,
-						icon: `${WEBUI_BASE_URL}/static/favicon.png`
-					});
-				}
-
-				if ($settings.responseAutoCopy) {
-					copyToClipboard(responseMessage.content);
-				}
-
-				if ($settings.responseAutoPlayback && !$showCallOverlay) {
-					await tick();
+		console.log(res);
 
-					document.getElementById(`speak-button-${responseMessage.id}`)?.click();
-				}
-			} else {
-				await handleOpenAIError(null, res, model, responseMessage);
-			}
-		} catch (error) {
-			await handleOpenAIError(error, null, model, responseMessage);
+		if (res) {
+			taskId = res.task_id;
 		}
 
-		await saveChatHandler(_chatId);
-
-		history.messages[responseMessageId] = responseMessage;
-
-		await chatCompletedHandler(
-			_chatId,
-			model.id,
-			responseMessageId,
-			createMessagesList(responseMessageId)
-		);
-
-		stopResponseFlag = false;
+		// Wait until history/message have been updated
 		await tick();
+		scrollToBottom();
 
-		let lastMessageContentPart =
-			getMessageContentParts(
-				responseMessage.content,
-				$config?.audio?.tts?.split_on ?? 'punctuation'
-			)?.at(-1) ?? '';
-		if (lastMessageContentPart) {
-			eventTarget.dispatchEvent(
-				new CustomEvent('chat', {
-					detail: { id: responseMessageId, content: lastMessageContentPart }
-				})
-			);
-		}
-
-		eventTarget.dispatchEvent(
-			new CustomEvent('chat:finish', {
-				detail: {
-					id: responseMessageId,
-					content: responseMessage.content
-				}
-			})
-		);
-
-		if (autoScroll) {
-			scrollToBottom();
-		}
-
-		const messages = createMessagesList(responseMessageId);
-		if (messages.length == 2 && selectedModels[0] === model.id) {
-			window.history.replaceState(history.state, '', `/c/${_chatId}`);
-
-			const title = await generateChatTitle(messages);
-			await setChatTitle(_chatId, title);
-
-			if ($settings?.autoTags ?? true) {
-				await setChatTags(messages);
-			}
-		}
-
-		return _response;
+		// 	if ($settings?.autoTags ?? true) {
+		// 		await setChatTags(messages);
+		// 	}
+		// }
 	};
 
-	const handleOpenAIError = async (error, res: Response | null, model, responseMessage) => {
+	const handleOpenAIError = async (error, responseMessage) => {
 		let errorMessage = '';
 		let innerError;
 
 		if (error) {
 			innerError = error;
-		} else if (res !== null) {
-			innerError = await res.json();
 		}
+
 		console.error(innerError);
 		if ('detail' in innerError) {
 			toast.error(innerError.detail);
@@ -1543,12 +1528,7 @@
 		}
 
 		responseMessage.error = {
-			content:
-				$i18n.t(`Uh-oh! There was an issue connecting to {{provider}}.`, {
-					provider: model.name ?? model.id
-				}) +
-				'\n' +
-				errorMessage
+			content: $i18n.t(`Uh-oh! There was an issue with the response.`) + '\n' + errorMessage
 		};
 		responseMessage.done = true;
 
@@ -1562,8 +1542,15 @@
 	};
 
 	const stopResponse = () => {
-		stopResponseFlag = true;
-		console.log('stopResponse');
+		if (taskId) {
+			const res = stopTask(localStorage.token, taskId).catch((error) => {
+				return null;
+			});
+
+			if (res) {
+				taskId = null;
+			}
+		}
 	};
 
 	const submitMessage = async (parentId, prompt) => {
@@ -1628,12 +1615,7 @@
 				.at(0);
 
 			if (model) {
-				await sendPromptOpenAI(
-					model,
-					history.messages[responseMessage.parentId].content,
-					responseMessage.id,
-					_chatId
-				);
+				await sendPromptSocket(model, responseMessage.id, _chatId);
 			}
 		}
 	};
@@ -1685,38 +1667,6 @@
 		}
 	};
 
-	const generateChatTitle = async (messages) => {
-		const lastUserMessage = messages.filter((message) => message.role === 'user').at(-1);
-
-		if ($settings?.title?.auto ?? true) {
-			const modelId = selectedModels[0];
-
-			const title = await generateTitle(localStorage.token, modelId, messages, $chatId).catch(
-				(error) => {
-					console.error(error);
-					return lastUserMessage?.content ?? 'New Chat';
-				}
-			);
-
-			return title ? title : (lastUserMessage?.content ?? 'New Chat');
-		} else {
-			return lastUserMessage?.content ?? 'New Chat';
-		}
-	};
-
-	const setChatTitle = async (_chatId, title) => {
-		if (_chatId === $chatId) {
-			chatTitle.set(title);
-		}
-
-		if (!$temporaryChatEnabled) {
-			chat = await updateChatById(localStorage.token, _chatId, { title: title });
-
-			currentChatPage.set(1);
-			await chats.set(await getChatList(localStorage.token, $currentChatPage));
-		}
-	};
-
 	const setChatTags = async (messages) => {
 		if (!$temporaryChatEnabled) {
 			const currentTags = await getTagsById(localStorage.token, $chatId);
@@ -1856,6 +1806,8 @@
 			currentChatPage.set(1);
 			await chats.set(await getChatList(localStorage.token, $currentChatPage));
 			await chatId.set(chat.id);
+
+			window.history.replaceState(history.state, '', `/c/${chat.id}`);
 		} else {
 			await chatId.set('local');
 		}