diff --git a/scraibe/cli.py b/scraibe/cli.py index eece1bb..ee40c8b 100644 --- a/scraibe/cli.py +++ b/scraibe/cli.py @@ -35,6 +35,10 @@ def cli(): parser.add_argument("-f", "--audio-files", nargs="+", type=str, default=None, help="List of audio files to transcribe.") + parser.add_argument("--whisper-type", type=str, default="whisper", + choices=["whisper", "whisperx"], + help="Type of Whisper model to use ('whisper' or 'whisperx').") + parser.add_argument("--whisper-model-name", default="medium", help="Name of the Whisper model to use.") @@ -92,8 +96,10 @@ def cli(): set_num_threads(arg_dict.pop("num_threads")) class_kwargs = {'whisper_model': arg_dict.pop("whisper_model_name"), + 'whisper_type':arg_dict.pop("whisper_type"), 'dia_model': arg_dict.pop("diarization_directory"), - 'use_auth_token': arg_dict.pop("hf_token")} + 'use_auth_token': arg_dict.pop("hf_token"), + } if arg_dict["whisper_model_directory"]: class_kwargs["download_root"] = arg_dict.pop("whisper_model_directory")