Jelajahi Sumber

Merge pull request #1499 from lainedfles/whisper_auto_update

feat: introduce Whisper model auto-update control.
Timothy Jaeryang Baek 1 tahun lalu
induk
melakukan
b5d882606a
2 mengubah file dengan 22 tambahan dan 6 penghapusan
  1. 19 6
      backend/apps/audio/main.py
  2. 3 0
      backend/config.py

+ 19 - 6
backend/apps/audio/main.py

@@ -28,6 +28,7 @@ from config import (
     UPLOAD_DIR,
     WHISPER_MODEL,
     WHISPER_MODEL_DIR,
+    WHISPER_MODEL_AUTO_UPDATE,
     DEVICE_TYPE,
 )
 
@@ -69,12 +70,24 @@ def transcribe(
             f.write(contents)
             f.close()
 
-        model = WhisperModel(
-            WHISPER_MODEL,
-            device=whisper_device_type,
-            compute_type="int8",
-            download_root=WHISPER_MODEL_DIR,
-        )
+        whisper_kwargs = {
+            "model_size_or_path": WHISPER_MODEL,
+            "device": whisper_device_type,
+            "compute_type": "int8",
+            "download_root": WHISPER_MODEL_DIR,
+            "local_files_only": not WHISPER_MODEL_AUTO_UPDATE,
+        }
+
+        log.debug(f"whisper_kwargs: {whisper_kwargs}")
+
+        try:
+            model = WhisperModel(**whisper_kwargs)
+        except:
+            log.warning(
+                "WhisperModel initialization failed, attempting download with local_files_only=False"
+            )
+            whisper_kwargs["local_files_only"] = False
+            model = WhisperModel(**whisper_kwargs)
 
         segments, info = model.transcribe(file_path, beam_size=5)
         log.info(

+ 3 - 0
backend/config.py

@@ -450,6 +450,9 @@ Query: [query]"""
 
 WHISPER_MODEL = os.getenv("WHISPER_MODEL", "base")
 WHISPER_MODEL_DIR = os.getenv("WHISPER_MODEL_DIR", f"{CACHE_DIR}/whisper/models")
+WHISPER_MODEL_AUTO_UPDATE = (
+    os.environ.get("WHISPER_MODEL_AUTO_UPDATE", "").lower() == "true"
+)
 
 
 ####################################