main.py 15 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504
  1. import os
  2. import logging
  3. from fastapi import (
  4. FastAPI,
  5. Request,
  6. Depends,
  7. HTTPException,
  8. status,
  9. UploadFile,
  10. File,
  11. Form,
  12. )
  13. from fastapi.responses import StreamingResponse, JSONResponse, FileResponse
  14. from fastapi.middleware.cors import CORSMiddleware
  15. from pydantic import BaseModel
  16. import uuid
  17. import requests
  18. import hashlib
  19. from pathlib import Path
  20. import json
  21. from constants import ERROR_MESSAGES
  22. from utils.utils import (
  23. decode_token,
  24. get_current_user,
  25. get_verified_user,
  26. get_admin_user,
  27. )
  28. from utils.misc import calculate_sha256
  29. from config import (
  30. SRC_LOG_LEVELS,
  31. CACHE_DIR,
  32. UPLOAD_DIR,
  33. WHISPER_MODEL,
  34. WHISPER_MODEL_DIR,
  35. WHISPER_MODEL_AUTO_UPDATE,
  36. DEVICE_TYPE,
  37. AUDIO_STT_OPENAI_API_BASE_URL,
  38. AUDIO_STT_OPENAI_API_KEY,
  39. AUDIO_TTS_OPENAI_API_BASE_URL,
  40. AUDIO_TTS_OPENAI_API_KEY,
  41. AUDIO_TTS_API_KEY,
  42. AUDIO_STT_ENGINE,
  43. AUDIO_STT_MODEL,
  44. AUDIO_TTS_ENGINE,
  45. AUDIO_TTS_MODEL,
  46. AUDIO_TTS_VOICE,
  47. AppConfig,
  48. )
  49. log = logging.getLogger(__name__)
  50. log.setLevel(SRC_LOG_LEVELS["AUDIO"])
  51. app = FastAPI()
  52. app.add_middleware(
  53. CORSMiddleware,
  54. allow_origins=["*"],
  55. allow_credentials=True,
  56. allow_methods=["*"],
  57. allow_headers=["*"],
  58. )
  59. app.state.config = AppConfig()
  60. app.state.config.STT_OPENAI_API_BASE_URL = AUDIO_STT_OPENAI_API_BASE_URL
  61. app.state.config.STT_OPENAI_API_KEY = AUDIO_STT_OPENAI_API_KEY
  62. app.state.config.STT_ENGINE = AUDIO_STT_ENGINE
  63. app.state.config.STT_MODEL = AUDIO_STT_MODEL
  64. app.state.config.TTS_OPENAI_API_BASE_URL = AUDIO_TTS_OPENAI_API_BASE_URL
  65. app.state.config.TTS_OPENAI_API_KEY = AUDIO_TTS_OPENAI_API_KEY
  66. app.state.config.TTS_ENGINE = AUDIO_TTS_ENGINE
  67. app.state.config.TTS_MODEL = AUDIO_TTS_MODEL
  68. app.state.config.TTS_VOICE = AUDIO_TTS_VOICE
  69. app.state.config.TTS_API_KEY = AUDIO_TTS_API_KEY
  70. # setting device type for whisper model
  71. whisper_device_type = DEVICE_TYPE if DEVICE_TYPE and DEVICE_TYPE == "cuda" else "cpu"
  72. log.info(f"whisper_device_type: {whisper_device_type}")
  73. SPEECH_CACHE_DIR = Path(CACHE_DIR).joinpath("./audio/speech/")
  74. SPEECH_CACHE_DIR.mkdir(parents=True, exist_ok=True)
  75. class TTSConfigForm(BaseModel):
  76. OPENAI_API_BASE_URL: str
  77. OPENAI_API_KEY: str
  78. API_KEY: str
  79. ENGINE: str
  80. MODEL: str
  81. VOICE: str
  82. class STTConfigForm(BaseModel):
  83. OPENAI_API_BASE_URL: str
  84. OPENAI_API_KEY: str
  85. ENGINE: str
  86. MODEL: str
  87. class AudioConfigUpdateForm(BaseModel):
  88. tts: TTSConfigForm
  89. stt: STTConfigForm
  90. from pydub import AudioSegment
  91. from pydub.utils import mediainfo
  92. def is_mp4_audio(file_path):
  93. """Check if the given file is an MP4 audio file."""
  94. if not os.path.isfile(file_path):
  95. print(f"File not found: {file_path}")
  96. return False
  97. info = mediainfo(file_path)
  98. if (
  99. info.get("codec_name") == "aac"
  100. and info.get("codec_type") == "audio"
  101. and info.get("codec_tag_string") == "mp4a"
  102. ):
  103. return True
  104. return False
  105. def convert_mp4_to_wav(file_path, output_path):
  106. """Convert MP4 audio file to WAV format."""
  107. audio = AudioSegment.from_file(file_path, format="mp4")
  108. audio.export(output_path, format="wav")
  109. print(f"Converted {file_path} to {output_path}")
  110. @app.get("/config")
  111. async def get_audio_config(user=Depends(get_admin_user)):
  112. return {
  113. "tts": {
  114. "OPENAI_API_BASE_URL": app.state.config.TTS_OPENAI_API_BASE_URL,
  115. "OPENAI_API_KEY": app.state.config.TTS_OPENAI_API_KEY,
  116. "API_KEY": app.state.config.TTS_API_KEY,
  117. "ENGINE": app.state.config.TTS_ENGINE,
  118. "MODEL": app.state.config.TTS_MODEL,
  119. "VOICE": app.state.config.TTS_VOICE,
  120. },
  121. "stt": {
  122. "OPENAI_API_BASE_URL": app.state.config.STT_OPENAI_API_BASE_URL,
  123. "OPENAI_API_KEY": app.state.config.STT_OPENAI_API_KEY,
  124. "ENGINE": app.state.config.STT_ENGINE,
  125. "MODEL": app.state.config.STT_MODEL,
  126. },
  127. }
  128. @app.post("/config/update")
  129. async def update_audio_config(
  130. form_data: AudioConfigUpdateForm, user=Depends(get_admin_user)
  131. ):
  132. app.state.config.TTS_OPENAI_API_BASE_URL = form_data.tts.OPENAI_API_BASE_URL
  133. app.state.config.TTS_OPENAI_API_KEY = form_data.tts.OPENAI_API_KEY
  134. app.state.config.TTS_API_KEY = form_data.tts.API_KEY
  135. app.state.config.TTS_ENGINE = form_data.tts.ENGINE
  136. app.state.config.TTS_MODEL = form_data.tts.MODEL
  137. app.state.config.TTS_VOICE = form_data.tts.VOICE
  138. app.state.config.STT_OPENAI_API_BASE_URL = form_data.stt.OPENAI_API_BASE_URL
  139. app.state.config.STT_OPENAI_API_KEY = form_data.stt.OPENAI_API_KEY
  140. app.state.config.STT_ENGINE = form_data.stt.ENGINE
  141. app.state.config.STT_MODEL = form_data.stt.MODEL
  142. return {
  143. "tts": {
  144. "OPENAI_API_BASE_URL": app.state.config.TTS_OPENAI_API_BASE_URL,
  145. "OPENAI_API_KEY": app.state.config.TTS_OPENAI_API_KEY,
  146. "API_KEY": app.state.config.TTS_API_KEY,
  147. "ENGINE": app.state.config.TTS_ENGINE,
  148. "MODEL": app.state.config.TTS_MODEL,
  149. "VOICE": app.state.config.TTS_VOICE,
  150. },
  151. "stt": {
  152. "OPENAI_API_BASE_URL": app.state.config.STT_OPENAI_API_BASE_URL,
  153. "OPENAI_API_KEY": app.state.config.STT_OPENAI_API_KEY,
  154. "ENGINE": app.state.config.STT_ENGINE,
  155. "MODEL": app.state.config.STT_MODEL,
  156. },
  157. }
  158. @app.post("/speech")
  159. async def speech(request: Request, user=Depends(get_verified_user)):
  160. body = await request.body()
  161. name = hashlib.sha256(body).hexdigest()
  162. file_path = SPEECH_CACHE_DIR.joinpath(f"{name}.mp3")
  163. file_body_path = SPEECH_CACHE_DIR.joinpath(f"{name}.json")
  164. # Check if the file already exists in the cache
  165. if file_path.is_file():
  166. return FileResponse(file_path)
  167. if app.state.config.TTS_ENGINE == "openai":
  168. headers = {}
  169. headers["Authorization"] = f"Bearer {app.state.config.TTS_OPENAI_API_KEY}"
  170. headers["Content-Type"] = "application/json"
  171. try:
  172. body = body.decode("utf-8")
  173. body = json.loads(body)
  174. body["model"] = app.state.config.TTS_MODEL
  175. body = json.dumps(body).encode("utf-8")
  176. except Exception as e:
  177. pass
  178. r = None
  179. try:
  180. r = requests.post(
  181. url=f"{app.state.config.TTS_OPENAI_API_BASE_URL}/audio/speech",
  182. data=body,
  183. headers=headers,
  184. stream=True,
  185. )
  186. r.raise_for_status()
  187. # Save the streaming content to a file
  188. with open(file_path, "wb") as f:
  189. for chunk in r.iter_content(chunk_size=8192):
  190. f.write(chunk)
  191. with open(file_body_path, "w") as f:
  192. json.dump(json.loads(body.decode("utf-8")), f)
  193. # Return the saved file
  194. return FileResponse(file_path)
  195. except Exception as e:
  196. log.exception(e)
  197. error_detail = "Open WebUI: Server Connection Error"
  198. if r is not None:
  199. try:
  200. res = r.json()
  201. if "error" in res:
  202. error_detail = f"External: {res['error']['message']}"
  203. except Exception:
  204. error_detail = f"External: {e}"
  205. raise HTTPException(
  206. status_code=r.status_code if r != None else 500,
  207. detail=error_detail,
  208. )
  209. elif app.state.config.TTS_ENGINE == "elevenlabs":
  210. payload = None
  211. try:
  212. payload = json.loads(body.decode("utf-8"))
  213. except Exception as e:
  214. log.exception(e)
  215. raise HTTPException(status_code=400, detail="Invalid JSON payload")
  216. voice_id = payload.get("voice", "")
  217. url = f"https://api.elevenlabs.io/v1/text-to-speech/{voice_id}"
  218. headers = {
  219. "Accept": "audio/mpeg",
  220. "Content-Type": "application/json",
  221. "xi-api-key": app.state.config.TTS_API_KEY,
  222. }
  223. data = {
  224. "text": payload["input"],
  225. "model_id": app.state.config.TTS_MODEL,
  226. "voice_settings": {"stability": 0.5, "similarity_boost": 0.5},
  227. }
  228. try:
  229. r = requests.post(url, json=data, headers=headers)
  230. r.raise_for_status()
  231. # Save the streaming content to a file
  232. with open(file_path, "wb") as f:
  233. for chunk in r.iter_content(chunk_size=8192):
  234. f.write(chunk)
  235. with open(file_body_path, "w") as f:
  236. json.dump(json.loads(body.decode("utf-8")), f)
  237. # Return the saved file
  238. return FileResponse(file_path)
  239. except Exception as e:
  240. log.exception(e)
  241. error_detail = "Open WebUI: Server Connection Error"
  242. if r is not None:
  243. try:
  244. res = r.json()
  245. if "error" in res:
  246. error_detail = f"External: {res['error']['message']}"
  247. except Exception:
  248. error_detail = f"External: {e}"
  249. raise HTTPException(
  250. status_code=r.status_code if r != None else 500,
  251. detail=error_detail,
  252. )
  253. @app.post("/transcriptions")
  254. def transcribe(
  255. file: UploadFile = File(...),
  256. user=Depends(get_current_user),
  257. ):
  258. log.info(f"file.content_type: {file.content_type}")
  259. if file.content_type not in ["audio/mpeg", "audio/wav"]:
  260. raise HTTPException(
  261. status_code=status.HTTP_400_BAD_REQUEST,
  262. detail=ERROR_MESSAGES.FILE_NOT_SUPPORTED,
  263. )
  264. try:
  265. ext = file.filename.split(".")[-1]
  266. id = uuid.uuid4()
  267. filename = f"{id}.{ext}"
  268. file_dir = f"{CACHE_DIR}/audio/transcriptions"
  269. os.makedirs(file_dir, exist_ok=True)
  270. file_path = f"{file_dir}/{filename}"
  271. print(filename)
  272. contents = file.file.read()
  273. with open(file_path, "wb") as f:
  274. f.write(contents)
  275. f.close()
  276. if app.state.config.STT_ENGINE == "":
  277. from faster_whisper import WhisperModel
  278. whisper_kwargs = {
  279. "model_size_or_path": WHISPER_MODEL,
  280. "device": whisper_device_type,
  281. "compute_type": "int8",
  282. "download_root": WHISPER_MODEL_DIR,
  283. "local_files_only": not WHISPER_MODEL_AUTO_UPDATE,
  284. }
  285. log.debug(f"whisper_kwargs: {whisper_kwargs}")
  286. try:
  287. model = WhisperModel(**whisper_kwargs)
  288. except Exception:
  289. log.warning(
  290. "WhisperModel initialization failed, attempting download with local_files_only=False"
  291. )
  292. whisper_kwargs["local_files_only"] = False
  293. model = WhisperModel(**whisper_kwargs)
  294. segments, info = model.transcribe(file_path, beam_size=5)
  295. log.info(
  296. "Detected language '%s' with probability %f"
  297. % (info.language, info.language_probability)
  298. )
  299. transcript = "".join([segment.text for segment in list(segments)])
  300. data = {"text": transcript.strip()}
  301. # save the transcript to a json file
  302. transcript_file = f"{file_dir}/{id}.json"
  303. with open(transcript_file, "w") as f:
  304. json.dump(data, f)
  305. print(data)
  306. return data
  307. elif app.state.config.STT_ENGINE == "openai":
  308. if is_mp4_audio(file_path):
  309. print("is_mp4_audio")
  310. os.rename(file_path, file_path.replace(".wav", ".mp4"))
  311. # Convert MP4 audio file to WAV format
  312. convert_mp4_to_wav(file_path.replace(".wav", ".mp4"), file_path)
  313. headers = {"Authorization": f"Bearer {app.state.config.STT_OPENAI_API_KEY}"}
  314. files = {"file": (filename, open(file_path, "rb"))}
  315. data = {"model": app.state.config.STT_MODEL}
  316. print(files, data)
  317. r = None
  318. try:
  319. r = requests.post(
  320. url=f"{app.state.config.STT_OPENAI_API_BASE_URL}/audio/transcriptions",
  321. headers=headers,
  322. files=files,
  323. data=data,
  324. )
  325. r.raise_for_status()
  326. data = r.json()
  327. # save the transcript to a json file
  328. transcript_file = f"{file_dir}/{id}.json"
  329. with open(transcript_file, "w") as f:
  330. json.dump(data, f)
  331. print(data)
  332. return data
  333. except Exception as e:
  334. log.exception(e)
  335. error_detail = "Open WebUI: Server Connection Error"
  336. if r is not None:
  337. try:
  338. res = r.json()
  339. if "error" in res:
  340. error_detail = f"External: {res['error']['message']}"
  341. except Exception:
  342. error_detail = f"External: {e}"
  343. raise HTTPException(
  344. status_code=r.status_code if r != None else 500,
  345. detail=error_detail,
  346. )
  347. except Exception as e:
  348. log.exception(e)
  349. raise HTTPException(
  350. status_code=status.HTTP_400_BAD_REQUEST,
  351. detail=ERROR_MESSAGES.DEFAULT(e),
  352. )
  353. def get_available_models() -> list[dict]:
  354. if app.state.config.TTS_ENGINE == "openai":
  355. return [{"id": "tts-1"}, {"id": "tts-1-hd"}]
  356. elif app.state.config.TTS_ENGINE == "elevenlabs":
  357. headers = {
  358. "xi-api-key": app.state.config.TTS_API_KEY,
  359. "Content-Type": "application/json",
  360. }
  361. try:
  362. response = requests.get(
  363. "https://api.elevenlabs.io/v1/models", headers=headers
  364. )
  365. response.raise_for_status()
  366. models = response.json()
  367. return [
  368. {"name": model["name"], "id": model["model_id"]} for model in models
  369. ]
  370. except requests.RequestException as e:
  371. log.error(f"Error fetching voices: {str(e)}")
  372. return []
  373. @app.get("/models")
  374. async def get_models(user=Depends(get_verified_user)):
  375. return {"models": get_available_models()}
  376. def get_available_voices() -> list[dict]:
  377. if app.state.config.TTS_ENGINE == "openai":
  378. return [
  379. {"name": "alloy", "id": "alloy"},
  380. {"name": "echo", "id": "echo"},
  381. {"name": "fable", "id": "fable"},
  382. {"name": "onyx", "id": "onyx"},
  383. {"name": "nova", "id": "nova"},
  384. {"name": "shimmer", "id": "shimmer"},
  385. ]
  386. elif app.state.config.TTS_ENGINE == "elevenlabs":
  387. headers = {
  388. "xi-api-key": app.state.config.TTS_API_KEY,
  389. "Content-Type": "application/json",
  390. }
  391. try:
  392. response = requests.get(
  393. "https://api.elevenlabs.io/v1/voices", headers=headers
  394. )
  395. response.raise_for_status()
  396. voices_data = response.json()
  397. voices = []
  398. for voice in voices_data.get("voices", []):
  399. voices.append({"name": voice["name"], "id": voice["voice_id"]})
  400. return voices
  401. except requests.RequestException as e:
  402. log.error(f"Error fetching voices: {str(e)}")
  403. return []
  404. @app.get("/voices")
  405. async def get_voices(user=Depends(get_verified_user)):
  406. return {"voices": get_available_voices()}