浏览代码

refac: use add_or_update_system_message

Michael Poluektov 9 月之前
父节点
当前提交
baf58ef396
共有 1 个文件被更改,包括 29 次插入30 次删除
  1. 29 30
      backend/apps/webui/main.py

+ 29 - 30
backend/apps/webui/main.py

@@ -19,7 +19,11 @@ from apps.webui.models.functions import Functions
 from apps.webui.models.models import Models
 from apps.webui.utils import load_function_module_by_id
 
-from utils.misc import stream_message_template, whole_message_template
+from utils.misc import (
+    stream_message_template,
+    whole_message_template,
+    add_or_update_system_message,
+)
 from utils.task import prompt_template
 
 
@@ -47,8 +51,6 @@ from config import (
 from apps.socket.main import get_event_call, get_event_emitter
 
 import inspect
-import uuid
-import time
 import json
 
 from typing import Iterator, Generator, AsyncGenerator
@@ -287,6 +289,7 @@ def get_extra_params(metadata: dict):
     }
 
 
+# inplace function: form_data is modified
 def add_model_params(params: dict, form_data: dict) -> dict:
     if not params:
         return form_data
@@ -307,44 +310,40 @@ def add_model_params(params: dict, form_data: dict) -> dict:
     return form_data
 
 
+# inplace function: form_data is modified
+def populate_system_message(params: dict, form_data: dict, user) -> dict:
+    system = params.get("system", None)
+    if not system:
+        return form_data
+
+    if user:
+        template_params = {
+            "user_name": user.name,
+            "user_location": user.info.get("location") if user.info else None,
+        }
+    else:
+        template_params = {}
+    system = prompt_template(system, **template_params)
+    form_data["messages"] = add_or_update_system_message(
+        system, form_data.get("messages", [])
+    )
+    return form_data
+
+
 async def generate_function_chat_completion(form_data, user):
-    print("entry point")
     model_id = form_data.get("model")
     model_info = Models.get_model_by_id(model_id)
-
     metadata = form_data.pop("metadata", None)
-    extra_params = get_extra_params(metadata)
 
+    # Add extra params such as __event_emitter__
+    extra_params = get_extra_params(metadata)
     if model_info:
         if model_info.base_model_id:
             form_data["model"] = model_info.base_model_id
 
         params = model_info.params.model_dump()
-        system = params.get("system", None)
         form_data = add_model_params(params, form_data)
-
-        if system:
-            if user:
-                template_params = {
-                    "user_name": user.name,
-                    "user_location": user.info.get("location") if user.info else None,
-                }
-            else:
-                template_params = {}
-
-            system = prompt_template(system, **template_params)
-
-            # Check if the payload already has a system message
-            # If not, add a system message to the payload
-            for message in form_data.get("messages", []):
-                if message.get("role") == "system":
-                    message["content"] = system + message["content"]
-                    break
-            else:
-                if form_data.get("messages"):
-                    form_data["messages"].insert(
-                        0, {"role": "system", "content": system}
-                    )
+        form_data = populate_system_message(params, form_data, user)
 
     async def job():
         pipe_id = get_pipe_id(form_data)