Sfoglia il codice sorgente

fix: pipe custom model

Timothy J. Baek 10 mesi fa
parent
commit
67c2ab006d
2 ha cambiato i file con 81 aggiunte e 1 eliminazioni
  1. 76 0
      backend/apps/webui/main.py
  2. 5 1
      backend/main.py

+ 76 - 0
backend/apps/webui/main.py

@@ -19,8 +19,13 @@ from apps.webui.routers import (
     functions,
 )
 from apps.webui.models.functions import Functions
+from apps.webui.models.models import Models
+
 from apps.webui.utils import load_function_module_by_id
+
 from utils.misc import stream_message_template
+from utils.task import prompt_template
+
 
 from config import (
     WEBUI_BUILD_HASH,
@@ -186,6 +191,77 @@ async def get_pipe_models():
 
 
 async def generate_function_chat_completion(form_data, user):
+    model_id = form_data.get("model")
+    model_info = Models.get_model_by_id(model_id)
+
+    if model_info:
+        if model_info.base_model_id:
+            form_data["model"] = model_info.base_model_id
+
+        model_info.params = model_info.params.model_dump()
+
+        if model_info.params:
+            if model_info.params.get("temperature", None) is not None:
+                form_data["temperature"] = float(model_info.params.get("temperature"))
+
+            if model_info.params.get("top_p", None):
+                form_data["top_p"] = int(model_info.params.get("top_p", None))
+
+            if model_info.params.get("max_tokens", None):
+                form_data["max_tokens"] = int(model_info.params.get("max_tokens", None))
+
+            if model_info.params.get("frequency_penalty", None):
+                form_data["frequency_penalty"] = int(
+                    model_info.params.get("frequency_penalty", None)
+                )
+
+            if model_info.params.get("seed", None):
+                form_data["seed"] = model_info.params.get("seed", None)
+
+            if model_info.params.get("stop", None):
+                form_data["stop"] = (
+                    [
+                        bytes(stop, "utf-8").decode("unicode_escape")
+                        for stop in model_info.params["stop"]
+                    ]
+                    if model_info.params.get("stop", None)
+                    else None
+                )
+
+        system = model_info.params.get("system", None)
+        if system:
+            system = prompt_template(
+                system,
+                **(
+                    {
+                        "user_name": user.name,
+                        "user_location": (
+                            user.info.get("location") if user.info else None
+                        ),
+                    }
+                    if user
+                    else {}
+                ),
+            )
+            # Check if the payload already has a system message
+            # If not, add a system message to the payload
+            if form_data.get("messages"):
+                for message in form_data["messages"]:
+                    if message.get("role") == "system":
+                        message["content"] = system + message["content"]
+                        break
+                else:
+                    form_data["messages"].insert(
+                        0,
+                        {
+                            "role": "system",
+                            "content": system,
+                        },
+                    )
+
+    else:
+        pass
+
     async def job():
         pipe_id = form_data["model"]
         if "." in pipe_id:

+ 5 - 1
backend/main.py

@@ -975,12 +975,16 @@ async def get_all_models():
                     model["info"] = custom_model.model_dump()
         else:
             owned_by = "openai"
+            pipe = None
+
             for model in models:
                 if (
                     custom_model.base_model_id == model["id"]
                     or custom_model.base_model_id == model["id"].split(":")[0]
                 ):
                     owned_by = model["owned_by"]
+                    if "pipe" in model:
+                        pipe = model["pipe"]
                     break
 
             models.append(
@@ -992,11 +996,11 @@ async def get_all_models():
                     "owned_by": owned_by,
                     "info": custom_model.model_dump(),
                     "preset": True,
+                    **({"pipe": pipe} if pipe is not None else {}),
                 }
             )
 
     app.state.MODELS = {model["id"]: model for model in models}
-
     webui_app.state.MODELS = app.state.MODELS
 
     return models