|
@@ -287,6 +287,26 @@ def get_extra_params(metadata: dict):
|
|
|
}
|
|
|
|
|
|
|
|
|
+def add_model_params(params: dict, form_data: dict) -> dict:
|
|
|
+ if not params:
|
|
|
+ return form_data
|
|
|
+
|
|
|
+ mappings = {
|
|
|
+ "temperature": float,
|
|
|
+ "top_p": int,
|
|
|
+ "max_tokens": int,
|
|
|
+ "frequency_penalty": int,
|
|
|
+ "seed": lambda x: x,
|
|
|
+ "stop": lambda x: [bytes(s, "utf-8").decode("unicode_escape") for s in x],
|
|
|
+ }
|
|
|
+
|
|
|
+ for key, cast_func in mappings.items():
|
|
|
+ if (value := params.get(key)) is not None:
|
|
|
+ form_data[key] = cast_func(value)
|
|
|
+
|
|
|
+ return form_data
|
|
|
+
|
|
|
+
|
|
|
async def generate_function_chat_completion(form_data, user):
|
|
|
print("entry point")
|
|
|
model_id = form_data.get("model")
|
|
@@ -300,24 +320,9 @@ async def generate_function_chat_completion(form_data, user):
|
|
|
form_data["model"] = model_info.base_model_id
|
|
|
|
|
|
params = model_info.params.model_dump()
|
|
|
-
|
|
|
- if params:
|
|
|
- mappings = {
|
|
|
- "temperature": float,
|
|
|
- "top_p": int,
|
|
|
- "max_tokens": int,
|
|
|
- "frequency_penalty": int,
|
|
|
- "seed": lambda x: x,
|
|
|
- "stop": lambda x: [
|
|
|
- bytes(s, "utf-8").decode("unicode_escape") for s in x
|
|
|
- ],
|
|
|
- }
|
|
|
-
|
|
|
- for key, cast_func in mappings.items():
|
|
|
- if (value := params.get(key)) is not None:
|
|
|
- form_data[key] = cast_func(value)
|
|
|
-
|
|
|
system = params.get("system", None)
|
|
|
+ form_data = add_model_params(params, form_data)
|
|
|
+
|
|
|
if system:
|
|
|
if user:
|
|
|
template_params = {
|
|
@@ -381,7 +386,7 @@ async def generate_function_chat_completion(form_data, user):
|
|
|
yield process_line(form_data, line)
|
|
|
|
|
|
if isinstance(res, str) or isinstance(res, Generator):
|
|
|
- finish_message = stream_message_template(form_data, "")
|
|
|
+ finish_message = stream_message_template(form_data["model"], "")
|
|
|
finish_message["choices"][0]["finish_reason"] = "stop"
|
|
|
yield f"data: {json.dumps(finish_message)}\n\n"
|
|
|
yield "data: [DONE]"
|