浏览代码

feat: external stt

Timothy J. Baek 11 月之前
父节点
当前提交
e516374d54
共有 1 个文件被更改,包括 62 次插入28 次删除
  1. 62 28
      backend/apps/audio/main.py

+ 62 - 28
backend/apps/audio/main.py

@@ -240,39 +240,73 @@ def transcribe(
             f.write(contents)
             f.close()
 
-        whisper_kwargs = {
-            "model_size_or_path": WHISPER_MODEL,
-            "device": whisper_device_type,
-            "compute_type": "int8",
-            "download_root": WHISPER_MODEL_DIR,
-            "local_files_only": not WHISPER_MODEL_AUTO_UPDATE,
-        }
-
-        log.debug(f"whisper_kwargs: {whisper_kwargs}")
-
-        try:
-            model = WhisperModel(**whisper_kwargs)
-        except:
-            log.warning(
-                "WhisperModel initialization failed, attempting download with local_files_only=False"
+        if app.state.config.STT_ENGINE == "":
+            whisper_kwargs = {
+                "model_size_or_path": WHISPER_MODEL,
+                "device": whisper_device_type,
+                "compute_type": "int8",
+                "download_root": WHISPER_MODEL_DIR,
+                "local_files_only": not WHISPER_MODEL_AUTO_UPDATE,
+            }
+
+            log.debug(f"whisper_kwargs: {whisper_kwargs}")
+
+            try:
+                model = WhisperModel(**whisper_kwargs)
+            except:
+                log.warning(
+                    "WhisperModel initialization failed, attempting download with local_files_only=False"
+                )
+                whisper_kwargs["local_files_only"] = False
+                model = WhisperModel(**whisper_kwargs)
+
+            segments, info = model.transcribe(file_path, beam_size=5)
+            log.info(
+                "Detected language '%s' with probability %f"
+                % (info.language, info.language_probability)
             )
-            whisper_kwargs["local_files_only"] = False
-            model = WhisperModel(**whisper_kwargs)
 
-        segments, info = model.transcribe(file_path, beam_size=5)
-        log.info(
-            "Detected language '%s' with probability %f"
-            % (info.language, info.language_probability)
-        )
+            transcript = "".join([segment.text for segment in list(segments)])
+
+            # save the transcript to a json file
+            transcript_file = f"{file_dir}/{id}.json"
+            with open(transcript_file, "w") as f:
+                json.dump({"transcript": transcript}, f)
 
-        transcript = "".join([segment.text for segment in list(segments)])
+            return {"text": transcript.strip()}
 
-        # save the transcript to a json file
-        transcript_file = f"{file_dir}/{id}.json"
-        with open(transcript_file, "w") as f:
-            json.dump({"transcript": transcript}, f)
+        elif app.state.config.STT_ENGINE == "openai":
+            headers = {"Authorization": f"Bearer {app.state.config.STT_OPENAI_API_KEY}"}
 
-        return {"text": transcript.strip()}
+            files = {"file": (filename, open(file_path, "rb"))}
+            data = {"model": "whisper-1"}
+
+            r = None
+            try:
+                r = requests.post(
+                    url=f"{app.state.config.STT_OPENAI_API_BASE_URL}/audio/transcriptions",
+                    headers=headers,
+                    files=files,
+                    data=data,
+                )
+
+                r.raise_for_status()
+                return r.json()
+            except Exception as e:
+                log.exception(e)
+                error_detail = "Open WebUI: Server Connection Error"
+                if r is not None:
+                    try:
+                        res = r.json()
+                        if "error" in res:
+                            error_detail = f"External: {res['error']['message']}"
+                    except:
+                        error_detail = f"External: {e}"
+
+                raise HTTPException(
+                    status_code=r.status_code if r != None else 500,
+                    detail=error_detail,
+                )
 
     except Exception as e:
         log.exception(e)