Procházet zdrojové kódy

feat: openai tts support

Timothy J. Baek před 1 rokem
rodič
revize
0b8df52c97

+ 73 - 4
backend/apps/openai/main.py

@@ -1,15 +1,19 @@
 from fastapi import FastAPI, Request, Response, HTTPException, Depends
 from fastapi.middleware.cors import CORSMiddleware
-from fastapi.responses import StreamingResponse, JSONResponse
+from fastapi.responses import StreamingResponse, JSONResponse, FileResponse
 
 import requests
 import json
 from pydantic import BaseModel
 
+
 from apps.web.models.users import Users
 from constants import ERROR_MESSAGES
 from utils.utils import decode_token, get_current_user
-from config import OPENAI_API_BASE_URL, OPENAI_API_KEY
+from config import OPENAI_API_BASE_URL, OPENAI_API_KEY, CACHE_DIR
+
+import hashlib
+from pathlib import Path
 
 app = FastAPI()
 app.add_middleware(
@@ -66,6 +70,73 @@ async def update_openai_key(form_data: KeyUpdateForm, user=Depends(get_current_u
         raise HTTPException(status_code=401, detail=ERROR_MESSAGES.ACCESS_PROHIBITED)
 
 
+@app.post("/audio/speech")
+async def speech(request: Request, user=Depends(get_current_user)):
+    target_url = f"{app.state.OPENAI_API_BASE_URL}/audio/speech"
+
+    if user.role not in ["user", "admin"]:
+        raise HTTPException(status_code=401, detail=ERROR_MESSAGES.ACCESS_PROHIBITED)
+    if app.state.OPENAI_API_KEY == "":
+        raise HTTPException(status_code=401, detail=ERROR_MESSAGES.API_KEY_NOT_FOUND)
+
+    body = await request.body()
+
+    filename = hashlib.sha256(body).hexdigest() + ".mp3"
+    SPEECH_CACHE_DIR = Path(CACHE_DIR).joinpath("./audio/speech/")
+    SPEECH_CACHE_DIR.mkdir(parents=True, exist_ok=True)
+    file_path = SPEECH_CACHE_DIR.joinpath(filename)
+
+    print(file_path)
+
+    # Check if the file already exists in the cache
+    if file_path.is_file():
+        return FileResponse(file_path)
+
+    headers = {}
+    headers["Authorization"] = f"Bearer {app.state.OPENAI_API_KEY}"
+    headers["Content-Type"] = "application/json"
+
+    try:
+        print("openai")
+        r = requests.post(
+            url=target_url,
+            data=body,
+            headers=headers,
+            stream=True,
+        )
+
+        print(r)
+
+        r.raise_for_status()
+
+        # Save the streaming content to a file
+        with open(file_path, "wb") as f:
+            for chunk in r.iter_content(chunk_size=8192):
+                f.write(chunk)
+
+        # Return the saved file
+        return FileResponse(file_path)
+
+        # return StreamingResponse(
+        #     r.iter_content(chunk_size=8192),
+        #     status_code=r.status_code,
+        #     headers=dict(r.headers),
+        # )
+
+    except Exception as e:
+        print(e)
+        error_detail = "Ollama WebUI: Server Connection Error"
+        if r is not None:
+            try:
+                res = r.json()
+                if "error" in res:
+                    error_detail = f"External: {res['error']}"
+            except:
+                error_detail = f"External: {e}"
+
+        raise HTTPException(status_code=r.status_code, detail=error_detail)
+
+
 @app.api_route("/{path:path}", methods=["GET", "POST", "PUT", "DELETE"])
 async def proxy(path: str, request: Request, user=Depends(get_current_user)):
     target_url = f"{app.state.OPENAI_API_BASE_URL}/{path}"
@@ -129,8 +200,6 @@ async def proxy(path: str, request: Request, user=Depends(get_current_user)):
 
             response_data = r.json()
 
-            print(type(response_data))
-
             if "openai" in app.state.OPENAI_API_BASE_URL and path == "models":
                 response_data["data"] = list(
                     filter(lambda model: "gpt" in model["id"], response_data["data"])

+ 8 - 0
backend/config.py

@@ -35,6 +35,14 @@ FRONTEND_BUILD_DIR = str(Path(os.getenv("FRONTEND_BUILD_DIR", "../build")))
 UPLOAD_DIR = f"{DATA_DIR}/uploads"
 Path(UPLOAD_DIR).mkdir(parents=True, exist_ok=True)
 
+
+####################################
+# Cache DIR
+####################################
+
+CACHE_DIR = f"{DATA_DIR}/cache"
+Path(CACHE_DIR).mkdir(parents=True, exist_ok=True)
+
 ####################################
 # OLLAMA_API_BASE_URL
 ####################################

+ 31 - 0
src/lib/apis/openai/index.ts

@@ -229,3 +229,34 @@ export const generateOpenAIChatCompletion = async (token: string = '', body: obj
 
 	return res;
 };
+
+export const synthesizeOpenAISpeech = async (
+	token: string = '',
+	speaker: string = 'alloy',
+	text: string = ''
+) => {
+	let error = null;
+
+	const res = await fetch(`${OPENAI_API_BASE_URL}/audio/speech`, {
+		method: 'POST',
+		headers: {
+			Authorization: `Bearer ${token}`,
+			'Content-Type': 'application/json'
+		},
+		body: JSON.stringify({
+			model: 'tts-1',
+			input: text,
+			voice: speaker
+		})
+	}).catch((err) => {
+		console.log(err);
+		error = err;
+		return null;
+	});
+
+	if (error) {
+		throw error;
+	}
+
+	return res;
+};

+ 50 - 12
src/lib/components/chat/Messages/ResponseMessage.svelte

@@ -1,4 +1,5 @@
 <script lang="ts">
+	import toast from 'svelte-french-toast';
 	import dayjs from 'dayjs';
 	import { marked } from 'marked';
 	import { settings } from '$lib/stores';
@@ -13,6 +14,8 @@
 	import Skeleton from './Skeleton.svelte';
 	import CodeBlock from './CodeBlock.svelte';
 
+	import { synthesizeOpenAISpeech } from '$lib/apis/openai';
+
 	export let modelfiles = [];
 	export let message;
 	export let siblings;
@@ -27,6 +30,8 @@
 	export let copyToClipboard: Function;
 	export let regenerateResponse: Function;
 
+	let audioMap = {};
+
 	let edit = false;
 	let editedContent = '';
 
@@ -114,22 +119,55 @@
 		if (speaking) {
 			speechSynthesis.cancel();
 			speaking = null;
+
+			audioMap[message.id].pause();
+			audioMap[message.id].currentTime = 0;
 		} else {
 			speaking = true;
 
-			let voices = [];
-			const getVoicesLoop = setInterval(async () => {
-				voices = await speechSynthesis.getVoices();
-				if (voices.length > 0) {
-					clearInterval(getVoicesLoop);
-
-					const voice = voices?.filter((v) => v.name === $settings?.speaker)?.at(0) ?? undefined;
-
-					const speak = new SpeechSynthesisUtterance(message.content);
-					speak.voice = voice;
-					speechSynthesis.speak(speak);
+			if ($settings?.speech?.engine === 'openai') {
+				const res = await synthesizeOpenAISpeech(
+					localStorage.token,
+					$settings?.speech?.speaker,
+					message.content
+				).catch((error) => {
+					toast.error(error);
+					return null;
+				});
+
+				if (res) {
+					const blob = await res.blob();
+					const blobUrl = URL.createObjectURL(blob);
+					console.log(blobUrl);
+
+					const audio = new Audio(blobUrl);
+					audioMap[message.id] = audio;
+
+					audio.onended = () => {
+						speaking = null;
+					};
+					audio.play().catch((e) => console.error('Error playing audio:', e));
 				}
-			}, 100);
+			} else {
+				let voices = [];
+				const getVoicesLoop = setInterval(async () => {
+					voices = await speechSynthesis.getVoices();
+					if (voices.length > 0) {
+						clearInterval(getVoicesLoop);
+
+						const voice =
+							voices?.filter((v) => v.name === $settings?.speech?.speaker)?.at(0) ?? undefined;
+
+						const speak = new SpeechSynthesisUtterance(message.content);
+
+						speak.onend = () => {
+							speaking = null;
+						};
+						speak.voice = voice;
+						speechSynthesis.speak(speak);
+					}
+				}, 100);
+			}
 		}
 	};
 

+ 56 - 9
src/lib/components/chat/Settings/Voice.svelte

@@ -6,16 +6,23 @@
 
 	// Voice
 	let engines = ['', 'openai'];
-	let selectedEngine = '';
+	let engine = '';
 
 	let voices = [];
 	let speaker = '';
 
-	onMount(async () => {
-		let settings = JSON.parse(localStorage.getItem('settings') ?? '{}');
-
-		speaker = settings.speaker ?? '';
+	const getOpenAIVoices = () => {
+		voices = [
+			{ name: 'alloy' },
+			{ name: 'echo' },
+			{ name: 'fable' },
+			{ name: 'onyx' },
+			{ name: 'nova' },
+			{ name: 'shimmer' }
+		];
+	};
 
+	const getWebAPIVoices = () => {
 		const getVoicesLoop = setInterval(async () => {
 			voices = await speechSynthesis.getVoices();
 
@@ -24,6 +31,19 @@
 				clearInterval(getVoicesLoop);
 			}
 		}, 100);
+	};
+
+	onMount(async () => {
+		let settings = JSON.parse(localStorage.getItem('settings') ?? '{}');
+
+		engine = settings?.speech?.engine ?? '';
+		speaker = settings?.speech?.speaker ?? '';
+
+		if (engine === 'openai') {
+			getOpenAIVoices();
+		} else {
+			getWebAPIVoices();
+		}
 	});
 </script>
 
@@ -31,7 +51,10 @@
 	class="flex flex-col h-full justify-between space-y-3 text-sm"
 	on:submit|preventDefault={() => {
 		saveSettings({
-			speaker: speaker !== '' ? speaker : undefined
+			speech: {
+				engine: engine !== '' ? engine : undefined,
+				speaker: speaker !== '' ? speaker : undefined
+			}
 		});
 		dispatch('save');
 	}}
@@ -42,10 +65,16 @@
 			<div class="flex items-center relative">
 				<select
 					class="w-fit pr-8 rounded py-2 px-2 text-xs bg-transparent outline-none text-right"
-					bind:value={selectedEngine}
+					bind:value={engine}
 					placeholder="Select a mode"
 					on:change={(e) => {
-						console.log(e);
+						if (e.target.value === 'openai') {
+							getOpenAIVoices();
+							speaker = 'alloy';
+						} else {
+							getWebAPIVoices();
+							speaker = '';
+						}
 					}}
 				>
 					<option value="">Default (Web API)</option>
@@ -56,7 +85,7 @@
 
 		<hr class=" dark:border-gray-700" />
 
-		{#if selectedEngine === ''}
+		{#if engine === ''}
 			<div>
 				<div class=" mb-2.5 text-sm font-medium">Set Voice</div>
 				<div class="flex w-full">
@@ -75,6 +104,24 @@
 					</div>
 				</div>
 			</div>
+		{:else if engine === 'openai'}
+			<div>
+				<div class=" mb-2.5 text-sm font-medium">Set Voice</div>
+				<div class="flex w-full">
+					<div class="flex-1">
+						<select
+							class="w-full rounded py-2 px-4 text-sm dark:text-gray-300 dark:bg-gray-800 outline-none"
+							bind:value={speaker}
+							placeholder="Select a voice"
+						>
+							{#each voices as voice}
+								<option value={voice.name} class="bg-gray-100 dark:bg-gray-700">{voice.name}</option
+								>
+							{/each}
+						</select>
+					</div>
+				</div>
+			</div>
 		{/if}
 	</div>