Forráskód Böngészése

fix: function early returns

Michael Poluektov 9 hónapja
szülő
commit
deec41d29a
1 módosított fájl, 56 hozzáadás és 52 törlés
  1. 56 52
      backend/apps/webui/main.py

+ 56 - 52
backend/apps/webui/main.py

@@ -291,12 +291,7 @@ def get_params_dict(pipe, form_data, user, extra_params, function_module):
     return params
 
 
-async def generate_function_chat_completion(form_data, user):
-    model_id = form_data.get("model")
-    model_info = Models.get_model_by_id(model_id)
-
-    metadata = form_data.pop("metadata", None)
-
+def get_extra_params(metadata: dict):
     __event_emitter__ = __event_call__ = __task__ = None
 
     if metadata:
@@ -305,57 +300,66 @@ async def generate_function_chat_completion(form_data, user):
             __event_call__ = get_event_call(metadata)
         __task__ = metadata.get("task", None)
 
-    if not model_info:
-        return
-
-    if model_info.base_model_id:
-        form_data["model"] = model_info.base_model_id
-
-    params = model_info.params.model_dump()
-
-    if params:
-        mappings = {
-            "temperature": float,
-            "top_p": int,
-            "max_tokens": int,
-            "frequency_penalty": int,
-            "seed": lambda x: x,
-            "stop": lambda x: [bytes(s, "utf-8").decode("unicode_escape") for s in x],
-        }
-
-        for key, cast_func in mappings.items():
-            if (value := params.get(key)) is not None:
-                form_data[key] = cast_func(value)
+    return {
+        "__event_emitter__": __event_emitter__,
+        "__event_call__": __event_call__,
+        "__task__": __task__,
+    }
 
-    system = params.get("system", None)
-    if not system:
-        return
 
-    if user:
-        template_params = {
-            "user_name": user.name,
-            "user_location": user.info.get("location") if user.info else None,
-        }
-    else:
-        template_params = {}
+async def generate_function_chat_completion(form_data, user):
+    print("entry point")
+    model_id = form_data.get("model")
+    model_info = Models.get_model_by_id(model_id)
 
-    system = prompt_template(system, **template_params)
+    metadata = form_data.pop("metadata", None)
+    extra_params = get_extra_params(metadata)
+
+    if model_info:
+        if model_info.base_model_id:
+            form_data["model"] = model_info.base_model_id
+
+        params = model_info.params.model_dump()
+
+        if params:
+            mappings = {
+                "temperature": float,
+                "top_p": int,
+                "max_tokens": int,
+                "frequency_penalty": int,
+                "seed": lambda x: x,
+                "stop": lambda x: [
+                    bytes(s, "utf-8").decode("unicode_escape") for s in x
+                ],
+            }
+
+            for key, cast_func in mappings.items():
+                if (value := params.get(key)) is not None:
+                    form_data[key] = cast_func(value)
+
+        system = params.get("system", None)
+        if system:
+            if user:
+                template_params = {
+                    "user_name": user.name,
+                    "user_location": user.info.get("location") if user.info else None,
+                }
+            else:
+                template_params = {}
 
-    # Check if the payload already has a system message
-    # If not, add a system message to the payload
-    for message in form_data.get("messages", []):
-        if message.get("role") == "system":
-            message["content"] = system + message["content"]
-            break
-    else:
-        if form_data.get("messages"):
-            form_data["messages"].insert(0, {"role": "system", "content": system})
+            system = prompt_template(system, **template_params)
 
-    extra_params = {
-        "__event_emitter__": __event_emitter__,
-        "__event_call__": __event_call__,
-        "__task__": __task__,
-    }
+            # Check if the payload already has a system message
+            # If not, add a system message to the payload
+            for message in form_data.get("messages", []):
+                if message.get("role") == "system":
+                    message["content"] = system + message["content"]
+                    break
+            else:
+                if form_data.get("messages"):
+                    form_data["messages"].insert(
+                        0, {"role": "system", "content": system}
+                    )
 
     async def job():
         pipe_id = get_pipe_id(form_data)