main.py 15 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504
  1. import os
  2. import logging
  3. from fastapi import (
  4. FastAPI,
  5. Request,
  6. Depends,
  7. HTTPException,
  8. status,
  9. UploadFile,
  10. File,
  11. Form,
  12. )
  13. from fastapi.responses import StreamingResponse, JSONResponse, FileResponse
  14. from fastapi.middleware.cors import CORSMiddleware
  15. from pydantic import BaseModel
  16. from typing import List
  17. import uuid
  18. import requests
  19. import hashlib
  20. from pathlib import Path
  21. import json
  22. from constants import ERROR_MESSAGES
  23. from utils.utils import (
  24. decode_token,
  25. get_current_user,
  26. get_verified_user,
  27. get_admin_user,
  28. )
  29. from utils.misc import calculate_sha256
  30. from config import (
  31. SRC_LOG_LEVELS,
  32. CACHE_DIR,
  33. UPLOAD_DIR,
  34. WHISPER_MODEL,
  35. WHISPER_MODEL_DIR,
  36. WHISPER_MODEL_AUTO_UPDATE,
  37. DEVICE_TYPE,
  38. AUDIO_STT_OPENAI_API_BASE_URL,
  39. AUDIO_STT_OPENAI_API_KEY,
  40. AUDIO_TTS_OPENAI_API_BASE_URL,
  41. AUDIO_TTS_OPENAI_API_KEY,
  42. AUDIO_TTS_API_KEY,
  43. AUDIO_STT_ENGINE,
  44. AUDIO_STT_MODEL,
  45. AUDIO_TTS_ENGINE,
  46. AUDIO_TTS_MODEL,
  47. AUDIO_TTS_VOICE,
  48. AppConfig,
  49. )
  50. log = logging.getLogger(__name__)
  51. log.setLevel(SRC_LOG_LEVELS["AUDIO"])
  52. app = FastAPI()
  53. app.add_middleware(
  54. CORSMiddleware,
  55. allow_origins=["*"],
  56. allow_credentials=True,
  57. allow_methods=["*"],
  58. allow_headers=["*"],
  59. )
  60. app.state.config = AppConfig()
  61. app.state.config.STT_OPENAI_API_BASE_URL = AUDIO_STT_OPENAI_API_BASE_URL
  62. app.state.config.STT_OPENAI_API_KEY = AUDIO_STT_OPENAI_API_KEY
  63. app.state.config.STT_ENGINE = AUDIO_STT_ENGINE
  64. app.state.config.STT_MODEL = AUDIO_STT_MODEL
  65. app.state.config.TTS_OPENAI_API_BASE_URL = AUDIO_TTS_OPENAI_API_BASE_URL
  66. app.state.config.TTS_OPENAI_API_KEY = AUDIO_TTS_OPENAI_API_KEY
  67. app.state.config.TTS_ENGINE = AUDIO_TTS_ENGINE
  68. app.state.config.TTS_MODEL = AUDIO_TTS_MODEL
  69. app.state.config.TTS_VOICE = AUDIO_TTS_VOICE
  70. app.state.config.TTS_API_KEY = AUDIO_TTS_API_KEY
  71. # setting device type for whisper model
  72. whisper_device_type = DEVICE_TYPE if DEVICE_TYPE and DEVICE_TYPE == "cuda" else "cpu"
  73. log.info(f"whisper_device_type: {whisper_device_type}")
  74. SPEECH_CACHE_DIR = Path(CACHE_DIR).joinpath("./audio/speech/")
  75. SPEECH_CACHE_DIR.mkdir(parents=True, exist_ok=True)
  76. class TTSConfigForm(BaseModel):
  77. OPENAI_API_BASE_URL: str
  78. OPENAI_API_KEY: str
  79. API_KEY: str
  80. ENGINE: str
  81. MODEL: str
  82. VOICE: str
  83. class STTConfigForm(BaseModel):
  84. OPENAI_API_BASE_URL: str
  85. OPENAI_API_KEY: str
  86. ENGINE: str
  87. MODEL: str
  88. class AudioConfigUpdateForm(BaseModel):
  89. tts: TTSConfigForm
  90. stt: STTConfigForm
  91. from pydub import AudioSegment
  92. from pydub.utils import mediainfo
  93. def is_mp4_audio(file_path):
  94. """Check if the given file is an MP4 audio file."""
  95. if not os.path.isfile(file_path):
  96. print(f"File not found: {file_path}")
  97. return False
  98. info = mediainfo(file_path)
  99. if (
  100. info.get("codec_name") == "aac"
  101. and info.get("codec_type") == "audio"
  102. and info.get("codec_tag_string") == "mp4a"
  103. ):
  104. return True
  105. return False
  106. def convert_mp4_to_wav(file_path, output_path):
  107. """Convert MP4 audio file to WAV format."""
  108. audio = AudioSegment.from_file(file_path, format="mp4")
  109. audio.export(output_path, format="wav")
  110. print(f"Converted {file_path} to {output_path}")
  111. @app.get("/config")
  112. async def get_audio_config(user=Depends(get_admin_user)):
  113. return {
  114. "tts": {
  115. "OPENAI_API_BASE_URL": app.state.config.TTS_OPENAI_API_BASE_URL,
  116. "OPENAI_API_KEY": app.state.config.TTS_OPENAI_API_KEY,
  117. "API_KEY": app.state.config.TTS_API_KEY,
  118. "ENGINE": app.state.config.TTS_ENGINE,
  119. "MODEL": app.state.config.TTS_MODEL,
  120. "VOICE": app.state.config.TTS_VOICE,
  121. },
  122. "stt": {
  123. "OPENAI_API_BASE_URL": app.state.config.STT_OPENAI_API_BASE_URL,
  124. "OPENAI_API_KEY": app.state.config.STT_OPENAI_API_KEY,
  125. "ENGINE": app.state.config.STT_ENGINE,
  126. "MODEL": app.state.config.STT_MODEL,
  127. },
  128. }
  129. @app.post("/config/update")
  130. async def update_audio_config(
  131. form_data: AudioConfigUpdateForm, user=Depends(get_admin_user)
  132. ):
  133. app.state.config.TTS_OPENAI_API_BASE_URL = form_data.tts.OPENAI_API_BASE_URL
  134. app.state.config.TTS_OPENAI_API_KEY = form_data.tts.OPENAI_API_KEY
  135. app.state.config.TTS_API_KEY = form_data.tts.API_KEY
  136. app.state.config.TTS_ENGINE = form_data.tts.ENGINE
  137. app.state.config.TTS_MODEL = form_data.tts.MODEL
  138. app.state.config.TTS_VOICE = form_data.tts.VOICE
  139. app.state.config.STT_OPENAI_API_BASE_URL = form_data.stt.OPENAI_API_BASE_URL
  140. app.state.config.STT_OPENAI_API_KEY = form_data.stt.OPENAI_API_KEY
  141. app.state.config.STT_ENGINE = form_data.stt.ENGINE
  142. app.state.config.STT_MODEL = form_data.stt.MODEL
  143. return {
  144. "tts": {
  145. "OPENAI_API_BASE_URL": app.state.config.TTS_OPENAI_API_BASE_URL,
  146. "OPENAI_API_KEY": app.state.config.TTS_OPENAI_API_KEY,
  147. "API_KEY": app.state.config.TTS_API_KEY,
  148. "ENGINE": app.state.config.TTS_ENGINE,
  149. "MODEL": app.state.config.TTS_MODEL,
  150. "VOICE": app.state.config.TTS_VOICE,
  151. },
  152. "stt": {
  153. "OPENAI_API_BASE_URL": app.state.config.STT_OPENAI_API_BASE_URL,
  154. "OPENAI_API_KEY": app.state.config.STT_OPENAI_API_KEY,
  155. "ENGINE": app.state.config.STT_ENGINE,
  156. "MODEL": app.state.config.STT_MODEL,
  157. },
  158. }
  159. @app.post("/speech")
  160. async def speech(request: Request, user=Depends(get_verified_user)):
  161. body = await request.body()
  162. name = hashlib.sha256(body).hexdigest()
  163. file_path = SPEECH_CACHE_DIR.joinpath(f"{name}.mp3")
  164. file_body_path = SPEECH_CACHE_DIR.joinpath(f"{name}.json")
  165. # Check if the file already exists in the cache
  166. if file_path.is_file():
  167. return FileResponse(file_path)
  168. if app.state.config.TTS_ENGINE == "openai":
  169. headers = {}
  170. headers["Authorization"] = f"Bearer {app.state.config.TTS_OPENAI_API_KEY}"
  171. headers["Content-Type"] = "application/json"
  172. try:
  173. body = body.decode("utf-8")
  174. body = json.loads(body)
  175. body["model"] = app.state.config.TTS_MODEL
  176. body = json.dumps(body).encode("utf-8")
  177. except Exception as e:
  178. pass
  179. r = None
  180. try:
  181. r = requests.post(
  182. url=f"{app.state.config.TTS_OPENAI_API_BASE_URL}/audio/speech",
  183. data=body,
  184. headers=headers,
  185. stream=True,
  186. )
  187. r.raise_for_status()
  188. # Save the streaming content to a file
  189. with open(file_path, "wb") as f:
  190. for chunk in r.iter_content(chunk_size=8192):
  191. f.write(chunk)
  192. with open(file_body_path, "w") as f:
  193. json.dump(json.loads(body.decode("utf-8")), f)
  194. # Return the saved file
  195. return FileResponse(file_path)
  196. except Exception as e:
  197. log.exception(e)
  198. error_detail = "Open WebUI: Server Connection Error"
  199. if r is not None:
  200. try:
  201. res = r.json()
  202. if "error" in res:
  203. error_detail = f"External: {res['error']['message']}"
  204. except:
  205. error_detail = f"External: {e}"
  206. raise HTTPException(
  207. status_code=r.status_code if r != None else 500,
  208. detail=error_detail,
  209. )
  210. elif app.state.config.TTS_ENGINE == "elevenlabs":
  211. payload = None
  212. try:
  213. payload = json.loads(body.decode("utf-8"))
  214. except Exception as e:
  215. log.exception(e)
  216. raise HTTPException(status_code=400, detail="Invalid JSON payload")
  217. voice_id = payload.get("voice", "")
  218. url = f"https://api.elevenlabs.io/v1/text-to-speech/{voice_id}"
  219. headers = {
  220. "Accept": "audio/mpeg",
  221. "Content-Type": "application/json",
  222. "xi-api-key": app.state.config.TTS_API_KEY,
  223. }
  224. data = {
  225. "text": payload["input"],
  226. "model_id": app.state.config.TTS_MODEL,
  227. "voice_settings": {"stability": 0.5, "similarity_boost": 0.5},
  228. }
  229. try:
  230. r = requests.post(url, json=data, headers=headers)
  231. r.raise_for_status()
  232. # Save the streaming content to a file
  233. with open(file_path, "wb") as f:
  234. for chunk in r.iter_content(chunk_size=8192):
  235. f.write(chunk)
  236. with open(file_body_path, "w") as f:
  237. json.dump(json.loads(body.decode("utf-8")), f)
  238. # Return the saved file
  239. return FileResponse(file_path)
  240. except Exception as e:
  241. log.exception(e)
  242. error_detail = "Open WebUI: Server Connection Error"
  243. if r is not None:
  244. try:
  245. res = r.json()
  246. if "error" in res:
  247. error_detail = f"External: {res['error']['message']}"
  248. except:
  249. error_detail = f"External: {e}"
  250. raise HTTPException(
  251. status_code=r.status_code if r != None else 500,
  252. detail=error_detail,
  253. )
  254. @app.post("/transcriptions")
  255. def transcribe(
  256. file: UploadFile = File(...),
  257. user=Depends(get_current_user),
  258. ):
  259. log.info(f"file.content_type: {file.content_type}")
  260. if file.content_type not in ["audio/mpeg", "audio/wav"]:
  261. raise HTTPException(
  262. status_code=status.HTTP_400_BAD_REQUEST,
  263. detail=ERROR_MESSAGES.FILE_NOT_SUPPORTED,
  264. )
  265. try:
  266. ext = file.filename.split(".")[-1]
  267. id = uuid.uuid4()
  268. filename = f"{id}.{ext}"
  269. file_dir = f"{CACHE_DIR}/audio/transcriptions"
  270. os.makedirs(file_dir, exist_ok=True)
  271. file_path = f"{file_dir}/{filename}"
  272. print(filename)
  273. contents = file.file.read()
  274. with open(file_path, "wb") as f:
  275. f.write(contents)
  276. f.close()
  277. if app.state.config.STT_ENGINE == "":
  278. from faster_whisper import WhisperModel
  279. whisper_kwargs = {
  280. "model_size_or_path": WHISPER_MODEL,
  281. "device": whisper_device_type,
  282. "compute_type": "int8",
  283. "download_root": WHISPER_MODEL_DIR,
  284. "local_files_only": not WHISPER_MODEL_AUTO_UPDATE,
  285. }
  286. log.debug(f"whisper_kwargs: {whisper_kwargs}")
  287. try:
  288. model = WhisperModel(**whisper_kwargs)
  289. except:
  290. log.warning(
  291. "WhisperModel initialization failed, attempting download with local_files_only=False"
  292. )
  293. whisper_kwargs["local_files_only"] = False
  294. model = WhisperModel(**whisper_kwargs)
  295. segments, info = model.transcribe(file_path, beam_size=5)
  296. log.info(
  297. "Detected language '%s' with probability %f"
  298. % (info.language, info.language_probability)
  299. )
  300. transcript = "".join([segment.text for segment in list(segments)])
  301. data = {"text": transcript.strip()}
  302. # save the transcript to a json file
  303. transcript_file = f"{file_dir}/{id}.json"
  304. with open(transcript_file, "w") as f:
  305. json.dump(data, f)
  306. print(data)
  307. return data
  308. elif app.state.config.STT_ENGINE == "openai":
  309. if is_mp4_audio(file_path):
  310. print("is_mp4_audio")
  311. os.rename(file_path, file_path.replace(".wav", ".mp4"))
  312. # Convert MP4 audio file to WAV format
  313. convert_mp4_to_wav(file_path.replace(".wav", ".mp4"), file_path)
  314. headers = {"Authorization": f"Bearer {app.state.config.STT_OPENAI_API_KEY}"}
  315. files = {"file": (filename, open(file_path, "rb"))}
  316. data = {"model": app.state.config.STT_MODEL}
  317. print(files, data)
  318. r = None
  319. try:
  320. r = requests.post(
  321. url=f"{app.state.config.STT_OPENAI_API_BASE_URL}/audio/transcriptions",
  322. headers=headers,
  323. files=files,
  324. data=data,
  325. )
  326. r.raise_for_status()
  327. data = r.json()
  328. # save the transcript to a json file
  329. transcript_file = f"{file_dir}/{id}.json"
  330. with open(transcript_file, "w") as f:
  331. json.dump(data, f)
  332. print(data)
  333. return data
  334. except Exception as e:
  335. log.exception(e)
  336. error_detail = "Open WebUI: Server Connection Error"
  337. if r is not None:
  338. try:
  339. res = r.json()
  340. if "error" in res:
  341. error_detail = f"External: {res['error']['message']}"
  342. except:
  343. error_detail = f"External: {e}"
  344. raise HTTPException(
  345. status_code=r.status_code if r != None else 500,
  346. detail=error_detail,
  347. )
  348. except Exception as e:
  349. log.exception(e)
  350. raise HTTPException(
  351. status_code=status.HTTP_400_BAD_REQUEST,
  352. detail=ERROR_MESSAGES.DEFAULT(e),
  353. )
  354. def get_available_models() -> List[dict]:
  355. if app.state.config.TTS_ENGINE == "openai":
  356. return [{"id": "tts-1"}, {"id": "tts-1-hd"}]
  357. elif app.state.config.TTS_ENGINE == "elevenlabs":
  358. headers = {
  359. "xi-api-key": app.state.config.TTS_API_KEY,
  360. "Content-Type": "application/json",
  361. }
  362. try:
  363. response = requests.get(
  364. "https://api.elevenlabs.io/v1/models", headers=headers
  365. )
  366. response.raise_for_status()
  367. models = response.json()
  368. return [
  369. {"name": model["name"], "id": model["model_id"]} for model in models
  370. ]
  371. except requests.RequestException as e:
  372. log.error(f"Error fetching voices: {str(e)}")
  373. return []
  374. @app.get("/models")
  375. async def get_models(user=Depends(get_verified_user)):
  376. return {"models": get_available_models()}
  377. def get_available_voices() -> List[dict]:
  378. if app.state.config.TTS_ENGINE == "openai":
  379. return [
  380. {"name": "alloy", "id": "alloy"},
  381. {"name": "echo", "id": "echo"},
  382. {"name": "fable", "id": "fable"},
  383. {"name": "onyx", "id": "onyx"},
  384. {"name": "nova", "id": "nova"},
  385. {"name": "shimmer", "id": "shimmer"},
  386. ]
  387. elif app.state.config.TTS_ENGINE == "elevenlabs":
  388. headers = {
  389. "xi-api-key": app.state.config.TTS_API_KEY,
  390. "Content-Type": "application/json",
  391. }
  392. try:
  393. response = requests.get(
  394. "https://api.elevenlabs.io/v1/voices", headers=headers
  395. )
  396. response.raise_for_status()
  397. voices_data = response.json()
  398. voices = []
  399. for voice in voices_data.get("voices", []):
  400. voices.append({"name": voice["name"], "id": voice["voice_id"]})
  401. return voices
  402. except requests.RequestException as e:
  403. log.error(f"Error fetching voices: {str(e)}")
  404. return []
  405. @app.get("/voices")
  406. async def get_voices(user=Depends(get_verified_user)):
  407. return {"voices": get_available_voices()}