main.py 2.7 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106
  1. import os
  2. import logging
  3. from fastapi import (
  4. FastAPI,
  5. Request,
  6. Depends,
  7. HTTPException,
  8. status,
  9. UploadFile,
  10. File,
  11. Form,
  12. )
  13. from fastapi.middleware.cors import CORSMiddleware
  14. from faster_whisper import WhisperModel
  15. from constants import ERROR_MESSAGES
  16. from utils.utils import (
  17. decode_token,
  18. get_current_user,
  19. get_verified_user,
  20. get_admin_user,
  21. )
  22. from utils.misc import calculate_sha256
  23. from config import (
  24. SRC_LOG_LEVELS,
  25. CACHE_DIR,
  26. UPLOAD_DIR,
  27. WHISPER_MODEL,
  28. WHISPER_MODEL_DIR,
  29. WHISPER_MODEL_AUTO_UPDATE,
  30. DEVICE_TYPE,
  31. )
  32. log = logging.getLogger(__name__)
  33. log.setLevel(SRC_LOG_LEVELS["AUDIO"])
  34. app = FastAPI()
  35. app.add_middleware(
  36. CORSMiddleware,
  37. allow_origins=["*"],
  38. allow_credentials=True,
  39. allow_methods=["*"],
  40. allow_headers=["*"],
  41. )
  42. # setting device type for whisper model
  43. whisper_device_type = DEVICE_TYPE if DEVICE_TYPE and DEVICE_TYPE == "cuda" else "cpu"
  44. log.info(f"whisper_device_type: {whisper_device_type}")
  45. @app.post("/transcribe")
  46. def transcribe(
  47. file: UploadFile = File(...),
  48. user=Depends(get_current_user),
  49. ):
  50. log.info(f"file.content_type: {file.content_type}")
  51. if file.content_type not in ["audio/mpeg", "audio/wav"]:
  52. raise HTTPException(
  53. status_code=status.HTTP_400_BAD_REQUEST,
  54. detail=ERROR_MESSAGES.FILE_NOT_SUPPORTED,
  55. )
  56. try:
  57. filename = file.filename
  58. file_path = f"{UPLOAD_DIR}/{filename}"
  59. contents = file.file.read()
  60. with open(file_path, "wb") as f:
  61. f.write(contents)
  62. f.close()
  63. whisper_kwargs = {
  64. "model_size_or_path": WHISPER_MODEL,
  65. "device": whisper_device_type,
  66. "compute_type": "int8",
  67. "download_root": WHISPER_MODEL_DIR,
  68. "local_files_only": not WHISPER_MODEL_AUTO_UPDATE,
  69. }
  70. log.debug(f"whisper_kwargs: {whisper_kwargs}")
  71. try:
  72. model = WhisperModel(**whisper_kwargs)
  73. except:
  74. log.debug("WhisperModel initialization failed, attempting download with local_files_only=False")
  75. whisper_kwargs["local_files_only"] = False
  76. model = WhisperModel(**whisper_kwargs)
  77. segments, info = model.transcribe(file_path, beam_size=5)
  78. log.info(
  79. "Detected language '%s' with probability %f"
  80. % (info.language, info.language_probability)
  81. )
  82. transcript = "".join([segment.text for segment in list(segments)])
  83. return {"text": transcript.strip()}
  84. except Exception as e:
  85. log.exception(e)
  86. raise HTTPException(
  87. status_code=status.HTTP_400_BAD_REQUEST,
  88. detail=ERROR_MESSAGES.DEFAULT(e),
  89. )