ソースを参照

feat: Kokoro-js TTS support

Timothy Jaeryang Baek 2 ヶ月 前
コミット
205ce635f6

+ 100 - 61
src/lib/components/chat/Messages/ResponseMessage.svelte

@@ -4,12 +4,18 @@
 
 	import { createEventDispatcher } from 'svelte';
 	import { onMount, tick, getContext } from 'svelte';
+	import type { Writable } from 'svelte/store';
+	import type { i18n as i18nType } from 'i18next';
 
 	const i18n = getContext<Writable<i18nType>>('i18n');
 
 	const dispatch = createEventDispatcher();
 
-	import { config, models, settings, user } from '$lib/stores';
+	import { createNewFeedback, getFeedbackById, updateFeedbackById } from '$lib/apis/evaluations';
+	import { getChatById } from '$lib/apis/chats';
+	import { generateTags } from '$lib/apis';
+
+	import { config, models, settings, TTSWorker, user } from '$lib/stores';
 	import { synthesizeOpenAISpeech } from '$lib/apis/audio';
 	import { imageGenerations } from '$lib/apis/images';
 	import {
@@ -34,13 +40,8 @@
 	import Error from './Error.svelte';
 	import Citations from './Citations.svelte';
 	import CodeExecutions from './CodeExecutions.svelte';
-
-	import type { Writable } from 'svelte/store';
-	import type { i18n as i18nType } from 'i18next';
 	import ContentRenderer from './ContentRenderer.svelte';
-	import { createNewFeedback, getFeedbackById, updateFeedbackById } from '$lib/apis/evaluations';
-	import { getChatById } from '$lib/apis/chats';
-	import { generateTags } from '$lib/apis';
+	import { KokoroWorker } from '$lib/workers/KokoroWorker';
 
 	interface MessageType {
 		id: string;
@@ -193,7 +194,42 @@
 
 		speaking = true;
 
-		if ($config.audio.tts.engine !== '') {
+		if ($config.audio.tts.engine === '') {
+			let voices = [];
+			const getVoicesLoop = setInterval(() => {
+				voices = speechSynthesis.getVoices();
+				if (voices.length > 0) {
+					clearInterval(getVoicesLoop);
+
+					const voice =
+						voices
+							?.filter(
+								(v) => v.voiceURI === ($settings?.audio?.tts?.voice ?? $config?.audio?.tts?.voice)
+							)
+							?.at(0) ?? undefined;
+
+					console.log(voice);
+
+					const speak = new SpeechSynthesisUtterance(message.content);
+					speak.rate = $settings.audio?.tts?.playbackRate ?? 1;
+
+					console.log(speak);
+
+					speak.onend = () => {
+						speaking = false;
+						if ($settings.conversationMode) {
+							document.getElementById('voice-input-button')?.click();
+						}
+					};
+
+					if (voice) {
+						speak.voice = voice;
+					}
+
+					speechSynthesis.speak(speak);
+				}
+			}, 100);
+		} else {
 			loadingSpeech = true;
 
 			const messageContentParts: string[] = getMessageContentParts(
@@ -222,67 +258,70 @@
 
 			let lastPlayedAudioPromise = Promise.resolve(); // Initialize a promise that resolves immediately
 
-			for (const [idx, sentence] of messageContentParts.entries()) {
-				const res = await synthesizeOpenAISpeech(
-					localStorage.token,
-					$settings?.audio?.tts?.defaultVoice === $config.audio.tts.voice
-						? ($settings?.audio?.tts?.voice ?? $config?.audio?.tts?.voice)
-						: $config?.audio?.tts?.voice,
-					sentence
-				).catch((error) => {
-					console.error(error);
-					toast.error(`${error}`);
+			if ($settings.audio?.tts?.engine === 'browser-kokoro') {
+				if (!$TTSWorker) {
+					await TTSWorker.set(
+						new KokoroWorker({
+							dtype: $settings.audio?.tts?.engineConfig?.dtype ?? 'fp32'
+						})
+					);
 
-					speaking = false;
-					loadingSpeech = false;
-				});
-
-				if (res) {
-					const blob = await res.blob();
-					const blobUrl = URL.createObjectURL(blob);
-					const audio = new Audio(blobUrl);
-					audio.playbackRate = $settings.audio?.tts?.playbackRate ?? 1;
-
-					audioParts[idx] = audio;
-					loadingSpeech = false;
-					lastPlayedAudioPromise = lastPlayedAudioPromise.then(() => playAudio(idx));
+					await $TTSWorker.init();
 				}
-			}
-		} else {
-			let voices = [];
-			const getVoicesLoop = setInterval(() => {
-				voices = speechSynthesis.getVoices();
-				if (voices.length > 0) {
-					clearInterval(getVoicesLoop);
-
-					const voice =
-						voices
-							?.filter(
-								(v) => v.voiceURI === ($settings?.audio?.tts?.voice ?? $config?.audio?.tts?.voice)
-							)
-							?.at(0) ?? undefined;
-
-					console.log(voice);
-
-					const speak = new SpeechSynthesisUtterance(message.content);
-					speak.rate = $settings.audio?.tts?.playbackRate ?? 1;
 
-					console.log(speak);
+				console.log($TTSWorker);
+
+				for (const [idx, sentence] of messageContentParts.entries()) {
+					const blob = await $TTSWorker
+						.generate({
+							text: sentence,
+							voice: $settings?.audio?.tts?.voice ?? $config?.audio?.tts?.voice
+						})
+						.catch((error) => {
+							console.error(error);
+							toast.error(`${error}`);
+
+							speaking = false;
+							loadingSpeech = false;
+						});
+
+					if (blob) {
+						const audio = new Audio(blob);
+						audio.playbackRate = $settings.audio?.tts?.playbackRate ?? 1;
+
+						audioParts[idx] = audio;
+						loadingSpeech = false;
+						lastPlayedAudioPromise = lastPlayedAudioPromise.then(() => playAudio(idx));
+					}
+				}
+			} else {
+				for (const [idx, sentence] of messageContentParts.entries()) {
+					const res = await synthesizeOpenAISpeech(
+						localStorage.token,
+						$settings?.audio?.tts?.defaultVoice === $config.audio.tts.voice
+							? ($settings?.audio?.tts?.voice ?? $config?.audio?.tts?.voice)
+							: $config?.audio?.tts?.voice,
+						sentence
+					).catch((error) => {
+						console.error(error);
+						toast.error(`${error}`);
 
-					speak.onend = () => {
 						speaking = false;
-						if ($settings.conversationMode) {
-							document.getElementById('voice-input-button')?.click();
-						}
-					};
+						loadingSpeech = false;
+					});
 
-					if (voice) {
-						speak.voice = voice;
-					}
+					if (res) {
+						const blob = await res.blob();
+						const blobUrl = URL.createObjectURL(blob);
+						const audio = new Audio(blobUrl);
+						audio.playbackRate = $settings.audio?.tts?.playbackRate ?? 1;
 
-					speechSynthesis.speak(speak);
+						audioParts[idx] = audio;
+						loadingSpeech = false;
+						lastPlayedAudioPromise = lastPlayedAudioPromise.then(() => playAudio(idx));
+					}
 				}
-			}, 100);
+			}
 		}
 	};
 

+ 162 - 16
src/lib/components/chat/Settings/Audio.svelte

@@ -1,11 +1,14 @@
 <script lang="ts">
 	import { toast } from 'svelte-sonner';
 	import { createEventDispatcher, onMount, getContext } from 'svelte';
+	import { KokoroTTS } from 'kokoro-js';
 
 	import { user, settings, config } from '$lib/stores';
 	import { getVoices as _getVoices } from '$lib/apis/audio';
 
 	import Switch from '$lib/components/common/Switch.svelte';
+	import { round } from '@huggingface/transformers';
+	import Spinner from '$lib/components/common/Spinner.svelte';
 	const dispatch = createEventDispatcher();
 
 	const i18n = getContext('i18n');
@@ -20,6 +23,13 @@
 
 	let STTEngine = '';
 
+	let TTSEngine = '';
+	let TTSEngineConfig = {};
+
+	let TTSModel = null;
+	let TTSModelProgress = null;
+	let TTSModelLoading = false;
+
 	let voices = [];
 	let voice = '';
 
@@ -28,23 +38,37 @@
 	const speedOptions = [2, 1.75, 1.5, 1.25, 1, 0.75, 0.5];
 
 	const getVoices = async () => {
-		if ($config.audio.tts.engine === '') {
-			const getVoicesLoop = setInterval(async () => {
-				voices = await speechSynthesis.getVoices();
+		if (TTSEngine === 'browser-kokoro') {
+			if (!TTSModel) {
+				await loadKokoro();
+			}
 
-				// do your loop
-				if (voices.length > 0) {
-					clearInterval(getVoicesLoop);
-				}
-			}, 100);
-		} else {
-			const res = await _getVoices(localStorage.token).catch((e) => {
-				toast.error(`${e}`);
+			voices = Object.entries(TTSModel.voices).map(([key, value]) => {
+				return {
+					id: key,
+					name: value.name,
+					localService: false
+				};
 			});
-
-			if (res) {
-				console.log(res);
-				voices = res.voices;
+		} else {
+			if ($config.audio.tts.engine === '') {
+				const getVoicesLoop = setInterval(async () => {
+					voices = await speechSynthesis.getVoices();
+
+					// do your loop
+					if (voices.length > 0) {
+						clearInterval(getVoicesLoop);
+					}
+				}, 100);
+			} else {
+				const res = await _getVoices(localStorage.token).catch((e) => {
+					toast.error(`${e}`);
+				});
+
+				if (res) {
+					console.log(res);
+					voices = res.voices;
+				}
 			}
 		}
 	};
@@ -67,6 +91,9 @@
 
 		STTEngine = $settings?.audio?.stt?.engine ?? '';
 
+		TTSEngine = $settings?.audio?.tts?.engine ?? '';
+		TTSEngineConfig = $settings?.audio?.tts?.engineConfig ?? {};
+
 		if ($settings?.audio?.tts?.defaultVoice === $config.audio.tts.voice) {
 			voice = $settings?.audio?.tts?.voice ?? $config.audio.tts.voice ?? '';
 		} else {
@@ -77,6 +104,51 @@
 
 		await getVoices();
 	});
+
+	$: if (TTSEngine && TTSEngineConfig) {
+		onTTSEngineChange();
+	}
+
+	const onTTSEngineChange = async () => {
+		if (TTSEngine === 'browser-kokoro') {
+			await loadKokoro();
+		}
+	};
+
+	const loadKokoro = async () => {
+		if (TTSEngine === 'browser-kokoro') {
+			voices = [];
+
+			if (TTSEngineConfig?.dtype) {
+				TTSModel = null;
+				TTSModelProgress = null;
+				TTSModelLoading = true;
+
+				const model_id = 'onnx-community/Kokoro-82M-v1.0-ONNX';
+
+				TTSModel = await KokoroTTS.from_pretrained(model_id, {
+					dtype: TTSEngineConfig.dtype, // Options: "fp32", "fp16", "q8", "q4", "q4f16"
+					device: !!navigator?.gpu ? 'webgpu' : 'wasm', // Detect WebGPU
+					progress_callback: (e) => {
+						TTSModelProgress = e;
+						console.log(e);
+					}
+				});
+
+				await getVoices();
+
+				// const rawAudio = await tts.generate(inputText, {
+				// 	// Use `tts.list_voices()` to list all available voices
+				// 	voice: voice
+				// });
+
+				// const blobUrl = URL.createObjectURL(await rawAudio.toBlob());
+				// const audio = new Audio(blobUrl);
+
+				// audio.play();
+			}
+		}
+	};
 </script>
 
 <form
@@ -88,6 +160,8 @@
 					engine: STTEngine !== '' ? STTEngine : undefined
 				},
 				tts: {
+					engine: TTSEngine !== '' ? TTSEngine : undefined,
+					engineConfig: TTSEngineConfig,
 					playbackRate: playbackRate,
 					voice: voice !== '' ? voice : undefined,
 					defaultVoice: $config?.audio?.tts?.voice ?? '',
@@ -142,6 +216,39 @@
 		<div>
 			<div class=" mb-1 text-sm font-medium">{$i18n.t('TTS Settings')}</div>
 
+			<div class=" py-0.5 flex w-full justify-between">
+				<div class=" self-center text-xs font-medium">{$i18n.t('Text-to-Speech Engine')}</div>
+				<div class="flex items-center relative">
+					<select
+						class="dark:bg-gray-900 w-fit pr-8 rounded px-2 p-1 text-xs bg-transparent outline-none text-right"
+						bind:value={TTSEngine}
+						placeholder="Select an engine"
+					>
+						<option value="">{$i18n.t('Default')}</option>
+						<option value="browser-kokoro">{$i18n.t('Kokoro.js (Browser)')}</option>
+					</select>
+				</div>
+			</div>
+
+			{#if TTSEngine === 'browser-kokoro'}
+				<div class=" py-0.5 flex w-full justify-between">
+					<div class=" self-center text-xs font-medium">{$i18n.t('Kokoro.js Dtype')}</div>
+					<div class="flex items-center relative">
+						<select
+							class="dark:bg-gray-900 w-fit pr-8 rounded px-2 p-1 text-xs bg-transparent outline-none text-right"
+							bind:value={TTSEngineConfig.dtype}
+							placeholder="Select dtype"
+						>
+							<option value="" disabled selected>Select dtype</option>
+							<option value="fp32">fp32</option>
+							<option value="fp16">fp16</option>
+							<option value="q8">q8</option>
+							<option value="q4">q4</option>
+						</select>
+					</div>
+				</div>
+			{/if}
+
 			<div class=" py-0.5 flex w-full justify-between">
 				<div class=" self-center text-xs font-medium">{$i18n.t('Auto-playback response')}</div>
 
@@ -178,7 +285,46 @@
 
 		<hr class=" dark:border-gray-850" />
 
-		{#if $config.audio.tts.engine === ''}
+		{#if TTSEngine === 'browser-kokoro'}
+			{#if TTSModel}
+				<div>
+					<div class=" mb-2.5 text-sm font-medium">{$i18n.t('Set Voice')}</div>
+					<div class="flex w-full">
+						<div class="flex-1">
+							<input
+								list="voice-list"
+								class="w-full rounded-lg py-2 px-4 text-sm dark:text-gray-300 dark:bg-gray-850 outline-none"
+								bind:value={voice}
+								placeholder="Select a voice"
+							/>
+
+							<datalist id="voice-list">
+								{#each voices as voice}
+									<option value={voice.id}>{voice.name}</option>
+								{/each}
+							</datalist>
+						</div>
+					</div>
+				</div>
+			{:else}
+				<div>
+					<div class=" mb-2.5 text-sm font-medium flex gap-2 items-center">
+						<Spinner className="size-4" />
+
+						<div class=" text-sm font-medium shimmer">
+							{$i18n.t('Loading Kokoro.js...')}
+							{TTSModelProgress && TTSModelProgress.status === 'progress'
+								? `(${Math.round(TTSModelProgress.progress * 10) / 10}%)`
+								: ''}
+						</div>
+					</div>
+
+					<div class="text-xs text-gray-500">
+						{$i18n.t('Please do not close the settings page while loading the model.')}
+					</div>
+				</div>
+			{/if}
+		{:else if $config.audio.tts.engine === ''}
 			<div>
 				<div class=" mb-2.5 text-sm font-medium">{$i18n.t('Set Voice')}</div>
 				<div class="flex w-full">

+ 2 - 0
src/lib/stores/index.ts

@@ -41,6 +41,8 @@ export const shortCodesToEmojis = writable(
 	}, {})
 );
 
+export const TTSWorker = writable(null);
+
 export const chatId = writable('');
 export const chatTitle = writable('');
 

+ 70 - 0
src/lib/workers/KokoroWorker.ts

@@ -0,0 +1,70 @@
+import WorkerInstance from '$lib/workers/kokoro.worker?worker';
+
+export class KokoroWorker {
+	private worker: Worker | null = null;
+	private initialized: boolean = false;
+	private dtype: string;
+
+	constructor(dtype: string = 'fp32') {
+		this.dtype = dtype;
+	}
+
+	public async init() {
+		if (this.worker) {
+			console.warn('KokoroWorker is already initialized.');
+			return;
+		}
+
+		this.worker = new WorkerInstance();
+
+		return new Promise<void>((resolve, reject) => {
+			this.worker!.onmessage = (event) => {
+				const { status, error } = event.data;
+
+				if (status === 'init:complete') {
+					this.initialized = true;
+					resolve();
+				} else if (status === 'init:error') {
+					console.error(error);
+					this.initialized = false;
+					reject(new Error(error));
+				}
+			};
+
+			this.worker!.postMessage({
+				type: 'init',
+				payload: { dtype: this.dtype }
+			});
+		});
+	}
+
+	public async generate({ text, voice }: { text: string; voice: string }): Promise<string> {
+		if (!this.initialized || !this.worker) {
+			throw new Error('KokoroTTS Worker is not initialized yet.');
+		}
+
+		return new Promise<string>((resolve, reject) => {
+			this.worker.postMessage({ type: 'generate', payload: { text, voice } });
+
+			const handleMessage = (event: MessageEvent) => {
+				if (event.data.status === 'generate:complete') {
+					this.worker!.removeEventListener('message', handleMessage);
+					resolve(event.data.audioUrl);
+				} else if (event.data.status === 'generate:error') {
+					this.worker!.removeEventListener('message', handleMessage);
+					reject(new Error(event.data.error));
+				}
+			};
+
+			this.worker.addEventListener('message', handleMessage);
+		});
+	}
+
+	public terminate() {
+		if (this.worker) {
+			this.worker.terminate();
+			this.worker = null;
+			this.initialized = false;
+		}
+	}
+}

+ 53 - 0
src/lib/workers/kokoro.worker.ts

@@ -0,0 +1,53 @@
+import { KokoroTTS } from 'kokoro-js';
+
+let tts;
+let isInitialized = false; // Flag to track initialization status
+const DEFAULT_MODEL_ID = 'onnx-community/Kokoro-82M-v1.0-ONNX'; // Default model
+
+self.onmessage = async (event) => {
+	const { type, payload } = event.data;
+
+	if (type === 'init') {
+		let { model_id, dtype } = payload;
+		model_id = model_id || DEFAULT_MODEL_ID; // Use default model if none provided
+
+		self.postMessage({ status: 'init:start' });
+
+		try {
+			tts = await KokoroTTS.from_pretrained(model_id, {
+				dtype,
+				device: !!navigator?.gpu ? 'webgpu' : 'wasm' // Detect WebGPU
+			});
+			isInitialized = true; // Mark as initialized after successful loading
+			self.postMessage({ status: 'init:complete' });
+		} catch (error) {
+			isInitialized = false; // Ensure it's marked as false on failure
+			self.postMessage({ status: 'init:error', error: error.message });
+		}
+	}
+
+	if (type === 'generate') {
+		if (!isInitialized || !tts) {
+			// Ensure model is initialized
+			self.postMessage({ status: 'generate:error', error: 'TTS model not initialized' });
+			return;
+		}
+
+		const { text, voice } = payload;
+		self.postMessage({ status: 'generate:start' });
+
+		try {
+			const rawAudio = await tts.generate(text, { voice });
+			const blob = await rawAudio.toBlob();
+			const blobUrl = URL.createObjectURL(blob);
+			self.postMessage({ status: 'generate:complete', audioUrl: blobUrl });
+		} catch (error) {
+			self.postMessage({ status: 'generate:error', error: error.message });
+		}
+	}
+
+	if (type === 'status') {
+		// Respond with the current initialization status
+		self.postMessage({ status: 'status:check', initialized: isInitialized });
+	}
+};