Просмотр исходного кода

refac: reuse stream_message_template

Michael Poluektov 9 месяцев назад
Родитель
Сommit
29a3b82336
2 измененных файлов с 25 добавлено и 38 удалено
  1. 9 28
      backend/apps/webui/main.py
  2. 16 10
      backend/utils/misc.py

+ 9 - 28
backend/apps/webui/main.py

@@ -19,7 +19,7 @@ from apps.webui.models.functions import Functions
 from apps.webui.models.models import Models
 from apps.webui.utils import load_function_module_by_id
 
-from utils.misc import stream_message_template
+from utils.misc import stream_message_template, whole_message_template
 from utils.task import prompt_template
 
 
@@ -203,7 +203,7 @@ async def execute_pipe(pipe, params):
         return pipe(**params)
 
 
-async def get_message(res: str | Generator | AsyncGenerator) -> str:
+async def get_message_content(res: str | Generator | AsyncGenerator) -> str:
     if isinstance(res, str):
         return res
     if isinstance(res, Generator):
@@ -212,28 +212,6 @@ async def get_message(res: str | Generator | AsyncGenerator) -> str:
         return "".join([str(stream) async for stream in res])
 
 
-def get_final_message(form_data: dict, message: str | None = None) -> dict:
-    choice = {
-        "index": 0,
-        "logprobs": None,
-        "finish_reason": "stop",
-    }
-
-    # If message is None, we're dealing with a chunk
-    if not message:
-        choice["delta"] = {}
-    else:
-        choice["message"] = {"role": "assistant", "content": message}
-
-    return {
-        "id": f"{form_data['model']}-{str(uuid.uuid4())}",
-        "created": int(time.time()),
-        "model": form_data["model"],
-        "object": "chat.completion" if message is not None else "chat.completion.chunk",
-        "choices": [choice],
-    }
-
-
 def process_line(form_data: dict, line):
     if isinstance(line, BaseModel):
         line = line.model_dump_json()
@@ -292,7 +270,9 @@ def get_params_dict(pipe, form_data, user, extra_params, function_module):
 
 
 def get_extra_params(metadata: dict):
-    __event_emitter__ = __event_call__ = __task__ = None
+    __event_emitter__ = None
+    __event_call__ = None
+    __task__ = None
 
     if metadata:
         if all(k in metadata for k in ("session_id", "chat_id", "message_id")):
@@ -401,7 +381,8 @@ async def generate_function_chat_completion(form_data, user):
                         yield process_line(form_data, line)
 
                 if isinstance(res, str) or isinstance(res, Generator):
-                    finish_message = get_final_message(form_data)
+                    finish_message = stream_message_template(form_data, "")
+                    finish_message["choices"][0]["finish_reason"] = "stop"
                     yield f"data: {json.dumps(finish_message)}\n\n"
                     yield "data: [DONE]"
 
@@ -419,7 +400,7 @@ async def generate_function_chat_completion(form_data, user):
             if isinstance(res, BaseModel):
                 return res.model_dump()
 
-            message = await get_message(res)
-            return get_final_message(form_data, message)
+            message = await get_message_content(res)
+            return whole_message_template(form_data["model"], message)
 
     return await job()

+ 16 - 10
backend/utils/misc.py

@@ -87,23 +87,29 @@ def add_or_update_system_message(content: str, messages: List[dict]):
     return messages
 
 
-def stream_message_template(model: str, message: str):
+def message_template(model: str):
     return {
         "id": f"{model}-{str(uuid.uuid4())}",
-        "object": "chat.completion.chunk",
         "created": int(time.time()),
         "model": model,
-        "choices": [
-            {
-                "index": 0,
-                "delta": {"content": message},
-                "logprobs": None,
-                "finish_reason": None,
-            }
-        ],
+        "choices": [{"index": 0, "logprobs": None, "finish_reason": None}],
     }
 
 
+def stream_message_template(model: str, message: str):
+    template = message_template(model)
+    template["object"] = "chat.completion.chunk"
+    template["choices"][0]["delta"] = {"content": message}
+    return template
+
+
+def whole_message_template(model: str, message: str):
+    template = message_template(model)
+    template["object"] = "chat.completion"
+    template["choices"][0]["message"] = {"content": message, "role": "assistant"}
+    template["choices"][0]["finish_reason"] = "stop"
+
+
 def get_gravatar_url(email):
     # Trim leading and trailing whitespace from
     # an email address and force all characters