瀏覽代碼

enh: faster whisper custom model support

Timothy J. Baek 6 月之前
父節點
當前提交
d5c1c2f0a7

+ 38 - 23
backend/open_webui/apps/audio/main.py

@@ -63,6 +63,9 @@ app.state.config.STT_OPENAI_API_KEY = AUDIO_STT_OPENAI_API_KEY
 app.state.config.STT_ENGINE = AUDIO_STT_ENGINE
 app.state.config.STT_MODEL = AUDIO_STT_MODEL
 
+app.state.config.WHISPER_MODEL = WHISPER_MODEL
+app.state.faster_whisper_model = None
+
 app.state.config.TTS_OPENAI_API_BASE_URL = AUDIO_TTS_OPENAI_API_BASE_URL
 app.state.config.TTS_OPENAI_API_KEY = AUDIO_TTS_OPENAI_API_KEY
 app.state.config.TTS_ENGINE = AUDIO_TTS_ENGINE
@@ -82,6 +85,31 @@ SPEECH_CACHE_DIR = Path(CACHE_DIR).joinpath("./audio/speech/")
 SPEECH_CACHE_DIR.mkdir(parents=True, exist_ok=True)
 
 
+def set_faster_whisper_model(model: str, auto_update: bool = False):
+    if model and app.state.config.STT_ENGINE == "":
+        from faster_whisper import WhisperModel
+
+        faster_whisper_kwargs = {
+            "model_size_or_path": model,
+            "device": whisper_device_type,
+            "compute_type": "int8",
+            "download_root": WHISPER_MODEL_DIR,
+            "local_files_only": not auto_update,
+        }
+
+        try:
+            app.state.faster_whisper_model = WhisperModel(**faster_whisper_kwargs)
+        except Exception:
+            log.warning(
+                "WhisperModel initialization failed, attempting download with local_files_only=False"
+            )
+            faster_whisper_kwargs["local_files_only"] = False
+            app.state.faster_whisper_model = WhisperModel(**faster_whisper_kwargs)
+
+    else:
+        app.state.faster_whisper_model = None
+
+
 class TTSConfigForm(BaseModel):
     OPENAI_API_BASE_URL: str
     OPENAI_API_KEY: str
@@ -99,6 +127,7 @@ class STTConfigForm(BaseModel):
     OPENAI_API_KEY: str
     ENGINE: str
     MODEL: str
+    WHISPER_MODEL: str
 
 
 class AudioConfigUpdateForm(BaseModel):
@@ -152,6 +181,7 @@ async def get_audio_config(user=Depends(get_admin_user)):
             "OPENAI_API_KEY": app.state.config.STT_OPENAI_API_KEY,
             "ENGINE": app.state.config.STT_ENGINE,
             "MODEL": app.state.config.STT_MODEL,
+            "WHISPER_MODEL": app.state.config.WHISPER_MODEL,
         },
     }
 
@@ -176,6 +206,8 @@ async def update_audio_config(
     app.state.config.STT_OPENAI_API_KEY = form_data.stt.OPENAI_API_KEY
     app.state.config.STT_ENGINE = form_data.stt.ENGINE
     app.state.config.STT_MODEL = form_data.stt.MODEL
+    app.state.config.WHISPER_MODEL = form_data.stt.WHISPER_MODEL
+    set_faster_whisper_model(form_data.stt.WHISPER_MODEL, WHISPER_MODEL_AUTO_UPDATE)
 
     return {
         "tts": {
@@ -194,6 +226,7 @@ async def update_audio_config(
             "OPENAI_API_KEY": app.state.config.STT_OPENAI_API_KEY,
             "ENGINE": app.state.config.STT_ENGINE,
             "MODEL": app.state.config.STT_MODEL,
+            "WHISPER_MODEL": app.state.config.WHISPER_MODEL,
         },
     }
 
@@ -367,27 +400,10 @@ def transcribe(file_path):
     id = filename.split(".")[0]
 
     if app.state.config.STT_ENGINE == "":
-        from faster_whisper import WhisperModel
-
-        whisper_kwargs = {
-            "model_size_or_path": WHISPER_MODEL,
-            "device": whisper_device_type,
-            "compute_type": "int8",
-            "download_root": WHISPER_MODEL_DIR,
-            "local_files_only": not WHISPER_MODEL_AUTO_UPDATE,
-        }
-
-        log.debug(f"whisper_kwargs: {whisper_kwargs}")
-
-        try:
-            model = WhisperModel(**whisper_kwargs)
-        except Exception:
-            log.warning(
-                "WhisperModel initialization failed, attempting download with local_files_only=False"
-            )
-            whisper_kwargs["local_files_only"] = False
-            model = WhisperModel(**whisper_kwargs)
+        if app.state.faster_whisper_model is None:
+            set_faster_whisper_model(app.state.config.WHISPER_MODEL)
 
+        model = app.state.faster_whisper_model
         segments, info = model.transcribe(file_path, beam_size=5)
         log.info(
             "Detected language '%s' with probability %f"
@@ -395,7 +411,6 @@ def transcribe(file_path):
         )
 
         transcript = "".join([segment.text for segment in list(segments)])
-
         data = {"text": transcript.strip()}
 
         # save the transcript to a json file
@@ -403,7 +418,7 @@ def transcribe(file_path):
         with open(transcript_file, "w") as f:
             json.dump(data, f)
 
-        print(data)
+        log.debug(data)
         return data
     elif app.state.config.STT_ENGINE == "openai":
         if is_mp4_audio(file_path):
@@ -417,7 +432,7 @@ def transcribe(file_path):
         files = {"file": (filename, open(file_path, "rb"))}
         data = {"model": app.state.config.STT_MODEL}
 
-        print(files, data)
+        log.debug(files, data)
 
         r = None
         try:

+ 1 - 1
backend/open_webui/apps/webui/routers/functions.py

@@ -9,7 +9,7 @@ from open_webui.apps.webui.models.functions import (
     Functions,
 )
 from open_webui.apps.webui.utils import load_function_module_by_id, replace_imports
-from open_webui.config import CACHE_DIR, FUNCTIONS_DIR
+from open_webui.config import CACHE_DIR
 from open_webui.constants import ERROR_MESSAGES
 from fastapi import APIRouter, Depends, HTTPException, Request, status
 from open_webui.utils.utils import get_admin_user, get_verified_user

+ 0 - 3
backend/open_webui/apps/webui/routers/tools.py

@@ -10,9 +10,6 @@ from fastapi import APIRouter, Depends, HTTPException, Request, status
 from open_webui.utils.tools import get_tools_specs
 from open_webui.utils.utils import get_admin_user, get_verified_user
 
-TOOLS_DIR = f"{DATA_DIR}/tools"
-os.makedirs(TOOLS_DIR, exist_ok=True)
-
 
 router = APIRouter()
 

+ 0 - 1
backend/open_webui/apps/webui/utils.py

@@ -8,7 +8,6 @@ import tempfile
 
 from open_webui.apps.webui.models.functions import Functions
 from open_webui.apps.webui.models.tools import Tools
-from open_webui.config import FUNCTIONS_DIR, TOOLS_DIR
 
 
 def extract_frontmatter(content):

+ 14 - 28
backend/open_webui/config.py

@@ -548,26 +548,10 @@ Path(UPLOAD_DIR).mkdir(parents=True, exist_ok=True)
 CACHE_DIR = f"{DATA_DIR}/cache"
 Path(CACHE_DIR).mkdir(parents=True, exist_ok=True)
 
-####################################
-# Tools DIR
-####################################
-
-TOOLS_DIR = os.getenv("TOOLS_DIR", f"{DATA_DIR}/tools")
-Path(TOOLS_DIR).mkdir(parents=True, exist_ok=True)
-
-
-####################################
-# Functions DIR
-####################################
-
-FUNCTIONS_DIR = os.getenv("FUNCTIONS_DIR", f"{DATA_DIR}/functions")
-Path(FUNCTIONS_DIR).mkdir(parents=True, exist_ok=True)
-
 ####################################
 # OLLAMA_BASE_URL
 ####################################
 
-
 ENABLE_OLLAMA_API = PersistentConfig(
     "ENABLE_OLLAMA_API",
     "ollama.enable",
@@ -1223,17 +1207,6 @@ RAG_WEB_SEARCH_CONCURRENT_REQUESTS = PersistentConfig(
 )
 
 
-####################################
-# Transcribe
-####################################
-
-WHISPER_MODEL = os.getenv("WHISPER_MODEL", "base")
-WHISPER_MODEL_DIR = os.getenv("WHISPER_MODEL_DIR", f"{CACHE_DIR}/whisper/models")
-WHISPER_MODEL_AUTO_UPDATE = (
-    os.environ.get("WHISPER_MODEL_AUTO_UPDATE", "").lower() == "true"
-)
-
-
 ####################################
 # Images
 ####################################
@@ -1449,6 +1422,19 @@ IMAGE_GENERATION_MODEL = PersistentConfig(
 # Audio
 ####################################
 
+# Transcription
+WHISPER_MODEL = PersistentConfig(
+    "WHISPER_MODEL",
+    "audio.stt.whisper_model",
+    os.getenv("WHISPER_MODEL", "base"),
+)
+
+WHISPER_MODEL_DIR = os.getenv("WHISPER_MODEL_DIR", f"{CACHE_DIR}/whisper/models")
+WHISPER_MODEL_AUTO_UPDATE = (
+    os.environ.get("WHISPER_MODEL_AUTO_UPDATE", "").lower() == "true"
+)
+
+
 AUDIO_STT_OPENAI_API_BASE_URL = PersistentConfig(
     "AUDIO_STT_OPENAI_API_BASE_URL",
     "audio.stt.openai.api_base_url",
@@ -1470,7 +1456,7 @@ AUDIO_STT_ENGINE = PersistentConfig(
 AUDIO_STT_MODEL = PersistentConfig(
     "AUDIO_STT_MODEL",
     "audio.stt.model",
-    os.getenv("AUDIO_STT_MODEL", "whisper-1"),
+    os.getenv("AUDIO_STT_MODEL", ""),
 )
 
 AUDIO_TTS_OPENAI_API_BASE_URL = PersistentConfig(

+ 87 - 4
src/lib/components/admin/Settings/Audio.svelte

@@ -38,6 +38,9 @@
 	let STT_OPENAI_API_KEY = '';
 	let STT_ENGINE = '';
 	let STT_MODEL = '';
+	let STT_WHISPER_MODEL = '';
+
+	let STT_WHISPER_MODEL_LOADING = false;
 
 	// eslint-disable-next-line no-undef
 	let voices: SpeechSynthesisVoice[] = [];
@@ -99,18 +102,23 @@
 				OPENAI_API_BASE_URL: STT_OPENAI_API_BASE_URL,
 				OPENAI_API_KEY: STT_OPENAI_API_KEY,
 				ENGINE: STT_ENGINE,
-				MODEL: STT_MODEL
+				MODEL: STT_MODEL,
+				WHISPER_MODEL: STT_WHISPER_MODEL
 			}
 		});
 
 		if (res) {
 			saveHandler();
-			getBackendConfig()
-				.then(config.set)
-				.catch(() => {});
+			config.set(await getBackendConfig());
 		}
 	};
 
+	const sttModelUpdateHandler = async () => {
+		STT_WHISPER_MODEL_LOADING = true;
+		await updateConfigHandler();
+		STT_WHISPER_MODEL_LOADING = false;
+	};
+
 	onMount(async () => {
 		const res = await getAudioConfig(localStorage.token);
 
@@ -134,6 +142,7 @@
 
 			STT_ENGINE = res.stt.ENGINE;
 			STT_MODEL = res.stt.MODEL;
+			STT_WHISPER_MODEL = res.stt.WHISPER_MODEL;
 		}
 
 		await getVoices();
@@ -201,6 +210,80 @@
 							</div>
 						</div>
 					</div>
+				{:else if STT_ENGINE === ''}
+					<div>
+						<div class=" mb-1.5 text-sm font-medium">{$i18n.t('STT Model')}</div>
+
+						<div class="flex w-full">
+							<div class="flex-1 mr-2">
+								<input
+									class="w-full rounded-lg py-2 px-4 text-sm bg-gray-50 dark:text-gray-300 dark:bg-gray-850 outline-none"
+									placeholder={$i18n.t('Set whisper model')}
+									bind:value={STT_WHISPER_MODEL}
+								/>
+							</div>
+
+							<button
+								class="px-2.5 bg-gray-50 hover:bg-gray-200 text-gray-800 dark:bg-gray-850 dark:hover:bg-gray-800 dark:text-gray-100 rounded-lg transition"
+								on:click={() => {
+									sttModelUpdateHandler();
+								}}
+								disabled={STT_WHISPER_MODEL_LOADING}
+							>
+								{#if STT_WHISPER_MODEL_LOADING}
+									<div class="self-center">
+										<svg
+											class=" w-4 h-4"
+											viewBox="0 0 24 24"
+											fill="currentColor"
+											xmlns="http://www.w3.org/2000/svg"
+										>
+											<style>
+												.spinner_ajPY {
+													transform-origin: center;
+													animation: spinner_AtaB 0.75s infinite linear;
+												}
+
+												@keyframes spinner_AtaB {
+													100% {
+														transform: rotate(360deg);
+													}
+												}
+											</style>
+											<path
+												d="M12,1A11,11,0,1,0,23,12,11,11,0,0,0,12,1Zm0,19a8,8,0,1,1,8-8A8,8,0,0,1,12,20Z"
+												opacity=".25"
+											/>
+											<path
+												d="M10.14,1.16a11,11,0,0,0-9,8.92A1.59,1.59,0,0,0,2.46,12,1.52,1.52,0,0,0,4.11,10.7a8,8,0,0,1,6.66-6.61A1.42,1.42,0,0,0,12,2.69h0A1.57,1.57,0,0,0,10.14,1.16Z"
+												class="spinner_ajPY"
+											/>
+										</svg>
+									</div>
+								{:else}
+									<svg
+										xmlns="http://www.w3.org/2000/svg"
+										viewBox="0 0 16 16"
+										fill="currentColor"
+										class="w-4 h-4"
+									>
+										<path
+											d="M8.75 2.75a.75.75 0 0 0-1.5 0v5.69L5.03 6.22a.75.75 0 0 0-1.06 1.06l3.5 3.5a.75.75 0 0 0 1.06 0l3.5-3.5a.75.75 0 0 0-1.06-1.06L8.75 8.44V2.75Z"
+										/>
+										<path
+											d="M3.5 9.75a.75.75 0 0 0-1.5 0v1.5A2.75 2.75 0 0 0 4.75 14h6.5A2.75 2.75 0 0 0 14 11.25v-1.5a.75.75 0 0 0-1.5 0v1.5c0 .69-.56 1.25-1.25 1.25h-6.5c-.69 0-1.25-.56-1.25-1.25v-1.5Z"
+										/>
+									</svg>
+								{/if}
+							</button>
+						</div>
+
+						<div class="mt-2 mb-1 text-xs text-gray-400 dark:text-gray-500">
+							{$i18n.t(
+								'If you want to use a custom model, please enter the model name and click the refresh button.'
+							)}
+						</div>
+					</div>
 				{/if}
 			</div>