Timothy J. Baek пре 9 месеци
родитељ
комит
e6c64282fc
1 измењених фајлова са 26 додато и 27 уклоњено
  1. 26 27
      backend/apps/webui/main.py

+ 26 - 27
backend/apps/webui/main.py

@@ -239,10 +239,10 @@ def get_pipe_id(form_data: dict) -> str:
     return pipe_id
 
 
-def get_params_dict(pipe, form_data, user, extra_params, function_module):
+def get_function_params(function_module, form_data, user, extra_params={}):
     pipe_id = get_pipe_id(form_data)
     # Get the signature of the function
-    sig = inspect.signature(pipe)
+    sig = inspect.signature(function_module.pipe)
     params = {"body": form_data}
 
     for key, value in extra_params.items():
@@ -269,26 +269,8 @@ def get_params_dict(pipe, form_data, user, extra_params, function_module):
     return params
 
 
-def get_extra_params(metadata: dict):
-    __event_emitter__ = None
-    __event_call__ = None
-    __task__ = None
-
-    if metadata:
-        if all(k in metadata for k in ("session_id", "chat_id", "message_id")):
-            __event_emitter__ = get_event_emitter(metadata)
-            __event_call__ = get_event_call(metadata)
-        __task__ = metadata.get("task", None)
-
-    return {
-        "__event_emitter__": __event_emitter__,
-        "__event_call__": __event_call__,
-        "__task__": __task__,
-    }
-
-
 # inplace function: form_data is modified
-def add_model_params(params: dict, form_data: dict) -> dict:
+def apply_model_params_to_body(params: dict, form_data: dict) -> dict:
     if not params:
         return form_data
 
@@ -309,7 +291,7 @@ def add_model_params(params: dict, form_data: dict) -> dict:
 
 
 # inplace function: form_data is modified
-def populate_system_message(params: dict, form_data: dict, user) -> dict:
+def apply_model_system_prompt_to_body(params: dict, form_data: dict, user) -> dict:
     system = params.get("system", None)
     if not system:
         return form_data
@@ -333,21 +315,38 @@ async def generate_function_chat_completion(form_data, user):
     model_info = Models.get_model_by_id(model_id)
     metadata = form_data.pop("metadata", None)
 
-    # Add extra params such as __event_emitter__
-    extra_params = get_extra_params(metadata)
+    __event_emitter__ = None
+    __event_call__ = None
+    __task__ = None
+
+    if metadata:
+        if all(k in metadata for k in ("session_id", "chat_id", "message_id")):
+            __event_emitter__ = get_event_emitter(metadata)
+            __event_call__ = get_event_call(metadata)
+        __task__ = metadata.get("task", None)
+
     if model_info:
         if model_info.base_model_id:
             form_data["model"] = model_info.base_model_id
 
         params = model_info.params.model_dump()
-        form_data = add_model_params(params, form_data)
-        form_data = populate_system_message(params, form_data, user)
+        form_data = apply_model_params_to_body(params, form_data)
+        form_data = apply_model_system_prompt_to_body(params, form_data, user)
 
     pipe_id = get_pipe_id(form_data)
     function_module = get_function_module(pipe_id)
 
     pipe = function_module.pipe
-    params = get_params_dict(pipe, form_data, user, extra_params, function_module)
+    params = get_function_params(
+        function_module,
+        form_data,
+        user,
+        {
+            "__event_emitter__": __event_emitter__,
+            "__event_call__": __event_call__,
+            "__task__": __task__,
+        },
+    )
 
     if form_data["stream"]: