Browse Source

fix: stream_message_template

Michael Poluektov 9 months ago
parent
commit
006fc3495e
1 changed files with 23 additions and 18 deletions
  1. 23 18
      backend/apps/webui/main.py

+ 23 - 18
backend/apps/webui/main.py

@@ -287,6 +287,26 @@ def get_extra_params(metadata: dict):
     }
 
 
+def add_model_params(params: dict, form_data: dict) -> dict:
+    if not params:
+        return form_data
+
+    mappings = {
+        "temperature": float,
+        "top_p": int,
+        "max_tokens": int,
+        "frequency_penalty": int,
+        "seed": lambda x: x,
+        "stop": lambda x: [bytes(s, "utf-8").decode("unicode_escape") for s in x],
+    }
+
+    for key, cast_func in mappings.items():
+        if (value := params.get(key)) is not None:
+            form_data[key] = cast_func(value)
+
+    return form_data
+
+
 async def generate_function_chat_completion(form_data, user):
     print("entry point")
     model_id = form_data.get("model")
@@ -300,24 +320,9 @@ async def generate_function_chat_completion(form_data, user):
             form_data["model"] = model_info.base_model_id
 
         params = model_info.params.model_dump()
-
-        if params:
-            mappings = {
-                "temperature": float,
-                "top_p": int,
-                "max_tokens": int,
-                "frequency_penalty": int,
-                "seed": lambda x: x,
-                "stop": lambda x: [
-                    bytes(s, "utf-8").decode("unicode_escape") for s in x
-                ],
-            }
-
-            for key, cast_func in mappings.items():
-                if (value := params.get(key)) is not None:
-                    form_data[key] = cast_func(value)
-
         system = params.get("system", None)
+        form_data = add_model_params(params, form_data)
+
         if system:
             if user:
                 template_params = {
@@ -381,7 +386,7 @@ async def generate_function_chat_completion(form_data, user):
                         yield process_line(form_data, line)
 
                 if isinstance(res, str) or isinstance(res, Generator):
-                    finish_message = stream_message_template(form_data, "")
+                    finish_message = stream_message_template(form_data["model"], "")
                     finish_message["choices"][0]["finish_reason"] = "stop"
                     yield f"data: {json.dumps(finish_message)}\n\n"
                     yield "data: [DONE]"