main.py 11 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375
  1. import os
  2. import logging
  3. from fastapi import (
  4. FastAPI,
  5. Request,
  6. Depends,
  7. HTTPException,
  8. status,
  9. UploadFile,
  10. File,
  11. Form,
  12. )
  13. from fastapi.responses import StreamingResponse, JSONResponse, FileResponse
  14. from fastapi.middleware.cors import CORSMiddleware
  15. from pydantic import BaseModel
  16. import uuid
  17. import requests
  18. import hashlib
  19. from pathlib import Path
  20. import json
  21. from constants import ERROR_MESSAGES
  22. from utils.utils import (
  23. decode_token,
  24. get_current_user,
  25. get_verified_user,
  26. get_admin_user,
  27. )
  28. from utils.misc import calculate_sha256
  29. from config import (
  30. SRC_LOG_LEVELS,
  31. CACHE_DIR,
  32. UPLOAD_DIR,
  33. WHISPER_MODEL,
  34. WHISPER_MODEL_DIR,
  35. WHISPER_MODEL_AUTO_UPDATE,
  36. DEVICE_TYPE,
  37. AUDIO_STT_OPENAI_API_BASE_URL,
  38. AUDIO_STT_OPENAI_API_KEY,
  39. AUDIO_TTS_OPENAI_API_BASE_URL,
  40. AUDIO_TTS_OPENAI_API_KEY,
  41. AUDIO_STT_ENGINE,
  42. AUDIO_STT_MODEL,
  43. AUDIO_TTS_ENGINE,
  44. AUDIO_TTS_MODEL,
  45. AUDIO_TTS_VOICE,
  46. AppConfig,
  47. )
  48. log = logging.getLogger(__name__)
  49. log.setLevel(SRC_LOG_LEVELS["AUDIO"])
  50. app = FastAPI()
  51. app.add_middleware(
  52. CORSMiddleware,
  53. allow_origins=["*"],
  54. allow_credentials=True,
  55. allow_methods=["*"],
  56. allow_headers=["*"],
  57. )
  58. app.state.config = AppConfig()
  59. app.state.config.STT_OPENAI_API_BASE_URL = AUDIO_STT_OPENAI_API_BASE_URL
  60. app.state.config.STT_OPENAI_API_KEY = AUDIO_STT_OPENAI_API_KEY
  61. app.state.config.STT_ENGINE = AUDIO_STT_ENGINE
  62. app.state.config.STT_MODEL = AUDIO_STT_MODEL
  63. app.state.config.TTS_OPENAI_API_BASE_URL = AUDIO_TTS_OPENAI_API_BASE_URL
  64. app.state.config.TTS_OPENAI_API_KEY = AUDIO_TTS_OPENAI_API_KEY
  65. app.state.config.TTS_ENGINE = AUDIO_TTS_ENGINE
  66. app.state.config.TTS_MODEL = AUDIO_TTS_MODEL
  67. app.state.config.TTS_VOICE = AUDIO_TTS_VOICE
  68. # setting device type for whisper model
  69. whisper_device_type = DEVICE_TYPE if DEVICE_TYPE and DEVICE_TYPE == "cuda" else "cpu"
  70. log.info(f"whisper_device_type: {whisper_device_type}")
  71. SPEECH_CACHE_DIR = Path(CACHE_DIR).joinpath("./audio/speech/")
  72. SPEECH_CACHE_DIR.mkdir(parents=True, exist_ok=True)
  73. class TTSConfigForm(BaseModel):
  74. OPENAI_API_BASE_URL: str
  75. OPENAI_API_KEY: str
  76. ENGINE: str
  77. MODEL: str
  78. VOICE: str
  79. class STTConfigForm(BaseModel):
  80. OPENAI_API_BASE_URL: str
  81. OPENAI_API_KEY: str
  82. ENGINE: str
  83. MODEL: str
  84. class AudioConfigUpdateForm(BaseModel):
  85. tts: TTSConfigForm
  86. stt: STTConfigForm
  87. from pydub import AudioSegment
  88. from pydub.utils import mediainfo
  89. def is_mp4_audio(file_path):
  90. """Check if the given file is an MP4 audio file."""
  91. if not os.path.isfile(file_path):
  92. print(f"File not found: {file_path}")
  93. return False
  94. info = mediainfo(file_path)
  95. if (
  96. info.get("codec_name") == "aac"
  97. and info.get("codec_type") == "audio"
  98. and info.get("codec_tag_string") == "mp4a"
  99. ):
  100. return True
  101. return False
  102. def convert_mp4_to_wav(file_path, output_path):
  103. """Convert MP4 audio file to WAV format."""
  104. audio = AudioSegment.from_file(file_path, format="mp4")
  105. audio.export(output_path, format="wav")
  106. print(f"Converted {file_path} to {output_path}")
  107. @app.get("/config")
  108. async def get_audio_config(user=Depends(get_admin_user)):
  109. return {
  110. "tts": {
  111. "OPENAI_API_BASE_URL": app.state.config.TTS_OPENAI_API_BASE_URL,
  112. "OPENAI_API_KEY": app.state.config.TTS_OPENAI_API_KEY,
  113. "ENGINE": app.state.config.TTS_ENGINE,
  114. "MODEL": app.state.config.TTS_MODEL,
  115. "VOICE": app.state.config.TTS_VOICE,
  116. },
  117. "stt": {
  118. "OPENAI_API_BASE_URL": app.state.config.STT_OPENAI_API_BASE_URL,
  119. "OPENAI_API_KEY": app.state.config.STT_OPENAI_API_KEY,
  120. "ENGINE": app.state.config.STT_ENGINE,
  121. "MODEL": app.state.config.STT_MODEL,
  122. },
  123. }
  124. @app.post("/config/update")
  125. async def update_audio_config(
  126. form_data: AudioConfigUpdateForm, user=Depends(get_admin_user)
  127. ):
  128. app.state.config.TTS_OPENAI_API_BASE_URL = form_data.tts.OPENAI_API_BASE_URL
  129. app.state.config.TTS_OPENAI_API_KEY = form_data.tts.OPENAI_API_KEY
  130. app.state.config.TTS_ENGINE = form_data.tts.ENGINE
  131. app.state.config.TTS_MODEL = form_data.tts.MODEL
  132. app.state.config.TTS_VOICE = form_data.tts.VOICE
  133. app.state.config.STT_OPENAI_API_BASE_URL = form_data.stt.OPENAI_API_BASE_URL
  134. app.state.config.STT_OPENAI_API_KEY = form_data.stt.OPENAI_API_KEY
  135. app.state.config.STT_ENGINE = form_data.stt.ENGINE
  136. app.state.config.STT_MODEL = form_data.stt.MODEL
  137. return {
  138. "tts": {
  139. "OPENAI_API_BASE_URL": app.state.config.TTS_OPENAI_API_BASE_URL,
  140. "OPENAI_API_KEY": app.state.config.TTS_OPENAI_API_KEY,
  141. "ENGINE": app.state.config.TTS_ENGINE,
  142. "MODEL": app.state.config.TTS_MODEL,
  143. "VOICE": app.state.config.TTS_VOICE,
  144. },
  145. "stt": {
  146. "OPENAI_API_BASE_URL": app.state.config.STT_OPENAI_API_BASE_URL,
  147. "OPENAI_API_KEY": app.state.config.STT_OPENAI_API_KEY,
  148. "ENGINE": app.state.config.STT_ENGINE,
  149. "MODEL": app.state.config.STT_MODEL,
  150. },
  151. }
  152. @app.post("/speech")
  153. async def speech(request: Request, user=Depends(get_verified_user)):
  154. body = await request.body()
  155. name = hashlib.sha256(body).hexdigest()
  156. file_path = SPEECH_CACHE_DIR.joinpath(f"{name}.mp3")
  157. file_body_path = SPEECH_CACHE_DIR.joinpath(f"{name}.json")
  158. # Check if the file already exists in the cache
  159. if file_path.is_file():
  160. return FileResponse(file_path)
  161. headers = {}
  162. headers["Authorization"] = f"Bearer {app.state.config.TTS_OPENAI_API_KEY}"
  163. headers["Content-Type"] = "application/json"
  164. try:
  165. body = body.decode("utf-8")
  166. body = json.loads(body)
  167. body["model"] = app.state.config.TTS_MODEL
  168. body = json.dumps(body).encode("utf-8")
  169. except Exception as e:
  170. pass
  171. r = None
  172. try:
  173. r = requests.post(
  174. url=f"{app.state.config.TTS_OPENAI_API_BASE_URL}/audio/speech",
  175. data=body,
  176. headers=headers,
  177. stream=True,
  178. )
  179. r.raise_for_status()
  180. # Save the streaming content to a file
  181. with open(file_path, "wb") as f:
  182. for chunk in r.iter_content(chunk_size=8192):
  183. f.write(chunk)
  184. with open(file_body_path, "w") as f:
  185. json.dump(json.loads(body.decode("utf-8")), f)
  186. # Return the saved file
  187. return FileResponse(file_path)
  188. except Exception as e:
  189. log.exception(e)
  190. error_detail = "Open WebUI: Server Connection Error"
  191. if r is not None:
  192. try:
  193. res = r.json()
  194. if "error" in res:
  195. error_detail = f"External: {res['error']['message']}"
  196. except:
  197. error_detail = f"External: {e}"
  198. raise HTTPException(
  199. status_code=r.status_code if r != None else 500,
  200. detail=error_detail,
  201. )
  202. @app.post("/transcriptions")
  203. def transcribe(
  204. file: UploadFile = File(...),
  205. user=Depends(get_current_user),
  206. ):
  207. log.info(f"file.content_type: {file.content_type}")
  208. if file.content_type not in ["audio/mpeg", "audio/wav"]:
  209. raise HTTPException(
  210. status_code=status.HTTP_400_BAD_REQUEST,
  211. detail=ERROR_MESSAGES.FILE_NOT_SUPPORTED,
  212. )
  213. try:
  214. ext = file.filename.split(".")[-1]
  215. id = uuid.uuid4()
  216. filename = f"{id}.{ext}"
  217. file_dir = f"{CACHE_DIR}/audio/transcriptions"
  218. os.makedirs(file_dir, exist_ok=True)
  219. file_path = f"{file_dir}/{filename}"
  220. print(filename)
  221. contents = file.file.read()
  222. with open(file_path, "wb") as f:
  223. f.write(contents)
  224. f.close()
  225. if app.state.config.STT_ENGINE == "":
  226. from faster_whisper import WhisperModel
  227. whisper_kwargs = {
  228. "model_size_or_path": WHISPER_MODEL,
  229. "device": whisper_device_type,
  230. "compute_type": "int8",
  231. "download_root": WHISPER_MODEL_DIR,
  232. "local_files_only": not WHISPER_MODEL_AUTO_UPDATE,
  233. }
  234. log.debug(f"whisper_kwargs: {whisper_kwargs}")
  235. try:
  236. model = WhisperModel(**whisper_kwargs)
  237. except:
  238. log.warning(
  239. "WhisperModel initialization failed, attempting download with local_files_only=False"
  240. )
  241. whisper_kwargs["local_files_only"] = False
  242. model = WhisperModel(**whisper_kwargs)
  243. segments, info = model.transcribe(file_path, beam_size=5)
  244. log.info(
  245. "Detected language '%s' with probability %f"
  246. % (info.language, info.language_probability)
  247. )
  248. transcript = "".join([segment.text for segment in list(segments)])
  249. data = {"text": transcript.strip()}
  250. # save the transcript to a json file
  251. transcript_file = f"{file_dir}/{id}.json"
  252. with open(transcript_file, "w") as f:
  253. json.dump(data, f)
  254. print(data)
  255. return data
  256. elif app.state.config.STT_ENGINE == "openai":
  257. if is_mp4_audio(file_path):
  258. print("is_mp4_audio")
  259. os.rename(file_path, file_path.replace(".wav", ".mp4"))
  260. # Convert MP4 audio file to WAV format
  261. convert_mp4_to_wav(file_path.replace(".wav", ".mp4"), file_path)
  262. headers = {"Authorization": f"Bearer {app.state.config.STT_OPENAI_API_KEY}"}
  263. files = {"file": (filename, open(file_path, "rb"))}
  264. data = {"model": app.state.config.STT_MODEL}
  265. print(files, data)
  266. r = None
  267. try:
  268. r = requests.post(
  269. url=f"{app.state.config.STT_OPENAI_API_BASE_URL}/audio/transcriptions",
  270. headers=headers,
  271. files=files,
  272. data=data,
  273. )
  274. r.raise_for_status()
  275. data = r.json()
  276. # save the transcript to a json file
  277. transcript_file = f"{file_dir}/{id}.json"
  278. with open(transcript_file, "w") as f:
  279. json.dump(data, f)
  280. print(data)
  281. return data
  282. except Exception as e:
  283. log.exception(e)
  284. error_detail = "Open WebUI: Server Connection Error"
  285. if r is not None:
  286. try:
  287. res = r.json()
  288. if "error" in res:
  289. error_detail = f"External: {res['error']['message']}"
  290. except:
  291. error_detail = f"External: {e}"
  292. raise HTTPException(
  293. status_code=r.status_code if r != None else 500,
  294. detail=error_detail,
  295. )
  296. except Exception as e:
  297. log.exception(e)
  298. raise HTTPException(
  299. status_code=status.HTTP_400_BAD_REQUEST,
  300. detail=ERROR_MESSAGES.DEFAULT(e),
  301. )