Timothy J. Baek 9 месяцев назад
Родитель
Сommit
b559bc84a7
3 измененных файлов с 175 добавлено и 102 удалено
  1. 65 39
      backend/apps/audio/main.py
  2. 56 0
      src/lib/apis/audio/index.ts
  3. 54 63
      src/lib/components/admin/Settings/Audio.svelte

+ 65 - 39
backend/apps/audio/main.py

@@ -10,12 +10,12 @@ from fastapi import (
     File,
     Form,
 )
-
 from fastapi.responses import StreamingResponse, JSONResponse, FileResponse
 
 from fastapi.middleware.cors import CORSMiddleware
 from pydantic import BaseModel
 
+from typing import List
 import uuid
 import requests
 import hashlib
@@ -31,6 +31,7 @@ from utils.utils import (
 )
 from utils.misc import calculate_sha256
 
+
 from config import (
     SRC_LOG_LEVELS,
     CACHE_DIR,
@@ -134,35 +135,6 @@ def convert_mp4_to_wav(file_path, output_path):
     print(f"Converted {file_path} to {output_path}")
 
 
-async def get_available_voices():
-    if app.state.config.TTS_ENGINE != "elevenlabs":
-        return {}
-
-    base_url = "https://api.elevenlabs.io/v1"
-    headers = {
-        "xi-api-key": app.state.config.TTS_API_KEY,
-        "Content-Type": "application/json",
-    }
-
-    voices_url = f"{base_url}/voices"
-    try:
-        response = requests.get(voices_url, headers=headers)
-        response.raise_for_status()
-        voices_data = response.json()
-
-        voice_options = {}
-        for voice in voices_data.get("voices", []):
-            voice_name = voice["name"]
-            voice_id = voice["voice_id"]
-            voice_options[voice_name] = voice_id
-
-        return voice_options
-
-    except requests.RequestException as e:
-        log.error(f"Error fetching voices: {str(e)}")
-        return {}
-
-
 @app.get("/config")
 async def get_audio_config(user=Depends(get_admin_user)):
     return {
@@ -281,7 +253,6 @@ async def speech(request: Request, user=Depends(get_verified_user)):
             )
 
     elif app.state.config.TTS_ENGINE == "elevenlabs":
-
         payload = None
         try:
             payload = json.loads(body.decode("utf-8"))
@@ -289,12 +260,7 @@ async def speech(request: Request, user=Depends(get_verified_user)):
             log.exception(e)
             raise HTTPException(status_code=400, detail="Invalid JSON payload")
 
-        voice_options = await get_available_voices()
-        voice_id = voice_options.get(payload['voice'])
-
-        if not voice_id:
-            raise HTTPException(status_code=400, detail="Invalid voice name")
-
+        voice_id = payload.get("voice", "")
         url = f"https://api.elevenlabs.io/v1/text-to-speech/{voice_id}"
 
         headers = {
@@ -472,7 +438,67 @@ def transcribe(
         )
 
 
+def get_available_models() -> List[dict]:
+    if app.state.config.TTS_ENGINE == "openai":
+        return [{"id": "tts-1"}, {"id": "tts-1-hd"}]
+    elif app.state.config.TTS_ENGINE == "elevenlabs":
+        headers = {
+            "xi-api-key": app.state.config.TTS_API_KEY,
+            "Content-Type": "application/json",
+        }
+
+        try:
+            response = requests.get(
+                "https://api.elevenlabs.io/v1/models", headers=headers
+            )
+            response.raise_for_status()
+            models = response.json()
+            return [
+                {"name": model["name"], "id": model["model_id"]} for model in models
+            ]
+        except requests.RequestException as e:
+            log.error(f"Error fetching voices: {str(e)}")
+    return []
+
+
+@app.get("/models")
+async def get_models(user=Depends(get_verified_user)):
+    return {"models": get_available_models()}
+
+
+def get_available_voices() -> List[dict]:
+    if app.state.config.TTS_ENGINE == "openai":
+        return [
+            {"name": "alloy", "id": "alloy"},
+            {"name": "echo", "id": "echo"},
+            {"name": "fable", "id": "fable"},
+            {"name": "onyx", "id": "onyx"},
+            {"name": "nova", "id": "nova"},
+            {"name": "shimmer", "id": "shimmer"},
+        ]
+    elif app.state.config.TTS_ENGINE == "elevenlabs":
+        headers = {
+            "xi-api-key": app.state.config.TTS_API_KEY,
+            "Content-Type": "application/json",
+        }
+
+        try:
+            response = requests.get(
+                "https://api.elevenlabs.io/v1/voices", headers=headers
+            )
+            response.raise_for_status()
+            voices_data = response.json()
+
+            voices = []
+            for voice in voices_data.get("voices", []):
+                voices.append({"name": voice["name"], "id": voice["voice_id"]})
+            return voices
+        except requests.RequestException as e:
+            log.error(f"Error fetching voices: {str(e)}")
+
+    return []
+
+
 @app.get("/voices")
 async def get_voices(user=Depends(get_verified_user)):
-    voices = await get_available_voices()
-    return {"voices": list(voices.keys())}
+    return {"voices": get_available_voices()}

+ 56 - 0
src/lib/apis/audio/index.ts

@@ -131,3 +131,59 @@ export const synthesizeOpenAISpeech = async (
 
 	return res;
 };
+
+export const getModels = async (token: string = '') => {
+	let error = null;
+
+	const res = await fetch(`${AUDIO_API_BASE_URL}/models`, {
+		method: 'GET',
+		headers: {
+			'Content-Type': 'application/json',
+			Authorization: `Bearer ${token}`
+		}
+	})
+		.then(async (res) => {
+			if (!res.ok) throw await res.json();
+			return res.json();
+		})
+		.catch((err) => {
+			error = err.detail;
+			console.log(err);
+
+			return null;
+		});
+
+	if (error) {
+		throw error;
+	}
+
+	return res;
+};
+
+export const getVoices = async (token: string = '') => {
+	let error = null;
+
+	const res = await fetch(`${AUDIO_API_BASE_URL}/voices`, {
+		method: 'GET',
+		headers: {
+			'Content-Type': 'application/json',
+			Authorization: `Bearer ${token}`
+		}
+	})
+		.then(async (res) => {
+			if (!res.ok) throw await res.json();
+			return res.json();
+		})
+		.catch((err) => {
+			error = err.detail;
+			console.log(err);
+
+			return null;
+		});
+
+	if (error) {
+		throw error;
+	}
+
+	return res;
+};

+ 54 - 63
src/lib/components/admin/Settings/Audio.svelte

@@ -1,12 +1,18 @@
 <script lang="ts">
-	import { getAudioConfig, updateAudioConfig } from '$lib/apis/audio';
-	import { user, settings, config } from '$lib/stores';
-	import { createEventDispatcher, onMount, getContext } from 'svelte';
 	import { toast } from 'svelte-sonner';
-	import Switch from '$lib/components/common/Switch.svelte';
+	import { createEventDispatcher, onMount, getContext } from 'svelte';
+	const dispatch = createEventDispatcher();
+
 	import { getBackendConfig } from '$lib/apis';
+	import {
+		getAudioConfig,
+		updateAudioConfig,
+		getModels as _getModels,
+		getVoices as _getVoices
+	} from '$lib/apis/audio';
+	import { user, settings, config } from '$lib/stores';
+
 	import SensitiveInput from '$lib/components/common/SensitiveInput.svelte';
-	const dispatch = createEventDispatcher();
 
 	const i18n = getContext('i18n');
 
@@ -30,49 +36,43 @@
 	let models = [];
 	let nonLocalVoices = false;
 
-	const getOpenAIVoices = () => {
-		voices = [
-			{ name: 'alloy' },
-			{ name: 'echo' },
-			{ name: 'fable' },
-			{ name: 'onyx' },
-			{ name: 'nova' },
-			{ name: 'shimmer' }
-		];
-	};
+	const getModels = async () => {
+		if (TTS_ENGINE === '') {
+			models = [];
+		} else {
+			const res = await _getModels(localStorage.token).catch((e) => {
+				toast.error(e);
+			});
 
-	const getOpenAIModels = () => {
-		models = [{ name: 'tts-1' }, { name: 'tts-1-hd' }];
+			if (res) {
+				console.log(res);
+				models = res.models;
+			}
+		}
 	};
 
-	const getWebAPIVoices = () => {
-		const getVoicesLoop = setInterval(async () => {
-			voices = await speechSynthesis.getVoices();
+	const getVoices = async () => {
+		if (TTS_ENGINE === '') {
+			const getVoicesLoop = setInterval(async () => {
+				voices = await speechSynthesis.getVoices();
 
-			// do your loop
-			if (voices.length > 0) {
-				clearInterval(getVoicesLoop);
+				// do your loop
+				if (voices.length > 0) {
+					clearInterval(getVoicesLoop);
+				}
+			}, 100);
+		} else {
+			const res = await _getVoices(localStorage.token).catch((e) => {
+				toast.error(e);
+			});
+
+			if (res) {
+				console.log(res);
+				voices = res.voices;
 			}
-		}, 100);
+		}
 	};
 
-    // Fetch available ElevenLabs voices
-    const getVoices = async () => {
-        const response = await fetch('/voices', {
-            method: 'GET',
-            headers: {
-                'Authorization': `Bearer ${localStorage.token}`
-            }
-        });
-
-        if (response.ok) {
-            const data = await response.json();
-            voices = data.voices.map(name => ({ name })); // Update voices array with fetched names
-        } else {
-            toast.error('Failed to fetch voices');
-        }
-    };
-
 	const updateConfigHandler = async () => {
 		const res = await updateAudioConfig(localStorage.token, {
 			tts: {
@@ -99,9 +99,6 @@
 	};
 
 	onMount(async () => {
-        // Fetch available voices on component mount
-        await getVoices(); 
-        
 		const res = await getAudioConfig(localStorage.token);
 
 		if (res) {
@@ -121,14 +118,8 @@
 			STT_MODEL = res.stt.MODEL;
 		}
 
-		if (TTS_ENGINE === 'openai') {
-			getOpenAIVoices();
-			getOpenAIModels();
-        } else if(TTS_ENGINE === 'elevenlabs') {
-            await getVoices(); //Get voices if TTS_ENGINE is ElevenLabs
-		} else {
-			getWebAPIVoices();
-		}
+		await getVoices();
+		await getModels();
 	});
 </script>
 
@@ -208,14 +199,14 @@
 							bind:value={TTS_ENGINE}
 							placeholder="Select a mode"
 							on:change={async (e) => {
+								await updateConfigHandler();
+								await getVoices();
+								await getModels();
+
 								if (e.target.value === 'openai') {
-									getOpenAIVoices();
 									TTS_VOICE = 'alloy';
 									TTS_MODEL = 'tts-1';
-								} else if(e.target.value === 'elevenlabs') {
-									await getVoices();
 								} else {
-									getWebAPIVoices();
 									TTS_VOICE = '';
 									TTS_MODEL = '';
 								}
@@ -256,7 +247,7 @@
 
 				<hr class=" dark:border-gray-850 my-2" />
 
-				{#if TTS_ENGINE !== ''}
+				{#if TTS_ENGINE === ''}
 					<div>
 						<div class=" mb-1.5 text-sm font-medium">{$i18n.t('TTS Voice')}</div>
 						<div class="flex w-full">
@@ -268,9 +259,9 @@
 									<option value="" selected={TTS_VOICE !== ''}>{$i18n.t('Default')}</option>
 									{#each voices as voice}
 										<option
-											value={voice.name}
+											value={voice.voiceURI}
 											class="bg-gray-100 dark:bg-gray-700"
-											selected={TTS_VOICE === voice.name}>{voice.name}</option
+											selected={TTS_VOICE === voice.voiceURI}>{voice.name}</option
 										>
 									{/each}
 								</select>
@@ -292,7 +283,7 @@
 
 									<datalist id="voice-list">
 										{#each voices as voice}
-											<option value={voice.name} />
+											<option value={voice.id}>{voice.name}</option>
 										{/each}
 									</datalist>
 								</div>
@@ -311,7 +302,7 @@
 
 									<datalist id="model-list">
 										{#each models as model}
-											<option value={model.name} />
+											<option value={model.id} />
 										{/each}
 									</datalist>
 								</div>
@@ -333,7 +324,7 @@
 
 									<datalist id="voice-list">
 										{#each voices as voice}
-											<option value={voice.name} />
+											<option value={voice.id}>{voice.name}</option>
 										{/each}
 									</datalist>
 								</div>
@@ -352,7 +343,7 @@
 
 									<datalist id="model-list">
 										{#each models as model}
-											<option value={model.name} />
+											<option value={model.id} />
 										{/each}
 									</datalist>
 								</div>