|
@@ -11,25 +11,27 @@ from pydub.silence import split_on_silence
|
|
|
import aiohttp
|
|
|
import aiofiles
|
|
|
import requests
|
|
|
+
|
|
|
+from fastapi import (
|
|
|
+ Depends,
|
|
|
+ FastAPI,
|
|
|
+ File,
|
|
|
+ HTTPException,
|
|
|
+ Request,
|
|
|
+ UploadFile,
|
|
|
+ status,
|
|
|
+ APIRouter,
|
|
|
+)
|
|
|
+from fastapi.middleware.cors import CORSMiddleware
|
|
|
+from fastapi.responses import FileResponse
|
|
|
+from pydantic import BaseModel
|
|
|
+
|
|
|
+
|
|
|
+from open_webui.utils.auth import get_admin_user, get_verified_user
|
|
|
from open_webui.config import (
|
|
|
- AUDIO_STT_ENGINE,
|
|
|
- AUDIO_STT_MODEL,
|
|
|
- AUDIO_STT_OPENAI_API_BASE_URL,
|
|
|
- AUDIO_STT_OPENAI_API_KEY,
|
|
|
- AUDIO_TTS_API_KEY,
|
|
|
- AUDIO_TTS_ENGINE,
|
|
|
- AUDIO_TTS_MODEL,
|
|
|
- AUDIO_TTS_OPENAI_API_BASE_URL,
|
|
|
- AUDIO_TTS_OPENAI_API_KEY,
|
|
|
- AUDIO_TTS_SPLIT_ON,
|
|
|
- AUDIO_TTS_VOICE,
|
|
|
- AUDIO_TTS_AZURE_SPEECH_REGION,
|
|
|
- AUDIO_TTS_AZURE_SPEECH_OUTPUT_FORMAT,
|
|
|
- WHISPER_MODEL,
|
|
|
WHISPER_MODEL_AUTO_UPDATE,
|
|
|
WHISPER_MODEL_DIR,
|
|
|
CACHE_DIR,
|
|
|
- AppConfig,
|
|
|
)
|
|
|
|
|
|
from open_webui.constants import ERROR_MESSAGES
|
|
@@ -40,52 +42,75 @@ from open_webui.env import (
|
|
|
ENABLE_FORWARD_USER_INFO_HEADERS,
|
|
|
)
|
|
|
|
|
|
-from fastapi import Depends, FastAPI, File, HTTPException, Request, UploadFile, status
|
|
|
-from fastapi.middleware.cors import CORSMiddleware
|
|
|
-from fastapi.responses import FileResponse
|
|
|
-from pydantic import BaseModel
|
|
|
-from open_webui.utils.auth import get_admin_user, get_verified_user
|
|
|
+
|
|
|
+router = APIRouter()
|
|
|
|
|
|
# Constants
|
|
|
MAX_FILE_SIZE_MB = 25
|
|
|
MAX_FILE_SIZE = MAX_FILE_SIZE_MB * 1024 * 1024 # Convert MB to bytes
|
|
|
|
|
|
-
|
|
|
log = logging.getLogger(__name__)
|
|
|
log.setLevel(SRC_LOG_LEVELS["AUDIO"])
|
|
|
|
|
|
-
|
|
|
-# setting device type for whisper model
|
|
|
-whisper_device_type = DEVICE_TYPE if DEVICE_TYPE and DEVICE_TYPE == "cuda" else "cpu"
|
|
|
-log.info(f"whisper_device_type: {whisper_device_type}")
|
|
|
-
|
|
|
SPEECH_CACHE_DIR = Path(CACHE_DIR).joinpath("./audio/speech/")
|
|
|
SPEECH_CACHE_DIR.mkdir(parents=True, exist_ok=True)
|
|
|
|
|
|
|
|
|
+##########################################
|
|
|
+#
|
|
|
+# Utility functions
|
|
|
+#
|
|
|
+##########################################
|
|
|
+
|
|
|
+from pydub import AudioSegment
|
|
|
+from pydub.utils import mediainfo
|
|
|
+
|
|
|
+
|
|
|
+def is_mp4_audio(file_path):
|
|
|
+ """Check if the given file is an MP4 audio file."""
|
|
|
+ if not os.path.isfile(file_path):
|
|
|
+ print(f"File not found: {file_path}")
|
|
|
+ return False
|
|
|
+
|
|
|
+ info = mediainfo(file_path)
|
|
|
+ if (
|
|
|
+ info.get("codec_name") == "aac"
|
|
|
+ and info.get("codec_type") == "audio"
|
|
|
+ and info.get("codec_tag_string") == "mp4a"
|
|
|
+ ):
|
|
|
+ return True
|
|
|
+ return False
|
|
|
+
|
|
|
+
|
|
|
+def convert_mp4_to_wav(file_path, output_path):
|
|
|
+ """Convert MP4 audio file to WAV format."""
|
|
|
+ audio = AudioSegment.from_file(file_path, format="mp4")
|
|
|
+ audio.export(output_path, format="wav")
|
|
|
+ print(f"Converted {file_path} to {output_path}")
|
|
|
+
|
|
|
+
|
|
|
def set_faster_whisper_model(model: str, auto_update: bool = False):
|
|
|
- if model and app.state.config.STT_ENGINE == "":
|
|
|
+ whisper_model = None
|
|
|
+ if model:
|
|
|
from faster_whisper import WhisperModel
|
|
|
|
|
|
faster_whisper_kwargs = {
|
|
|
"model_size_or_path": model,
|
|
|
- "device": whisper_device_type,
|
|
|
+ "device": DEVICE_TYPE if DEVICE_TYPE and DEVICE_TYPE == "cuda" else "cpu",
|
|
|
"compute_type": "int8",
|
|
|
"download_root": WHISPER_MODEL_DIR,
|
|
|
"local_files_only": not auto_update,
|
|
|
}
|
|
|
|
|
|
try:
|
|
|
- app.state.faster_whisper_model = WhisperModel(**faster_whisper_kwargs)
|
|
|
+ whisper_model = WhisperModel(**faster_whisper_kwargs)
|
|
|
except Exception:
|
|
|
log.warning(
|
|
|
"WhisperModel initialization failed, attempting download with local_files_only=False"
|
|
|
)
|
|
|
faster_whisper_kwargs["local_files_only"] = False
|
|
|
- app.state.faster_whisper_model = WhisperModel(**faster_whisper_kwargs)
|
|
|
-
|
|
|
- else:
|
|
|
- app.state.faster_whisper_model = None
|
|
|
+ whisper_model = WhisperModel(**faster_whisper_kwargs)
|
|
|
+ return whisper_model
|
|
|
|
|
|
|
|
|
class TTSConfigForm(BaseModel):
|
|
@@ -113,98 +138,75 @@ class AudioConfigUpdateForm(BaseModel):
|
|
|
stt: STTConfigForm
|
|
|
|
|
|
|
|
|
-from pydub import AudioSegment
|
|
|
-from pydub.utils import mediainfo
|
|
|
-
|
|
|
-
|
|
|
-def is_mp4_audio(file_path):
|
|
|
- """Check if the given file is an MP4 audio file."""
|
|
|
- if not os.path.isfile(file_path):
|
|
|
- print(f"File not found: {file_path}")
|
|
|
- return False
|
|
|
-
|
|
|
- info = mediainfo(file_path)
|
|
|
- if (
|
|
|
- info.get("codec_name") == "aac"
|
|
|
- and info.get("codec_type") == "audio"
|
|
|
- and info.get("codec_tag_string") == "mp4a"
|
|
|
- ):
|
|
|
- return True
|
|
|
- return False
|
|
|
-
|
|
|
-
|
|
|
-def convert_mp4_to_wav(file_path, output_path):
|
|
|
- """Convert MP4 audio file to WAV format."""
|
|
|
- audio = AudioSegment.from_file(file_path, format="mp4")
|
|
|
- audio.export(output_path, format="wav")
|
|
|
- print(f"Converted {file_path} to {output_path}")
|
|
|
-
|
|
|
-
|
|
|
-@app.get("/config")
|
|
|
-async def get_audio_config(user=Depends(get_admin_user)):
|
|
|
+@router.get("/config")
|
|
|
+async def get_audio_config(request: Request, user=Depends(get_admin_user)):
|
|
|
return {
|
|
|
"tts": {
|
|
|
- "OPENAI_API_BASE_URL": app.state.config.TTS_OPENAI_API_BASE_URL,
|
|
|
- "OPENAI_API_KEY": app.state.config.TTS_OPENAI_API_KEY,
|
|
|
- "API_KEY": app.state.config.TTS_API_KEY,
|
|
|
- "ENGINE": app.state.config.TTS_ENGINE,
|
|
|
- "MODEL": app.state.config.TTS_MODEL,
|
|
|
- "VOICE": app.state.config.TTS_VOICE,
|
|
|
- "SPLIT_ON": app.state.config.TTS_SPLIT_ON,
|
|
|
- "AZURE_SPEECH_REGION": app.state.config.TTS_AZURE_SPEECH_REGION,
|
|
|
- "AZURE_SPEECH_OUTPUT_FORMAT": app.state.config.TTS_AZURE_SPEECH_OUTPUT_FORMAT,
|
|
|
+ "OPENAI_API_BASE_URL": request.app.state.config.TTS_OPENAI_API_BASE_URL,
|
|
|
+ "OPENAI_API_KEY": request.app.state.config.TTS_OPENAI_API_KEY,
|
|
|
+ "API_KEY": request.app.state.config.TTS_API_KEY,
|
|
|
+ "ENGINE": request.app.state.config.TTS_ENGINE,
|
|
|
+ "MODEL": request.app.state.config.TTS_MODEL,
|
|
|
+ "VOICE": request.app.state.config.TTS_VOICE,
|
|
|
+ "SPLIT_ON": request.app.state.config.TTS_SPLIT_ON,
|
|
|
+ "AZURE_SPEECH_REGION": request.app.state.config.TTS_AZURE_SPEECH_REGION,
|
|
|
+ "AZURE_SPEECH_OUTPUT_FORMAT": request.app.state.config.TTS_AZURE_SPEECH_OUTPUT_FORMAT,
|
|
|
},
|
|
|
"stt": {
|
|
|
- "OPENAI_API_BASE_URL": app.state.config.STT_OPENAI_API_BASE_URL,
|
|
|
- "OPENAI_API_KEY": app.state.config.STT_OPENAI_API_KEY,
|
|
|
- "ENGINE": app.state.config.STT_ENGINE,
|
|
|
- "MODEL": app.state.config.STT_MODEL,
|
|
|
- "WHISPER_MODEL": app.state.config.WHISPER_MODEL,
|
|
|
+ "OPENAI_API_BASE_URL": request.app.state.config.STT_OPENAI_API_BASE_URL,
|
|
|
+ "OPENAI_API_KEY": request.app.state.config.STT_OPENAI_API_KEY,
|
|
|
+ "ENGINE": request.app.state.config.STT_ENGINE,
|
|
|
+ "MODEL": request.app.state.config.STT_MODEL,
|
|
|
+ "WHISPER_MODEL": request.app.state.config.WHISPER_MODEL,
|
|
|
},
|
|
|
}
|
|
|
|
|
|
|
|
|
-@app.post("/config/update")
|
|
|
+@router.post("/config/update")
|
|
|
async def update_audio_config(
|
|
|
- form_data: AudioConfigUpdateForm, user=Depends(get_admin_user)
|
|
|
+ request: Request, form_data: AudioConfigUpdateForm, user=Depends(get_admin_user)
|
|
|
):
|
|
|
- app.state.config.TTS_OPENAI_API_BASE_URL = form_data.tts.OPENAI_API_BASE_URL
|
|
|
- app.state.config.TTS_OPENAI_API_KEY = form_data.tts.OPENAI_API_KEY
|
|
|
- app.state.config.TTS_API_KEY = form_data.tts.API_KEY
|
|
|
- app.state.config.TTS_ENGINE = form_data.tts.ENGINE
|
|
|
- app.state.config.TTS_MODEL = form_data.tts.MODEL
|
|
|
- app.state.config.TTS_VOICE = form_data.tts.VOICE
|
|
|
- app.state.config.TTS_SPLIT_ON = form_data.tts.SPLIT_ON
|
|
|
- app.state.config.TTS_AZURE_SPEECH_REGION = form_data.tts.AZURE_SPEECH_REGION
|
|
|
- app.state.config.TTS_AZURE_SPEECH_OUTPUT_FORMAT = (
|
|
|
+ request.app.state.config.TTS_OPENAI_API_BASE_URL = form_data.tts.OPENAI_API_BASE_URL
|
|
|
+ request.app.state.config.TTS_OPENAI_API_KEY = form_data.tts.OPENAI_API_KEY
|
|
|
+ request.app.state.config.TTS_API_KEY = form_data.tts.API_KEY
|
|
|
+ request.app.state.config.TTS_ENGINE = form_data.tts.ENGINE
|
|
|
+ request.app.state.config.TTS_MODEL = form_data.tts.MODEL
|
|
|
+ request.app.state.config.TTS_VOICE = form_data.tts.VOICE
|
|
|
+ request.app.state.config.TTS_SPLIT_ON = form_data.tts.SPLIT_ON
|
|
|
+ request.app.state.config.TTS_AZURE_SPEECH_REGION = form_data.tts.AZURE_SPEECH_REGION
|
|
|
+ request.app.state.config.TTS_AZURE_SPEECH_OUTPUT_FORMAT = (
|
|
|
form_data.tts.AZURE_SPEECH_OUTPUT_FORMAT
|
|
|
)
|
|
|
|
|
|
- app.state.config.STT_OPENAI_API_BASE_URL = form_data.stt.OPENAI_API_BASE_URL
|
|
|
- app.state.config.STT_OPENAI_API_KEY = form_data.stt.OPENAI_API_KEY
|
|
|
- app.state.config.STT_ENGINE = form_data.stt.ENGINE
|
|
|
- app.state.config.STT_MODEL = form_data.stt.MODEL
|
|
|
- app.state.config.WHISPER_MODEL = form_data.stt.WHISPER_MODEL
|
|
|
- set_faster_whisper_model(form_data.stt.WHISPER_MODEL, WHISPER_MODEL_AUTO_UPDATE)
|
|
|
+ request.app.state.config.STT_OPENAI_API_BASE_URL = form_data.stt.OPENAI_API_BASE_URL
|
|
|
+ request.app.state.config.STT_OPENAI_API_KEY = form_data.stt.OPENAI_API_KEY
|
|
|
+ request.app.state.config.STT_ENGINE = form_data.stt.ENGINE
|
|
|
+ request.app.state.config.STT_MODEL = form_data.stt.MODEL
|
|
|
+ request.app.state.config.WHISPER_MODEL = form_data.stt.WHISPER_MODEL
|
|
|
+
|
|
|
+ if request.app.state.config.STT_ENGINE == "":
|
|
|
+ request.app.state.faster_whisper_model = set_faster_whisper_model(
|
|
|
+ form_data.stt.WHISPER_MODEL, WHISPER_MODEL_AUTO_UPDATE
|
|
|
+ )
|
|
|
|
|
|
return {
|
|
|
"tts": {
|
|
|
- "OPENAI_API_BASE_URL": app.state.config.TTS_OPENAI_API_BASE_URL,
|
|
|
- "OPENAI_API_KEY": app.state.config.TTS_OPENAI_API_KEY,
|
|
|
- "API_KEY": app.state.config.TTS_API_KEY,
|
|
|
- "ENGINE": app.state.config.TTS_ENGINE,
|
|
|
- "MODEL": app.state.config.TTS_MODEL,
|
|
|
- "VOICE": app.state.config.TTS_VOICE,
|
|
|
- "SPLIT_ON": app.state.config.TTS_SPLIT_ON,
|
|
|
- "AZURE_SPEECH_REGION": app.state.config.TTS_AZURE_SPEECH_REGION,
|
|
|
- "AZURE_SPEECH_OUTPUT_FORMAT": app.state.config.TTS_AZURE_SPEECH_OUTPUT_FORMAT,
|
|
|
+ "OPENAI_API_BASE_URL": request.app.state.config.TTS_OPENAI_API_BASE_URL,
|
|
|
+ "OPENAI_API_KEY": request.app.state.config.TTS_OPENAI_API_KEY,
|
|
|
+ "API_KEY": request.app.state.config.TTS_API_KEY,
|
|
|
+ "ENGINE": request.app.state.config.TTS_ENGINE,
|
|
|
+ "MODEL": request.app.state.config.TTS_MODEL,
|
|
|
+ "VOICE": request.app.state.config.TTS_VOICE,
|
|
|
+ "SPLIT_ON": request.app.state.config.TTS_SPLIT_ON,
|
|
|
+ "AZURE_SPEECH_REGION": request.app.state.config.TTS_AZURE_SPEECH_REGION,
|
|
|
+ "AZURE_SPEECH_OUTPUT_FORMAT": request.app.state.config.TTS_AZURE_SPEECH_OUTPUT_FORMAT,
|
|
|
},
|
|
|
"stt": {
|
|
|
- "OPENAI_API_BASE_URL": app.state.config.STT_OPENAI_API_BASE_URL,
|
|
|
- "OPENAI_API_KEY": app.state.config.STT_OPENAI_API_KEY,
|
|
|
- "ENGINE": app.state.config.STT_ENGINE,
|
|
|
- "MODEL": app.state.config.STT_MODEL,
|
|
|
- "WHISPER_MODEL": app.state.config.WHISPER_MODEL,
|
|
|
+ "OPENAI_API_BASE_URL": request.app.state.config.STT_OPENAI_API_BASE_URL,
|
|
|
+ "OPENAI_API_KEY": request.app.state.config.STT_OPENAI_API_KEY,
|
|
|
+ "ENGINE": request.app.state.config.STT_ENGINE,
|
|
|
+ "MODEL": request.app.state.config.STT_MODEL,
|
|
|
+ "WHISPER_MODEL": request.app.state.config.WHISPER_MODEL,
|
|
|
},
|
|
|
}
|
|
|
|
|
@@ -213,18 +215,18 @@ def load_speech_pipeline():
|
|
|
from transformers import pipeline
|
|
|
from datasets import load_dataset
|
|
|
|
|
|
- if app.state.speech_synthesiser is None:
|
|
|
- app.state.speech_synthesiser = pipeline(
|
|
|
+ if request.app.state.speech_synthesiser is None:
|
|
|
+ request.app.state.speech_synthesiser = pipeline(
|
|
|
"text-to-speech", "microsoft/speecht5_tts"
|
|
|
)
|
|
|
|
|
|
- if app.state.speech_speaker_embeddings_dataset is None:
|
|
|
- app.state.speech_speaker_embeddings_dataset = load_dataset(
|
|
|
+ if request.app.state.speech_speaker_embeddings_dataset is None:
|
|
|
+ request.app.state.speech_speaker_embeddings_dataset = load_dataset(
|
|
|
"Matthijs/cmu-arctic-xvectors", split="validation"
|
|
|
)
|
|
|
|
|
|
|
|
|
-@app.post("/speech")
|
|
|
+@router.post("/speech")
|
|
|
async def speech(request: Request, user=Depends(get_verified_user)):
|
|
|
body = await request.body()
|
|
|
name = hashlib.sha256(body).hexdigest()
|
|
@@ -236,9 +238,11 @@ async def speech(request: Request, user=Depends(get_verified_user)):
|
|
|
if file_path.is_file():
|
|
|
return FileResponse(file_path)
|
|
|
|
|
|
- if app.state.config.TTS_ENGINE == "openai":
|
|
|
+ if request.app.state.config.TTS_ENGINE == "openai":
|
|
|
headers = {}
|
|
|
- headers["Authorization"] = f"Bearer {app.state.config.TTS_OPENAI_API_KEY}"
|
|
|
+ headers["Authorization"] = (
|
|
|
+ f"Bearer {request.app.state.config.TTS_OPENAI_API_KEY}"
|
|
|
+ )
|
|
|
headers["Content-Type"] = "application/json"
|
|
|
|
|
|
if ENABLE_FORWARD_USER_INFO_HEADERS:
|
|
@@ -250,7 +254,7 @@ async def speech(request: Request, user=Depends(get_verified_user)):
|
|
|
try:
|
|
|
body = body.decode("utf-8")
|
|
|
body = json.loads(body)
|
|
|
- body["model"] = app.state.config.TTS_MODEL
|
|
|
+ body["model"] = request.app.state.config.TTS_MODEL
|
|
|
body = json.dumps(body).encode("utf-8")
|
|
|
except Exception:
|
|
|
pass
|
|
@@ -258,7 +262,7 @@ async def speech(request: Request, user=Depends(get_verified_user)):
|
|
|
try:
|
|
|
async with aiohttp.ClientSession() as session:
|
|
|
async with session.post(
|
|
|
- url=f"{app.state.config.TTS_OPENAI_API_BASE_URL}/audio/speech",
|
|
|
+ url=f"{request.app.state.config.TTS_OPENAI_API_BASE_URL}/audio/speech",
|
|
|
data=body,
|
|
|
headers=headers,
|
|
|
) as r:
|
|
@@ -287,7 +291,7 @@ async def speech(request: Request, user=Depends(get_verified_user)):
|
|
|
detail=error_detail,
|
|
|
)
|
|
|
|
|
|
- elif app.state.config.TTS_ENGINE == "elevenlabs":
|
|
|
+ elif request.app.state.config.TTS_ENGINE == "elevenlabs":
|
|
|
try:
|
|
|
payload = json.loads(body.decode("utf-8"))
|
|
|
except Exception as e:
|
|
@@ -305,11 +309,11 @@ async def speech(request: Request, user=Depends(get_verified_user)):
|
|
|
headers = {
|
|
|
"Accept": "audio/mpeg",
|
|
|
"Content-Type": "application/json",
|
|
|
- "xi-api-key": app.state.config.TTS_API_KEY,
|
|
|
+ "xi-api-key": request.app.state.config.TTS_API_KEY,
|
|
|
}
|
|
|
data = {
|
|
|
"text": payload["input"],
|
|
|
- "model_id": app.state.config.TTS_MODEL,
|
|
|
+ "model_id": request.app.state.config.TTS_MODEL,
|
|
|
"voice_settings": {"stability": 0.5, "similarity_boost": 0.5},
|
|
|
}
|
|
|
|
|
@@ -341,21 +345,21 @@ async def speech(request: Request, user=Depends(get_verified_user)):
|
|
|
detail=error_detail,
|
|
|
)
|
|
|
|
|
|
- elif app.state.config.TTS_ENGINE == "azure":
|
|
|
+ elif request.app.state.config.TTS_ENGINE == "azure":
|
|
|
try:
|
|
|
payload = json.loads(body.decode("utf-8"))
|
|
|
except Exception as e:
|
|
|
log.exception(e)
|
|
|
raise HTTPException(status_code=400, detail="Invalid JSON payload")
|
|
|
|
|
|
- region = app.state.config.TTS_AZURE_SPEECH_REGION
|
|
|
- language = app.state.config.TTS_VOICE
|
|
|
- locale = "-".join(app.state.config.TTS_VOICE.split("-")[:1])
|
|
|
- output_format = app.state.config.TTS_AZURE_SPEECH_OUTPUT_FORMAT
|
|
|
+ region = request.app.state.config.TTS_AZURE_SPEECH_REGION
|
|
|
+ language = request.app.state.config.TTS_VOICE
|
|
|
+ locale = "-".join(request.app.state.config.TTS_VOICE.split("-")[:1])
|
|
|
+ output_format = request.app.state.config.TTS_AZURE_SPEECH_OUTPUT_FORMAT
|
|
|
url = f"https://{region}.tts.speech.microsoft.com/cognitiveservices/v1"
|
|
|
|
|
|
headers = {
|
|
|
- "Ocp-Apim-Subscription-Key": app.state.config.TTS_API_KEY,
|
|
|
+ "Ocp-Apim-Subscription-Key": request.app.state.config.TTS_API_KEY,
|
|
|
"Content-Type": "application/ssml+xml",
|
|
|
"X-Microsoft-OutputFormat": output_format,
|
|
|
}
|
|
@@ -378,7 +382,7 @@ async def speech(request: Request, user=Depends(get_verified_user)):
|
|
|
except Exception as e:
|
|
|
log.exception(e)
|
|
|
raise HTTPException(status_code=500, detail=str(e))
|
|
|
- elif app.state.config.TTS_ENGINE == "transformers":
|
|
|
+ elif request.app.state.config.TTS_ENGINE == "transformers":
|
|
|
payload = None
|
|
|
try:
|
|
|
payload = json.loads(body.decode("utf-8"))
|
|
@@ -391,12 +395,12 @@ async def speech(request: Request, user=Depends(get_verified_user)):
|
|
|
|
|
|
load_speech_pipeline()
|
|
|
|
|
|
- embeddings_dataset = app.state.speech_speaker_embeddings_dataset
|
|
|
+ embeddings_dataset = request.app.state.speech_speaker_embeddings_dataset
|
|
|
|
|
|
speaker_index = 6799
|
|
|
try:
|
|
|
speaker_index = embeddings_dataset["filename"].index(
|
|
|
- app.state.config.TTS_MODEL
|
|
|
+ request.app.state.config.TTS_MODEL
|
|
|
)
|
|
|
except Exception:
|
|
|
pass
|
|
@@ -405,7 +409,7 @@ async def speech(request: Request, user=Depends(get_verified_user)):
|
|
|
embeddings_dataset[speaker_index]["xvector"]
|
|
|
).unsqueeze(0)
|
|
|
|
|
|
- speech = app.state.speech_synthesiser(
|
|
|
+ speech = request.app.state.speech_synthesiser(
|
|
|
payload["input"],
|
|
|
forward_params={"speaker_embeddings": speaker_embedding},
|
|
|
)
|
|
@@ -417,17 +421,19 @@ async def speech(request: Request, user=Depends(get_verified_user)):
|
|
|
return FileResponse(file_path)
|
|
|
|
|
|
|
|
|
-def transcribe(file_path):
|
|
|
+def transcribe(request: Request, file_path):
|
|
|
print("transcribe", file_path)
|
|
|
filename = os.path.basename(file_path)
|
|
|
file_dir = os.path.dirname(file_path)
|
|
|
id = filename.split(".")[0]
|
|
|
|
|
|
- if app.state.config.STT_ENGINE == "":
|
|
|
- if app.state.faster_whisper_model is None:
|
|
|
- set_faster_whisper_model(app.state.config.WHISPER_MODEL)
|
|
|
+ if request.app.state.config.STT_ENGINE == "":
|
|
|
+ if request.app.state.faster_whisper_model is None:
|
|
|
+ request.app.state.faster_whisper_model = set_faster_whisper_model(
|
|
|
+ request.app.state.config.WHISPER_MODEL
|
|
|
+ )
|
|
|
|
|
|
- model = app.state.faster_whisper_model
|
|
|
+ model = request.app.state.faster_whisper_model
|
|
|
segments, info = model.transcribe(file_path, beam_size=5)
|
|
|
log.info(
|
|
|
"Detected language '%s' with probability %f"
|
|
@@ -444,31 +450,24 @@ def transcribe(file_path):
|
|
|
|
|
|
log.debug(data)
|
|
|
return data
|
|
|
- elif app.state.config.STT_ENGINE == "openai":
|
|
|
+ elif request.app.state.config.STT_ENGINE == "openai":
|
|
|
if is_mp4_audio(file_path):
|
|
|
- print("is_mp4_audio")
|
|
|
os.rename(file_path, file_path.replace(".wav", ".mp4"))
|
|
|
# Convert MP4 audio file to WAV format
|
|
|
convert_mp4_to_wav(file_path.replace(".wav", ".mp4"), file_path)
|
|
|
|
|
|
- headers = {"Authorization": f"Bearer {app.state.config.STT_OPENAI_API_KEY}"}
|
|
|
-
|
|
|
- files = {"file": (filename, open(file_path, "rb"))}
|
|
|
- data = {"model": app.state.config.STT_MODEL}
|
|
|
-
|
|
|
- log.debug(files, data)
|
|
|
-
|
|
|
r = None
|
|
|
try:
|
|
|
r = requests.post(
|
|
|
- url=f"{app.state.config.STT_OPENAI_API_BASE_URL}/audio/transcriptions",
|
|
|
- headers=headers,
|
|
|
- files=files,
|
|
|
- data=data,
|
|
|
+ url=f"{request.app.state.config.STT_OPENAI_API_BASE_URL}/audio/transcriptions",
|
|
|
+ headers={
|
|
|
+ "Authorization": f"Bearer {request.app.state.config.STT_OPENAI_API_KEY}"
|
|
|
+ },
|
|
|
+ files={"file": (filename, open(file_path, "rb"))},
|
|
|
+ data={"model": request.app.state.config.STT_MODEL},
|
|
|
)
|
|
|
|
|
|
r.raise_for_status()
|
|
|
-
|
|
|
data = r.json()
|
|
|
|
|
|
# save the transcript to a json file
|
|
@@ -476,24 +475,43 @@ def transcribe(file_path):
|
|
|
with open(transcript_file, "w") as f:
|
|
|
json.dump(data, f)
|
|
|
|
|
|
- print(data)
|
|
|
return data
|
|
|
except Exception as e:
|
|
|
log.exception(e)
|
|
|
- error_detail = "Open WebUI: Server Connection Error"
|
|
|
+
|
|
|
+ detail = None
|
|
|
if r is not None:
|
|
|
try:
|
|
|
res = r.json()
|
|
|
if "error" in res:
|
|
|
- error_detail = f"External: {res['error']['message']}"
|
|
|
+ detail = f"External: {res['error'].get('message', '')}"
|
|
|
except Exception:
|
|
|
- error_detail = f"External: {e}"
|
|
|
+ detail = f"External: {e}"
|
|
|
+
|
|
|
+ raise Exception(detail if detail else "Open WebUI: Server Connection Error")
|
|
|
|
|
|
- raise Exception(error_detail)
|
|
|
|
|
|
+def compress_audio(file_path):
|
|
|
+ if os.path.getsize(file_path) > MAX_FILE_SIZE:
|
|
|
+ file_dir = os.path.dirname(file_path)
|
|
|
+ audio = AudioSegment.from_file(file_path)
|
|
|
+ audio = audio.set_frame_rate(16000).set_channels(1) # Compress audio
|
|
|
+ compressed_path = f"{file_dir}/{id}_compressed.opus"
|
|
|
+ audio.export(compressed_path, format="opus", bitrate="32k")
|
|
|
+ log.debug(f"Compressed audio to {compressed_path}")
|
|
|
|
|
|
-@app.post("/transcriptions")
|
|
|
+ if (
|
|
|
+ os.path.getsize(compressed_path) > MAX_FILE_SIZE
|
|
|
+ ): # Still larger than MAX_FILE_SIZE after compression
|
|
|
+ raise Exception(ERROR_MESSAGES.FILE_TOO_LARGE(size=f"{MAX_FILE_SIZE_MB}MB"))
|
|
|
+ return compressed_path
|
|
|
+ else:
|
|
|
+ return file_path
|
|
|
+
|
|
|
+
|
|
|
+@router.post("/transcriptions")
|
|
|
def transcription(
|
|
|
+ request: Request,
|
|
|
file: UploadFile = File(...),
|
|
|
user=Depends(get_verified_user),
|
|
|
):
|
|
@@ -520,36 +538,22 @@ def transcription(
|
|
|
f.write(contents)
|
|
|
|
|
|
try:
|
|
|
- if os.path.getsize(file_path) > MAX_FILE_SIZE: # file is bigger than 25MB
|
|
|
- log.debug(f"File size is larger than {MAX_FILE_SIZE_MB}MB")
|
|
|
- audio = AudioSegment.from_file(file_path)
|
|
|
- audio = audio.set_frame_rate(16000).set_channels(1) # Compress audio
|
|
|
- compressed_path = f"{file_dir}/{id}_compressed.opus"
|
|
|
- audio.export(compressed_path, format="opus", bitrate="32k")
|
|
|
- log.debug(f"Compressed audio to {compressed_path}")
|
|
|
- file_path = compressed_path
|
|
|
-
|
|
|
- if (
|
|
|
- os.path.getsize(file_path) > MAX_FILE_SIZE
|
|
|
- ): # Still larger than 25MB after compression
|
|
|
- log.debug(
|
|
|
- f"Compressed file size is still larger than {MAX_FILE_SIZE_MB}MB: {os.path.getsize(file_path)}"
|
|
|
- )
|
|
|
- raise HTTPException(
|
|
|
- status_code=status.HTTP_400_BAD_REQUEST,
|
|
|
- detail=ERROR_MESSAGES.FILE_TOO_LARGE(
|
|
|
- size=f"{MAX_FILE_SIZE_MB}MB"
|
|
|
- ),
|
|
|
- )
|
|
|
-
|
|
|
- data = transcribe(file_path)
|
|
|
- else:
|
|
|
- data = transcribe(file_path)
|
|
|
+ try:
|
|
|
+ file_path = compress_audio(file_path)
|
|
|
+ except Exception as e:
|
|
|
+ log.exception(e)
|
|
|
+
|
|
|
+ raise HTTPException(
|
|
|
+ status_code=status.HTTP_400_BAD_REQUEST,
|
|
|
+ detail=ERROR_MESSAGES.DEFAULT(e),
|
|
|
+ )
|
|
|
|
|
|
+ data = transcribe(request, file_path)
|
|
|
file_path = file_path.split("/")[-1]
|
|
|
return {**data, "filename": file_path}
|
|
|
except Exception as e:
|
|
|
log.exception(e)
|
|
|
+
|
|
|
raise HTTPException(
|
|
|
status_code=status.HTTP_400_BAD_REQUEST,
|
|
|
detail=ERROR_MESSAGES.DEFAULT(e),
|
|
@@ -564,39 +568,41 @@ def transcription(
|
|
|
)
|
|
|
|
|
|
|
|
|
-def get_available_models() -> list[dict]:
|
|
|
- if app.state.config.TTS_ENGINE == "openai":
|
|
|
- return [{"id": "tts-1"}, {"id": "tts-1-hd"}]
|
|
|
- elif app.state.config.TTS_ENGINE == "elevenlabs":
|
|
|
- headers = {
|
|
|
- "xi-api-key": app.state.config.TTS_API_KEY,
|
|
|
- "Content-Type": "application/json",
|
|
|
- }
|
|
|
-
|
|
|
+def get_available_models(request: Request) -> list[dict]:
|
|
|
+ available_models = []
|
|
|
+ if request.app.state.config.TTS_ENGINE == "openai":
|
|
|
+ available_models = [{"id": "tts-1"}, {"id": "tts-1-hd"}]
|
|
|
+ elif request.app.state.config.TTS_ENGINE == "elevenlabs":
|
|
|
try:
|
|
|
response = requests.get(
|
|
|
- "https://api.elevenlabs.io/v1/models", headers=headers, timeout=5
|
|
|
+ "https://api.elevenlabs.io/v1/models",
|
|
|
+ headers={
|
|
|
+ "xi-api-key": request.app.state.config.TTS_API_KEY,
|
|
|
+ "Content-Type": "application/json",
|
|
|
+ },
|
|
|
+ timeout=5,
|
|
|
)
|
|
|
response.raise_for_status()
|
|
|
models = response.json()
|
|
|
- return [
|
|
|
+
|
|
|
+ available_models = [
|
|
|
{"name": model["name"], "id": model["model_id"]} for model in models
|
|
|
]
|
|
|
except requests.RequestException as e:
|
|
|
log.error(f"Error fetching voices: {str(e)}")
|
|
|
- return []
|
|
|
+ return available_models
|
|
|
|
|
|
|
|
|
-@app.get("/models")
|
|
|
-async def get_models(user=Depends(get_verified_user)):
|
|
|
- return {"models": get_available_models()}
|
|
|
+@router.get("/models")
|
|
|
+async def get_models(request: Request, user=Depends(get_verified_user)):
|
|
|
+ return {"models": get_available_models(request)}
|
|
|
|
|
|
|
|
|
-def get_available_voices() -> dict:
|
|
|
+def get_available_voices(request) -> dict:
|
|
|
"""Returns {voice_id: voice_name} dict"""
|
|
|
- ret = {}
|
|
|
- if app.state.config.TTS_ENGINE == "openai":
|
|
|
- ret = {
|
|
|
+ available_voices = {}
|
|
|
+ if request.app.state.config.TTS_ENGINE == "openai":
|
|
|
+ available_voices = {
|
|
|
"alloy": "alloy",
|
|
|
"echo": "echo",
|
|
|
"fable": "fable",
|
|
@@ -604,33 +610,38 @@ def get_available_voices() -> dict:
|
|
|
"nova": "nova",
|
|
|
"shimmer": "shimmer",
|
|
|
}
|
|
|
- elif app.state.config.TTS_ENGINE == "elevenlabs":
|
|
|
+ elif request.app.state.config.TTS_ENGINE == "elevenlabs":
|
|
|
try:
|
|
|
- ret = get_elevenlabs_voices()
|
|
|
+ available_voices = get_elevenlabs_voices(
|
|
|
+ api_key=request.app.state.config.TTS_API_KEY
|
|
|
+ )
|
|
|
except Exception:
|
|
|
# Avoided @lru_cache with exception
|
|
|
pass
|
|
|
- elif app.state.config.TTS_ENGINE == "azure":
|
|
|
+ elif request.app.state.config.TTS_ENGINE == "azure":
|
|
|
try:
|
|
|
- region = app.state.config.TTS_AZURE_SPEECH_REGION
|
|
|
+ region = request.app.state.config.TTS_AZURE_SPEECH_REGION
|
|
|
url = f"https://{region}.tts.speech.microsoft.com/cognitiveservices/voices/list"
|
|
|
- headers = {"Ocp-Apim-Subscription-Key": app.state.config.TTS_API_KEY}
|
|
|
+ headers = {
|
|
|
+ "Ocp-Apim-Subscription-Key": request.app.state.config.TTS_API_KEY
|
|
|
+ }
|
|
|
|
|
|
response = requests.get(url, headers=headers)
|
|
|
response.raise_for_status()
|
|
|
voices = response.json()
|
|
|
+
|
|
|
for voice in voices:
|
|
|
- ret[voice["ShortName"]] = (
|
|
|
+ available_voices[voice["ShortName"]] = (
|
|
|
f"{voice['DisplayName']} ({voice['ShortName']})"
|
|
|
)
|
|
|
except requests.RequestException as e:
|
|
|
log.error(f"Error fetching voices: {str(e)}")
|
|
|
|
|
|
- return ret
|
|
|
+ return available_voices
|
|
|
|
|
|
|
|
|
@lru_cache
|
|
|
-def get_elevenlabs_voices() -> dict:
|
|
|
+def get_elevenlabs_voices(api_key: str) -> dict:
|
|
|
"""
|
|
|
Note, set the following in your .env file to use Elevenlabs:
|
|
|
AUDIO_TTS_ENGINE=elevenlabs
|
|
@@ -638,13 +649,16 @@ def get_elevenlabs_voices() -> dict:
|
|
|
AUDIO_TTS_VOICE=EXAVITQu4vr4xnSDxMaL # From https://api.elevenlabs.io/v1/voices
|
|
|
AUDIO_TTS_MODEL=eleven_multilingual_v2
|
|
|
"""
|
|
|
- headers = {
|
|
|
- "xi-api-key": app.state.config.TTS_API_KEY,
|
|
|
- "Content-Type": "application/json",
|
|
|
- }
|
|
|
+
|
|
|
try:
|
|
|
# TODO: Add retries
|
|
|
- response = requests.get("https://api.elevenlabs.io/v1/voices", headers=headers)
|
|
|
+ response = requests.get(
|
|
|
+ "https://api.elevenlabs.io/v1/voices",
|
|
|
+ headers={
|
|
|
+ "xi-api-key": api_key,
|
|
|
+ "Content-Type": "application/json",
|
|
|
+ },
|
|
|
+ )
|
|
|
response.raise_for_status()
|
|
|
voices_data = response.json()
|
|
|
|
|
@@ -659,6 +673,10 @@ def get_elevenlabs_voices() -> dict:
|
|
|
return voices
|
|
|
|
|
|
|
|
|
-@app.get("/voices")
|
|
|
-async def get_voices(user=Depends(get_verified_user)):
|
|
|
- return {"voices": [{"id": k, "name": v} for k, v in get_available_voices().items()]}
|
|
|
+@router.get("/voices")
|
|
|
+async def get_voices(request: Request, user=Depends(get_verified_user)):
|
|
|
+ return {
|
|
|
+ "voices": [
|
|
|
+ {"id": k, "name": v} for k, v in get_available_voices(request).items()
|
|
|
+ ]
|
|
|
+ }
|