Browse Source

feat: compress audio

Co-Authored-By: Beck Bekmyradov <47065940+bekmuradov@users.noreply.github.com>
Timothy J. Baek 7 months ago
parent
commit
7152af949b

+ 138 - 96
backend/open_webui/apps/audio/main.py

@@ -5,6 +5,8 @@ import os
 import uuid
 import uuid
 from functools import lru_cache
 from functools import lru_cache
 from pathlib import Path
 from pathlib import Path
+from pydub import AudioSegment
+from pydub.silence import split_on_silence
 
 
 import requests
 import requests
 from open_webui.config import (
 from open_webui.config import (
@@ -35,7 +37,12 @@ from fastapi import Depends, FastAPI, File, HTTPException, Request, UploadFile,
 from fastapi.middleware.cors import CORSMiddleware
 from fastapi.middleware.cors import CORSMiddleware
 from fastapi.responses import FileResponse
 from fastapi.responses import FileResponse
 from pydantic import BaseModel
 from pydantic import BaseModel
-from open_webui.utils.utils import get_admin_user, get_current_user, get_verified_user
+from open_webui.utils.utils import get_admin_user, get_verified_user
+
+# Constants
+MAX_FILE_SIZE_MB = 25
+MAX_FILE_SIZE = MAX_FILE_SIZE_MB * 1024 * 1024  # Convert MB to bytes
+
 
 
 log = logging.getLogger(__name__)
 log = logging.getLogger(__name__)
 log.setLevel(SRC_LOG_LEVELS["AUDIO"])
 log.setLevel(SRC_LOG_LEVELS["AUDIO"])
@@ -353,67 +360,77 @@ async def speech(request: Request, user=Depends(get_verified_user)):
             )
             )
 
 
 
 
-@app.post("/transcriptions")
-def transcribe(
-    file: UploadFile = File(...),
-    user=Depends(get_current_user),
-):
-    log.info(f"file.content_type: {file.content_type}")
+def transcribe(file_path):
+    print("transcribe", file_path)
+    filename = os.path.basename(file_path)
+    file_dir = os.path.dirname(file_path)
+    id = filename.split(".")[0]
 
 
-    if file.content_type not in ["audio/mpeg", "audio/wav", "audio/ogg", "audio/x-m4a"]:
-        raise HTTPException(
-            status_code=status.HTTP_400_BAD_REQUEST,
-            detail=ERROR_MESSAGES.FILE_NOT_SUPPORTED,
+    if app.state.config.STT_ENGINE == "":
+        from faster_whisper import WhisperModel
+
+        whisper_kwargs = {
+            "model_size_or_path": WHISPER_MODEL,
+            "device": whisper_device_type,
+            "compute_type": "int8",
+            "download_root": WHISPER_MODEL_DIR,
+            "local_files_only": not WHISPER_MODEL_AUTO_UPDATE,
+        }
+
+        log.debug(f"whisper_kwargs: {whisper_kwargs}")
+
+        try:
+            model = WhisperModel(**whisper_kwargs)
+        except Exception:
+            log.warning(
+                "WhisperModel initialization failed, attempting download with local_files_only=False"
+            )
+            whisper_kwargs["local_files_only"] = False
+            model = WhisperModel(**whisper_kwargs)
+
+        segments, info = model.transcribe(file_path, beam_size=5)
+        log.info(
+            "Detected language '%s' with probability %f"
+            % (info.language, info.language_probability)
         )
         )
 
 
-    try:
-        ext = file.filename.split(".")[-1]
+        transcript = "".join([segment.text for segment in list(segments)])
 
 
-        id = uuid.uuid4()
-        filename = f"{id}.{ext}"
+        data = {"text": transcript.strip()}
 
 
-        file_dir = f"{CACHE_DIR}/audio/transcriptions"
-        os.makedirs(file_dir, exist_ok=True)
-        file_path = f"{file_dir}/{filename}"
+        # save the transcript to a json file
+        transcript_file = f"{file_dir}/{id}.json"
+        with open(transcript_file, "w") as f:
+            json.dump(data, f)
 
 
-        print(filename)
+        print(data)
+        return data
+    elif app.state.config.STT_ENGINE == "openai":
+        if is_mp4_audio(file_path):
+            print("is_mp4_audio")
+            os.rename(file_path, file_path.replace(".wav", ".mp4"))
+            # Convert MP4 audio file to WAV format
+            convert_mp4_to_wav(file_path.replace(".wav", ".mp4"), file_path)
 
 
-        contents = file.file.read()
-        with open(file_path, "wb") as f:
-            f.write(contents)
-            f.close()
-
-        if app.state.config.STT_ENGINE == "":
-            from faster_whisper import WhisperModel
-
-            whisper_kwargs = {
-                "model_size_or_path": WHISPER_MODEL,
-                "device": whisper_device_type,
-                "compute_type": "int8",
-                "download_root": WHISPER_MODEL_DIR,
-                "local_files_only": not WHISPER_MODEL_AUTO_UPDATE,
-            }
-
-            log.debug(f"whisper_kwargs: {whisper_kwargs}")
-
-            try:
-                model = WhisperModel(**whisper_kwargs)
-            except Exception:
-                log.warning(
-                    "WhisperModel initialization failed, attempting download with local_files_only=False"
-                )
-                whisper_kwargs["local_files_only"] = False
-                model = WhisperModel(**whisper_kwargs)
+        headers = {"Authorization": f"Bearer {app.state.config.STT_OPENAI_API_KEY}"}
 
 
-            segments, info = model.transcribe(file_path, beam_size=5)
-            log.info(
-                "Detected language '%s' with probability %f"
-                % (info.language, info.language_probability)
+        files = {"file": (filename, open(file_path, "rb"))}
+        data = {"model": app.state.config.STT_MODEL}
+
+        print(files, data)
+
+        r = None
+        try:
+            r = requests.post(
+                url=f"{app.state.config.STT_OPENAI_API_BASE_URL}/audio/transcriptions",
+                headers=headers,
+                files=files,
+                data=data,
             )
             )
 
 
-            transcript = "".join([segment.text for segment in list(segments)])
+            r.raise_for_status()
 
 
-            data = {"text": transcript.strip()}
+            data = r.json()
 
 
             # save the transcript to a json file
             # save the transcript to a json file
             transcript_file = f"{file_dir}/{id}.json"
             transcript_file = f"{file_dir}/{id}.json"
@@ -421,58 +438,83 @@ def transcribe(
                 json.dump(data, f)
                 json.dump(data, f)
 
 
             print(data)
             print(data)
-
             return data
             return data
+        except Exception as e:
+            log.exception(e)
+            error_detail = "Open WebUI: Server Connection Error"
+            if r is not None:
+                try:
+                    res = r.json()
+                    if "error" in res:
+                        error_detail = f"External: {res['error']['message']}"
+                except Exception:
+                    error_detail = f"External: {e}"
 
 
-        elif app.state.config.STT_ENGINE == "openai":
-            if is_mp4_audio(file_path):
-                print("is_mp4_audio")
-                os.rename(file_path, file_path.replace(".wav", ".mp4"))
-                # Convert MP4 audio file to WAV format
-                convert_mp4_to_wav(file_path.replace(".wav", ".mp4"), file_path)
+            raise error_detail
 
 
-            headers = {"Authorization": f"Bearer {app.state.config.STT_OPENAI_API_KEY}"}
 
 
-            files = {"file": (filename, open(file_path, "rb"))}
-            data = {"model": app.state.config.STT_MODEL}
+@app.post("/transcriptions")
+def transcription(
+    file: UploadFile = File(...),
+    user=Depends(get_verified_user),
+):
+    log.info(f"file.content_type: {file.content_type}")
 
 
-            print(files, data)
+    if file.content_type not in ["audio/mpeg", "audio/wav", "audio/ogg", "audio/x-m4a"]:
+        raise HTTPException(
+            status_code=status.HTTP_400_BAD_REQUEST,
+            detail=ERROR_MESSAGES.FILE_NOT_SUPPORTED,
+        )
 
 
-            r = None
-            try:
-                r = requests.post(
-                    url=f"{app.state.config.STT_OPENAI_API_BASE_URL}/audio/transcriptions",
-                    headers=headers,
-                    files=files,
-                    data=data,
-                )
+    try:
+        ext = file.filename.split(".")[-1]
+        id = uuid.uuid4()
 
 
-                r.raise_for_status()
-
-                data = r.json()
-
-                # save the transcript to a json file
-                transcript_file = f"{file_dir}/{id}.json"
-                with open(transcript_file, "w") as f:
-                    json.dump(data, f)
-
-                print(data)
-                return data
-            except Exception as e:
-                log.exception(e)
-                error_detail = "Open WebUI: Server Connection Error"
-                if r is not None:
-                    try:
-                        res = r.json()
-                        if "error" in res:
-                            error_detail = f"External: {res['error']['message']}"
-                    except Exception:
-                        error_detail = f"External: {e}"
-
-                raise HTTPException(
-                    status_code=r.status_code if r != None else 500,
-                    detail=error_detail,
-                )
+        filename = f"{id}.{ext}"
+        contents = file.file.read()
+
+        file_dir = f"{CACHE_DIR}/audio/transcriptions"
+        os.makedirs(file_dir, exist_ok=True)
+        file_path = f"{file_dir}/{filename}"
+
+        with open(file_path, "wb") as f:
+            f.write(contents)
+
+        try:
+            if os.path.getsize(file_path) > MAX_FILE_SIZE:  # file is bigger than 25MB
+                log.debug(f"File size is larger than {MAX_FILE_SIZE_MB}MB")
+                audio = AudioSegment.from_file(file_path)
+                audio = audio.set_frame_rate(16000).set_channels(1)  # Compress audio
+                compressed_path = f"{file_dir}/{id}_compressed.opus"
+                audio.export(compressed_path, format="opus", bitrate="32k")
+                log.debug(f"Compressed audio to {compressed_path}")
+                file_path = compressed_path
+
+                if (
+                    os.path.getsize(file_path) > MAX_FILE_SIZE
+                ):  # Still larger than 25MB after compression
+                    chunks = split_on_silence(
+                        audio, min_silence_len=500, silence_thresh=-40
+                    )
+                    texts = []
+                    for i, chunk in enumerate(chunks):
+                        chunk_file_path = f"{file_dir}/{id}_chunk{i}.{ext}"
+                        chunk.export(chunk_file_path, format=ext)
+                        text = transcribe(chunk_file_path)
+                        texts.append(text)
+                    data = {"text": " ".join(texts)}
+                else:
+                    data = transcribe(file_path)
+            else:
+                data = transcribe(file_path)
+
+            return data
+        except Exception as e:
+            log.exception(e)
+            raise HTTPException(
+                status_code=status.HTTP_400_BAD_REQUEST,
+                detail=ERROR_MESSAGES.DEFAULT(e),
+            )
 
 
     except Exception as e:
     except Exception as e:
         log.exception(e)
         log.exception(e)

+ 1 - 1
src/lib/components/chat/Chat.svelte

@@ -700,7 +700,7 @@
 				childrenIds: [],
 				childrenIds: [],
 				role: 'user',
 				role: 'user',
 				content: userPrompt,
 				content: userPrompt,
-				files: chatFiles.length > 0 ? chatFiles : undefined,
+				files: _files.length > 0 ? _files : undefined,
 				timestamp: Math.floor(Date.now() / 1000), // Unix epoch
 				timestamp: Math.floor(Date.now() / 1000), // Unix epoch
 				models: selectedModels
 				models: selectedModels
 			};
 			};

+ 1 - 1
src/lib/components/common/FileItemModal.svelte

@@ -54,7 +54,7 @@
 			</div>
 			</div>
 
 
 			<div>
 			<div>
-				<div class="flex flex-col md:flex-row gap-1 justify-between w-full">
+				<div class="flex flex-col items-center md:flex-row gap-1 justify-between w-full">
 					<div class=" flex flex-wrap text-sm gap-1 text-gray-500">
 					<div class=" flex flex-wrap text-sm gap-1 text-gray-500">
 						{#if file.size}
 						{#if file.size}
 							<div class="capitalize shrink-0">{formatFileSize(file.size)}</div>
 							<div class="capitalize shrink-0">{formatFileSize(file.size)}</div>