浏览代码

enh: "stream" hook

Timothy Jaeryang Baek 2 月之前
父节点
当前提交
46c4da4864
共有 2 个文件被更改,包括 46 次插入9 次删除
  1. 6 1
      backend/open_webui/utils/filter.py
  2. 40 8
      backend/open_webui/utils/middleware.py

+ 6 - 1
backend/open_webui/utils/filter.py

@@ -61,7 +61,12 @@ async def process_filter_functions(
         try:
             # Prepare parameters
             sig = inspect.signature(handler)
-            params = {"body": form_data} | {
+
+            params = {"body": form_data}
+            if filter_type == "stream":
+                params = {"event": form_data}
+
+            params = params | {
                 k: v
                 for k, v in {
                     **extra_params,

+ 40 - 8
backend/open_webui/utils/middleware.py

@@ -1048,6 +1048,21 @@ async def process_chat_response(
     ):
         return response
 
+    extra_params = {
+        "__event_emitter__": event_emitter,
+        "__event_call__": event_caller,
+        "__user__": {
+            "id": user.id,
+            "email": user.email,
+            "name": user.name,
+            "role": user.role,
+        },
+        "__metadata__": metadata,
+        "__request__": request,
+        "__model__": metadata.get("model"),
+    }
+    filter_ids = get_sorted_filter_ids(form_data.get("model"))
+
     # Streaming response
     if event_emitter and event_caller:
         task_id = str(uuid4())  # Create a unique task ID.
@@ -1402,16 +1417,12 @@ async def process_chat_response(
                 ("reasoning", "/reasoning"),
                 ("thought", "/thought"),
                 ("Thought", "/Thought"),
-                ("|begin_of_thought|", "|end_of_thought|")
+                ("|begin_of_thought|", "|end_of_thought|"),
             ]
 
-            code_interpreter_tags = [
-                ("code_interpreter", "/code_interpreter")
-            ]
+            code_interpreter_tags = [("code_interpreter", "/code_interpreter")]
 
-            solution_tags = [
-                ("|begin_of_solution|", "|end_of_solution|")
-            ]
+            solution_tags = [("|begin_of_solution|", "|end_of_solution|")]
 
             try:
                 for event in events:
@@ -1455,6 +1466,14 @@ async def process_chat_response(
                         try:
                             data = json.loads(data)
 
+                            data, _ = await process_filter_functions(
+                                request=request,
+                                filter_ids=filter_ids,
+                                filter_type="stream",
+                                form_data=data,
+                                extra_params=extra_params,
+                            )
+
                             if "selected_model_id" in data:
                                 model_id = data["selected_model_id"]
                                 Chats.upsert_message_to_chat_by_id_and_message_id(
@@ -1968,16 +1987,29 @@ async def process_chat_response(
         return {"status": True, "task_id": task_id}
 
     else:
-
         # Fallback to the original response
         async def stream_wrapper(original_generator, events):
             def wrap_item(item):
                 return f"data: {item}\n\n"
 
             for event in events:
+                event, _ = await process_filter_functions(
+                    request=request,
+                    filter_ids=filter_ids,
+                    filter_type="stream",
+                    form_data=event,
+                    extra_params=extra_params,
+                )
                 yield wrap_item(json.dumps(event))
 
             async for data in original_generator:
+                data, _ = await process_filter_functions(
+                    request=request,
+                    filter_ids=filter_ids,
+                    filter_type="stream",
+                    form_data=data,
+                    extra_params=extra_params,
+                )
                 yield data
 
         return StreamingResponse(