Browse Source

fix: prevent TTS blocking using aiohttp and aiofiles

houcheng 5 months ago
parent
commit
a83f89d430
2 changed files with 49 additions and 58 deletions
  1. 48 58
      backend/open_webui/apps/audio/main.py
  2. 1 0
      backend/requirements.txt

+ 48 - 58
backend/open_webui/apps/audio/main.py

@@ -8,6 +8,8 @@ from pathlib import Path
 from pydub import AudioSegment
 from pydub.silence import split_on_silence
 
+import aiohttp
+import aiofiles
 import requests
 from open_webui.config import (
     AUDIO_STT_ENGINE,
@@ -292,46 +294,39 @@ async def speech(request: Request, user=Depends(get_verified_user)):
         except Exception:
             pass
 
-        r = None
         try:
-            r = requests.post(
-                url=f"{app.state.config.TTS_OPENAI_API_BASE_URL}/audio/speech",
-                data=body,
-                headers=headers,
-                stream=True,
-            )
+            async with aiohttp.ClientSession() as session:
+                async with session.post(
+                    url=f"{app.state.config.TTS_OPENAI_API_BASE_URL}/audio/speech",
+                    data=body,
+                    headers=headers
+                ) as r:
+                    r.raise_for_status()
+                    async with aiofiles.open(file_path, "wb") as f:
+                        await f.write(await r.read())
+                    
+                    async with aiofiles.open(file_body_path, "w") as f:
+                        await f.write(json.dumps(json.loads(body.decode("utf-8"))))
 
-            r.raise_for_status()
-
-            # Save the streaming content to a file
-            with open(file_path, "wb") as f:
-                for chunk in r.iter_content(chunk_size=8192):
-                    f.write(chunk)
-
-            with open(file_body_path, "w") as f:
-                json.dump(json.loads(body.decode("utf-8")), f)
-
-            # Return the saved file
             return FileResponse(file_path)
 
         except Exception as e:
             log.exception(e)
             error_detail = "Open WebUI: Server Connection Error"
-            if r is not None:
-                try:
-                    res = r.json()
+            try:
+                if r.status != 200:
+                    res = await r.json()
                     if "error" in res:
                         error_detail = f"External: {res['error']['message']}"
-                except Exception:
-                    error_detail = f"External: {e}"
+            except Exception:
+                error_detail = f"External: {e}"
 
             raise HTTPException(
-                status_code=r.status_code if r != None else 500,
+                status_code=getattr(r, 'status', 500),
                 detail=error_detail,
             )
 
     elif app.state.config.TTS_ENGINE == "elevenlabs":
-        payload = None
         try:
             payload = json.loads(body.decode("utf-8"))
         except Exception as e:
@@ -339,7 +334,6 @@ async def speech(request: Request, user=Depends(get_verified_user)):
             raise HTTPException(status_code=400, detail="Invalid JSON payload")
 
         voice_id = payload.get("voice", "")
-
         if voice_id not in get_available_voices():
             raise HTTPException(
                 status_code=400,
@@ -347,13 +341,11 @@ async def speech(request: Request, user=Depends(get_verified_user)):
             )
 
         url = f"https://api.elevenlabs.io/v1/text-to-speech/{voice_id}"
-
         headers = {
             "Accept": "audio/mpeg",
             "Content-Type": "application/json",
             "xi-api-key": app.state.config.TTS_API_KEY,
         }
-
         data = {
             "text": payload["input"],
             "model_id": app.state.config.TTS_MODEL,
@@ -361,39 +353,34 @@ async def speech(request: Request, user=Depends(get_verified_user)):
         }
 
         try:
-            r = requests.post(url, json=data, headers=headers)
+            async with aiohttp.ClientSession() as session:
+                async with session.post(url, json=data, headers=headers) as r:
+                    r.raise_for_status()
+                    async with aiofiles.open(file_path, "wb") as f:
+                        await f.write(await r.read())
+                    
+                    async with aiofiles.open(file_body_path, "w") as f:
+                        await f.write(json.dumps(json.loads(body.decode("utf-8"))))
 
-            r.raise_for_status()
-
-            # Save the streaming content to a file
-            with open(file_path, "wb") as f:
-                for chunk in r.iter_content(chunk_size=8192):
-                    f.write(chunk)
-
-            with open(file_body_path, "w") as f:
-                json.dump(json.loads(body.decode("utf-8")), f)
-
-            # Return the saved file
             return FileResponse(file_path)
 
         except Exception as e:
             log.exception(e)
             error_detail = "Open WebUI: Server Connection Error"
-            if r is not None:
-                try:
-                    res = r.json()
+            try:
+                if r.status != 200:
+                    res = await r.json()
                     if "error" in res:
                         error_detail = f"External: {res['error']['message']}"
-                except Exception:
-                    error_detail = f"External: {e}"
+            except Exception:
+                error_detail = f"External: {e}"
 
             raise HTTPException(
-                status_code=r.status_code if r != None else 500,
+                status_code=getattr(r, 'status', 500),
                 detail=error_detail,
             )
 
     elif app.state.config.TTS_ENGINE == "azure":
-        payload = None
         try:
             payload = json.loads(body.decode("utf-8"))
         except Exception as e:
@@ -416,17 +403,20 @@ async def speech(request: Request, user=Depends(get_verified_user)):
                 <voice name="{language}">{payload["input"]}</voice>
             </speak>"""
 
-        response = requests.post(url, headers=headers, data=data)
-
-        if response.status_code == 200:
-            with open(file_path, "wb") as f:
-                f.write(response.content)
-            return FileResponse(file_path)
-        else:
-            log.error(f"Error synthesizing speech - {response.reason}")
-            raise HTTPException(
-                status_code=500, detail=f"Error synthesizing speech - {response.reason}"
-            )
+        try:
+            async with aiohttp.ClientSession() as session:
+                async with session.post(url, headers=headers, data=data) as response:
+                    if response.status == 200:
+                        async with aiofiles.open(file_path, "wb") as f:
+                            await f.write(await response.read())
+                        return FileResponse(file_path)
+                    else:
+                        error_msg = f"Error synthesizing speech - {response.reason}"
+                        log.error(error_msg)
+                        raise HTTPException(status_code=500, detail=error_msg)
+        except Exception as e:
+            log.exception(e)
+            raise HTTPException(status_code=500, detail=str(e))
     elif app.state.config.TTS_ENGINE == "transformers":
         payload = None
         try:

+ 1 - 0
backend/requirements.txt

@@ -14,6 +14,7 @@ requests==2.32.3
 aiohttp==3.10.8
 async-timeout
 aiocache
+aiofiles
 
 sqlalchemy==2.0.32
 alembic==1.13.2