From 5ec66effc2eba6939f0dd90e9cd2ab4d245358db Mon Sep 17 00:00:00 2001 From: "Schmieder, Jacob" Date: Thu, 30 May 2024 14:50:06 +0000 Subject: [PATCH] added whisper type to cli --- scraibe/cli.py | 8 +++++++- 1 file changed, 7 insertions(+), 1 deletion(-) diff --git a/scraibe/cli.py b/scraibe/cli.py index eece1bb..ee40c8b 100644 --- a/scraibe/cli.py +++ b/scraibe/cli.py @@ -35,6 +35,10 @@ def cli(): parser.add_argument("-f", "--audio-files", nargs="+", type=str, default=None, help="List of audio files to transcribe.") + parser.add_argument("--whisper-type", type=str, default="whisper", + choices=["whisper", "whisperx"], + help="Type of Whisper model to use ('whisper' or 'whisperx').") + parser.add_argument("--whisper-model-name", default="medium", help="Name of the Whisper model to use.") @@ -92,8 +96,10 @@ def cli(): set_num_threads(arg_dict.pop("num_threads")) class_kwargs = {'whisper_model': arg_dict.pop("whisper_model_name"), + 'whisper_type':arg_dict.pop("whisper_type"), 'dia_model': arg_dict.pop("diarization_directory"), - 'use_auth_token': arg_dict.pop("hf_token")} + 'use_auth_token': arg_dict.pop("hf_token"), + } if arg_dict["whisper_model_directory"]: class_kwargs["download_root"] = arg_dict.pop("whisper_model_directory")