main.py 16 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517
  1. import hashlib
  2. import json
  3. import logging
  4. import os
  5. import uuid
  6. from functools import lru_cache
  7. from pathlib import Path
  8. import requests
  9. from config import (
  10. AUDIO_STT_ENGINE,
  11. AUDIO_STT_MODEL,
  12. AUDIO_STT_OPENAI_API_BASE_URL,
  13. AUDIO_STT_OPENAI_API_KEY,
  14. AUDIO_TTS_API_KEY,
  15. AUDIO_TTS_ENGINE,
  16. AUDIO_TTS_MODEL,
  17. AUDIO_TTS_OPENAI_API_BASE_URL,
  18. AUDIO_TTS_OPENAI_API_KEY,
  19. AUDIO_TTS_SPLIT_ON,
  20. AUDIO_TTS_VOICE,
  21. CACHE_DIR,
  22. CORS_ALLOW_ORIGIN,
  23. DEVICE_TYPE,
  24. WHISPER_MODEL,
  25. WHISPER_MODEL_AUTO_UPDATE,
  26. WHISPER_MODEL_DIR,
  27. AppConfig,
  28. )
  29. from constants import ERROR_MESSAGES
  30. from env import SRC_LOG_LEVELS
  31. from fastapi import Depends, FastAPI, File, HTTPException, Request, UploadFile, status
  32. from fastapi.middleware.cors import CORSMiddleware
  33. from fastapi.responses import FileResponse
  34. from pydantic import BaseModel
  35. from utils.utils import get_admin_user, get_current_user, get_verified_user
  36. log = logging.getLogger(__name__)
  37. log.setLevel(SRC_LOG_LEVELS["AUDIO"])
  38. app = FastAPI()
  39. app.add_middleware(
  40. CORSMiddleware,
  41. allow_origins=CORS_ALLOW_ORIGIN,
  42. allow_credentials=True,
  43. allow_methods=["*"],
  44. allow_headers=["*"],
  45. )
  46. app.state.config = AppConfig()
  47. app.state.config.STT_OPENAI_API_BASE_URL = AUDIO_STT_OPENAI_API_BASE_URL
  48. app.state.config.STT_OPENAI_API_KEY = AUDIO_STT_OPENAI_API_KEY
  49. app.state.config.STT_ENGINE = AUDIO_STT_ENGINE
  50. app.state.config.STT_MODEL = AUDIO_STT_MODEL
  51. app.state.config.TTS_OPENAI_API_BASE_URL = AUDIO_TTS_OPENAI_API_BASE_URL
  52. app.state.config.TTS_OPENAI_API_KEY = AUDIO_TTS_OPENAI_API_KEY
  53. app.state.config.TTS_ENGINE = AUDIO_TTS_ENGINE
  54. app.state.config.TTS_MODEL = AUDIO_TTS_MODEL
  55. app.state.config.TTS_VOICE = AUDIO_TTS_VOICE
  56. app.state.config.TTS_API_KEY = AUDIO_TTS_API_KEY
  57. app.state.config.TTS_SPLIT_ON = AUDIO_TTS_SPLIT_ON
  58. # setting device type for whisper model
  59. whisper_device_type = DEVICE_TYPE if DEVICE_TYPE and DEVICE_TYPE == "cuda" else "cpu"
  60. log.info(f"whisper_device_type: {whisper_device_type}")
  61. SPEECH_CACHE_DIR = Path(CACHE_DIR).joinpath("./audio/speech/")
  62. SPEECH_CACHE_DIR.mkdir(parents=True, exist_ok=True)
  63. class TTSConfigForm(BaseModel):
  64. OPENAI_API_BASE_URL: str
  65. OPENAI_API_KEY: str
  66. API_KEY: str
  67. ENGINE: str
  68. MODEL: str
  69. VOICE: str
  70. SPLIT_ON: str
  71. class STTConfigForm(BaseModel):
  72. OPENAI_API_BASE_URL: str
  73. OPENAI_API_KEY: str
  74. ENGINE: str
  75. MODEL: str
  76. class AudioConfigUpdateForm(BaseModel):
  77. tts: TTSConfigForm
  78. stt: STTConfigForm
  79. from pydub import AudioSegment
  80. from pydub.utils import mediainfo
  81. def is_mp4_audio(file_path):
  82. """Check if the given file is an MP4 audio file."""
  83. if not os.path.isfile(file_path):
  84. print(f"File not found: {file_path}")
  85. return False
  86. info = mediainfo(file_path)
  87. if (
  88. info.get("codec_name") == "aac"
  89. and info.get("codec_type") == "audio"
  90. and info.get("codec_tag_string") == "mp4a"
  91. ):
  92. return True
  93. return False
  94. def convert_mp4_to_wav(file_path, output_path):
  95. """Convert MP4 audio file to WAV format."""
  96. audio = AudioSegment.from_file(file_path, format="mp4")
  97. audio.export(output_path, format="wav")
  98. print(f"Converted {file_path} to {output_path}")
  99. @app.get("/config")
  100. async def get_audio_config(user=Depends(get_admin_user)):
  101. return {
  102. "tts": {
  103. "OPENAI_API_BASE_URL": app.state.config.TTS_OPENAI_API_BASE_URL,
  104. "OPENAI_API_KEY": app.state.config.TTS_OPENAI_API_KEY,
  105. "API_KEY": app.state.config.TTS_API_KEY,
  106. "ENGINE": app.state.config.TTS_ENGINE,
  107. "MODEL": app.state.config.TTS_MODEL,
  108. "VOICE": app.state.config.TTS_VOICE,
  109. "SPLIT_ON": app.state.config.TTS_SPLIT_ON,
  110. },
  111. "stt": {
  112. "OPENAI_API_BASE_URL": app.state.config.STT_OPENAI_API_BASE_URL,
  113. "OPENAI_API_KEY": app.state.config.STT_OPENAI_API_KEY,
  114. "ENGINE": app.state.config.STT_ENGINE,
  115. "MODEL": app.state.config.STT_MODEL,
  116. },
  117. }
  118. @app.post("/config/update")
  119. async def update_audio_config(
  120. form_data: AudioConfigUpdateForm, user=Depends(get_admin_user)
  121. ):
  122. app.state.config.TTS_OPENAI_API_BASE_URL = form_data.tts.OPENAI_API_BASE_URL
  123. app.state.config.TTS_OPENAI_API_KEY = form_data.tts.OPENAI_API_KEY
  124. app.state.config.TTS_API_KEY = form_data.tts.API_KEY
  125. app.state.config.TTS_ENGINE = form_data.tts.ENGINE
  126. app.state.config.TTS_MODEL = form_data.tts.MODEL
  127. app.state.config.TTS_VOICE = form_data.tts.VOICE
  128. app.state.config.TTS_SPLIT_ON = form_data.tts.SPLIT_ON
  129. app.state.config.STT_OPENAI_API_BASE_URL = form_data.stt.OPENAI_API_BASE_URL
  130. app.state.config.STT_OPENAI_API_KEY = form_data.stt.OPENAI_API_KEY
  131. app.state.config.STT_ENGINE = form_data.stt.ENGINE
  132. app.state.config.STT_MODEL = form_data.stt.MODEL
  133. return {
  134. "tts": {
  135. "OPENAI_API_BASE_URL": app.state.config.TTS_OPENAI_API_BASE_URL,
  136. "OPENAI_API_KEY": app.state.config.TTS_OPENAI_API_KEY,
  137. "API_KEY": app.state.config.TTS_API_KEY,
  138. "ENGINE": app.state.config.TTS_ENGINE,
  139. "MODEL": app.state.config.TTS_MODEL,
  140. "VOICE": app.state.config.TTS_VOICE,
  141. "SPLIT_ON": app.state.config.TTS_SPLIT_ON,
  142. },
  143. "stt": {
  144. "OPENAI_API_BASE_URL": app.state.config.STT_OPENAI_API_BASE_URL,
  145. "OPENAI_API_KEY": app.state.config.STT_OPENAI_API_KEY,
  146. "ENGINE": app.state.config.STT_ENGINE,
  147. "MODEL": app.state.config.STT_MODEL,
  148. },
  149. }
  150. @app.post("/speech")
  151. async def speech(request: Request, user=Depends(get_verified_user)):
  152. body = await request.body()
  153. name = hashlib.sha256(body).hexdigest()
  154. file_path = SPEECH_CACHE_DIR.joinpath(f"{name}.mp3")
  155. file_body_path = SPEECH_CACHE_DIR.joinpath(f"{name}.json")
  156. # Check if the file already exists in the cache
  157. if file_path.is_file():
  158. return FileResponse(file_path)
  159. if app.state.config.TTS_ENGINE == "openai":
  160. headers = {}
  161. headers["Authorization"] = f"Bearer {app.state.config.TTS_OPENAI_API_KEY}"
  162. headers["Content-Type"] = "application/json"
  163. try:
  164. body = body.decode("utf-8")
  165. body = json.loads(body)
  166. body["model"] = app.state.config.TTS_MODEL
  167. body = json.dumps(body).encode("utf-8")
  168. except Exception:
  169. pass
  170. r = None
  171. try:
  172. r = requests.post(
  173. url=f"{app.state.config.TTS_OPENAI_API_BASE_URL}/audio/speech",
  174. data=body,
  175. headers=headers,
  176. stream=True,
  177. )
  178. r.raise_for_status()
  179. # Save the streaming content to a file
  180. with open(file_path, "wb") as f:
  181. for chunk in r.iter_content(chunk_size=8192):
  182. f.write(chunk)
  183. with open(file_body_path, "w") as f:
  184. json.dump(json.loads(body.decode("utf-8")), f)
  185. # Return the saved file
  186. return FileResponse(file_path)
  187. except Exception as e:
  188. log.exception(e)
  189. error_detail = "Open WebUI: Server Connection Error"
  190. if r is not None:
  191. try:
  192. res = r.json()
  193. if "error" in res:
  194. error_detail = f"External: {res['error']['message']}"
  195. except Exception:
  196. error_detail = f"External: {e}"
  197. raise HTTPException(
  198. status_code=r.status_code if r != None else 500,
  199. detail=error_detail,
  200. )
  201. elif app.state.config.TTS_ENGINE == "elevenlabs":
  202. payload = None
  203. try:
  204. payload = json.loads(body.decode("utf-8"))
  205. except Exception as e:
  206. log.exception(e)
  207. raise HTTPException(status_code=400, detail="Invalid JSON payload")
  208. voice_id = payload.get("voice", "")
  209. if voice_id not in get_available_voices():
  210. raise HTTPException(
  211. status_code=400,
  212. detail="Invalid voice id",
  213. )
  214. url = f"https://api.elevenlabs.io/v1/text-to-speech/{voice_id}"
  215. headers = {
  216. "Accept": "audio/mpeg",
  217. "Content-Type": "application/json",
  218. "xi-api-key": app.state.config.TTS_API_KEY,
  219. }
  220. data = {
  221. "text": payload["input"],
  222. "model_id": app.state.config.TTS_MODEL,
  223. "voice_settings": {"stability": 0.5, "similarity_boost": 0.5},
  224. }
  225. try:
  226. r = requests.post(url, json=data, headers=headers)
  227. r.raise_for_status()
  228. # Save the streaming content to a file
  229. with open(file_path, "wb") as f:
  230. for chunk in r.iter_content(chunk_size=8192):
  231. f.write(chunk)
  232. with open(file_body_path, "w") as f:
  233. json.dump(json.loads(body.decode("utf-8")), f)
  234. # Return the saved file
  235. return FileResponse(file_path)
  236. except Exception as e:
  237. log.exception(e)
  238. error_detail = "Open WebUI: Server Connection Error"
  239. if r is not None:
  240. try:
  241. res = r.json()
  242. if "error" in res:
  243. error_detail = f"External: {res['error']['message']}"
  244. except Exception:
  245. error_detail = f"External: {e}"
  246. raise HTTPException(
  247. status_code=r.status_code if r != None else 500,
  248. detail=error_detail,
  249. )
  250. @app.post("/transcriptions")
  251. def transcribe(
  252. file: UploadFile = File(...),
  253. user=Depends(get_current_user),
  254. ):
  255. log.info(f"file.content_type: {file.content_type}")
  256. if file.content_type not in ["audio/mpeg", "audio/wav"]:
  257. raise HTTPException(
  258. status_code=status.HTTP_400_BAD_REQUEST,
  259. detail=ERROR_MESSAGES.FILE_NOT_SUPPORTED,
  260. )
  261. try:
  262. ext = file.filename.split(".")[-1]
  263. id = uuid.uuid4()
  264. filename = f"{id}.{ext}"
  265. file_dir = f"{CACHE_DIR}/audio/transcriptions"
  266. os.makedirs(file_dir, exist_ok=True)
  267. file_path = f"{file_dir}/{filename}"
  268. print(filename)
  269. contents = file.file.read()
  270. with open(file_path, "wb") as f:
  271. f.write(contents)
  272. f.close()
  273. if app.state.config.STT_ENGINE == "":
  274. from faster_whisper import WhisperModel
  275. whisper_kwargs = {
  276. "model_size_or_path": WHISPER_MODEL,
  277. "device": whisper_device_type,
  278. "compute_type": "int8",
  279. "download_root": WHISPER_MODEL_DIR,
  280. "local_files_only": not WHISPER_MODEL_AUTO_UPDATE,
  281. }
  282. log.debug(f"whisper_kwargs: {whisper_kwargs}")
  283. try:
  284. model = WhisperModel(**whisper_kwargs)
  285. except Exception:
  286. log.warning(
  287. "WhisperModel initialization failed, attempting download with local_files_only=False"
  288. )
  289. whisper_kwargs["local_files_only"] = False
  290. model = WhisperModel(**whisper_kwargs)
  291. segments, info = model.transcribe(file_path, beam_size=5)
  292. log.info(
  293. "Detected language '%s' with probability %f"
  294. % (info.language, info.language_probability)
  295. )
  296. transcript = "".join([segment.text for segment in list(segments)])
  297. data = {"text": transcript.strip()}
  298. # save the transcript to a json file
  299. transcript_file = f"{file_dir}/{id}.json"
  300. with open(transcript_file, "w") as f:
  301. json.dump(data, f)
  302. print(data)
  303. return data
  304. elif app.state.config.STT_ENGINE == "openai":
  305. if is_mp4_audio(file_path):
  306. print("is_mp4_audio")
  307. os.rename(file_path, file_path.replace(".wav", ".mp4"))
  308. # Convert MP4 audio file to WAV format
  309. convert_mp4_to_wav(file_path.replace(".wav", ".mp4"), file_path)
  310. headers = {"Authorization": f"Bearer {app.state.config.STT_OPENAI_API_KEY}"}
  311. files = {"file": (filename, open(file_path, "rb"))}
  312. data = {"model": app.state.config.STT_MODEL}
  313. print(files, data)
  314. r = None
  315. try:
  316. r = requests.post(
  317. url=f"{app.state.config.STT_OPENAI_API_BASE_URL}/audio/transcriptions",
  318. headers=headers,
  319. files=files,
  320. data=data,
  321. )
  322. r.raise_for_status()
  323. data = r.json()
  324. # save the transcript to a json file
  325. transcript_file = f"{file_dir}/{id}.json"
  326. with open(transcript_file, "w") as f:
  327. json.dump(data, f)
  328. print(data)
  329. return data
  330. except Exception as e:
  331. log.exception(e)
  332. error_detail = "Open WebUI: Server Connection Error"
  333. if r is not None:
  334. try:
  335. res = r.json()
  336. if "error" in res:
  337. error_detail = f"External: {res['error']['message']}"
  338. except Exception:
  339. error_detail = f"External: {e}"
  340. raise HTTPException(
  341. status_code=r.status_code if r != None else 500,
  342. detail=error_detail,
  343. )
  344. except Exception as e:
  345. log.exception(e)
  346. raise HTTPException(
  347. status_code=status.HTTP_400_BAD_REQUEST,
  348. detail=ERROR_MESSAGES.DEFAULT(e),
  349. )
  350. def get_available_models() -> list[dict]:
  351. if app.state.config.TTS_ENGINE == "openai":
  352. return [{"id": "tts-1"}, {"id": "tts-1-hd"}]
  353. elif app.state.config.TTS_ENGINE == "elevenlabs":
  354. headers = {
  355. "xi-api-key": app.state.config.TTS_API_KEY,
  356. "Content-Type": "application/json",
  357. }
  358. try:
  359. response = requests.get(
  360. "https://api.elevenlabs.io/v1/models", headers=headers
  361. )
  362. response.raise_for_status()
  363. models = response.json()
  364. return [
  365. {"name": model["name"], "id": model["model_id"]} for model in models
  366. ]
  367. except requests.RequestException as e:
  368. log.error(f"Error fetching voices: {str(e)}")
  369. return []
  370. @app.get("/models")
  371. async def get_models(user=Depends(get_verified_user)):
  372. return {"models": get_available_models()}
  373. def get_available_voices() -> dict:
  374. """Returns {voice_id: voice_name} dict"""
  375. ret = {}
  376. if app.state.config.TTS_ENGINE == "openai":
  377. ret = {
  378. "alloy": "alloy",
  379. "echo": "echo",
  380. "fable": "fable",
  381. "onyx": "onyx",
  382. "nova": "nova",
  383. "shimmer": "shimmer",
  384. }
  385. elif app.state.config.TTS_ENGINE == "elevenlabs":
  386. try:
  387. ret = get_elevenlabs_voices()
  388. except Exception:
  389. # Avoided @lru_cache with exception
  390. pass
  391. return ret
  392. @lru_cache
  393. def get_elevenlabs_voices() -> dict:
  394. """
  395. Note, set the following in your .env file to use Elevenlabs:
  396. AUDIO_TTS_ENGINE=elevenlabs
  397. AUDIO_TTS_API_KEY=sk_... # Your Elevenlabs API key
  398. AUDIO_TTS_VOICE=EXAVITQu4vr4xnSDxMaL # From https://api.elevenlabs.io/v1/voices
  399. AUDIO_TTS_MODEL=eleven_multilingual_v2
  400. """
  401. headers = {
  402. "xi-api-key": app.state.config.TTS_API_KEY,
  403. "Content-Type": "application/json",
  404. }
  405. try:
  406. # TODO: Add retries
  407. response = requests.get("https://api.elevenlabs.io/v1/voices", headers=headers)
  408. response.raise_for_status()
  409. voices_data = response.json()
  410. voices = {}
  411. for voice in voices_data.get("voices", []):
  412. voices[voice["voice_id"]] = voice["name"]
  413. except requests.RequestException as e:
  414. # Avoid @lru_cache with exception
  415. log.error(f"Error fetching voices: {str(e)}")
  416. raise RuntimeError(f"Error fetching voices: {str(e)}")
  417. return voices
  418. @app.get("/voices")
  419. async def get_voices(user=Depends(get_verified_user)):
  420. return {"voices": [{"id": k, "name": v} for k, v in get_available_voices().items()]}