Browse Source

Fix: O1 does not support the system parameter

kivvi 7 months ago
parent
commit
be74a4c9c1
1 changed files with 24 additions and 21 deletions
  1. 24 21
      backend/open_webui/apps/openai/main.py

+ 24 - 21
backend/open_webui/apps/openai/main.py

@@ -27,7 +27,6 @@ from fastapi.responses import FileResponse, StreamingResponse
 from pydantic import BaseModel
 from pydantic import BaseModel
 from starlette.background import BackgroundTask
 from starlette.background import BackgroundTask
 
 
-
 from open_webui.utils.payload import (
 from open_webui.utils.payload import (
     apply_model_params_to_body_openai,
     apply_model_params_to_body_openai,
     apply_model_system_prompt_to_body,
     apply_model_system_prompt_to_body,
@@ -47,7 +46,6 @@ app.add_middleware(
     allow_headers=["*"],
     allow_headers=["*"],
 )
 )
 
 
-
 app.state.config = AppConfig()
 app.state.config = AppConfig()
 
 
 app.state.config.ENABLE_MODEL_FILTER = ENABLE_MODEL_FILTER
 app.state.config.ENABLE_MODEL_FILTER = ENABLE_MODEL_FILTER
@@ -193,8 +191,8 @@ async def fetch_url(url, key):
 
 
 
 
 async def cleanup_response(
 async def cleanup_response(
-    response: Optional[aiohttp.ClientResponse],
-    session: Optional[aiohttp.ClientSession],
+        response: Optional[aiohttp.ClientResponse],
+        session: Optional[aiohttp.ClientSession],
 ):
 ):
     if response:
     if response:
         response.close()
         response.close()
@@ -219,18 +217,18 @@ def merge_models_lists(model_lists):
                     }
                     }
                     for model in models
                     for model in models
                     if "api.openai.com"
                     if "api.openai.com"
-                    not in app.state.config.OPENAI_API_BASE_URLS[idx]
-                    or not any(
-                        name in model["id"]
-                        for name in [
-                            "babbage",
-                            "dall-e",
-                            "davinci",
-                            "embedding",
-                            "tts",
-                            "whisper",
-                        ]
-                    )
+                       not in app.state.config.OPENAI_API_BASE_URLS[idx]
+                       or not any(
+                    name in model["id"]
+                    for name in [
+                        "babbage",
+                        "dall-e",
+                        "davinci",
+                        "embedding",
+                        "tts",
+                        "whisper",
+                    ]
+                )
                 ]
                 ]
             )
             )
 
 
@@ -373,9 +371,9 @@ async def get_models(url_idx: Optional[int] = None, user=Depends(get_verified_us
 @app.post("/chat/completions")
 @app.post("/chat/completions")
 @app.post("/chat/completions/{url_idx}")
 @app.post("/chat/completions/{url_idx}")
 async def generate_chat_completion(
 async def generate_chat_completion(
-    form_data: dict,
-    url_idx: Optional[int] = None,
-    user=Depends(get_verified_user),
+        form_data: dict,
+        url_idx: Optional[int] = None,
+        user=Depends(get_verified_user),
 ):
 ):
     idx = 0
     idx = 0
     payload = {**form_data}
     payload = {**form_data}
@@ -407,20 +405,25 @@ async def generate_chat_completion(
 
 
     url = app.state.config.OPENAI_API_BASE_URLS[idx]
     url = app.state.config.OPENAI_API_BASE_URLS[idx]
     key = app.state.config.OPENAI_API_KEYS[idx]
     key = app.state.config.OPENAI_API_KEYS[idx]
+    is_o1 = payload["model"].lower().startswith("o1")
 
 
     # Change max_completion_tokens to max_tokens (Backward compatible)
     # Change max_completion_tokens to max_tokens (Backward compatible)
-    if "api.openai.com" not in url and not payload["model"].lower().startswith("o1-"):
+    if "api.openai.com" not in url and not is_o1:
         if "max_completion_tokens" in payload:
         if "max_completion_tokens" in payload:
             # Remove "max_completion_tokens" from the payload
             # Remove "max_completion_tokens" from the payload
             payload["max_tokens"] = payload["max_completion_tokens"]
             payload["max_tokens"] = payload["max_completion_tokens"]
             del payload["max_completion_tokens"]
             del payload["max_completion_tokens"]
     else:
     else:
-        if payload["model"].lower().startswith("o1-") and "max_tokens" in payload:
+        if is_o1 and "max_tokens" in payload:
             payload["max_completion_tokens"] = payload["max_tokens"]
             payload["max_completion_tokens"] = payload["max_tokens"]
             del payload["max_tokens"]
             del payload["max_tokens"]
         if "max_tokens" in payload and "max_completion_tokens" in payload:
         if "max_tokens" in payload and "max_completion_tokens" in payload:
             del payload["max_tokens"]
             del payload["max_tokens"]
 
 
+    # Fix: O1 does not support the "system" parameter, Modify "system" to "user"
+    if is_o1 and payload["messages"][0]["role"] == "system":
+        payload["messages"][0]["role"] = "user"
+
     # Convert the modified body back to JSON
     # Convert the modified body back to JSON
     payload = json.dumps(payload)
     payload = json.dumps(payload)