Selaa lähdekoodia

Support thinking tags used by Openthinker

Bob McElrath 2 kuukautta sitten
vanhempi
commit
c4b441de65
1 muutettua tiedostoa jossa 37 lisäystä ja 16 poistoa
  1. 37 16
      backend/open_webui/utils/middleware.py

+ 37 - 16
backend/open_webui/utils/middleware.py

@@ -1127,12 +1127,12 @@ async def process_chat_response(
 
                         if reasoning_duration is not None:
                             if raw:
-                                content = f'{content}\n<{block["tag"]}>{block["content"]}</{block["tag"]}>\n'
+                                content = f'{content}\n<{block["start_tag"]}>{block["content"]}<{block["end_tag"]}>\n'
                             else:
                                 content = f'{content}\n<details type="reasoning" done="true" duration="{reasoning_duration}">\n<summary>Thought for {reasoning_duration} seconds</summary>\n{reasoning_display_content}\n</details>\n'
                         else:
                             if raw:
-                                content = f'{content}\n<{block["tag"]}>{block["content"]}</{block["tag"]}>\n'
+                                content = f'{content}\n<{block["start_tag"]}>{block["content"]}<{block["end_tag"]}>\n'
                             else:
                                 content = f'{content}\n<details type="reasoning" done="false">\n<summary>Thinking…</summary>\n{reasoning_display_content}\n</details>\n'
 
@@ -1228,9 +1228,9 @@ async def process_chat_response(
                     return attributes
 
                 if content_blocks[-1]["type"] == "text":
-                    for tag in tags:
+                    for start_tag, end_tag in tags:
                         # Match start tag e.g., <tag> or <tag attr="value">
-                        start_tag_pattern = rf"<{tag}(\s.*?)?>"
+                        start_tag_pattern = rf"<{re.escape(start_tag)}(\s.*?)?>"
                         match = re.search(start_tag_pattern, content)
                         if match:
                             attr_content = (
@@ -1263,7 +1263,8 @@ async def process_chat_response(
                             content_blocks.append(
                                 {
                                     "type": content_type,
-                                    "tag": tag,
+                                    "start_tag": start_tag,
+                                    "end_tag": end_tag,
                                     "attributes": attributes,
                                     "content": "",
                                     "started_at": time.time(),
@@ -1275,9 +1276,10 @@ async def process_chat_response(
 
                             break
                 elif content_blocks[-1]["type"] == content_type:
-                    tag = content_blocks[-1]["tag"]
+                    start_tag = content_blocks[-1]["start_tag"]
+                    end_tag = content_blocks[-1]["end_tag"]
                     # Match end tag e.g., </tag>
-                    end_tag_pattern = rf"</{tag}>"
+                    end_tag_pattern = rf"<{re.escape(end_tag)}>"
 
                     # Check if the content has the end tag
                     if re.search(end_tag_pattern, content):
@@ -1285,7 +1287,7 @@ async def process_chat_response(
 
                         block_content = content_blocks[-1]["content"]
                         # Strip start and end tags from the content
-                        start_tag_pattern = rf"<{tag}(.*?)>"
+                        start_tag_pattern = rf"<{re.escape(start_tag)}(.*?)>"
                         block_content = re.sub(
                             start_tag_pattern, "", block_content
                         ).strip()
@@ -1350,7 +1352,7 @@ async def process_chat_response(
 
                         # Clean processed content
                         content = re.sub(
-                            rf"<{tag}(.*?)>(.|\n)*?</{tag}>",
+                            rf"<{re.escape(start_tag)}(.*?)>(.|\n)*?<{re.escape(end_tag)}>",
                             "",
                             content,
                             flags=re.DOTALL,
@@ -1388,19 +1390,28 @@ async def process_chat_response(
 
             # We might want to disable this by default
             DETECT_REASONING = True
+            DETECT_SOLUTION = True
             DETECT_CODE_INTERPRETER = metadata.get("features", {}).get(
                 "code_interpreter", False
             )
 
             reasoning_tags = [
-                "think",
-                "thinking",
-                "reason",
-                "reasoning",
-                "thought",
-                "Thought",
+                ("think", "/think"),
+                ("thinking", "/thinking"),
+                ("reason", "/reason"),
+                ("reasoning", "/reasoning"),
+                ("thought", "/thought"),
+                ("Thought", "/Thought"),
+                ("|begin_of_thought|", "|end_of_thought|")
+            ]
+
+            code_interpreter_tags = [
+                ("code_interpreter", "/code_interpreter")
+            ]
+
+            solution_tags = [
+                ("|begin_of_solution|", "|end_of_solution|")
             ]
-            code_interpreter_tags = ["code_interpreter"]
 
             try:
                 for event in events:
@@ -1533,6 +1544,16 @@ async def process_chat_response(
                                         if end:
                                             break
 
+                                    if DETECT_SOLUTION:
+                                        content, content_blocks, _ = (
+                                            tag_content_handler(
+                                                "solution",
+                                                solution_tags,
+                                                content,
+                                                content_blocks,
+                                            )
+                                        )
+
                                     if ENABLE_REALTIME_CHAT_SAVE:
                                         # Save message in the database
                                         Chats.upsert_message_to_chat_by_id_and_message_id(