Browse Source

Update audio.py

Timothy Jaeryang Baek 4 months ago
parent
commit
87d695caad
1 changed files with 89 additions and 68 deletions
  1. 89 68
      backend/open_webui/routers/audio.py

+ 89 - 68
backend/open_webui/routers/audio.py

@@ -113,6 +113,13 @@ def set_faster_whisper_model(model: str, auto_update: bool = False):
     return whisper_model
 
 
+##########################################
+#
+# Audio API
+#
+##########################################
+
+
 class TTSConfigForm(BaseModel):
     OPENAI_API_BASE_URL: str
     OPENAI_API_KEY: str
@@ -238,35 +245,38 @@ async def speech(request: Request, user=Depends(get_verified_user)):
     if file_path.is_file():
         return FileResponse(file_path)
 
-    if request.app.state.config.TTS_ENGINE == "openai":
-        headers = {}
-        headers["Authorization"] = (
-            f"Bearer {request.app.state.config.TTS_OPENAI_API_KEY}"
-        )
-        headers["Content-Type"] = "application/json"
-
-        if ENABLE_FORWARD_USER_INFO_HEADERS:
-            headers["X-OpenWebUI-User-Name"] = user.name
-            headers["X-OpenWebUI-User-Id"] = user.id
-            headers["X-OpenWebUI-User-Email"] = user.email
-            headers["X-OpenWebUI-User-Role"] = user.role
+    payload = None
+    try:
+        payload = json.loads(body.decode("utf-8"))
+    except Exception as e:
+        log.exception(e)
+        raise HTTPException(status_code=400, detail="Invalid JSON payload")
 
-        try:
-            body = body.decode("utf-8")
-            body = json.loads(body)
-            body["model"] = request.app.state.config.TTS_MODEL
-            body = json.dumps(body).encode("utf-8")
-        except Exception:
-            pass
+    if request.app.state.config.TTS_ENGINE == "openai":
+        payload["model"] = request.app.state.config.TTS_MODEL
 
         try:
             async with aiohttp.ClientSession() as session:
                 async with session.post(
                     url=f"{request.app.state.config.TTS_OPENAI_API_BASE_URL}/audio/speech",
-                    data=body,
-                    headers=headers,
+                    data=payload,
+                    headers={
+                        "Content-Type": "application/json",
+                        "Authorization": f"Bearer {request.app.state.config.TTS_OPENAI_API_KEY}",
+                        **(
+                            {
+                                "X-OpenWebUI-User-Name": user.name,
+                                "X-OpenWebUI-User-Id": user.id,
+                                "X-OpenWebUI-User-Email": user.email,
+                                "X-OpenWebUI-User-Role": user.role,
+                            }
+                            if ENABLE_FORWARD_USER_INFO_HEADERS
+                            else {}
+                        ),
+                    },
                 ) as r:
                     r.raise_for_status()
+
                     async with aiofiles.open(file_path, "wb") as f:
                         await f.write(await r.read())
 
@@ -277,50 +287,47 @@ async def speech(request: Request, user=Depends(get_verified_user)):
 
         except Exception as e:
             log.exception(e)
-            error_detail = "Open WebUI: Server Connection Error"
+            detail = None
+
             try:
                 if r.status != 200:
                     res = await r.json()
                     if "error" in res:
-                        error_detail = f"External: {res['error']['message']}"
+                        detail = f"External: {res['error'].get('message', '')}"
             except Exception:
-                error_detail = f"External: {e}"
+                detail = f"External: {e}"
 
             raise HTTPException(
                 status_code=getattr(r, "status", 500),
-                detail=error_detail,
+                detail=detail if detail else "Open WebUI: Server Connection Error",
             )
 
     elif request.app.state.config.TTS_ENGINE == "elevenlabs":
-        try:
-            payload = json.loads(body.decode("utf-8"))
-        except Exception as e:
-            log.exception(e)
-            raise HTTPException(status_code=400, detail="Invalid JSON payload")
-
         voice_id = payload.get("voice", "")
+
         if voice_id not in get_available_voices():
             raise HTTPException(
                 status_code=400,
                 detail="Invalid voice id",
             )
 
-        url = f"https://api.elevenlabs.io/v1/text-to-speech/{voice_id}"
-        headers = {
-            "Accept": "audio/mpeg",
-            "Content-Type": "application/json",
-            "xi-api-key": request.app.state.config.TTS_API_KEY,
-        }
-        data = {
-            "text": payload["input"],
-            "model_id": request.app.state.config.TTS_MODEL,
-            "voice_settings": {"stability": 0.5, "similarity_boost": 0.5},
-        }
-
         try:
             async with aiohttp.ClientSession() as session:
-                async with session.post(url, json=data, headers=headers) as r:
+                async with session.post(
+                    f"https://api.elevenlabs.io/v1/text-to-speech/{voice_id}",
+                    json={
+                        "text": payload["input"],
+                        "model_id": request.app.state.config.TTS_MODEL,
+                        "voice_settings": {"stability": 0.5, "similarity_boost": 0.5},
+                    },
+                    headers={
+                        "Accept": "audio/mpeg",
+                        "Content-Type": "application/json",
+                        "xi-api-key": request.app.state.config.TTS_API_KEY,
+                    },
+                ) as r:
                     r.raise_for_status()
+
                     async with aiofiles.open(file_path, "wb") as f:
                         await f.write(await r.read())
 
@@ -331,18 +338,19 @@ async def speech(request: Request, user=Depends(get_verified_user)):
 
         except Exception as e:
             log.exception(e)
-            error_detail = "Open WebUI: Server Connection Error"
+            detail = None
+
             try:
                 if r.status != 200:
                     res = await r.json()
                     if "error" in res:
-                        error_detail = f"External: {res['error']['message']}"
+                        detail = f"External: {res['error'].get('message', '')}"
             except Exception:
-                error_detail = f"External: {e}"
+                detail = f"External: {e}"
 
             raise HTTPException(
                 status_code=getattr(r, "status", 500),
-                detail=error_detail,
+                detail=detail if detail else "Open WebUI: Server Connection Error",
             )
 
     elif request.app.state.config.TTS_ENGINE == "azure":
@@ -356,32 +364,45 @@ async def speech(request: Request, user=Depends(get_verified_user)):
         language = request.app.state.config.TTS_VOICE
         locale = "-".join(request.app.state.config.TTS_VOICE.split("-")[:1])
         output_format = request.app.state.config.TTS_AZURE_SPEECH_OUTPUT_FORMAT
-        url = f"https://{region}.tts.speech.microsoft.com/cognitiveservices/v1"
-
-        headers = {
-            "Ocp-Apim-Subscription-Key": request.app.state.config.TTS_API_KEY,
-            "Content-Type": "application/ssml+xml",
-            "X-Microsoft-OutputFormat": output_format,
-        }
 
-        data = f"""<speak version="1.0" xmlns="http://www.w3.org/2001/10/synthesis" xml:lang="{locale}">
+        try:
+            data = f"""<speak version="1.0" xmlns="http://www.w3.org/2001/10/synthesis" xml:lang="{locale}">
                 <voice name="{language}">{payload["input"]}</voice>
             </speak>"""
-
-        try:
             async with aiohttp.ClientSession() as session:
-                async with session.post(url, headers=headers, data=data) as response:
-                    if response.status == 200:
-                        async with aiofiles.open(file_path, "wb") as f:
-                            await f.write(await response.read())
-                        return FileResponse(file_path)
-                    else:
-                        error_msg = f"Error synthesizing speech - {response.reason}"
-                        log.error(error_msg)
-                        raise HTTPException(status_code=500, detail=error_msg)
+                async with session.post(
+                    f"https://{region}.tts.speech.microsoft.com/cognitiveservices/v1",
+                    headers={
+                        "Ocp-Apim-Subscription-Key": request.app.state.config.TTS_API_KEY,
+                        "Content-Type": "application/ssml+xml",
+                        "X-Microsoft-OutputFormat": output_format,
+                    },
+                    data=data,
+                ) as r:
+                    r.raise_for_status()
+
+                    async with aiofiles.open(file_path, "wb") as f:
+                        await f.write(await r.read())
+
+                    return FileResponse(file_path)
+
         except Exception as e:
             log.exception(e)
-            raise HTTPException(status_code=500, detail=str(e))
+            detail = None
+
+            try:
+                if r.status != 200:
+                    res = await r.json()
+                    if "error" in res:
+                        detail = f"External: {res['error'].get('message', '')}"
+            except Exception:
+                detail = f"External: {e}"
+
+            raise HTTPException(
+                status_code=getattr(r, "status", 500),
+                detail=detail if detail else "Open WebUI: Server Connection Error",
+            )
+
     elif request.app.state.config.TTS_ENGINE == "transformers":
         payload = None
         try: