Browse Source

refac: reasoning tag

Timothy Jaeryang Baek 3 months ago
parent
commit
eb1ede119e
1 changed files with 154 additions and 83 deletions
  1. 154 83
      backend/open_webui/utils/middleware.py

+ 154 - 83
backend/open_webui/utils/middleware.py

@@ -8,6 +8,8 @@ from typing import Any, Optional
 import random
 import json
 import inspect
+import re
+
 from uuid import uuid4
 from concurrent.futures import ThreadPoolExecutor
 
@@ -987,6 +989,7 @@ async def process_chat_response(
                             pass
 
     event_emitter = None
+    event_caller = None
     if (
         "session_id" in metadata
         and metadata["session_id"]
@@ -996,10 +999,11 @@ async def process_chat_response(
         and metadata["message_id"]
     ):
         event_emitter = get_event_emitter(metadata)
+        event_caller = get_event_call(metadata)
 
+    # Non-streaming response
     if not isinstance(response, StreamingResponse):
         if event_emitter:
-
             if "selected_model_id" in response:
                 Chats.upsert_message_to_chat_by_id_and_message_id(
                     metadata["chat_id"],
@@ -1064,22 +1068,136 @@ async def process_chat_response(
         else:
             return response
 
+    # Non standard response
     if not any(
         content_type in response.headers["Content-Type"]
         for content_type in ["text/event-stream", "application/x-ndjson"]
     ):
         return response
 
-    if event_emitter:
-
+    # Streaming response
+    if event_emitter and event_caller:
         task_id = str(uuid4())  # Create a unique task ID.
 
         # Handle as a background task
         async def post_response_handler(response, events):
+            def serialize_content_blocks(content_blocks):
+                content = ""
+
+                for block in content_blocks:
+                    if block["type"] == "text":
+                        content = f"{content}{block['content'].strip()}\n"
+                    elif block["type"] == "reasoning":
+                        reasoning_display_content = "\n".join(
+                            (f"> {line}" if not line.startswith(">") else line)
+                            for line in block["content"].splitlines()
+                        )
+
+                        reasoning_duration = block.get("duration", None)
+
+                        if reasoning_duration:
+                            content = f'{content}<details type="reasoning" done="true" duration="{reasoning_duration}">\n<summary>Thought for {reasoning_duration} seconds</summary>\n{reasoning_display_content}\n</details>\n'
+                        else:
+                            content = f'{content}<details type="reasoning" done="false">\n<summary>Thinking…</summary>\n{reasoning_display_content}\n</details>\n'
+
+                    else:
+                        content = f"{content}{block['type']}: {block['content']}\n"
+
+                return content
+
+            def tag_content_handler(content_type, tags, content, content_blocks):
+                def extract_attributes(tag_content):
+                    """Extract attributes from a tag if they exist."""
+                    attributes = {}
+                    # Match attributes in the format: key="value" (ignores single quotes for simplicity)
+                    matches = re.findall(r'(\w+)\s*=\s*"([^"]+)"', tag_content)
+                    for key, value in matches:
+                        attributes[key] = value
+                    return attributes
+
+                if content_blocks[-1]["type"] == "text":
+                    for tag in tags:
+                        # Match start tag e.g., <tag> or <tag attr="value">
+                        start_tag_pattern = rf"<{tag}(.*?)>"
+                        match = re.search(start_tag_pattern, content)
+                        if match:
+                            # Extract attributes in the tag (if present)
+                            attributes = extract_attributes(match.group(1))
+                            # Remove the start tag from the currently handling text block
+                            content_blocks[-1]["content"] = content_blocks[-1][
+                                "content"
+                            ].replace(match.group(0), "")
+                            if not content_blocks[-1]["content"]:
+                                content_blocks.pop()
+                            # Append the new block
+                            content_blocks.append(
+                                {
+                                    "type": content_type,
+                                    "tag": tag,
+                                    "attributes": attributes,
+                                    "content": "",
+                                    "started_at": time.time(),
+                                }
+                            )
+                            break
+                elif content_blocks[-1]["type"] == content_type:
+                    tag = content_blocks[-1]["tag"]
+                    # Match end tag e.g., </tag>
+                    end_tag_pattern = rf"</{tag}>"
+                    if re.search(end_tag_pattern, content):
+                        block_content = content_blocks[-1]["content"]
+                        # Strip start and end tags from the content
+                        start_tag_pattern = rf"<{tag}(.*?)>"
+                        block_content = re.sub(
+                            start_tag_pattern, "", block_content
+                        ).strip()
+                        block_content = re.sub(
+                            end_tag_pattern, "", block_content
+                        ).strip()
+                        if block_content:
+                            content_blocks[-1]["content"] = block_content
+                            content_blocks[-1]["ended_at"] = time.time()
+                            content_blocks[-1]["duration"] = int(
+                                content_blocks[-1]["ended_at"]
+                                - content_blocks[-1]["started_at"]
+                            )
+                            # Reset the content_blocks by appending a new text block
+                            content_blocks.append(
+                                {
+                                    "type": "text",
+                                    "content": "",
+                                }
+                            )
+                            # Clean processed content
+                            content = re.sub(
+                                rf"<{tag}(.*?)>(.|\n)*?</{tag}>",
+                                "",
+                                content,
+                                flags=re.DOTALL,
+                            )
+                        else:
+                            # Remove the block if content is empty
+                            content_blocks.pop()
+                return content, content_blocks
+
             message = Chats.get_message_by_id_and_message_id(
                 metadata["chat_id"], metadata["message_id"]
             )
+
             content = message.get("content", "") if message else ""
+            content_blocks = [
+                {
+                    "type": "text",
+                    "content": content,
+                }
+            ]
+
+            # We might want to disable this by default
+            DETECT_REASONING = True
+            DETECT_CODE_INTERPRETER = True
+
+            reasoning_tags = ["think", "reason", "reasoning", "thought", "Thought"]
+            code_interpreter_tags = ["oi::code_interpreter"]
 
             try:
                 for event in events:
@@ -1099,16 +1217,6 @@ async def process_chat_response(
                         },
                     )
 
-                # We might want to disable this by default
-                detect_reasoning = True
-                reasoning_tags = ["think", "reason", "reasoning", "thought", "Thought"]
-                current_tag = None
-
-                reasoning_start_time = None
-
-                reasoning_content = ""
-                ongoing_content = ""
-
                 async for line in response.body_iterator:
                     line = line.decode("utf-8") if isinstance(line, bytes) else line
                     data = line
@@ -1144,73 +1252,28 @@ async def process_chat_response(
 
                             if value:
                                 content = f"{content}{value}"
+                                content_blocks[-1]["content"] = (
+                                    content_blocks[-1]["content"] + value
+                                )
+
+                                print(f"Content: {content}")
+                                print(f"Content Blocks: {content_blocks}")
+
+                                if DETECT_REASONING:
+                                    content, content_blocks = tag_content_handler(
+                                        "reasoning",
+                                        reasoning_tags,
+                                        content,
+                                        content_blocks,
+                                    )
 
-                                if detect_reasoning:
-                                    for tag in reasoning_tags:
-                                        start_tag = f"<{tag}>\n"
-                                        end_tag = f"</{tag}>\n"
-
-                                        if start_tag in content:
-                                            # Remove the start tag
-                                            content = content.replace(start_tag, "")
-                                            ongoing_content = content
-
-                                            reasoning_start_time = time.time()
-                                            reasoning_content = ""
-
-                                            current_tag = tag
-                                            break
-
-                                    if reasoning_start_time is not None:
-                                        # Remove the last value from the content
-                                        content = content[: -len(value)]
-
-                                        reasoning_content += value
-
-                                        end_tag = f"</{current_tag}>\n"
-                                        if end_tag in reasoning_content:
-                                            reasoning_end_time = time.time()
-                                            reasoning_duration = int(
-                                                reasoning_end_time
-                                                - reasoning_start_time
-                                            )
-                                            reasoning_content = (
-                                                reasoning_content.strip(
-                                                    f"<{current_tag}>\n"
-                                                )
-                                                .strip(end_tag)
-                                                .strip()
-                                            )
-
-                                            if reasoning_content:
-                                                reasoning_display_content = "\n".join(
-                                                    (
-                                                        f"> {line}"
-                                                        if not line.startswith(">")
-                                                        else line
-                                                    )
-                                                    for line in reasoning_content.splitlines()
-                                                )
-
-                                                # Format reasoning with <details> tag
-                                                content = f'{ongoing_content}<details type="reasoning" done="true" duration="{reasoning_duration}">\n<summary>Thought for {reasoning_duration} seconds</summary>\n{reasoning_display_content}\n</details>\n'
-                                            else:
-                                                content = ""
-
-                                            reasoning_start_time = None
-                                        else:
-
-                                            reasoning_display_content = "\n".join(
-                                                (
-                                                    f"> {line}"
-                                                    if not line.startswith(">")
-                                                    else line
-                                                )
-                                                for line in reasoning_content.splitlines()
-                                            )
-
-                                            # Show ongoing thought process
-                                            content = f'{ongoing_content}<details type="reasoning" done="false">\n<summary>Thinking…</summary>\n{reasoning_display_content}\n</details>\n'
+                                if DETECT_CODE_INTERPRETER:
+                                    content, content_blocks = tag_content_handler(
+                                        "code_interpreter",
+                                        code_interpreter_tags,
+                                        content,
+                                        content_blocks,
+                                    )
 
                                 if ENABLE_REALTIME_CHAT_SAVE:
                                     # Save message in the database
@@ -1218,12 +1281,16 @@ async def process_chat_response(
                                         metadata["chat_id"],
                                         metadata["message_id"],
                                         {
-                                            "content": content,
+                                            "content": serialize_content_blocks(
+                                                content_blocks
+                                            ),
                                         },
                                     )
                                 else:
                                     data = {
-                                        "content": content,
+                                        "content": serialize_content_blocks(
+                                            content_blocks
+                                        ),
                                     }
 
                         await event_emitter(
@@ -1240,7 +1307,11 @@ async def process_chat_response(
                             continue
 
                 title = Chats.get_chat_title_by_id(metadata["chat_id"])
-                data = {"done": True, "content": content, "title": title}
+                data = {
+                    "done": True,
+                    "content": serialize_content_blocks(content_blocks),
+                    "title": title,
+                }
 
                 if not ENABLE_REALTIME_CHAT_SAVE:
                     # Save message in the database
@@ -1248,7 +1319,7 @@ async def process_chat_response(
                         metadata["chat_id"],
                         metadata["message_id"],
                         {
-                            "content": content,
+                            "content": serialize_content_blocks(content_blocks),
                         },
                     )