ソースを参照

feat: compress audio

Co-Authored-By: Beck Bekmyradov <47065940+bekmuradov@users.noreply.github.com>
Timothy J. Baek 7 ヶ月 前
コミット
7152af949b

+ 138 - 96
backend/open_webui/apps/audio/main.py

@@ -5,6 +5,8 @@ import os
 import uuid
 from functools import lru_cache
 from pathlib import Path
+from pydub import AudioSegment
+from pydub.silence import split_on_silence
 
 import requests
 from open_webui.config import (
@@ -35,7 +37,12 @@ from fastapi import Depends, FastAPI, File, HTTPException, Request, UploadFile,
 from fastapi.middleware.cors import CORSMiddleware
 from fastapi.responses import FileResponse
 from pydantic import BaseModel
-from open_webui.utils.utils import get_admin_user, get_current_user, get_verified_user
+from open_webui.utils.utils import get_admin_user, get_verified_user
+
+# Constants
+MAX_FILE_SIZE_MB = 25
+MAX_FILE_SIZE = MAX_FILE_SIZE_MB * 1024 * 1024  # Convert MB to bytes
+
 
 log = logging.getLogger(__name__)
 log.setLevel(SRC_LOG_LEVELS["AUDIO"])
@@ -353,67 +360,77 @@ async def speech(request: Request, user=Depends(get_verified_user)):
             )
 
 
-@app.post("/transcriptions")
-def transcribe(
-    file: UploadFile = File(...),
-    user=Depends(get_current_user),
-):
-    log.info(f"file.content_type: {file.content_type}")
+def transcribe(file_path):
+    print("transcribe", file_path)
+    filename = os.path.basename(file_path)
+    file_dir = os.path.dirname(file_path)
+    id = filename.split(".")[0]
 
-    if file.content_type not in ["audio/mpeg", "audio/wav", "audio/ogg", "audio/x-m4a"]:
-        raise HTTPException(
-            status_code=status.HTTP_400_BAD_REQUEST,
-            detail=ERROR_MESSAGES.FILE_NOT_SUPPORTED,
+    if app.state.config.STT_ENGINE == "":
+        from faster_whisper import WhisperModel
+
+        whisper_kwargs = {
+            "model_size_or_path": WHISPER_MODEL,
+            "device": whisper_device_type,
+            "compute_type": "int8",
+            "download_root": WHISPER_MODEL_DIR,
+            "local_files_only": not WHISPER_MODEL_AUTO_UPDATE,
+        }
+
+        log.debug(f"whisper_kwargs: {whisper_kwargs}")
+
+        try:
+            model = WhisperModel(**whisper_kwargs)
+        except Exception:
+            log.warning(
+                "WhisperModel initialization failed, attempting download with local_files_only=False"
+            )
+            whisper_kwargs["local_files_only"] = False
+            model = WhisperModel(**whisper_kwargs)
+
+        segments, info = model.transcribe(file_path, beam_size=5)
+        log.info(
+            "Detected language '%s' with probability %f"
+            % (info.language, info.language_probability)
         )
 
-    try:
-        ext = file.filename.split(".")[-1]
+        transcript = "".join([segment.text for segment in list(segments)])
 
-        id = uuid.uuid4()
-        filename = f"{id}.{ext}"
+        data = {"text": transcript.strip()}
 
-        file_dir = f"{CACHE_DIR}/audio/transcriptions"
-        os.makedirs(file_dir, exist_ok=True)
-        file_path = f"{file_dir}/{filename}"
+        # save the transcript to a json file
+        transcript_file = f"{file_dir}/{id}.json"
+        with open(transcript_file, "w") as f:
+            json.dump(data, f)
 
-        print(filename)
+        print(data)
+        return data
+    elif app.state.config.STT_ENGINE == "openai":
+        if is_mp4_audio(file_path):
+            print("is_mp4_audio")
+            os.rename(file_path, file_path.replace(".wav", ".mp4"))
+            # Convert MP4 audio file to WAV format
+            convert_mp4_to_wav(file_path.replace(".wav", ".mp4"), file_path)
 
-        contents = file.file.read()
-        with open(file_path, "wb") as f:
-            f.write(contents)
-            f.close()
-
-        if app.state.config.STT_ENGINE == "":
-            from faster_whisper import WhisperModel
-
-            whisper_kwargs = {
-                "model_size_or_path": WHISPER_MODEL,
-                "device": whisper_device_type,
-                "compute_type": "int8",
-                "download_root": WHISPER_MODEL_DIR,
-                "local_files_only": not WHISPER_MODEL_AUTO_UPDATE,
-            }
-
-            log.debug(f"whisper_kwargs: {whisper_kwargs}")
-
-            try:
-                model = WhisperModel(**whisper_kwargs)
-            except Exception:
-                log.warning(
-                    "WhisperModel initialization failed, attempting download with local_files_only=False"
-                )
-                whisper_kwargs["local_files_only"] = False
-                model = WhisperModel(**whisper_kwargs)
+        headers = {"Authorization": f"Bearer {app.state.config.STT_OPENAI_API_KEY}"}
 
-            segments, info = model.transcribe(file_path, beam_size=5)
-            log.info(
-                "Detected language '%s' with probability %f"
-                % (info.language, info.language_probability)
+        files = {"file": (filename, open(file_path, "rb"))}
+        data = {"model": app.state.config.STT_MODEL}
+
+        print(files, data)
+
+        r = None
+        try:
+            r = requests.post(
+                url=f"{app.state.config.STT_OPENAI_API_BASE_URL}/audio/transcriptions",
+                headers=headers,
+                files=files,
+                data=data,
             )
 
-            transcript = "".join([segment.text for segment in list(segments)])
+            r.raise_for_status()
 
-            data = {"text": transcript.strip()}
+            data = r.json()
 
             # save the transcript to a json file
             transcript_file = f"{file_dir}/{id}.json"
@@ -421,58 +438,83 @@ def transcribe(
                 json.dump(data, f)
 
             print(data)
-
             return data
+        except Exception as e:
+            log.exception(e)
+            error_detail = "Open WebUI: Server Connection Error"
+            if r is not None:
+                try:
+                    res = r.json()
+                    if "error" in res:
+                        error_detail = f"External: {res['error']['message']}"
+                except Exception:
+                    error_detail = f"External: {e}"
 
-        elif app.state.config.STT_ENGINE == "openai":
-            if is_mp4_audio(file_path):
-                print("is_mp4_audio")
-                os.rename(file_path, file_path.replace(".wav", ".mp4"))
-                # Convert MP4 audio file to WAV format
-                convert_mp4_to_wav(file_path.replace(".wav", ".mp4"), file_path)
+            raise error_detail
 
-            headers = {"Authorization": f"Bearer {app.state.config.STT_OPENAI_API_KEY}"}
 
-            files = {"file": (filename, open(file_path, "rb"))}
-            data = {"model": app.state.config.STT_MODEL}
+@app.post("/transcriptions")
+def transcription(
+    file: UploadFile = File(...),
+    user=Depends(get_verified_user),
+):
+    log.info(f"file.content_type: {file.content_type}")
 
-            print(files, data)
+    if file.content_type not in ["audio/mpeg", "audio/wav", "audio/ogg", "audio/x-m4a"]:
+        raise HTTPException(
+            status_code=status.HTTP_400_BAD_REQUEST,
+            detail=ERROR_MESSAGES.FILE_NOT_SUPPORTED,
+        )
 
-            r = None
-            try:
-                r = requests.post(
-                    url=f"{app.state.config.STT_OPENAI_API_BASE_URL}/audio/transcriptions",
-                    headers=headers,
-                    files=files,
-                    data=data,
-                )
+    try:
+        ext = file.filename.split(".")[-1]
+        id = uuid.uuid4()
 
-                r.raise_for_status()
-
-                data = r.json()
-
-                # save the transcript to a json file
-                transcript_file = f"{file_dir}/{id}.json"
-                with open(transcript_file, "w") as f:
-                    json.dump(data, f)
-
-                print(data)
-                return data
-            except Exception as e:
-                log.exception(e)
-                error_detail = "Open WebUI: Server Connection Error"
-                if r is not None:
-                    try:
-                        res = r.json()
-                        if "error" in res:
-                            error_detail = f"External: {res['error']['message']}"
-                    except Exception:
-                        error_detail = f"External: {e}"
-
-                raise HTTPException(
-                    status_code=r.status_code if r != None else 500,
-                    detail=error_detail,
-                )
+        filename = f"{id}.{ext}"
+        contents = file.file.read()
+
+        file_dir = f"{CACHE_DIR}/audio/transcriptions"
+        os.makedirs(file_dir, exist_ok=True)
+        file_path = f"{file_dir}/{filename}"
+
+        with open(file_path, "wb") as f:
+            f.write(contents)
+
+        try:
+            if os.path.getsize(file_path) > MAX_FILE_SIZE:  # file is bigger than 25MB
+                log.debug(f"File size is larger than {MAX_FILE_SIZE_MB}MB")
+                audio = AudioSegment.from_file(file_path)
+                audio = audio.set_frame_rate(16000).set_channels(1)  # Compress audio
+                compressed_path = f"{file_dir}/{id}_compressed.opus"
+                audio.export(compressed_path, format="opus", bitrate="32k")
+                log.debug(f"Compressed audio to {compressed_path}")
+                file_path = compressed_path
+
+                if (
+                    os.path.getsize(file_path) > MAX_FILE_SIZE
+                ):  # Still larger than 25MB after compression
+                    chunks = split_on_silence(
+                        audio, min_silence_len=500, silence_thresh=-40
+                    )
+                    texts = []
+                    for i, chunk in enumerate(chunks):
+                        chunk_file_path = f"{file_dir}/{id}_chunk{i}.{ext}"
+                        chunk.export(chunk_file_path, format=ext)
+                        text = transcribe(chunk_file_path)
+                        texts.append(text)
+                    data = {"text": " ".join(texts)}
+                else:
+                    data = transcribe(file_path)
+            else:
+                data = transcribe(file_path)
+
+            return data
+        except Exception as e:
+            log.exception(e)
+            raise HTTPException(
+                status_code=status.HTTP_400_BAD_REQUEST,
+                detail=ERROR_MESSAGES.DEFAULT(e),
+            )
 
     except Exception as e:
         log.exception(e)

+ 1 - 1
src/lib/components/chat/Chat.svelte

@@ -700,7 +700,7 @@
 				childrenIds: [],
 				role: 'user',
 				content: userPrompt,
-				files: chatFiles.length > 0 ? chatFiles : undefined,
+				files: _files.length > 0 ? _files : undefined,
 				timestamp: Math.floor(Date.now() / 1000), // Unix epoch
 				models: selectedModels
 			};

+ 1 - 1
src/lib/components/common/FileItemModal.svelte

@@ -54,7 +54,7 @@
 			</div>
 
 			<div>
-				<div class="flex flex-col md:flex-row gap-1 justify-between w-full">
+				<div class="flex flex-col items-center md:flex-row gap-1 justify-between w-full">
 					<div class=" flex flex-wrap text-sm gap-1 text-gray-500">
 						{#if file.size}
 							<div class="capitalize shrink-0">{formatFileSize(file.size)}</div>