瀏覽代碼

refac: audio

Timothy J. Baek 1 年之前
父節點
當前提交
710850e442

+ 78 - 1
backend/apps/audio/main.py

@@ -10,9 +10,18 @@ from fastapi import (
     File,
     Form,
 )
+
+from fastapi.responses import StreamingResponse, JSONResponse, FileResponse
+
 from fastapi.middleware.cors import CORSMiddleware
 from faster_whisper import WhisperModel
 
+import requests
+import hashlib
+from pathlib import Path
+import json
+
+
 from constants import ERROR_MESSAGES
 from utils.utils import (
     decode_token,
@@ -30,6 +39,8 @@ from config import (
     WHISPER_MODEL_DIR,
     WHISPER_MODEL_AUTO_UPDATE,
     DEVICE_TYPE,
+    OPENAI_API_BASE_URL,
+    OPENAI_API_KEY,
 )
 
 log = logging.getLogger(__name__)
@@ -44,12 +55,78 @@ app.add_middleware(
     allow_headers=["*"],
 )
 
+
+app.state.OPENAI_API_BASE_URL = OPENAI_API_BASE_URL
+app.state.OPENAI_API_KEY = OPENAI_API_KEY
+
 # setting device type for whisper model
 whisper_device_type = DEVICE_TYPE if DEVICE_TYPE and DEVICE_TYPE == "cuda" else "cpu"
 log.info(f"whisper_device_type: {whisper_device_type}")
 
+SPEECH_CACHE_DIR = Path(CACHE_DIR).joinpath("./audio/speech/")
+SPEECH_CACHE_DIR.mkdir(parents=True, exist_ok=True)
+
+
+@app.post("/speech")
+async def speech(request: Request, user=Depends(get_verified_user)):
+    idx = None
+    try:
+        body = await request.body()
+        name = hashlib.sha256(body).hexdigest()
+
+        file_path = SPEECH_CACHE_DIR.joinpath(f"{name}.mp3")
+        file_body_path = SPEECH_CACHE_DIR.joinpath(f"{name}.json")
+
+        # Check if the file already exists in the cache
+        if file_path.is_file():
+            return FileResponse(file_path)
+
+        headers = {}
+        headers["Authorization"] = f"Bearer {app.state.OPENAI_API_KEY}"
+        headers["Content-Type"] = "application/json"
+
+        r = None
+        try:
+            r = requests.post(
+                url=f"{app.state.OPENAI_API_BASE_URL}/audio/speech",
+                data=body,
+                headers=headers,
+                stream=True,
+            )
+
+            r.raise_for_status()
+
+            # Save the streaming content to a file
+            with open(file_path, "wb") as f:
+                for chunk in r.iter_content(chunk_size=8192):
+                    f.write(chunk)
+
+            with open(file_body_path, "w") as f:
+                json.dump(json.loads(body.decode("utf-8")), f)
+
+            # Return the saved file
+            return FileResponse(file_path)
+
+        except Exception as e:
+            log.exception(e)
+            error_detail = "Open WebUI: Server Connection Error"
+            if r is not None:
+                try:
+                    res = r.json()
+                    if "error" in res:
+                        error_detail = f"External: {res['error']}"
+                except:
+                    error_detail = f"External: {e}"
+
+            raise HTTPException(
+                status_code=r.status_code if r else 500, detail=error_detail
+            )
+
+    except ValueError:
+        raise HTTPException(status_code=401, detail=ERROR_MESSAGES.OPENAI_NOT_FOUND)
+
 
-@app.post("/transcribe")
+@app.post("/transcriptions")
 def transcribe(
     file: UploadFile = File(...),
     user=Depends(get_current_user),

+ 6 - 2
backend/apps/images/main.py

@@ -35,6 +35,8 @@ from config import (
     ENABLE_IMAGE_GENERATION,
     AUTOMATIC1111_BASE_URL,
     COMFYUI_BASE_URL,
+    OPENAI_API_BASE_URL,
+    OPENAI_API_KEY,
 )
 
 
@@ -56,7 +58,9 @@ app.add_middleware(
 app.state.ENGINE = ""
 app.state.ENABLED = ENABLE_IMAGE_GENERATION
 
-app.state.OPENAI_API_KEY = ""
+app.state.OPENAI_API_BASE_URL = OPENAI_API_BASE_URL
+app.state.OPENAI_API_KEY = OPENAI_API_KEY
+
 app.state.MODEL = ""
 
 
@@ -360,7 +364,7 @@ def generate_image(
             }
 
             r = requests.post(
-                url=f"https://api.openai.com/v1/images/generations",
+                url=f"{app.state.OPENAI_API_BASE_URL}/images/generations",
                 json=data,
                 headers=headers,
             )

+ 4 - 2
backend/apps/rag/main.py

@@ -70,6 +70,8 @@ from config import (
     RAG_EMBEDDING_ENGINE,
     RAG_EMBEDDING_MODEL,
     RAG_EMBEDDING_MODEL_AUTO_UPDATE,
+    RAG_OPENAI_API_BASE_URL,
+    RAG_OPENAI_API_KEY,
     DEVICE_TYPE,
     CHROMA_CLIENT,
     CHUNK_SIZE,
@@ -94,8 +96,8 @@ app.state.RAG_EMBEDDING_ENGINE = RAG_EMBEDDING_ENGINE
 app.state.RAG_EMBEDDING_MODEL = RAG_EMBEDDING_MODEL
 app.state.RAG_TEMPLATE = RAG_TEMPLATE
 
-app.state.RAG_OPENAI_API_BASE_URL = "https://api.openai.com"
-app.state.RAG_OPENAI_API_KEY = ""
+app.state.RAG_OPENAI_API_BASE_URL = RAG_OPENAI_API_BASE_URL
+app.state.RAG_OPENAI_API_KEY = RAG_OPENAI_API_KEY
 
 app.state.PDF_EXTRACT_IMAGES = False
 

+ 2 - 2
backend/apps/rag/utils.py

@@ -324,11 +324,11 @@ def get_embedding_model_path(
 
 
 def generate_openai_embeddings(
-    model: str, text: str, key: str, url: str = "https://api.openai.com"
+    model: str, text: str, key: str, url: str = "https://api.openai.com/v1"
 ):
     try:
         r = requests.post(
-            f"{url}/v1/embeddings",
+            f"{url}/embeddings",
             headers={
                 "Content-Type": "application/json",
                 "Authorization": f"Bearer {key}",

+ 10 - 0
backend/config.py

@@ -321,6 +321,13 @@ OPENAI_API_BASE_URLS = [
     for url in OPENAI_API_BASE_URLS.split(";")
 ]
 
+OPENAI_API_KEY = ""
+OPENAI_API_KEY = OPENAI_API_KEYS[
+    OPENAI_API_BASE_URLS.index("https://api.openai.com/v1")
+]
+OPENAI_API_BASE_URL = "https://api.openai.com/v1"
+
+
 ####################################
 # WEBUI
 ####################################
@@ -447,6 +454,9 @@ And answer according to the language of the user's question.
 Given the context information, answer the query.
 Query: [query]"""
 
+RAG_OPENAI_API_BASE_URL = os.getenv("RAG_OPENAI_API_BASE_URL", OPENAI_API_BASE_URL)
+RAG_OPENAI_API_KEY = os.getenv("RAG_OPENAI_API_KEY", OPENAI_API_KEY)
+
 ####################################
 # Transcribe
 ####################################

+ 32 - 1
src/lib/apis/audio/index.ts

@@ -5,7 +5,7 @@ export const transcribeAudio = async (token: string, file: File) => {
 	data.append('file', file);
 
 	let error = null;
-	const res = await fetch(`${AUDIO_API_BASE_URL}/transcribe`, {
+	const res = await fetch(`${AUDIO_API_BASE_URL}/transcriptions`, {
 		method: 'POST',
 		headers: {
 			Accept: 'application/json',
@@ -29,3 +29,34 @@ export const transcribeAudio = async (token: string, file: File) => {
 
 	return res;
 };
+
+export const synthesizeOpenAISpeech = async (
+	token: string = '',
+	speaker: string = 'alloy',
+	text: string = ''
+) => {
+	let error = null;
+
+	const res = await fetch(`${AUDIO_API_BASE_URL}/speech`, {
+		method: 'POST',
+		headers: {
+			Authorization: `Bearer ${token}`,
+			'Content-Type': 'application/json'
+		},
+		body: JSON.stringify({
+			model: 'tts-1',
+			input: text,
+			voice: speaker
+		})
+	}).catch((err) => {
+		console.log(err);
+		error = err;
+		return null;
+	});
+
+	if (error) {
+		throw error;
+	}
+
+	return res;
+};

+ 1 - 1
src/lib/components/chat/Messages/ResponseMessage.svelte

@@ -15,7 +15,7 @@
 	const dispatch = createEventDispatcher();
 
 	import { config, settings } from '$lib/stores';
-	import { synthesizeOpenAISpeech } from '$lib/apis/openai';
+	import { synthesizeOpenAISpeech } from '$lib/apis/audio';
 	import { imageGenerations } from '$lib/apis/images';
 	import {
 		approximateToHumanReadable,