소스 검색

Merge pull request #4674 from crizCraig/sanitize-11labs-voiceid

sec: Sanitize 11labs voice id to address semgrep security issue: tainted-path-traversal-stdlib-fastapi
Timothy Jaeryang Baek 8 달 전
부모
커밋
bd8df3583d
2개의 변경된 파일72개의 추가작업 그리고 51개의 파일을 삭제
  1. 70 49
      backend/apps/audio/main.py
  2. 2 2
      backend/config.py

+ 70 - 49
backend/apps/audio/main.py

@@ -1,5 +1,12 @@
-import os
+import hashlib
+import json
 import logging
+import os
+import uuid
+from functools import lru_cache
+from pathlib import Path
+
+import requests
 from fastapi import (
     FastAPI,
     Request,
@@ -8,34 +15,14 @@ from fastapi import (
     status,
     UploadFile,
     File,
-    Form,
 )
-from fastapi.responses import StreamingResponse, JSONResponse, FileResponse
-
 from fastapi.middleware.cors import CORSMiddleware
+from fastapi.responses import FileResponse
 from pydantic import BaseModel
 
-
-import uuid
-import requests
-import hashlib
-from pathlib import Path
-import json
-
-from constants import ERROR_MESSAGES
-from utils.utils import (
-    decode_token,
-    get_current_user,
-    get_verified_user,
-    get_admin_user,
-)
-from utils.misc import calculate_sha256
-
-
 from config import (
     SRC_LOG_LEVELS,
     CACHE_DIR,
-    UPLOAD_DIR,
     WHISPER_MODEL,
     WHISPER_MODEL_DIR,
     WHISPER_MODEL_AUTO_UPDATE,
@@ -52,6 +39,12 @@ from config import (
     AUDIO_TTS_VOICE,
     AppConfig,
 )
+from constants import ERROR_MESSAGES
+from utils.utils import (
+    get_current_user,
+    get_verified_user,
+    get_admin_user,
+)
 
 log = logging.getLogger(__name__)
 log.setLevel(SRC_LOG_LEVELS["AUDIO"])
@@ -261,6 +254,13 @@ async def speech(request: Request, user=Depends(get_verified_user)):
             raise HTTPException(status_code=400, detail="Invalid JSON payload")
 
         voice_id = payload.get("voice", "")
+
+        if voice_id not in get_available_voices():
+            raise HTTPException(
+                status_code=400,
+                detail="Invalid voice id",
+            )
+
         url = f"https://api.elevenlabs.io/v1/text-to-speech/{voice_id}"
 
         headers = {
@@ -466,39 +466,60 @@ async def get_models(user=Depends(get_verified_user)):
     return {"models": get_available_models()}
 
 
-def get_available_voices() -> list[dict]:
+def get_available_voices() -> dict:
+    """Returns {voice_id: voice_name} dict"""
+    ret = {}
     if app.state.config.TTS_ENGINE == "openai":
-        return [
-            {"name": "alloy", "id": "alloy"},
-            {"name": "echo", "id": "echo"},
-            {"name": "fable", "id": "fable"},
-            {"name": "onyx", "id": "onyx"},
-            {"name": "nova", "id": "nova"},
-            {"name": "shimmer", "id": "shimmer"},
-        ]
-    elif app.state.config.TTS_ENGINE == "elevenlabs":
-        headers = {
-            "xi-api-key": app.state.config.TTS_API_KEY,
-            "Content-Type": "application/json",
+        ret = {
+            "alloy": "alloy",
+            "echo": "echo",
+            "fable": "fable",
+            "onyx": "onyx",
+            "nova": "nova",
+            "shimmer": "shimmer",
         }
-
+    elif app.state.config.TTS_ENGINE == "elevenlabs":
         try:
-            response = requests.get(
-                "https://api.elevenlabs.io/v1/voices", headers=headers
-            )
-            response.raise_for_status()
-            voices_data = response.json()
+            ret = get_elevenlabs_voices()
+        except Exception as e:
+            # Avoided @lru_cache with exception
+            pass
 
-            voices = []
-            for voice in voices_data.get("voices", []):
-                voices.append({"name": voice["name"], "id": voice["voice_id"]})
-            return voices
-        except requests.RequestException as e:
-            log.error(f"Error fetching voices: {str(e)}")
+    return ret
+
+
+@lru_cache
+def get_elevenlabs_voices() -> dict:
+    """
+    Note, set the following in your .env file to use Elevenlabs:
+    AUDIO_TTS_ENGINE=elevenlabs
+    AUDIO_TTS_API_KEY=sk_...  # Your Elevenlabs API key
+    AUDIO_TTS_VOICE=EXAVITQu4vr4xnSDxMaL  # From https://api.elevenlabs.io/v1/voices
+    AUDIO_TTS_MODEL=eleven_multilingual_v2
+    """
+    headers = {
+        "xi-api-key": app.state.config.TTS_API_KEY,
+        "Content-Type": "application/json",
+    }
+    try:
+        # TODO: Add retries
+        response = requests.get(
+            "https://api.elevenlabs.io/v1/voices", headers=headers
+        )
+        response.raise_for_status()
+        voices_data = response.json()
 
-    return []
+        voices = {}
+        for voice in voices_data.get("voices", []):
+            voices[voice["voice_id"]] = voice["name"]
+    except requests.RequestException as e:
+        # Avoid @lru_cache with exception
+        log.error(f"Error fetching voices: {str(e)}")
+        raise RuntimeError(f"Error fetching voices: {str(e)}")
+
+    return voices
 
 
 @app.get("/voices")
 async def get_voices(user=Depends(get_verified_user)):
-    return {"voices": get_available_voices()}
+    return {"voices": [{"id": k, "name": v} for k, v in get_available_voices().items()]}

+ 2 - 2
backend/config.py

@@ -1410,13 +1410,13 @@ AUDIO_TTS_ENGINE = PersistentConfig(
 AUDIO_TTS_MODEL = PersistentConfig(
     "AUDIO_TTS_MODEL",
     "audio.tts.model",
-    os.getenv("AUDIO_TTS_MODEL", "tts-1"),
+    os.getenv("AUDIO_TTS_MODEL", "tts-1"),  # OpenAI default model
 )
 
 AUDIO_TTS_VOICE = PersistentConfig(
     "AUDIO_TTS_VOICE",
     "audio.tts.voice",
-    os.getenv("AUDIO_TTS_VOICE", "alloy"),
+    os.getenv("AUDIO_TTS_VOICE", "alloy"),  # OpenAI default voice
 )