Timothy J. Baek 10 miesięcy temu
rodzic
commit
c44fc82ecd
1 zmienionych plików z 129 dodań i 81 usunięć
  1. 129 81
      backend/apps/openai/main.py

+ 129 - 81
backend/apps/openai/main.py

@@ -345,108 +345,156 @@ async def get_models(url_idx: Optional[int] = None, user=Depends(get_current_use
             )
 
 
-@app.api_route("/{path:path}", methods=["GET", "POST", "PUT", "DELETE"])
-async def proxy(path: str, request: Request, user=Depends(get_verified_user)):
+@app.post("/chat/completions")
+@app.post("/chat/completions/{url_idx}")
+async def generate_chat_completion(
+    form_data: dict,
+    url_idx: Optional[int] = None,
+    user=Depends(get_verified_user),
+):
     idx = 0
+    payload = {**form_data}
 
-    body = await request.body()
-    # TODO: Remove below after gpt-4-vision fix from Open AI
-    # Try to decode the body of the request from bytes to a UTF-8 string (Require add max_token to fix gpt-4-vision)
+    model_id = form_data.get("model")
+    model_info = Models.get_model_by_id(model_id)
 
-    payload = None
+    if model_info:
+        print(model_info)
+        if model_info.base_model_id:
+            payload["model"] = model_info.base_model_id
 
-    try:
-        if "chat/completions" in path:
-            body = body.decode("utf-8")
-            body = json.loads(body)
+        model_info.params = model_info.params.model_dump()
 
-            payload = {**body}
+        if model_info.params:
+            if model_info.params.get("temperature", None) is not None:
+                payload["temperature"] = float(model_info.params.get("temperature"))
 
-            model_id = body.get("model")
-            model_info = Models.get_model_by_id(model_id)
+            if model_info.params.get("top_p", None):
+                payload["top_p"] = int(model_info.params.get("top_p", None))
 
-            if model_info:
-                print(model_info)
-                if model_info.base_model_id:
-                    payload["model"] = model_info.base_model_id
+            if model_info.params.get("max_tokens", None):
+                payload["max_tokens"] = int(model_info.params.get("max_tokens", None))
 
-                model_info.params = model_info.params.model_dump()
+            if model_info.params.get("frequency_penalty", None):
+                payload["frequency_penalty"] = int(
+                    model_info.params.get("frequency_penalty", None)
+                )
+
+            if model_info.params.get("seed", None):
+                payload["seed"] = model_info.params.get("seed", None)
+
+            if model_info.params.get("stop", None):
+                payload["stop"] = (
+                    [
+                        bytes(stop, "utf-8").decode("unicode_escape")
+                        for stop in model_info.params["stop"]
+                    ]
+                    if model_info.params.get("stop", None)
+                    else None
+                )
 
-                if model_info.params:
-                    if model_info.params.get("temperature", None) is not None:
-                        payload["temperature"] = float(
-                            model_info.params.get("temperature")
+        if model_info.params.get("system", None):
+            # Check if the payload already has a system message
+            # If not, add a system message to the payload
+            if payload.get("messages"):
+                for message in payload["messages"]:
+                    if message.get("role") == "system":
+                        message["content"] = (
+                            model_info.params.get("system", None) + message["content"]
                         )
+                        break
+                else:
+                    payload["messages"].insert(
+                        0,
+                        {
+                            "role": "system",
+                            "content": model_info.params.get("system", None),
+                        },
+                    )
 
-                    if model_info.params.get("top_p", None):
-                        payload["top_p"] = int(model_info.params.get("top_p", None))
+    else:
+        pass
 
-                    if model_info.params.get("max_tokens", None):
-                        payload["max_tokens"] = int(
-                            model_info.params.get("max_tokens", None)
-                        )
+    model = app.state.MODELS[payload.get("model")]
+    idx = model["urlIdx"]
 
-                    if model_info.params.get("frequency_penalty", None):
-                        payload["frequency_penalty"] = int(
-                            model_info.params.get("frequency_penalty", None)
-                        )
+    if "pipeline" in model and model.get("pipeline"):
+        payload["user"] = {"name": user.name, "id": user.id}
 
-                    if model_info.params.get("seed", None):
-                        payload["seed"] = model_info.params.get("seed", None)
-
-                    if model_info.params.get("stop", None):
-                        payload["stop"] = (
-                            [
-                                bytes(stop, "utf-8").decode("unicode_escape")
-                                for stop in model_info.params["stop"]
-                            ]
-                            if model_info.params.get("stop", None)
-                            else None
-                        )
+    # Check if the model is "gpt-4-vision-preview" and set "max_tokens" to 4000
+    # This is a workaround until OpenAI fixes the issue with this model
+    if payload.get("model") == "gpt-4-vision-preview":
+        if "max_tokens" not in payload:
+            payload["max_tokens"] = 4000
+        log.debug("Modified payload:", payload)
 
-                if model_info.params.get("system", None):
-                    # Check if the payload already has a system message
-                    # If not, add a system message to the payload
-                    if payload.get("messages"):
-                        for message in payload["messages"]:
-                            if message.get("role") == "system":
-                                message["content"] = (
-                                    model_info.params.get("system", None)
-                                    + message["content"]
-                                )
-                                break
-                        else:
-                            payload["messages"].insert(
-                                0,
-                                {
-                                    "role": "system",
-                                    "content": model_info.params.get("system", None),
-                                },
-                            )
-            else:
-                pass
+    # Convert the modified body back to JSON
+    payload = json.dumps(payload)
+
+    print(payload)
 
-            model = app.state.MODELS[payload.get("model")]
+    url = app.state.config.OPENAI_API_BASE_URLS[idx]
+    key = app.state.config.OPENAI_API_KEYS[idx]
 
-            idx = model["urlIdx"]
+    print(payload)
 
-            if "pipeline" in model and model.get("pipeline"):
-                payload["user"] = {"name": user.name, "id": user.id}
+    headers = {}
+    headers["Authorization"] = f"Bearer {key}"
+    headers["Content-Type"] = "application/json"
 
-            # Check if the model is "gpt-4-vision-preview" and set "max_tokens" to 4000
-            # This is a workaround until OpenAI fixes the issue with this model
-            if payload.get("model") == "gpt-4-vision-preview":
-                if "max_tokens" not in payload:
-                    payload["max_tokens"] = 4000
-                log.debug("Modified payload:", payload)
+    r = None
+    session = None
+    streaming = False
 
-            # Convert the modified body back to JSON
-            payload = json.dumps(payload)
+    try:
+        session = aiohttp.ClientSession(trust_env=True)
+        r = await session.request(
+            method="POST",
+            url=f"{url}/chat/completions",
+            data=payload,
+            headers=headers,
+        )
 
-    except json.JSONDecodeError as e:
-        log.error("Error loading request body into a dictionary:", e)
+        r.raise_for_status()
 
-    print(payload)
+        # Check if response is SSE
+        if "text/event-stream" in r.headers.get("Content-Type", ""):
+            streaming = True
+            return StreamingResponse(
+                r.content,
+                status_code=r.status,
+                headers=dict(r.headers),
+                background=BackgroundTask(
+                    cleanup_response, response=r, session=session
+                ),
+            )
+        else:
+            response_data = await r.json()
+            return response_data
+    except Exception as e:
+        log.exception(e)
+        error_detail = "Open WebUI: Server Connection Error"
+        if r is not None:
+            try:
+                res = await r.json()
+                print(res)
+                if "error" in res:
+                    error_detail = f"External: {res['error']['message'] if 'message' in res['error'] else res['error']}"
+            except:
+                error_detail = f"External: {e}"
+        raise HTTPException(status_code=r.status if r else 500, detail=error_detail)
+    finally:
+        if not streaming and session:
+            if r:
+                r.close()
+            await session.close()
+
+
+@app.api_route("/{path:path}", methods=["GET", "POST", "PUT", "DELETE"])
+async def proxy(path: str, request: Request, user=Depends(get_verified_user)):
+    idx = 0
+
+    body = await request.body()
 
     url = app.state.config.OPENAI_API_BASE_URLS[idx]
     key = app.state.config.OPENAI_API_KEYS[idx]
@@ -466,7 +514,7 @@ async def proxy(path: str, request: Request, user=Depends(get_verified_user)):
         r = await session.request(
             method=request.method,
             url=target_url,
-            data=payload if payload else body,
+            data=body,
             headers=headers,
         )