Kaynağa Gözat

feat: native speecht5 support

Timothy J. Baek 5 ay önce
ebeveyn
işleme
1fd67d7e5d

+ 56 - 0
backend/open_webui/apps/audio/main.py

@@ -74,6 +74,10 @@ app.state.config.TTS_VOICE = AUDIO_TTS_VOICE
 app.state.config.TTS_API_KEY = AUDIO_TTS_API_KEY
 app.state.config.TTS_SPLIT_ON = AUDIO_TTS_SPLIT_ON
 
+
+app.state.speech_synthesiser = None
+app.state.speech_speaker_embeddings_dataset = None
+
 app.state.config.TTS_AZURE_SPEECH_REGION = AUDIO_TTS_AZURE_SPEECH_REGION
 app.state.config.TTS_AZURE_SPEECH_OUTPUT_FORMAT = AUDIO_TTS_AZURE_SPEECH_OUTPUT_FORMAT
 
@@ -231,6 +235,21 @@ async def update_audio_config(
     }
 
 
+def load_speech_pipeline():
+    from transformers import pipeline
+    from datasets import load_dataset
+
+    if app.state.speech_synthesiser is None:
+        app.state.speech_synthesiser = pipeline(
+            "text-to-speech", "microsoft/speecht5_tts"
+        )
+
+    if app.state.speech_speaker_embeddings_dataset is None:
+        app.state.speech_speaker_embeddings_dataset = load_dataset(
+            "Matthijs/cmu-arctic-xvectors", split="validation"
+        )
+
+
 @app.post("/speech")
 async def speech(request: Request, user=Depends(get_verified_user)):
     body = await request.body()
@@ -397,6 +416,43 @@ async def speech(request: Request, user=Depends(get_verified_user)):
             raise HTTPException(
                 status_code=500, detail=f"Error synthesizing speech - {response.reason}"
             )
+    elif app.state.config.TTS_ENGINE == "transformers":
+        payload = None
+        try:
+            payload = json.loads(body.decode("utf-8"))
+        except Exception as e:
+            log.exception(e)
+            raise HTTPException(status_code=400, detail="Invalid JSON payload")
+
+        import torch
+        import soundfile as sf
+
+        load_speech_pipeline()
+
+        embeddings_dataset = app.state.speech_speaker_embeddings_dataset
+
+        speaker_index = 6799
+        try:
+            speaker_index = embeddings_dataset["filename"].index(
+                app.state.config.TTS_MODEL
+            )
+        except Exception:
+            pass
+
+        speaker_embedding = torch.tensor(
+            embeddings_dataset[speaker_index]["xvector"]
+        ).unsqueeze(0)
+
+        speech = app.state.speech_synthesiser(
+            payload["input"],
+            forward_params={"speaker_embeddings": speaker_embedding},
+        )
+
+        sf.write(file_path, speech["audio"], samplerate=speech["sampling_rate"])
+        with open(file_body_path, "w") as f:
+            json.dump(json.loads(body.decode("utf-8")), f)
+
+        return FileResponse(file_path)
 
 
 def transcribe(file_path):

+ 6 - 0
package-lock.json

@@ -13,6 +13,7 @@
 				"@codemirror/language-data": "^6.5.1",
 				"@codemirror/theme-one-dark": "^6.1.2",
 				"@huggingface/transformers": "^3.0.0",
+				"@mediapipe/tasks-vision": "^0.10.17",
 				"@pyscript/core": "^0.4.32",
 				"@sveltejs/adapter-node": "^2.0.0",
 				"@xyflow/svelte": "^0.1.19",
@@ -1749,6 +1750,11 @@
 				"@lezer/lr": "^1.4.0"
 			}
 		},
+		"node_modules/@mediapipe/tasks-vision": {
+			"version": "0.10.17",
+			"resolved": "https://registry.npmjs.org/@mediapipe/tasks-vision/-/tasks-vision-0.10.17.tgz",
+			"integrity": "sha512-CZWV/q6TTe8ta61cZXjfnnHsfWIdFhms03M9T7Cnd5y2mdpylJM0rF1qRq+wsQVRMLz1OYPVEBU9ph2Bx8cxrg=="
+		},
 		"node_modules/@melt-ui/svelte": {
 			"version": "0.76.0",
 			"resolved": "https://registry.npmjs.org/@melt-ui/svelte/-/svelte-0.76.0.tgz",

+ 1 - 0
package.json

@@ -53,6 +53,7 @@
 		"@codemirror/language-data": "^6.5.1",
 		"@codemirror/theme-one-dark": "^6.1.2",
 		"@huggingface/transformers": "^3.0.0",
+		"@mediapipe/tasks-vision": "^0.10.17",
 		"@pyscript/core": "^0.4.32",
 		"@sveltejs/adapter-node": "^2.0.0",
 		"@xyflow/svelte": "^0.1.19",

+ 42 - 0
src/lib/components/admin/Settings/Audio.svelte

@@ -322,6 +322,7 @@
 							}}
 						>
 							<option value="">{$i18n.t('Web API')}</option>
+							<option value="transformers">{$i18n.t('Transformers')} ({$i18n.t('Local')})</option>
 							<option value="openai">{$i18n.t('OpenAI')}</option>
 							<option value="elevenlabs">{$i18n.t('ElevenLabs')}</option>
 							<option value="azure">{$i18n.t('Azure AI Speech')}</option>
@@ -396,6 +397,47 @@
 							</div>
 						</div>
 					</div>
+				{:else if TTS_ENGINE === 'transformers'}
+					<div>
+						<div class=" mb-1.5 text-sm font-medium">{$i18n.t('TTS Model')}</div>
+						<div class="flex w-full">
+							<div class="flex-1">
+								<input
+									list="model-list"
+									class="w-full rounded-lg py-2 px-4 text-sm bg-gray-50 dark:text-gray-300 dark:bg-gray-850 outline-none"
+									bind:value={TTS_MODEL}
+									placeholder="CMU ARCTIC speaker embedding name"
+								/>
+
+								<datalist id="model-list">
+									<option value="tts-1" />
+								</datalist>
+							</div>
+						</div>
+						<div class="mt-2 mb-1 text-xs text-gray-400 dark:text-gray-500">
+							{$i18n.t(`Open WebUI uses SpeechT5 and CMU Arctic speaker embeddings.`)}
+
+							To learn more about SpeechT5,
+
+							<a
+								class=" hover:underline dark:text-gray-200 text-gray-800"
+								href="https://github.com/microsoft/SpeechT5"
+								target="_blank"
+							>
+								{$i18n.t(`click here`, {
+									name: 'SpeechT5'
+								})}.
+							</a>
+							To see the available CMU Arctic speaker embeddings,
+							<a
+								class=" hover:underline dark:text-gray-200 text-gray-800"
+								href="https://huggingface.co/datasets/Matthijs/cmu-arctic-xvectors"
+								target="_blank"
+							>
+								{$i18n.t(`click here`)}.
+							</a>
+						</div>
+					</div>
 				{:else if TTS_ENGINE === 'openai'}
 					<div class=" flex gap-2">
 						<div class="w-full">