瀏覽代碼

fix: openai streaming cancellation using aiohttp

Jun Siang Cheah 11 月之前
父節點
當前提交
7f74426a22
共有 2 個文件被更改,包括 31 次插入15 次删除
  1. 1 1
      backend/apps/ollama/main.py
  2. 30 14
      backend/apps/openai/main.py

+ 1 - 1
backend/apps/ollama/main.py

@@ -153,7 +153,7 @@ async def cleanup_response(
         await session.close()
 
 
-async def post_streaming_url(url, payload):
+async def post_streaming_url(url: str, payload: str):
     r = None
     try:
         session = aiohttp.ClientSession()

+ 30 - 14
backend/apps/openai/main.py

@@ -9,6 +9,7 @@ import json
 import logging
 
 from pydantic import BaseModel
+from starlette.background import BackgroundTask
 
 from apps.webui.models.models import Models
 from apps.webui.models.users import Users
@@ -194,6 +195,16 @@ async def fetch_url(url, key):
         return None
 
 
+async def cleanup_response(
+    response: Optional[aiohttp.ClientResponse],
+    session: Optional[aiohttp.ClientSession],
+):
+    if response:
+        response.close()
+    if session:
+        await session.close()
+
+
 def merge_models_lists(model_lists):
     log.debug(f"merge_models_lists {model_lists}")
     merged_list = []
@@ -426,40 +437,45 @@ async def proxy(path: str, request: Request, user=Depends(get_verified_user)):
     headers["Content-Type"] = "application/json"
 
     r = None
+    session = None
+    streaming = False
 
     try:
-        r = requests.request(
-            method=request.method,
-            url=target_url,
-            data=payload if payload else body,
-            headers=headers,
-            stream=True,
+        session = aiohttp.ClientSession()
+        r = await session.request(
+            method=request.method, url=target_url, data=payload, headers=headers
         )
 
         r.raise_for_status()
 
         # Check if response is SSE
         if "text/event-stream" in r.headers.get("Content-Type", ""):
+            streaming = True
             return StreamingResponse(
-                r.iter_content(chunk_size=8192),
-                status_code=r.status_code,
+                r.content,
+                status_code=r.status,
                 headers=dict(r.headers),
+                background=BackgroundTask(
+                    cleanup_response, response=r, session=session
+                ),
             )
         else:
-            response_data = r.json()
+            response_data = await r.json()
             return response_data
     except Exception as e:
         log.exception(e)
         error_detail = "Open WebUI: Server Connection Error"
         if r is not None:
             try:
-                res = r.json()
+                res = await r.json()
                 print(res)
                 if "error" in res:
                     error_detail = f"External: {res['error']['message'] if 'message' in res['error'] else res['error']}"
             except:
                 error_detail = f"External: {e}"
-
-        raise HTTPException(
-            status_code=r.status_code if r else 500, detail=error_detail
-        )
+        raise HTTPException(status_code=r.status if r else 500, detail=error_detail)
+    finally:
+        if not streaming and session:
+            if r:
+                r.close()
+            await session.close()