main.py 2.0 KB

1234567891011121314151617181920212223242526272829303132333435363738394041424344454647484950515253545556575859606162636465666768697071727374757677787980818283
  1. import os
  2. from fastapi import (
  3. FastAPI,
  4. Request,
  5. Depends,
  6. HTTPException,
  7. status,
  8. UploadFile,
  9. File,
  10. Form,
  11. )
  12. from fastapi.middleware.cors import CORSMiddleware
  13. from faster_whisper import WhisperModel
  14. from constants import ERROR_MESSAGES
  15. from utils.utils import (
  16. decode_token,
  17. get_current_user,
  18. get_verified_user,
  19. get_admin_user,
  20. )
  21. from utils.misc import calculate_sha256
  22. from config import CACHE_DIR, UPLOAD_DIR, WHISPER_MODEL_NAME
  23. app = FastAPI()
  24. app.add_middleware(
  25. CORSMiddleware,
  26. allow_origins=["*"],
  27. allow_credentials=True,
  28. allow_methods=["*"],
  29. allow_headers=["*"],
  30. )
  31. @app.post("/transcribe")
  32. def transcribe(
  33. file: UploadFile = File(...),
  34. user=Depends(get_current_user),
  35. ):
  36. print(file.content_type)
  37. if file.content_type not in ["audio/mpeg", "audio/wav"]:
  38. raise HTTPException(
  39. status_code=status.HTTP_400_BAD_REQUEST,
  40. detail=ERROR_MESSAGES.FILE_NOT_SUPPORTED,
  41. )
  42. try:
  43. filename = file.filename
  44. file_path = f"{UPLOAD_DIR}/{filename}"
  45. contents = file.file.read()
  46. with open(file_path, "wb") as f:
  47. f.write(contents)
  48. f.close()
  49. model_name = os.getenv('WHISPER_MODEL', WHISPER_MODEL_NAME)
  50. download_root = os.getenv('WHISPER_DIR', f"{CACHE_DIR}/whisper/models")
  51. model = WhisperModel(
  52. model_name,
  53. device="cpu",
  54. compute_type="int8",
  55. download_root=download_root,
  56. )
  57. segments, info = model.transcribe(file_path, beam_size=5)
  58. print(
  59. "Detected language '%s' with probability %f"
  60. % (info.language, info.language_probability)
  61. )
  62. transcript = "".join([segment.text for segment in list(segments)])
  63. return {"text": transcript.strip()}
  64. except Exception as e:
  65. print(e)
  66. raise HTTPException(
  67. status_code=status.HTTP_400_BAD_REQUEST,
  68. detail=ERROR_MESSAGES.DEFAULT(e),
  69. )