瀏覽代碼

fix: "stream" hook not working

Timothy Jaeryang Baek 2 月之前
父節點
當前提交
d7088efe73
共有 3 個文件被更改,包括 14 次插入10 次删除
  1. 3 3
      backend/open_webui/main.py
  2. 1 1
      backend/open_webui/utils/filter.py
  3. 10 6
      backend/open_webui/utils/middleware.py

+ 3 - 3
backend/open_webui/main.py

@@ -1021,7 +1021,7 @@ async def chat_completion(
             "files": form_data.get("files", None),
             "features": form_data.get("features", None),
             "variables": form_data.get("variables", None),
-            "model": model_info.model_dump() if model_info else model,
+            "model": model,
             "direct": model_item.get("direct", False),
             **(
                 {"function_calling": "native"}
@@ -1039,7 +1039,7 @@ async def chat_completion(
         form_data["metadata"] = metadata
 
         form_data, metadata, events = await process_chat_payload(
-            request, form_data, metadata, user, model
+            request, form_data, user, metadata, model
         )
 
     except Exception as e:
@@ -1053,7 +1053,7 @@ async def chat_completion(
         response = await chat_completion_handler(request, form_data, user)
 
         return await process_chat_response(
-            request, response, form_data, user, events, metadata, tasks
+            request, response, form_data, user, metadata, model, events, tasks
         )
     except Exception as e:
         raise HTTPException(

+ 1 - 1
backend/open_webui/utils/filter.py

@@ -9,7 +9,7 @@ log = logging.getLogger(__name__)
 log.setLevel(SRC_LOG_LEVELS["MAIN"])
 
 
-def get_sorted_filter_ids(model):
+def get_sorted_filter_ids(model: dict):
     def get_priority(function_id):
         function = Functions.get_function_by_id(function_id)
         if function is not None and hasattr(function, "valves"):

+ 10 - 6
backend/open_webui/utils/middleware.py

@@ -68,7 +68,7 @@ from open_webui.utils.misc import (
     get_last_user_message,
     get_last_assistant_message,
     prepend_to_first_user_message_content,
-    convert_logit_bias_input_to_json
+    convert_logit_bias_input_to_json,
 )
 from open_webui.utils.tools import get_tools
 from open_webui.utils.plugin import load_function_module_by_id
@@ -613,14 +613,16 @@ def apply_params_to_form_data(form_data, model):
             form_data["reasoning_effort"] = params["reasoning_effort"]
         if "logit_bias" in params:
             try:
-                form_data["logit_bias"] = json.loads(convert_logit_bias_input_to_json(params["logit_bias"]))
+                form_data["logit_bias"] = json.loads(
+                    convert_logit_bias_input_to_json(params["logit_bias"])
+                )
             except Exception as e:
                 print(f"Error parsing logit_bias: {e}")
 
     return form_data
 
 
-async def process_chat_payload(request, form_data, metadata, user, model):
+async def process_chat_payload(request, form_data, user, metadata, model):
 
     form_data = apply_params_to_form_data(form_data, model)
     log.debug(f"form_data: {form_data}")
@@ -862,7 +864,7 @@ async def process_chat_payload(request, form_data, metadata, user, model):
 
 
 async def process_chat_response(
-    request, response, form_data, user, events, metadata, tasks
+    request, response, form_data, user, metadata, model, events, tasks
 ):
     async def background_tasks_handler():
         message_map = Chats.get_messages_by_chat_id(metadata["chat_id"])
@@ -1067,9 +1069,11 @@ async def process_chat_response(
         },
         "__metadata__": metadata,
         "__request__": request,
-        "__model__": metadata.get("model"),
+        "__model__": model,
     }
-    filter_ids = get_sorted_filter_ids(form_data.get("model"))
+    filter_ids = get_sorted_filter_ids(model)
+
+    print(f"{filter_ids=}")
 
     # Streaming response
     if event_emitter and event_caller: