diff --git a/scraibe/cli.py b/scraibe/cli.py index b6f2c17..ee40c8b 100644 --- a/scraibe/cli.py +++ b/scraibe/cli.py @@ -5,17 +5,12 @@ The function includes arguments for specifying the audio files, model paths, output formats, and other options necessary for transcription. """ import os -from argparse import ArgumentParser, ArgumentDefaultsHelpFormatter import json - -from .autotranscript import Scraibe -from .misc import ParseKwargs - - +from argparse import ArgumentParser, ArgumentDefaultsHelpFormatter from whisper.tokenizer import LANGUAGES, TO_LANGUAGE_CODE from torch.cuda import is_available from torch import set_num_threads - +from .autotranscript import Scraibe def cli(): """ @@ -37,22 +32,13 @@ def cli(): parser = ArgumentParser(formatter_class=ArgumentDefaultsHelpFormatter) - group = parser.add_mutually_exclusive_group() - parser.add_argument("-f", "--audio-files", nargs="+", type=str, default=None, help="List of audio files to transcribe.") - group.add_argument('--start-server', action='store_true', - help='Start the Gradio app.' - 'If set, all other arguments are ignored' - 'besides --server-config or --server-kwargs.') - - parser.add_argument("--server-config", type=str, default=None, - help="Path to the configy.yml file.") - - parser.add_argument('--server-kwargs', nargs='*', action=ParseKwargs, default={}, - help='Keyword arguments for the Gradio app.') - + parser.add_argument("--whisper-type", type=str, default="whisper", + choices=["whisper", "whisperx"], + help="Type of Whisper model to use ('whisper' or 'whisperx').") + parser.add_argument("--whisper-model-name", default="medium", help="Name of the Whisper model to use.") @@ -83,7 +69,7 @@ def cli(): parser.add_argument("--verbose-output", type=str2bool, default=True, help="Enable or disable progress and debug messages.") - parser.add_argument("--task", type=str, default='autotranscribe', # unifinished code + parser.add_argument("--task", type=str, default='autotranscribe', choices=["autotranscribe", "diarization", "autotranscribe+translate", "translate", 'transcribe'], help="Choose to perform transcription, diarization, or translation. \ @@ -104,91 +90,65 @@ def cli(): out_format = arg_dict.pop("output_format") - # seup server arg: - start_server = arg_dict.pop("start_server") - task = arg_dict.pop("task") if args.num_threads > 0: set_num_threads(arg_dict.pop("num_threads")) class_kwargs = {'whisper_model': arg_dict.pop("whisper_model_name"), + 'whisper_type':arg_dict.pop("whisper_type"), 'dia_model': arg_dict.pop("diarization_directory"), - 'use_auth_token': arg_dict.pop("hf_token")} + 'use_auth_token': arg_dict.pop("hf_token"), + } if arg_dict["whisper_model_directory"]: class_kwargs["download_root"] = arg_dict.pop("whisper_model_directory") + - if not start_server: + model = Scraibe(**class_kwargs) - model = Scraibe(**class_kwargs) + if arg_dict["audio_files"]: + audio_files = arg_dict.pop("audio_files") - if arg_dict["audio_files"]: - audio_files = arg_dict.pop("audio_files") + if task == "autotranscribe" or task == "autotranscribe+translate": + for audio in audio_files: + if task == "autotranscribe+translate": + task = "translate" + else: + task = "transcribe" - if task == "autotranscribe" or task == "autotranscribe+translate": - for audio in audio_files: - if task == "autotranscribe+translate": - task = "translate" - else: - task = "transcribe" + out = model.autotranscribe(audio, task=task, language=arg_dict.pop( + "language"), verbose=arg_dict.pop("verbose_output")) + basename = audio.split("/")[-1].split(".")[0] + print(f'Saving {basename}.{out_format} to {out_folder}') + out.save(os.path.join( + out_folder, f"{basename}.{out_format}")) - out = model.autotranscribe(audio, task=task, language=arg_dict.pop( - "language"), verbose=arg_dict.pop("verbose_output")) - basename = audio.split("/")[-1].split(".")[0] - print(f'Saving {basename}.{out_format} to {out_folder}') - out.save(os.path.join( - out_folder, f"{basename}.{out_format}")) + elif task == "diarization": + for audio in audio_files: + if arg_dict.pop("verbose_output"): + print("Verbose not implemented for diarization.") - elif task == "diarization": - for audio in audio_files: - if arg_dict.pop("verbose_output"): - print("Verbose not implemented for diarization.") + out = model.diarization(audio) + basename = audio.split("/")[-1].split(".")[0] + path = os.path.join(out_folder, f"{basename}.{out_format}") - out = model.diarization(audio) - basename = audio.split("/")[-1].split(".")[0] - path = os.path.join(out_folder, f"{basename}.{out_format}") + print(f'Saving {basename}.{out_format} to {out_folder}') - print(f'Saving {basename}.{out_format} to {out_folder}') + with open(path, "w") as f: + json.dump(json.dumps(out, indent=1), f) - with open(path, "w") as f: - json.dump(json.dumps(out, indent=1), f) + elif task == "transcribe" or task == "translate": - elif task == "transcribe" or task == "translate": - - for audio in audio_files: - - out = model.transcribe(audio, task=task, - language=arg_dict.pop("language"), - verbose=arg_dict.pop("verbose_output")) - basename = audio.split("/")[-1].split(".")[0] - path = os.path.join(out_folder, f"{basename}.{out_format}") - with open(path, "w") as f: - f.write(out) - - else: # unfinished code - raise NotImplementedError("Currently not Working") - import subprocess - import sys - - execute_path = os.path.join( - os.path.dirname(__file__), "app/app_starter.py") - - config = arg_dict.pop("server_config") - server_kwargs = arg_dict.pop("server_kwargs") - - if not config: - subprocess.run([sys.executable, execute_path, - f"--server-kwargs={server_kwargs}"]) - elif not server_kwargs: - subprocess.run([sys.executable, execute_path, - f"--server-config={config}"]) - elif not config and not server_kwargs: - subprocess.run([sys.executable, execute_path]) - else: - subprocess.run([sys.executable, execute_path, - f"--server-config={config}", f"--server-kwargs={server_kwargs}"]) + for audio in audio_files: + out = model.transcribe(audio, task=task, + language=arg_dict.pop("language"), + verbose=arg_dict.pop("verbose_output")) + basename = audio.split("/")[-1].split(".")[0] + path = os.path.join(out_folder, f"{basename}.{out_format}") + with open(path, "w") as f: + f.write(out) if __name__ == "__main__": cli()