main.py 14 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478
  1. import os
  2. import logging
  3. from fastapi import (
  4. FastAPI,
  5. Request,
  6. Depends,
  7. HTTPException,
  8. status,
  9. UploadFile,
  10. File,
  11. Form,
  12. )
  13. from fastapi.responses import StreamingResponse, JSONResponse, FileResponse
  14. from fastapi.middleware.cors import CORSMiddleware
  15. from pydantic import BaseModel
  16. import uuid
  17. import requests
  18. import hashlib
  19. from pathlib import Path
  20. import json
  21. from constants import ERROR_MESSAGES
  22. from utils.utils import (
  23. decode_token,
  24. get_current_user,
  25. get_verified_user,
  26. get_admin_user,
  27. )
  28. from utils.misc import calculate_sha256
  29. from config import (
  30. SRC_LOG_LEVELS,
  31. CACHE_DIR,
  32. UPLOAD_DIR,
  33. WHISPER_MODEL,
  34. WHISPER_MODEL_DIR,
  35. WHISPER_MODEL_AUTO_UPDATE,
  36. DEVICE_TYPE,
  37. AUDIO_STT_OPENAI_API_BASE_URL,
  38. AUDIO_STT_OPENAI_API_KEY,
  39. AUDIO_TTS_OPENAI_API_BASE_URL,
  40. AUDIO_TTS_OPENAI_API_KEY,
  41. AUDIO_TTS_API_KEY,
  42. AUDIO_STT_ENGINE,
  43. AUDIO_STT_MODEL,
  44. AUDIO_TTS_ENGINE,
  45. AUDIO_TTS_MODEL,
  46. AUDIO_TTS_VOICE,
  47. AppConfig,
  48. )
  49. log = logging.getLogger(__name__)
  50. log.setLevel(SRC_LOG_LEVELS["AUDIO"])
  51. app = FastAPI()
  52. app.add_middleware(
  53. CORSMiddleware,
  54. allow_origins=["*"],
  55. allow_credentials=True,
  56. allow_methods=["*"],
  57. allow_headers=["*"],
  58. )
  59. app.state.config = AppConfig()
  60. app.state.config.STT_OPENAI_API_BASE_URL = AUDIO_STT_OPENAI_API_BASE_URL
  61. app.state.config.STT_OPENAI_API_KEY = AUDIO_STT_OPENAI_API_KEY
  62. app.state.config.STT_ENGINE = AUDIO_STT_ENGINE
  63. app.state.config.STT_MODEL = AUDIO_STT_MODEL
  64. app.state.config.TTS_OPENAI_API_BASE_URL = AUDIO_TTS_OPENAI_API_BASE_URL
  65. app.state.config.TTS_OPENAI_API_KEY = AUDIO_TTS_OPENAI_API_KEY
  66. app.state.config.TTS_ENGINE = AUDIO_TTS_ENGINE
  67. app.state.config.TTS_MODEL = AUDIO_TTS_MODEL
  68. app.state.config.TTS_VOICE = AUDIO_TTS_VOICE
  69. app.state.config.TTS_API_KEY = AUDIO_TTS_API_KEY
  70. # setting device type for whisper model
  71. whisper_device_type = DEVICE_TYPE if DEVICE_TYPE and DEVICE_TYPE == "cuda" else "cpu"
  72. log.info(f"whisper_device_type: {whisper_device_type}")
  73. SPEECH_CACHE_DIR = Path(CACHE_DIR).joinpath("./audio/speech/")
  74. SPEECH_CACHE_DIR.mkdir(parents=True, exist_ok=True)
  75. class TTSConfigForm(BaseModel):
  76. OPENAI_API_BASE_URL: str
  77. OPENAI_API_KEY: str
  78. API_KEY: str
  79. ENGINE: str
  80. MODEL: str
  81. VOICE: str
  82. class STTConfigForm(BaseModel):
  83. OPENAI_API_BASE_URL: str
  84. OPENAI_API_KEY: str
  85. ENGINE: str
  86. MODEL: str
  87. class AudioConfigUpdateForm(BaseModel):
  88. tts: TTSConfigForm
  89. stt: STTConfigForm
  90. from pydub import AudioSegment
  91. from pydub.utils import mediainfo
  92. def is_mp4_audio(file_path):
  93. """Check if the given file is an MP4 audio file."""
  94. if not os.path.isfile(file_path):
  95. print(f"File not found: {file_path}")
  96. return False
  97. info = mediainfo(file_path)
  98. if (
  99. info.get("codec_name") == "aac"
  100. and info.get("codec_type") == "audio"
  101. and info.get("codec_tag_string") == "mp4a"
  102. ):
  103. return True
  104. return False
  105. def convert_mp4_to_wav(file_path, output_path):
  106. """Convert MP4 audio file to WAV format."""
  107. audio = AudioSegment.from_file(file_path, format="mp4")
  108. audio.export(output_path, format="wav")
  109. print(f"Converted {file_path} to {output_path}")
  110. async def fetch_available_voices():
  111. if app.state.config.TTS_ENGINE != "elevenlabs":
  112. return {}
  113. base_url = "https://api.elevenlabs.io/v1"
  114. headers = {
  115. "xi-api-key": app.state.config.TTS_API_KEY,
  116. "Content-Type": "application/json",
  117. }
  118. voices_url = f"{base_url}/voices"
  119. try:
  120. response = requests.get(voices_url, headers=headers)
  121. response.raise_for_status()
  122. voices_data = response.json()
  123. voice_options = {}
  124. for voice in voices_data.get("voices", []):
  125. voice_name = voice["name"]
  126. voice_id = voice["voice_id"]
  127. voice_options[voice_name] = voice_id
  128. return voice_options
  129. except requests.RequestException as e:
  130. log.error(f"Error fetching voices: {str(e)}")
  131. return {}
  132. @app.get("/config")
  133. async def get_audio_config(user=Depends(get_admin_user)):
  134. return {
  135. "tts": {
  136. "OPENAI_API_BASE_URL": app.state.config.TTS_OPENAI_API_BASE_URL,
  137. "OPENAI_API_KEY": app.state.config.TTS_OPENAI_API_KEY,
  138. "API_KEY": app.state.config.TTS_API_KEY,
  139. "ENGINE": app.state.config.TTS_ENGINE,
  140. "MODEL": app.state.config.TTS_MODEL,
  141. "VOICE": app.state.config.TTS_VOICE,
  142. },
  143. "stt": {
  144. "OPENAI_API_BASE_URL": app.state.config.STT_OPENAI_API_BASE_URL,
  145. "OPENAI_API_KEY": app.state.config.STT_OPENAI_API_KEY,
  146. "ENGINE": app.state.config.STT_ENGINE,
  147. "MODEL": app.state.config.STT_MODEL,
  148. },
  149. }
  150. @app.post("/config/update")
  151. async def update_audio_config(
  152. form_data: AudioConfigUpdateForm, user=Depends(get_admin_user)
  153. ):
  154. app.state.config.TTS_OPENAI_API_BASE_URL = form_data.tts.OPENAI_API_BASE_URL
  155. app.state.config.TTS_OPENAI_API_KEY = form_data.tts.OPENAI_API_KEY
  156. app.state.config.TTS_API_KEY = form_data.tts.API_KEY
  157. app.state.config.TTS_ENGINE = form_data.tts.ENGINE
  158. app.state.config.TTS_MODEL = form_data.tts.MODEL
  159. app.state.config.TTS_VOICE = form_data.tts.VOICE
  160. app.state.config.STT_OPENAI_API_BASE_URL = form_data.stt.OPENAI_API_BASE_URL
  161. app.state.config.STT_OPENAI_API_KEY = form_data.stt.OPENAI_API_KEY
  162. app.state.config.STT_ENGINE = form_data.stt.ENGINE
  163. app.state.config.STT_MODEL = form_data.stt.MODEL
  164. return {
  165. "tts": {
  166. "OPENAI_API_BASE_URL": app.state.config.TTS_OPENAI_API_BASE_URL,
  167. "OPENAI_API_KEY": app.state.config.TTS_OPENAI_API_KEY,
  168. "API_KEY": app.state.config.TTS_API_KEY,
  169. "ENGINE": app.state.config.TTS_ENGINE,
  170. "MODEL": app.state.config.TTS_MODEL,
  171. "VOICE": app.state.config.TTS_VOICE,
  172. },
  173. "stt": {
  174. "OPENAI_API_BASE_URL": app.state.config.STT_OPENAI_API_BASE_URL,
  175. "OPENAI_API_KEY": app.state.config.STT_OPENAI_API_KEY,
  176. "ENGINE": app.state.config.STT_ENGINE,
  177. "MODEL": app.state.config.STT_MODEL,
  178. },
  179. }
  180. @app.post("/speech")
  181. async def speech(request: Request, user=Depends(get_verified_user)):
  182. body = await request.body()
  183. name = hashlib.sha256(body).hexdigest()
  184. file_path = SPEECH_CACHE_DIR.joinpath(f"{name}.mp3")
  185. file_body_path = SPEECH_CACHE_DIR.joinpath(f"{name}.json")
  186. # Check if the file already exists in the cache
  187. if file_path.is_file():
  188. return FileResponse(file_path)
  189. if app.state.config.TTS_ENGINE == "openai":
  190. headers = {}
  191. headers["Authorization"] = f"Bearer {app.state.config.TTS_OPENAI_API_KEY}"
  192. headers["Content-Type"] = "application/json"
  193. try:
  194. body = body.decode("utf-8")
  195. body = json.loads(body)
  196. body["model"] = app.state.config.TTS_MODEL
  197. body = json.dumps(body).encode("utf-8")
  198. except Exception as e:
  199. pass
  200. r = None
  201. try:
  202. r = requests.post(
  203. url=f"{app.state.config.TTS_OPENAI_API_BASE_URL}/audio/speech",
  204. data=body,
  205. headers=headers,
  206. stream=True,
  207. )
  208. r.raise_for_status()
  209. # Save the streaming content to a file
  210. with open(file_path, "wb") as f:
  211. for chunk in r.iter_content(chunk_size=8192):
  212. f.write(chunk)
  213. with open(file_body_path, "w") as f:
  214. json.dump(json.loads(body.decode("utf-8")), f)
  215. # Return the saved file
  216. return FileResponse(file_path)
  217. except Exception as e:
  218. log.exception(e)
  219. error_detail = "Open WebUI: Server Connection Error"
  220. if r is not None:
  221. try:
  222. res = r.json()
  223. if "error" in res:
  224. error_detail = f"External: {res['error']['message']}"
  225. except:
  226. error_detail = f"External: {e}"
  227. raise HTTPException(
  228. status_code=r.status_code if r != None else 500,
  229. detail=error_detail,
  230. )
  231. elif app.state.config.TTS_ENGINE == "elevenlabs":
  232. payload = None
  233. try:
  234. payload = json.loads(body.decode("utf-8"))
  235. except Exception as e:
  236. log.exception(e)
  237. raise HTTPException(status_code=400, detail="Invalid JSON payload")
  238. voice_options = await fetch_available_voices()
  239. voice_id = voice_options.get(payload['voice'])
  240. if not voice_id:
  241. raise HTTPException(status_code=400, detail="Invalid voice name")
  242. url = f"https://api.elevenlabs.io/v1/text-to-speech/{voice_id}"
  243. headers = {
  244. "Accept": "audio/mpeg",
  245. "Content-Type": "application/json",
  246. "xi-api-key": app.state.config.TTS_API_KEY,
  247. }
  248. data = {
  249. "text": payload["input"],
  250. "model_id": app.state.config.TTS_MODEL,
  251. "voice_settings": {"stability": 0.5, "similarity_boost": 0.5},
  252. }
  253. try:
  254. r = requests.post(url, json=data, headers=headers)
  255. r.raise_for_status()
  256. # Save the streaming content to a file
  257. with open(file_path, "wb") as f:
  258. for chunk in r.iter_content(chunk_size=8192):
  259. f.write(chunk)
  260. with open(file_body_path, "w") as f:
  261. json.dump(json.loads(body.decode("utf-8")), f)
  262. # Return the saved file
  263. return FileResponse(file_path)
  264. except Exception as e:
  265. log.exception(e)
  266. error_detail = "Open WebUI: Server Connection Error"
  267. if r is not None:
  268. try:
  269. res = r.json()
  270. if "error" in res:
  271. error_detail = f"External: {res['error']['message']}"
  272. except:
  273. error_detail = f"External: {e}"
  274. raise HTTPException(
  275. status_code=r.status_code if r != None else 500,
  276. detail=error_detail,
  277. )
  278. @app.post("/transcriptions")
  279. def transcribe(
  280. file: UploadFile = File(...),
  281. user=Depends(get_current_user),
  282. ):
  283. log.info(f"file.content_type: {file.content_type}")
  284. if file.content_type not in ["audio/mpeg", "audio/wav"]:
  285. raise HTTPException(
  286. status_code=status.HTTP_400_BAD_REQUEST,
  287. detail=ERROR_MESSAGES.FILE_NOT_SUPPORTED,
  288. )
  289. try:
  290. ext = file.filename.split(".")[-1]
  291. id = uuid.uuid4()
  292. filename = f"{id}.{ext}"
  293. file_dir = f"{CACHE_DIR}/audio/transcriptions"
  294. os.makedirs(file_dir, exist_ok=True)
  295. file_path = f"{file_dir}/{filename}"
  296. print(filename)
  297. contents = file.file.read()
  298. with open(file_path, "wb") as f:
  299. f.write(contents)
  300. f.close()
  301. if app.state.config.STT_ENGINE == "":
  302. from faster_whisper import WhisperModel
  303. whisper_kwargs = {
  304. "model_size_or_path": WHISPER_MODEL,
  305. "device": whisper_device_type,
  306. "compute_type": "int8",
  307. "download_root": WHISPER_MODEL_DIR,
  308. "local_files_only": not WHISPER_MODEL_AUTO_UPDATE,
  309. }
  310. log.debug(f"whisper_kwargs: {whisper_kwargs}")
  311. try:
  312. model = WhisperModel(**whisper_kwargs)
  313. except:
  314. log.warning(
  315. "WhisperModel initialization failed, attempting download with local_files_only=False"
  316. )
  317. whisper_kwargs["local_files_only"] = False
  318. model = WhisperModel(**whisper_kwargs)
  319. segments, info = model.transcribe(file_path, beam_size=5)
  320. log.info(
  321. "Detected language '%s' with probability %f"
  322. % (info.language, info.language_probability)
  323. )
  324. transcript = "".join([segment.text for segment in list(segments)])
  325. data = {"text": transcript.strip()}
  326. # save the transcript to a json file
  327. transcript_file = f"{file_dir}/{id}.json"
  328. with open(transcript_file, "w") as f:
  329. json.dump(data, f)
  330. print(data)
  331. return data
  332. elif app.state.config.STT_ENGINE == "openai":
  333. if is_mp4_audio(file_path):
  334. print("is_mp4_audio")
  335. os.rename(file_path, file_path.replace(".wav", ".mp4"))
  336. # Convert MP4 audio file to WAV format
  337. convert_mp4_to_wav(file_path.replace(".wav", ".mp4"), file_path)
  338. headers = {"Authorization": f"Bearer {app.state.config.STT_OPENAI_API_KEY}"}
  339. files = {"file": (filename, open(file_path, "rb"))}
  340. data = {"model": app.state.config.STT_MODEL}
  341. print(files, data)
  342. r = None
  343. try:
  344. r = requests.post(
  345. url=f"{app.state.config.STT_OPENAI_API_BASE_URL}/audio/transcriptions",
  346. headers=headers,
  347. files=files,
  348. data=data,
  349. )
  350. r.raise_for_status()
  351. data = r.json()
  352. # save the transcript to a json file
  353. transcript_file = f"{file_dir}/{id}.json"
  354. with open(transcript_file, "w") as f:
  355. json.dump(data, f)
  356. print(data)
  357. return data
  358. except Exception as e:
  359. log.exception(e)
  360. error_detail = "Open WebUI: Server Connection Error"
  361. if r is not None:
  362. try:
  363. res = r.json()
  364. if "error" in res:
  365. error_detail = f"External: {res['error']['message']}"
  366. except:
  367. error_detail = f"External: {e}"
  368. raise HTTPException(
  369. status_code=r.status_code if r != None else 500,
  370. detail=error_detail,
  371. )
  372. except Exception as e:
  373. log.exception(e)
  374. raise HTTPException(
  375. status_code=status.HTTP_400_BAD_REQUEST,
  376. detail=ERROR_MESSAGES.DEFAULT(e),
  377. )
  378. @app.get("/voices")
  379. async def get_voices(user=Depends(get_verified_user)):
  380. voices = await fetch_available_voices()
  381. return {"voices": list(voices.keys())}