|
@@ -9,6 +9,7 @@ import json
|
|
|
import logging
|
|
|
|
|
|
from pydantic import BaseModel
|
|
|
+from starlette.background import BackgroundTask
|
|
|
|
|
|
from apps.webui.models.models import Models
|
|
|
from apps.webui.models.users import Users
|
|
@@ -194,6 +195,16 @@ async def fetch_url(url, key):
|
|
|
return None
|
|
|
|
|
|
|
|
|
+async def cleanup_response(
|
|
|
+ response: Optional[aiohttp.ClientResponse],
|
|
|
+ session: Optional[aiohttp.ClientSession],
|
|
|
+):
|
|
|
+ if response:
|
|
|
+ response.close()
|
|
|
+ if session:
|
|
|
+ await session.close()
|
|
|
+
|
|
|
+
|
|
|
def merge_models_lists(model_lists):
|
|
|
log.debug(f"merge_models_lists {model_lists}")
|
|
|
merged_list = []
|
|
@@ -426,40 +437,45 @@ async def proxy(path: str, request: Request, user=Depends(get_verified_user)):
|
|
|
headers["Content-Type"] = "application/json"
|
|
|
|
|
|
r = None
|
|
|
+ session = None
|
|
|
+ streaming = False
|
|
|
|
|
|
try:
|
|
|
- r = requests.request(
|
|
|
- method=request.method,
|
|
|
- url=target_url,
|
|
|
- data=payload if payload else body,
|
|
|
- headers=headers,
|
|
|
- stream=True,
|
|
|
+ session = aiohttp.ClientSession()
|
|
|
+ r = await session.request(
|
|
|
+ method=request.method, url=target_url, data=payload, headers=headers
|
|
|
)
|
|
|
|
|
|
r.raise_for_status()
|
|
|
|
|
|
# Check if response is SSE
|
|
|
if "text/event-stream" in r.headers.get("Content-Type", ""):
|
|
|
+ streaming = True
|
|
|
return StreamingResponse(
|
|
|
- r.iter_content(chunk_size=8192),
|
|
|
- status_code=r.status_code,
|
|
|
+ r.content,
|
|
|
+ status_code=r.status,
|
|
|
headers=dict(r.headers),
|
|
|
+ background=BackgroundTask(
|
|
|
+ cleanup_response, response=r, session=session
|
|
|
+ ),
|
|
|
)
|
|
|
else:
|
|
|
- response_data = r.json()
|
|
|
+ response_data = await r.json()
|
|
|
return response_data
|
|
|
except Exception as e:
|
|
|
log.exception(e)
|
|
|
error_detail = "Open WebUI: Server Connection Error"
|
|
|
if r is not None:
|
|
|
try:
|
|
|
- res = r.json()
|
|
|
+ res = await r.json()
|
|
|
print(res)
|
|
|
if "error" in res:
|
|
|
error_detail = f"External: {res['error']['message'] if 'message' in res['error'] else res['error']}"
|
|
|
except:
|
|
|
error_detail = f"External: {e}"
|
|
|
-
|
|
|
- raise HTTPException(
|
|
|
- status_code=r.status_code if r else 500, detail=error_detail
|
|
|
- )
|
|
|
+ raise HTTPException(status_code=r.status if r else 500, detail=error_detail)
|
|
|
+ finally:
|
|
|
+ if not streaming and session:
|
|
|
+ if r:
|
|
|
+ r.close()
|
|
|
+ await session.close()
|