Explorar o código

Merge pull request #5758 from kivvi3412/fix_o1_system_message

Fix: O1 does not support the system parameter
Timothy Jaeryang Baek hai 7 meses
pai
achega
c30c876659
Modificáronse 1 ficheiros con 7 adicións e 4 borrados
  1. 7 4
      backend/open_webui/apps/openai/main.py

+ 7 - 4
backend/open_webui/apps/openai/main.py

@@ -27,7 +27,6 @@ from fastapi.responses import FileResponse, StreamingResponse
 from pydantic import BaseModel
 from starlette.background import BackgroundTask
 
-
 from open_webui.utils.payload import (
     apply_model_params_to_body_openai,
     apply_model_system_prompt_to_body,
@@ -47,7 +46,6 @@ app.add_middleware(
     allow_headers=["*"],
 )
 
-
 app.state.config = AppConfig()
 
 app.state.config.ENABLE_MODEL_FILTER = ENABLE_MODEL_FILTER
@@ -407,20 +405,25 @@ async def generate_chat_completion(
 
     url = app.state.config.OPENAI_API_BASE_URLS[idx]
     key = app.state.config.OPENAI_API_KEYS[idx]
+    is_o1 = payload["model"].lower().startswith("o1-")
 
     # Change max_completion_tokens to max_tokens (Backward compatible)
-    if "api.openai.com" not in url and not payload["model"].lower().startswith("o1-"):
+    if "api.openai.com" not in url and not is_o1:
         if "max_completion_tokens" in payload:
             # Remove "max_completion_tokens" from the payload
             payload["max_tokens"] = payload["max_completion_tokens"]
             del payload["max_completion_tokens"]
     else:
-        if payload["model"].lower().startswith("o1-") and "max_tokens" in payload:
+        if is_o1 and "max_tokens" in payload:
             payload["max_completion_tokens"] = payload["max_tokens"]
             del payload["max_tokens"]
         if "max_tokens" in payload and "max_completion_tokens" in payload:
             del payload["max_tokens"]
 
+    # Fix: O1 does not support the "system" parameter, Modify "system" to "user"
+    if is_o1 and payload["messages"][0]["role"] == "system":
+        payload["messages"][0]["role"] = "user"
+
     # Convert the modified body back to JSON
     payload = json.dumps(payload)