Browse Source

fix: tag support

Timothy Jaeryang Baek 2 months ago
parent
commit
90cd6f272f
1 changed files with 49 additions and 20 deletions
  1. 49 20
      backend/open_webui/utils/middleware.py

+ 49 - 20
backend/open_webui/utils/middleware.py

@@ -1231,18 +1231,10 @@ async def process_chat_response(
                             content_blocks[-1]["content"] = content_blocks[-1][
                                 "content"
                             ].replace(match.group(0), "")
+
                             if not content_blocks[-1]["content"]:
                                 content_blocks.pop()
 
-                                if not content_blocks:
-                                    # Append the new block
-                                    content_blocks.append(
-                                        {
-                                            "type": "text",
-                                            "content": "",
-                                        }
-                                    )
-
                             # Append the new block
                             content_blocks.append(
                                 {
@@ -1258,6 +1250,7 @@ async def process_chat_response(
                     tag = content_blocks[-1]["tag"]
                     # Match end tag e.g., </tag>
                     end_tag_pattern = rf"</{tag}>"
+
                     if re.search(end_tag_pattern, content):
                         block_content = content_blocks[-1]["content"]
                         # Strip start and end tags from the content
@@ -1265,9 +1258,23 @@ async def process_chat_response(
                         block_content = re.sub(
                             start_tag_pattern, "", block_content
                         ).strip()
-                        block_content = re.sub(
-                            end_tag_pattern, "", block_content
-                        ).strip()
+
+                        end_tag_regex = re.compile(end_tag_pattern, re.DOTALL)
+                        split_content = end_tag_regex.split(block_content, maxsplit=1)
+
+                        # Content inside the tag
+                        block_content = (
+                            split_content[0].strip() if split_content else ""
+                        )
+
+                        # Leftover content (everything after `</tag>`)
+                        leftover_content = (
+                            split_content[1].strip() if len(split_content) > 1 else ""
+                        )
+
+                        print(f"block_content: {block_content}")
+                        print(f"leftover_content: {leftover_content}")
+
                         if block_content:
                             end_flag = True
                             content_blocks[-1]["content"] = block_content
@@ -1280,19 +1287,31 @@ async def process_chat_response(
                             content_blocks.append(
                                 {
                                     "type": "text",
-                                    "content": "",
+                                    "content": leftover_content,
                                 }
                             )
-                            # Clean processed content
-                            content = re.sub(
-                                rf"<{tag}(.*?)>(.|\n)*?</{tag}>",
-                                "",
-                                content,
-                                flags=re.DOTALL,
-                            )
+
                         else:
+                            end_flag = True
                             # Remove the block if content is empty
                             content_blocks.pop()
+
+                            if leftover_content:
+                                content_blocks.append(
+                                    {
+                                        "type": "text",
+                                        "content": leftover_content,
+                                    }
+                                )
+
+                        # Clean processed content
+                        content = re.sub(
+                            rf"<{tag}(.*?)>(.|\n)*?</{tag}>",
+                            "",
+                            content,
+                            flags=re.DOTALL,
+                        )
+
                 return content, content_blocks, end_flag
 
             message = Chats.get_message_by_id_and_message_id(
@@ -1358,6 +1377,7 @@ async def process_chat_response(
 
                         try:
                             data = json.loads(data)
+                            print(data)
 
                             if "selected_model_id" in data:
                                 model_id = data["selected_model_id"]
@@ -1412,6 +1432,15 @@ async def process_chat_response(
 
                                 if value:
                                     content = f"{content}{value}"
+
+                                    if not content_blocks:
+                                        content_blocks.append(
+                                            {
+                                                "type": "text",
+                                                "content": "",
+                                            }
+                                        )
+
                                     content_blocks[-1]["content"] = (
                                         content_blocks[-1]["content"] + value
                                     )