Timothy Jaeryang Baek 4 months ago
parent
commit
4c989808d6
1 changed files with 62 additions and 53 deletions
  1. 62 53
      backend/open_webui/utils/middleware.py

+ 62 - 53
backend/open_webui/utils/middleware.py

@@ -601,69 +601,78 @@ async def process_chat_response(request, response, user, events, metadata, tasks
                 if message:
                 if message:
                     messages = get_message_list(message_map, message.get("id"))
                     messages = get_message_list(message_map, message.get("id"))
 
 
-                    if TASKS.TITLE_GENERATION in tasks:
-                        res = await generate_title(
-                            request,
-                            {
-                                "model": message["model"],
-                                "messages": messages,
-                                "chat_id": metadata["chat_id"],
-                            },
-                            user,
-                        )
-
-                        if res:
-                            title = (
-                                res.get("choices", [])[0]
-                                .get("message", {})
-                                .get("content", message.get("content", "New Chat"))
-                            )
-
-                            Chats.update_chat_title_by_id(metadata["chat_id"], title)
-
-                            await event_emitter(
+                    if tasks:
+                        if (
+                            TASKS.TITLE_GENERATION in tasks
+                            and tasks[TASKS.TITLE_GENERATION]
+                        ):
+                            res = await generate_title(
+                                request,
                                 {
                                 {
-                                    "type": "chat:title",
-                                    "data": title,
-                                }
-                            )
-
-                    if TASKS.TAGS_GENERATION in tasks:
-                        res = await generate_chat_tags(
-                            request,
-                            {
-                                "model": message["model"],
-                                "messages": messages,
-                                "chat_id": metadata["chat_id"],
-                            },
-                            user,
-                        )
-
-                        if res:
-                            tags_string = (
-                                res.get("choices", [])[0]
-                                .get("message", {})
-                                .get("content", "")
+                                    "model": message["model"],
+                                    "messages": messages,
+                                    "chat_id": metadata["chat_id"],
+                                },
+                                user,
                             )
                             )
 
 
-                            tags_string = tags_string[
-                                tags_string.find("{") : tags_string.rfind("}") + 1
-                            ]
+                            if res:
+                                title = (
+                                    res.get("choices", [])[0]
+                                    .get("message", {})
+                                    .get("content", message.get("content", "New Chat"))
+                                )
 
 
-                            try:
-                                tags = json.loads(tags_string).get("tags", [])
-                                Chats.update_chat_tags_by_id(
-                                    metadata["chat_id"], tags, user
+                                Chats.update_chat_title_by_id(
+                                    metadata["chat_id"], title
                                 )
                                 )
 
 
                                 await event_emitter(
                                 await event_emitter(
                                     {
                                     {
-                                        "type": "chat:tags",
-                                        "data": tags,
+                                        "type": "chat:title",
+                                        "data": title,
                                     }
                                     }
                                 )
                                 )
-                            except Exception as e:
-                                print(f"Error: {e}")
+
+                        if (
+                            TASKS.TAGS_GENERATION in tasks
+                            and tasks[TASKS.TAGS_GENERATION]
+                        ):
+                            res = await generate_chat_tags(
+                                request,
+                                {
+                                    "model": message["model"],
+                                    "messages": messages,
+                                    "chat_id": metadata["chat_id"],
+                                },
+                                user,
+                            )
+
+                            if res:
+                                tags_string = (
+                                    res.get("choices", [])[0]
+                                    .get("message", {})
+                                    .get("content", "")
+                                )
+
+                                tags_string = tags_string[
+                                    tags_string.find("{") : tags_string.rfind("}") + 1
+                                ]
+
+                                try:
+                                    tags = json.loads(tags_string).get("tags", [])
+                                    Chats.update_chat_tags_by_id(
+                                        metadata["chat_id"], tags, user
+                                    )
+
+                                    await event_emitter(
+                                        {
+                                            "type": "chat:tags",
+                                            "data": tags,
+                                        }
+                                    )
+                                except Exception as e:
+                                    print(f"Error: {e}")
 
 
             except asyncio.CancelledError:
             except asyncio.CancelledError:
                 print("Task was cancelled!")
                 print("Task was cancelled!")