Browse Source

feat: whisper support

Timothy J. Baek 1 year ago
parent
commit
a5b9bbf10b

+ 80 - 0
backend/apps/audio/main.py

@@ -0,0 +1,80 @@
+from fastapi import (
+    FastAPI,
+    Request,
+    Depends,
+    HTTPException,
+    status,
+    UploadFile,
+    File,
+    Form,
+)
+from fastapi.middleware.cors import CORSMiddleware
+from faster_whisper import WhisperModel
+
+from constants import ERROR_MESSAGES
+from utils.utils import (
+    decode_token,
+    get_current_user,
+    get_verified_user,
+    get_admin_user,
+)
+from utils.misc import calculate_sha256
+
+from config import CACHE_DIR, UPLOAD_DIR, WHISPER_MODEL_NAME
+
+app = FastAPI()
+app.add_middleware(
+    CORSMiddleware,
+    allow_origins=["*"],
+    allow_credentials=True,
+    allow_methods=["*"],
+    allow_headers=["*"],
+)
+
+
+@app.post("/transcribe")
+def transcribe(
+    file: UploadFile = File(...),
+    user=Depends(get_current_user),
+):
+    print(file.content_type)
+
+    if file.content_type not in ["audio/mpeg", "audio/wav"]:
+        raise HTTPException(
+            status_code=status.HTTP_400_BAD_REQUEST,
+            detail=ERROR_MESSAGES.FILE_NOT_SUPPORTED,
+        )
+
+    try:
+        filename = file.filename
+        file_path = f"{UPLOAD_DIR}/{filename}"
+        contents = file.file.read()
+        with open(file_path, "wb") as f:
+            f.write(contents)
+            f.close()
+
+        model_name = WHISPER_MODEL_NAME
+        model = WhisperModel(
+            model_name,
+            device="cpu",
+            compute_type="int8",
+            download_root=f"{CACHE_DIR}/whisper/models",
+        )
+
+        segments, info = model.transcribe(file_path, beam_size=5)
+        print(
+            "Detected language '%s' with probability %f"
+            % (info.language, info.language_probability)
+        )
+
+        transcript = "".join([segment.text for segment in list(segments)])
+
+        return {"text": transcript}
+
+    except Exception as e:
+        print(e)
+
+        raise HTTPException(
+            status_code=status.HTTP_400_BAD_REQUEST,
+            detail=ERROR_MESSAGES.DEFAULT(e),
+        )

+ 5 - 0
backend/config.py

@@ -132,3 +132,8 @@ CHROMA_CLIENT = chromadb.PersistentClient(
 )
 )
 CHUNK_SIZE = 1500
 CHUNK_SIZE = 1500
 CHUNK_OVERLAP = 100
 CHUNK_OVERLAP = 100
+
+####################################
+# Transcribe
+####################################
+WHISPER_MODEL_NAME = "tiny"

+ 4 - 0
backend/main.py

@@ -10,6 +10,8 @@ from starlette.exceptions import HTTPException as StarletteHTTPException
 
 
 from apps.ollama.main import app as ollama_app
 from apps.ollama.main import app as ollama_app
 from apps.openai.main import app as openai_app
 from apps.openai.main import app as openai_app
+from apps.audio.main import app as audio_app
+
 
 
 from apps.web.main import app as webui_app
 from apps.web.main import app as webui_app
 from apps.rag.main import app as rag_app
 from apps.rag.main import app as rag_app
@@ -55,6 +57,8 @@ app.mount("/api/v1", webui_app)
 
 
 app.mount("/ollama/api", ollama_app)
 app.mount("/ollama/api", ollama_app)
 app.mount("/openai/api", openai_app)
 app.mount("/openai/api", openai_app)
+
+app.mount("/audio/api/v1", audio_app)
 app.mount("/rag/api/v1", rag_app)
 app.mount("/rag/api/v1", rag_app)
 
 
 
 

+ 2 - 0
backend/requirements.txt

@@ -30,6 +30,8 @@ openpyxl
 pyxlsb
 pyxlsb
 xlrd
 xlrd
 
 
+faster-whisper
+
 PyJWT
 PyJWT
 pyjwt[crypto]
 pyjwt[crypto]
 
 

+ 31 - 0
src/lib/apis/audio/index.ts

@@ -0,0 +1,31 @@
+import { AUDIO_API_BASE_URL } from '$lib/constants';
+
+export const transcribeAudio = async (token: string, file: File) => {
+	const data = new FormData();
+	data.append('file', file);
+
+	let error = null;
+	const res = await fetch(`${AUDIO_API_BASE_URL}/transcribe`, {
+		method: 'POST',
+		headers: {
+			Accept: 'application/json',
+			authorization: `Bearer ${token}`
+		},
+		body: data
+	})
+		.then(async (res) => {
+			if (!res.ok) throw await res.json();
+			return res.json();
+		})
+		.catch((err) => {
+			error = err.detail;
+			console.log(err);
+			return null;
+		});
+
+	if (error) {
+		throw error;
+	}
+
+	return res;
+};

+ 4 - 0
src/lib/components/chat/MessageInput.svelte

@@ -11,6 +11,7 @@
 	import { SUPPORTED_FILE_TYPE, SUPPORTED_FILE_EXTENSIONS } from '$lib/constants';
 	import { SUPPORTED_FILE_TYPE, SUPPORTED_FILE_EXTENSIONS } from '$lib/constants';
 	import Documents from './MessageInput/Documents.svelte';
 	import Documents from './MessageInput/Documents.svelte';
 	import Models from './MessageInput/Models.svelte';
 	import Models from './MessageInput/Models.svelte';
+	import { transcribeAudio } from '$lib/apis/audio';
 
 
 	export let submitPrompt: Function;
 	export let submitPrompt: Function;
 	export let stopResponse: Function;
 	export let stopResponse: Function;
@@ -201,6 +202,9 @@
 					console.log(file, file.name.split('.').at(-1));
 					console.log(file, file.name.split('.').at(-1));
 					if (['image/gif', 'image/jpeg', 'image/png'].includes(file['type'])) {
 					if (['image/gif', 'image/jpeg', 'image/png'].includes(file['type'])) {
 						reader.readAsDataURL(file);
 						reader.readAsDataURL(file);
+					} else if (['audio/mpeg', 'audio/wav'].includes(file['type'])) {
+						const res = await transcribeAudio(localStorage.token, file);
+						console.log(res);
 					} else if (
 					} else if (
 						SUPPORTED_FILE_TYPE.includes(file['type']) ||
 						SUPPORTED_FILE_TYPE.includes(file['type']) ||
 						SUPPORTED_FILE_EXTENSIONS.includes(file.name.split('.').at(-1))
 						SUPPORTED_FILE_EXTENSIONS.includes(file.name.split('.').at(-1))

+ 1 - 0
src/lib/constants.ts

@@ -7,6 +7,7 @@ export const WEBUI_API_BASE_URL = `${WEBUI_BASE_URL}/api/v1`;
 export const OLLAMA_API_BASE_URL = `${WEBUI_BASE_URL}/ollama/api`;
 export const OLLAMA_API_BASE_URL = `${WEBUI_BASE_URL}/ollama/api`;
 export const OPENAI_API_BASE_URL = `${WEBUI_BASE_URL}/openai/api`;
 export const OPENAI_API_BASE_URL = `${WEBUI_BASE_URL}/openai/api`;
 export const RAG_API_BASE_URL = `${WEBUI_BASE_URL}/rag/api/v1`;
 export const RAG_API_BASE_URL = `${WEBUI_BASE_URL}/rag/api/v1`;
+export const AUDIO_API_BASE_URL = `${WEBUI_BASE_URL}/audio/api/v1`;
 
 
 export const WEB_UI_VERSION = 'v1.0.0-alpha-static';
 export const WEB_UI_VERSION = 'v1.0.0-alpha-static';