Explorar o código

Merge pull request #707 from ollama-webui/whisper

feat: whisper support
Timothy Jaeryang Baek hai 1 ano
pai
achega
e1a6ccd1aa

+ 80 - 0
backend/apps/audio/main.py

@@ -0,0 +1,80 @@
+from fastapi import (
+    FastAPI,
+    Request,
+    Depends,
+    HTTPException,
+    status,
+    UploadFile,
+    File,
+    Form,
+)
+from fastapi.middleware.cors import CORSMiddleware
+from faster_whisper import WhisperModel
+
+from constants import ERROR_MESSAGES
+from utils.utils import (
+    decode_token,
+    get_current_user,
+    get_verified_user,
+    get_admin_user,
+)
+from utils.misc import calculate_sha256
+
+from config import CACHE_DIR, UPLOAD_DIR, WHISPER_MODEL_NAME
+
+app = FastAPI()
+app.add_middleware(
+    CORSMiddleware,
+    allow_origins=["*"],
+    allow_credentials=True,
+    allow_methods=["*"],
+    allow_headers=["*"],
+)
+
+
+@app.post("/transcribe")
+def transcribe(
+    file: UploadFile = File(...),
+    user=Depends(get_current_user),
+):
+    print(file.content_type)
+
+    if file.content_type not in ["audio/mpeg", "audio/wav"]:
+        raise HTTPException(
+            status_code=status.HTTP_400_BAD_REQUEST,
+            detail=ERROR_MESSAGES.FILE_NOT_SUPPORTED,
+        )
+
+    try:
+        filename = file.filename
+        file_path = f"{UPLOAD_DIR}/{filename}"
+        contents = file.file.read()
+        with open(file_path, "wb") as f:
+            f.write(contents)
+            f.close()
+
+        model_name = WHISPER_MODEL_NAME
+        model = WhisperModel(
+            model_name,
+            device="cpu",
+            compute_type="int8",
+            download_root=f"{CACHE_DIR}/whisper/models",
+        )
+
+        segments, info = model.transcribe(file_path, beam_size=5)
+        print(
+            "Detected language '%s' with probability %f"
+            % (info.language, info.language_probability)
+        )
+
+        transcript = "".join([segment.text for segment in list(segments)])
+
+        return {"text": transcript.strip()}
+
+    except Exception as e:
+        print(e)
+
+        raise HTTPException(
+            status_code=status.HTTP_400_BAD_REQUEST,
+            detail=ERROR_MESSAGES.DEFAULT(e),
+        )

+ 5 - 0
backend/config.py

@@ -132,3 +132,8 @@ CHROMA_CLIENT = chromadb.PersistentClient(
 )
 CHUNK_SIZE = 1500
 CHUNK_OVERLAP = 100
+
+####################################
+# Transcribe
+####################################
+WHISPER_MODEL_NAME = "base"

+ 4 - 0
backend/main.py

@@ -10,6 +10,8 @@ from starlette.exceptions import HTTPException as StarletteHTTPException
 
 from apps.ollama.main import app as ollama_app
 from apps.openai.main import app as openai_app
+from apps.audio.main import app as audio_app
+
 
 from apps.web.main import app as webui_app
 from apps.rag.main import app as rag_app
@@ -55,6 +57,8 @@ app.mount("/api/v1", webui_app)
 
 app.mount("/ollama/api", ollama_app)
 app.mount("/openai/api", openai_app)
+
+app.mount("/audio/api/v1", audio_app)
 app.mount("/rag/api/v1", rag_app)
 
 

+ 2 - 0
backend/requirements.txt

@@ -30,6 +30,8 @@ openpyxl
 pyxlsb
 xlrd
 
+faster-whisper
+
 PyJWT
 pyjwt[crypto]
 

+ 31 - 0
src/lib/apis/audio/index.ts

@@ -0,0 +1,31 @@
+import { AUDIO_API_BASE_URL } from '$lib/constants';
+
+export const transcribeAudio = async (token: string, file: File) => {
+	const data = new FormData();
+	data.append('file', file);
+
+	let error = null;
+	const res = await fetch(`${AUDIO_API_BASE_URL}/transcribe`, {
+		method: 'POST',
+		headers: {
+			Accept: 'application/json',
+			authorization: `Bearer ${token}`
+		},
+		body: data
+	})
+		.then(async (res) => {
+			if (!res.ok) throw await res.json();
+			return res.json();
+		})
+		.catch((err) => {
+			error = err.detail;
+			console.log(err);
+			return null;
+		});
+
+	if (error) {
+		throw error;
+	}
+
+	return res;
+};

+ 180 - 54
src/lib/components/chat/MessageInput.svelte

@@ -2,7 +2,7 @@
 	import toast from 'svelte-french-toast';
 	import { onMount, tick } from 'svelte';
 	import { settings } from '$lib/stores';
-	import { calculateSHA256, findWordIndices } from '$lib/utils';
+	import { blobToFile, calculateSHA256, findWordIndices } from '$lib/utils';
 
 	import Prompts from './MessageInput/PromptCommands.svelte';
 	import Suggestions from './MessageInput/Suggestions.svelte';
@@ -11,6 +11,7 @@
 	import { SUPPORTED_FILE_TYPE, SUPPORTED_FILE_EXTENSIONS } from '$lib/constants';
 	import Documents from './MessageInput/Documents.svelte';
 	import Models from './MessageInput/Models.svelte';
+	import { transcribeAudio } from '$lib/apis/audio';
 
 	export let submitPrompt: Function;
 	export let stopResponse: Function;
@@ -34,7 +35,6 @@
 
 	export let fileUploadEnabled = true;
 	export let speechRecognitionEnabled = true;
-	export let speechRecognitionListening = false;
 
 	export let prompt = '';
 	export let messages = [];
@@ -50,62 +50,170 @@
 		}
 	}
 
+	let mediaRecorder;
+	let audioChunks = [];
+	let isRecording = false;
+	const MIN_DECIBELS = -45;
+
+	const startRecording = async () => {
+		const stream = await navigator.mediaDevices.getUserMedia({ audio: true });
+		mediaRecorder = new MediaRecorder(stream);
+		mediaRecorder.onstart = () => {
+			isRecording = true;
+			console.log('Recording started');
+		};
+		mediaRecorder.ondataavailable = (event) => audioChunks.push(event.data);
+		mediaRecorder.onstop = async () => {
+			isRecording = false;
+			console.log('Recording stopped');
+
+			// Create a blob from the audio chunks
+			const audioBlob = new Blob(audioChunks, { type: 'audio/wav' });
+
+			const file = blobToFile(audioBlob, 'recording.wav');
+
+			const res = await transcribeAudio(localStorage.token, file).catch((error) => {
+				toast.error(error);
+				return null;
+			});
+
+			if (res) {
+				prompt = res.text;
+				await tick();
+
+				const inputElement = document.getElementById('chat-textarea');
+				inputElement?.focus();
+
+				if (prompt !== '' && $settings?.speechAutoSend === true) {
+					submitPrompt(prompt, user);
+				}
+			}
+
+			// saveRecording(audioBlob);
+			audioChunks = [];
+		};
+
+		// Start recording
+		mediaRecorder.start();
+
+		// Monitor silence
+		monitorSilence(stream);
+	};
+
+	const monitorSilence = (stream) => {
+		const audioContext = new AudioContext();
+		const audioStreamSource = audioContext.createMediaStreamSource(stream);
+		const analyser = audioContext.createAnalyser();
+		analyser.minDecibels = MIN_DECIBELS;
+		audioStreamSource.connect(analyser);
+
+		const bufferLength = analyser.frequencyBinCount;
+		const domainData = new Uint8Array(bufferLength);
+
+		let lastSoundTime = Date.now();
+
+		const detectSound = () => {
+			analyser.getByteFrequencyData(domainData);
+
+			if (domainData.some((value) => value > 0)) {
+				lastSoundTime = Date.now();
+			}
+
+			if (isRecording && Date.now() - lastSoundTime > 3000) {
+				mediaRecorder.stop();
+				audioContext.close();
+				return;
+			}
+
+			window.requestAnimationFrame(detectSound);
+		};
+
+		window.requestAnimationFrame(detectSound);
+	};
+
+	const saveRecording = (blob) => {
+		const url = URL.createObjectURL(blob);
+		const a = document.createElement('a');
+		document.body.appendChild(a);
+		a.style = 'display: none';
+		a.href = url;
+		a.download = 'recording.wav';
+		a.click();
+		window.URL.revokeObjectURL(url);
+	};
+
 	const speechRecognitionHandler = () => {
 		// Check if SpeechRecognition is supported
 
-		if (speechRecognitionListening) {
-			speechRecognition.stop();
-		} else {
-			if ('SpeechRecognition' in window || 'webkitSpeechRecognition' in window) {
-				// Create a SpeechRecognition object
-				speechRecognition = new (window.SpeechRecognition || window.webkitSpeechRecognition)();
-
-				// Set continuous to true for continuous recognition
-				speechRecognition.continuous = true;
-
-				// Set the timeout for turning off the recognition after inactivity (in milliseconds)
-				const inactivityTimeout = 3000; // 3 seconds
-
-				let timeoutId;
-				// Start recognition
-				speechRecognition.start();
-				speechRecognitionListening = true;
-
-				// Event triggered when speech is recognized
-				speechRecognition.onresult = function (event) {
-					// Clear the inactivity timeout
-					clearTimeout(timeoutId);
-
-					// Handle recognized speech
-					console.log(event);
-					const transcript = event.results[Object.keys(event.results).length - 1][0].transcript;
-					prompt = `${prompt}${transcript}`;
-
-					// Restart the inactivity timeout
-					timeoutId = setTimeout(() => {
-						console.log('Speech recognition turned off due to inactivity.');
-						speechRecognition.stop();
-					}, inactivityTimeout);
-				};
+		if (isRecording) {
+			if (speechRecognition) {
+				speechRecognition.stop();
+			}
 
-				// Event triggered when recognition is ended
-				speechRecognition.onend = function () {
-					// Restart recognition after it ends
-					console.log('recognition ended');
-					speechRecognitionListening = false;
-					if (prompt !== '' && $settings?.speechAutoSend === true) {
-						submitPrompt(prompt, user);
-					}
-				};
+			if (mediaRecorder) {
+				mediaRecorder.stop();
+			}
+		} else {
+			isRecording = true;
 
-				// Event triggered when an error occurs
-				speechRecognition.onerror = function (event) {
-					console.log(event);
-					toast.error(`Speech recognition error: ${event.error}`);
-					speechRecognitionListening = false;
-				};
+			if ($settings?.voice?.STTEngine ?? '' !== '') {
+				startRecording();
 			} else {
-				toast.error('SpeechRecognition API is not supported in this browser.');
+				if ('SpeechRecognition' in window || 'webkitSpeechRecognition' in window) {
+					// Create a SpeechRecognition object
+					speechRecognition = new (window.SpeechRecognition || window.webkitSpeechRecognition)();
+
+					// Set continuous to true for continuous recognition
+					speechRecognition.continuous = true;
+
+					// Set the timeout for turning off the recognition after inactivity (in milliseconds)
+					const inactivityTimeout = 3000; // 3 seconds
+
+					let timeoutId;
+					// Start recognition
+					speechRecognition.start();
+
+					// Event triggered when speech is recognized
+					speechRecognition.onresult = async (event) => {
+						// Clear the inactivity timeout
+						clearTimeout(timeoutId);
+
+						// Handle recognized speech
+						console.log(event);
+						const transcript = event.results[Object.keys(event.results).length - 1][0].transcript;
+
+						prompt = `${prompt}${transcript}`;
+
+						await tick();
+						const inputElement = document.getElementById('chat-textarea');
+						inputElement?.focus();
+
+						// Restart the inactivity timeout
+						timeoutId = setTimeout(() => {
+							console.log('Speech recognition turned off due to inactivity.');
+							speechRecognition.stop();
+						}, inactivityTimeout);
+					};
+
+					// Event triggered when recognition is ended
+					speechRecognition.onend = function () {
+						// Restart recognition after it ends
+						console.log('recognition ended');
+						isRecording = false;
+						if (prompt !== '' && $settings?.speechAutoSend === true) {
+							submitPrompt(prompt, user);
+						}
+					};
+
+					// Event triggered when an error occurs
+					speechRecognition.onerror = function (event) {
+						console.log(event);
+						toast.error(`Speech recognition error: ${event.error}`);
+						isRecording = false;
+					};
+				} else {
+					toast.error('SpeechRecognition API is not supported in this browser.');
+				}
 			}
 		}
 	};
@@ -123,6 +231,20 @@
 
 		try {
 			files = [...files, doc];
+
+			if (['audio/mpeg', 'audio/wav'].includes(file['type'])) {
+				const res = await transcribeAudio(localStorage.token, file).catch((error) => {
+					toast.error(error);
+					return null;
+				});
+
+				if (res) {
+					console.log(res);
+					const blob = new Blob([res.text], { type: 'text/plain' });
+					file = blobToFile(blob, `${file.name}.txt`);
+				}
+			}
+
 			const res = await uploadDocToVectorDB(localStorage.token, '', file);
 
 			if (res) {
@@ -535,7 +657,7 @@
 								: ' pl-4'} rounded-xl resize-none h-[48px]"
 							placeholder={chatInputPlaceholder !== ''
 								? chatInputPlaceholder
-								: speechRecognitionListening
+								: isRecording
 								? 'Listening...'
 								: 'Send a message'}
 							bind:value={prompt}
@@ -644,6 +766,10 @@
 								e.target.style.height = Math.min(e.target.scrollHeight, 200) + 'px';
 								user = null;
 							}}
+							on:focus={(e) => {
+								e.target.style.height = '';
+								e.target.style.height = Math.min(e.target.scrollHeight, 200) + 'px';
+							}}
 							on:paste={(e) => {
 								const clipboardData = e.clipboardData || window.clipboardData;
 
@@ -681,7 +807,7 @@
 											speechRecognitionHandler();
 										}}
 									>
-										{#if speechRecognitionListening}
+										{#if isRecording}
 											<svg
 												class=" w-5 h-5 translate-y-[0.5px]"
 												fill="currentColor"

+ 3 - 3
src/lib/components/chat/Messages/ResponseMessage.svelte

@@ -148,7 +148,7 @@
 		} else {
 			speaking = true;
 
-			if ($settings?.speech?.engine === 'openai') {
+			if ($settings?.audio?.TTSEngine === 'openai') {
 				loadingSpeech = true;
 
 				const sentences = extractSentences(message.content).reduce((mergedTexts, currentText) => {
@@ -179,7 +179,7 @@
 				for (const [idx, sentence] of sentences.entries()) {
 					const res = await synthesizeOpenAISpeech(
 						localStorage.token,
-						$settings?.speech?.speaker,
+						$settings?.audio?.speaker,
 						sentence
 					).catch((error) => {
 						toast.error(error);
@@ -204,7 +204,7 @@
 						clearInterval(getVoicesLoop);
 
 						const voice =
-							voices?.filter((v) => v.name === $settings?.speech?.speaker)?.at(0) ?? undefined;
+							voices?.filter((v) => v.name === $settings?.audio?.speaker)?.at(0) ?? undefined;
 
 						const speak = new SpeechSynthesisUtterance(message.content);
 

+ 53 - 21
src/lib/components/chat/Settings/Voice.svelte → src/lib/components/chat/Settings/Audio.svelte

@@ -1,17 +1,21 @@
 <script lang="ts">
 	import { createEventDispatcher, onMount } from 'svelte';
+	import toast from 'svelte-french-toast';
 	const dispatch = createEventDispatcher();
 
 	export let saveSettings: Function;
 
-	// Voice
+	// Audio
+
+	let STTEngines = ['', 'openai'];
+	let STTEngine = '';
 
 	let conversationMode = false;
 	let speechAutoSend = false;
 	let responseAutoPlayback = false;
 
-	let engines = ['', 'openai'];
-	let engine = '';
+	let TTSEngines = ['', 'openai'];
+	let TTSEngine = '';
 
 	let voices = [];
 	let speaker = '';
@@ -70,10 +74,11 @@
 		speechAutoSend = settings.speechAutoSend ?? false;
 		responseAutoPlayback = settings.responseAutoPlayback ?? false;
 
-		engine = settings?.speech?.engine ?? '';
-		speaker = settings?.speech?.speaker ?? '';
+		STTEngine = settings?.audio?.STTEngine ?? '';
+		TTSEngine = settings?.audio?.TTSEngine ?? '';
+		speaker = settings?.audio?.speaker ?? '';
 
-		if (engine === 'openai') {
+		if (TTSEngine === 'openai') {
 			getOpenAIVoices();
 		} else {
 			getWebAPIVoices();
@@ -85,37 +90,37 @@
 	class="flex flex-col h-full justify-between space-y-3 text-sm"
 	on:submit|preventDefault={() => {
 		saveSettings({
-			speech: {
-				engine: engine !== '' ? engine : undefined,
+			audio: {
+				STTEngine: STTEngine !== '' ? STTEngine : undefined,
+				TTSEngine: TTSEngine !== '' ? TTSEngine : undefined,
 				speaker: speaker !== '' ? speaker : undefined
 			}
 		});
 		dispatch('save');
 	}}
 >
-	<div class=" space-y-3">
+	<div class=" space-y-3 pr-1.5 overflow-y-scroll max-h-80">
 		<div>
-			<div class=" mb-1 text-sm font-medium">TTS Settings</div>
+			<div class=" mb-1 text-sm font-medium">STT Settings</div>
 
 			<div class=" py-0.5 flex w-full justify-between">
-				<div class=" self-center text-xs font-medium">Speech Engine</div>
+				<div class=" self-center text-xs font-medium">Speech-to-Text Engine</div>
 				<div class="flex items-center relative">
 					<select
 						class="w-fit pr-8 rounded px-2 p-1 text-xs bg-transparent outline-none text-right"
-						bind:value={engine}
+						bind:value={STTEngine}
 						placeholder="Select a mode"
 						on:change={(e) => {
-							if (e.target.value === 'openai') {
-								getOpenAIVoices();
-								speaker = 'alloy';
-							} else {
-								getWebAPIVoices();
-								speaker = '';
+							if (e.target.value !== '') {
+								navigator.mediaDevices.getUserMedia({ audio: true }).catch(function (err) {
+									toast.error(`Permission denied when accessing microphone: ${err}`);
+									STTEngine = '';
+								});
 							}
 						}}
 					>
 						<option value="">Default (Web API)</option>
-						<option value="openai">Open AI</option>
+						<option value="whisper-local">Whisper (Local)</option>
 					</select>
 				</div>
 			</div>
@@ -155,6 +160,33 @@
 					{/if}
 				</button>
 			</div>
+		</div>
+
+		<div>
+			<div class=" mb-1 text-sm font-medium">TTS Settings</div>
+
+			<div class=" py-0.5 flex w-full justify-between">
+				<div class=" self-center text-xs font-medium">Text-to-Speech Engine</div>
+				<div class="flex items-center relative">
+					<select
+						class="w-fit pr-8 rounded px-2 p-1 text-xs bg-transparent outline-none text-right"
+						bind:value={TTSEngine}
+						placeholder="Select a mode"
+						on:change={(e) => {
+							if (e.target.value === 'openai') {
+								getOpenAIVoices();
+								speaker = 'alloy';
+							} else {
+								getWebAPIVoices();
+								speaker = '';
+							}
+						}}
+					>
+						<option value="">Default (Web API)</option>
+						<option value="openai">Open AI</option>
+					</select>
+				</div>
+			</div>
 
 			<div class=" py-0.5 flex w-full justify-between">
 				<div class=" self-center text-xs font-medium">Auto-playback response</div>
@@ -177,7 +209,7 @@
 
 		<hr class=" dark:border-gray-700" />
 
-		{#if engine === ''}
+		{#if TTSEngine === ''}
 			<div>
 				<div class=" mb-2.5 text-sm font-medium">Set Voice</div>
 				<div class="flex w-full">
@@ -196,7 +228,7 @@
 					</div>
 				</div>
 			</div>
-		{:else if engine === 'openai'}
+		{:else if TTSEngine === 'openai'}
 			<div>
 				<div class=" mb-2.5 text-sm font-medium">Set Voice</div>
 				<div class="flex w-full">

+ 6 - 6
src/lib/components/chat/SettingsModal.svelte

@@ -13,7 +13,7 @@
 	import General from './Settings/General.svelte';
 	import External from './Settings/External.svelte';
 	import Interface from './Settings/Interface.svelte';
-	import Voice from './Settings/Voice.svelte';
+	import Audio from './Settings/Audio.svelte';
 	import Chats from './Settings/Chats.svelte';
 
 	export let show = false;
@@ -206,11 +206,11 @@
 
 				<button
 					class="px-2.5 py-2.5 min-w-fit rounded-lg flex-1 md:flex-none flex text-right transition {selectedTab ===
-					'voice'
+					'audio'
 						? 'bg-gray-200 dark:bg-gray-700'
 						: ' hover:bg-gray-300 dark:hover:bg-gray-800'}"
 					on:click={() => {
-						selectedTab = 'voice';
+						selectedTab = 'audio';
 					}}
 				>
 					<div class=" self-center mr-2">
@@ -228,7 +228,7 @@
 							/>
 						</svg>
 					</div>
-					<div class=" self-center">Voice</div>
+					<div class=" self-center">Audio</div>
 				</button>
 
 				<button
@@ -341,8 +341,8 @@
 							show = false;
 						}}
 					/>
-				{:else if selectedTab === 'voice'}
-					<Voice
+				{:else if selectedTab === 'audio'}
+					<Audio
 						{saveSettings}
 						on:save={() => {
 							show = false;

+ 4 - 1
src/lib/constants.ts

@@ -7,6 +7,7 @@ export const WEBUI_API_BASE_URL = `${WEBUI_BASE_URL}/api/v1`;
 export const OLLAMA_API_BASE_URL = `${WEBUI_BASE_URL}/ollama/api`;
 export const OPENAI_API_BASE_URL = `${WEBUI_BASE_URL}/openai/api`;
 export const RAG_API_BASE_URL = `${WEBUI_BASE_URL}/rag/api/v1`;
+export const AUDIO_API_BASE_URL = `${WEBUI_BASE_URL}/audio/api/v1`;
 
 export const WEB_UI_VERSION = 'v1.0.0-alpha-static';
 
@@ -23,7 +24,9 @@ export const SUPPORTED_FILE_TYPE = [
 	'application/vnd.openxmlformats-officedocument.wordprocessingml.document',
 	'application/octet-stream',
 	'application/x-javascript',
-	'text/markdown'
+	'text/markdown',
+	'audio/mpeg',
+	'audio/wav'
 ];
 
 export const SUPPORTED_FILE_EXTENSIONS = [

+ 6 - 0
src/lib/utils/index.ts

@@ -341,3 +341,9 @@ export const extractSentences = (text) => {
 		.map((sentence) => removeEmojis(sentence.trim()))
 		.filter((sentence) => sentence !== '');
 };
+
+export const blobToFile = (blob, fileName) => {
+	// Create a new File object from the Blob
+	const file = new File([blob], fileName, { type: blob.type });
+	return file;
+};