|
@@ -19,7 +19,11 @@ from apps.webui.models.functions import Functions
|
|
from apps.webui.models.models import Models
|
|
from apps.webui.models.models import Models
|
|
from apps.webui.utils import load_function_module_by_id
|
|
from apps.webui.utils import load_function_module_by_id
|
|
|
|
|
|
-from utils.misc import stream_message_template, whole_message_template
|
|
|
|
|
|
+from utils.misc import (
|
|
|
|
+ stream_message_template,
|
|
|
|
+ whole_message_template,
|
|
|
|
+ add_or_update_system_message,
|
|
|
|
+)
|
|
from utils.task import prompt_template
|
|
from utils.task import prompt_template
|
|
|
|
|
|
|
|
|
|
@@ -47,8 +51,6 @@ from config import (
|
|
from apps.socket.main import get_event_call, get_event_emitter
|
|
from apps.socket.main import get_event_call, get_event_emitter
|
|
|
|
|
|
import inspect
|
|
import inspect
|
|
-import uuid
|
|
|
|
-import time
|
|
|
|
import json
|
|
import json
|
|
|
|
|
|
from typing import Iterator, Generator, AsyncGenerator
|
|
from typing import Iterator, Generator, AsyncGenerator
|
|
@@ -287,6 +289,7 @@ def get_extra_params(metadata: dict):
|
|
}
|
|
}
|
|
|
|
|
|
|
|
|
|
|
|
+# inplace function: form_data is modified
|
|
def add_model_params(params: dict, form_data: dict) -> dict:
|
|
def add_model_params(params: dict, form_data: dict) -> dict:
|
|
if not params:
|
|
if not params:
|
|
return form_data
|
|
return form_data
|
|
@@ -307,44 +310,40 @@ def add_model_params(params: dict, form_data: dict) -> dict:
|
|
return form_data
|
|
return form_data
|
|
|
|
|
|
|
|
|
|
|
|
+# inplace function: form_data is modified
|
|
|
|
+def populate_system_message(params: dict, form_data: dict, user) -> dict:
|
|
|
|
+ system = params.get("system", None)
|
|
|
|
+ if not system:
|
|
|
|
+ return form_data
|
|
|
|
+
|
|
|
|
+ if user:
|
|
|
|
+ template_params = {
|
|
|
|
+ "user_name": user.name,
|
|
|
|
+ "user_location": user.info.get("location") if user.info else None,
|
|
|
|
+ }
|
|
|
|
+ else:
|
|
|
|
+ template_params = {}
|
|
|
|
+ system = prompt_template(system, **template_params)
|
|
|
|
+ form_data["messages"] = add_or_update_system_message(
|
|
|
|
+ system, form_data.get("messages", [])
|
|
|
|
+ )
|
|
|
|
+ return form_data
|
|
|
|
+
|
|
|
|
+
|
|
async def generate_function_chat_completion(form_data, user):
|
|
async def generate_function_chat_completion(form_data, user):
|
|
- print("entry point")
|
|
|
|
model_id = form_data.get("model")
|
|
model_id = form_data.get("model")
|
|
model_info = Models.get_model_by_id(model_id)
|
|
model_info = Models.get_model_by_id(model_id)
|
|
-
|
|
|
|
metadata = form_data.pop("metadata", None)
|
|
metadata = form_data.pop("metadata", None)
|
|
- extra_params = get_extra_params(metadata)
|
|
|
|
|
|
|
|
|
|
+ # Add extra params such as __event_emitter__
|
|
|
|
+ extra_params = get_extra_params(metadata)
|
|
if model_info:
|
|
if model_info:
|
|
if model_info.base_model_id:
|
|
if model_info.base_model_id:
|
|
form_data["model"] = model_info.base_model_id
|
|
form_data["model"] = model_info.base_model_id
|
|
|
|
|
|
params = model_info.params.model_dump()
|
|
params = model_info.params.model_dump()
|
|
- system = params.get("system", None)
|
|
|
|
form_data = add_model_params(params, form_data)
|
|
form_data = add_model_params(params, form_data)
|
|
-
|
|
|
|
- if system:
|
|
|
|
- if user:
|
|
|
|
- template_params = {
|
|
|
|
- "user_name": user.name,
|
|
|
|
- "user_location": user.info.get("location") if user.info else None,
|
|
|
|
- }
|
|
|
|
- else:
|
|
|
|
- template_params = {}
|
|
|
|
-
|
|
|
|
- system = prompt_template(system, **template_params)
|
|
|
|
-
|
|
|
|
- # Check if the payload already has a system message
|
|
|
|
- # If not, add a system message to the payload
|
|
|
|
- for message in form_data.get("messages", []):
|
|
|
|
- if message.get("role") == "system":
|
|
|
|
- message["content"] = system + message["content"]
|
|
|
|
- break
|
|
|
|
- else:
|
|
|
|
- if form_data.get("messages"):
|
|
|
|
- form_data["messages"].insert(
|
|
|
|
- 0, {"role": "system", "content": system}
|
|
|
|
- )
|
|
|
|
|
|
+ form_data = populate_system_message(params, form_data, user)
|
|
|
|
|
|
async def job():
|
|
async def job():
|
|
pipe_id = get_pipe_id(form_data)
|
|
pipe_id = get_pipe_id(form_data)
|