|
@@ -19,8 +19,13 @@ from apps.webui.routers import (
|
|
functions,
|
|
functions,
|
|
)
|
|
)
|
|
from apps.webui.models.functions import Functions
|
|
from apps.webui.models.functions import Functions
|
|
|
|
+from apps.webui.models.models import Models
|
|
|
|
+
|
|
from apps.webui.utils import load_function_module_by_id
|
|
from apps.webui.utils import load_function_module_by_id
|
|
|
|
+
|
|
from utils.misc import stream_message_template
|
|
from utils.misc import stream_message_template
|
|
|
|
+from utils.task import prompt_template
|
|
|
|
+
|
|
|
|
|
|
from config import (
|
|
from config import (
|
|
WEBUI_BUILD_HASH,
|
|
WEBUI_BUILD_HASH,
|
|
@@ -186,6 +191,77 @@ async def get_pipe_models():
|
|
|
|
|
|
|
|
|
|
async def generate_function_chat_completion(form_data, user):
|
|
async def generate_function_chat_completion(form_data, user):
|
|
|
|
+ model_id = form_data.get("model")
|
|
|
|
+ model_info = Models.get_model_by_id(model_id)
|
|
|
|
+
|
|
|
|
+ if model_info:
|
|
|
|
+ if model_info.base_model_id:
|
|
|
|
+ form_data["model"] = model_info.base_model_id
|
|
|
|
+
|
|
|
|
+ model_info.params = model_info.params.model_dump()
|
|
|
|
+
|
|
|
|
+ if model_info.params:
|
|
|
|
+ if model_info.params.get("temperature", None) is not None:
|
|
|
|
+ form_data["temperature"] = float(model_info.params.get("temperature"))
|
|
|
|
+
|
|
|
|
+ if model_info.params.get("top_p", None):
|
|
|
|
+ form_data["top_p"] = int(model_info.params.get("top_p", None))
|
|
|
|
+
|
|
|
|
+ if model_info.params.get("max_tokens", None):
|
|
|
|
+ form_data["max_tokens"] = int(model_info.params.get("max_tokens", None))
|
|
|
|
+
|
|
|
|
+ if model_info.params.get("frequency_penalty", None):
|
|
|
|
+ form_data["frequency_penalty"] = int(
|
|
|
|
+ model_info.params.get("frequency_penalty", None)
|
|
|
|
+ )
|
|
|
|
+
|
|
|
|
+ if model_info.params.get("seed", None):
|
|
|
|
+ form_data["seed"] = model_info.params.get("seed", None)
|
|
|
|
+
|
|
|
|
+ if model_info.params.get("stop", None):
|
|
|
|
+ form_data["stop"] = (
|
|
|
|
+ [
|
|
|
|
+ bytes(stop, "utf-8").decode("unicode_escape")
|
|
|
|
+ for stop in model_info.params["stop"]
|
|
|
|
+ ]
|
|
|
|
+ if model_info.params.get("stop", None)
|
|
|
|
+ else None
|
|
|
|
+ )
|
|
|
|
+
|
|
|
|
+ system = model_info.params.get("system", None)
|
|
|
|
+ if system:
|
|
|
|
+ system = prompt_template(
|
|
|
|
+ system,
|
|
|
|
+ **(
|
|
|
|
+ {
|
|
|
|
+ "user_name": user.name,
|
|
|
|
+ "user_location": (
|
|
|
|
+ user.info.get("location") if user.info else None
|
|
|
|
+ ),
|
|
|
|
+ }
|
|
|
|
+ if user
|
|
|
|
+ else {}
|
|
|
|
+ ),
|
|
|
|
+ )
|
|
|
|
+ # Check if the payload already has a system message
|
|
|
|
+ # If not, add a system message to the payload
|
|
|
|
+ if form_data.get("messages"):
|
|
|
|
+ for message in form_data["messages"]:
|
|
|
|
+ if message.get("role") == "system":
|
|
|
|
+ message["content"] = system + message["content"]
|
|
|
|
+ break
|
|
|
|
+ else:
|
|
|
|
+ form_data["messages"].insert(
|
|
|
|
+ 0,
|
|
|
|
+ {
|
|
|
|
+ "role": "system",
|
|
|
|
+ "content": system,
|
|
|
|
+ },
|
|
|
|
+ )
|
|
|
|
+
|
|
|
|
+ else:
|
|
|
|
+ pass
|
|
|
|
+
|
|
async def job():
|
|
async def job():
|
|
pipe_id = form_data["model"]
|
|
pipe_id = form_data["model"]
|
|
if "." in pipe_id:
|
|
if "." in pipe_id:
|