Timothy Jaeryang Baek 4 mēneši atpakaļ
vecāks
revīzija
70de5cf7b8
1 mainītis faili ar 18 papildinājumiem un 8 dzēšanām
  1. 18 8
      backend/open_webui/routers/audio.py

+ 18 - 8
backend/open_webui/routers/audio.py

@@ -218,7 +218,7 @@ async def update_audio_config(
     }
     }
 
 
 
 
-def load_speech_pipeline():
+def load_speech_pipeline(request):
     from transformers import pipeline
     from transformers import pipeline
     from datasets import load_dataset
     from datasets import load_dataset
 
 
@@ -236,7 +236,11 @@ def load_speech_pipeline():
 @router.post("/speech")
 @router.post("/speech")
 async def speech(request: Request, user=Depends(get_verified_user)):
 async def speech(request: Request, user=Depends(get_verified_user)):
     body = await request.body()
     body = await request.body()
-    name = hashlib.sha256(body).hexdigest()
+    name = hashlib.sha256(
+        body
+        + str(request.app.state.config.TTS_ENGINE).encode("utf-8")
+        + str(request.app.state.config.TTS_MODEL).encode("utf-8")
+    ).hexdigest()
 
 
     file_path = SPEECH_CACHE_DIR.joinpath(f"{name}.mp3")
     file_path = SPEECH_CACHE_DIR.joinpath(f"{name}.mp3")
     file_body_path = SPEECH_CACHE_DIR.joinpath(f"{name}.json")
     file_body_path = SPEECH_CACHE_DIR.joinpath(f"{name}.json")
@@ -256,10 +260,11 @@ async def speech(request: Request, user=Depends(get_verified_user)):
         payload["model"] = request.app.state.config.TTS_MODEL
         payload["model"] = request.app.state.config.TTS_MODEL
 
 
         try:
         try:
+            # print(payload)
             async with aiohttp.ClientSession() as session:
             async with aiohttp.ClientSession() as session:
                 async with session.post(
                 async with session.post(
                     url=f"{request.app.state.config.TTS_OPENAI_API_BASE_URL}/audio/speech",
                     url=f"{request.app.state.config.TTS_OPENAI_API_BASE_URL}/audio/speech",
-                    data=payload,
+                    json=payload,
                     headers={
                     headers={
                         "Content-Type": "application/json",
                         "Content-Type": "application/json",
                         "Authorization": f"Bearer {request.app.state.config.TTS_OPENAI_API_KEY}",
                         "Authorization": f"Bearer {request.app.state.config.TTS_OPENAI_API_KEY}",
@@ -281,7 +286,7 @@ async def speech(request: Request, user=Depends(get_verified_user)):
                         await f.write(await r.read())
                         await f.write(await r.read())
 
 
                     async with aiofiles.open(file_body_path, "w") as f:
                     async with aiofiles.open(file_body_path, "w") as f:
-                        await f.write(json.dumps(json.loads(body.decode("utf-8"))))
+                        await f.write(json.dumps(payload))
 
 
             return FileResponse(file_path)
             return FileResponse(file_path)
 
 
@@ -292,6 +297,7 @@ async def speech(request: Request, user=Depends(get_verified_user)):
             try:
             try:
                 if r.status != 200:
                 if r.status != 200:
                     res = await r.json()
                     res = await r.json()
+
                     if "error" in res:
                     if "error" in res:
                         detail = f"External: {res['error'].get('message', '')}"
                         detail = f"External: {res['error'].get('message', '')}"
             except Exception:
             except Exception:
@@ -332,7 +338,7 @@ async def speech(request: Request, user=Depends(get_verified_user)):
                         await f.write(await r.read())
                         await f.write(await r.read())
 
 
                     async with aiofiles.open(file_body_path, "w") as f:
                     async with aiofiles.open(file_body_path, "w") as f:
-                        await f.write(json.dumps(json.loads(body.decode("utf-8"))))
+                        await f.write(json.dumps(payload))
 
 
             return FileResponse(file_path)
             return FileResponse(file_path)
 
 
@@ -384,6 +390,9 @@ async def speech(request: Request, user=Depends(get_verified_user)):
                     async with aiofiles.open(file_path, "wb") as f:
                     async with aiofiles.open(file_path, "wb") as f:
                         await f.write(await r.read())
                         await f.write(await r.read())
 
 
+                    async with aiofiles.open(file_body_path, "w") as f:
+                        await f.write(json.dumps(payload))
+
                     return FileResponse(file_path)
                     return FileResponse(file_path)
 
 
         except Exception as e:
         except Exception as e:
@@ -414,7 +423,7 @@ async def speech(request: Request, user=Depends(get_verified_user)):
         import torch
         import torch
         import soundfile as sf
         import soundfile as sf
 
 
-        load_speech_pipeline()
+        load_speech_pipeline(request)
 
 
         embeddings_dataset = request.app.state.speech_speaker_embeddings_dataset
         embeddings_dataset = request.app.state.speech_speaker_embeddings_dataset
 
 
@@ -436,8 +445,9 @@ async def speech(request: Request, user=Depends(get_verified_user)):
         )
         )
 
 
         sf.write(file_path, speech["audio"], samplerate=speech["sampling_rate"])
         sf.write(file_path, speech["audio"], samplerate=speech["sampling_rate"])
-        with open(file_body_path, "w") as f:
-            json.dump(json.loads(body.decode("utf-8")), f)
+
+        async with aiofiles.open(file_body_path, "w") as f:
+            await f.write(json.dumps(payload))
 
 
         return FileResponse(file_path)
         return FileResponse(file_path)