From 4aecc63c3ab30503430b2b3c54d046c854805360 Mon Sep 17 00:00:00 2001 From: Jaikinator Date: Fri, 1 Sep 2023 12:58:38 +0200 Subject: [PATCH] bug fixes --- autotranscript/cli.py | 46 +++++++++++++++++++++---------------------- 1 file changed, 22 insertions(+), 24 deletions(-) diff --git a/autotranscript/cli.py b/autotranscript/cli.py index e4c8e45..5fa0774 100644 --- a/autotranscript/cli.py +++ b/autotranscript/cli.py @@ -8,12 +8,8 @@ import os from argparse import ArgumentParser, ArgumentDefaultsHelpFormatter import json -from .transcriber import WHISPER_DEFAULT_PATH -from .diarisation import PYANNOTE_DEFAULT_PATH from .autotranscript import AutoTranscribe -from whisper import available_models -from whisper.utils import get_writer from whisper.tokenizer import LANGUAGES , TO_LANGUAGE_CODE from torch.cuda import is_available from torch import set_num_threads @@ -47,10 +43,10 @@ def cli(): help='Start the Gradio app.') parser.add_argument("--port", type=int, default= None, - help="Port to run the Gradio app on.") + help="Port to run the Gradio app on. Defaults to 7860.") - parser.add_argument("--server_name", type=str, default= "autotranscript", - help="Name of the Gradio app.") + parser.add_argument("--server_name", type=str, default= None, + help="Name of the Gradio app. If empty 127.0.0.1 or 0.0.0.0 will be used.") parser.add_argument("--whisper_model_name", default="medium", help="Name of the Whisper model to use.") @@ -84,10 +80,11 @@ def cli(): parser.add_argument("--verbose_output", type=str2bool, default=True, help="Enable or disable progress and debug messages.") - parser.add_argument("--task", type=str, default= None, # unifinished code - choices=["autotranscribe", "diarization", "autotranscribe+translate", "translate"], + parser.add_argument("--task", type=str, default= 'autotranscribe', # unifinished code + choices=["autotranscribe", "diarization", + "autotranscribe+translate", "translate", 'transcribe'], help="Choose to perform transcription, diarization, or translation. \ - If set to translate, the language argument must be specified.") + If set to translate, the output will be translated to English.") parser.add_argument("--language", type=str, default=None, choices=sorted(LANGUAGES.keys()) + sorted([k.title() for k in TO_LANGUAGE_CODE.keys()]), @@ -96,7 +93,7 @@ def cli(): args = parser.parse_args() arg_dict = vars(args) - + # configure output out_folder = arg_dict.pop("output_directory") os.makedirs(out_folder, exist_ok=True) @@ -119,9 +116,10 @@ def cli(): model = AutoTranscribe(**class_kwargs) + if arg_dict["audio_files"]: - audio_files = args.pop("audio_files") + audio_files = arg_dict.pop("audio_files") if task == "autotranscribe" or task == "autotranscribe+translate": for audio in audio_files: @@ -132,6 +130,7 @@ def cli(): out = model.autotranscribe(audio,task = task, language=arg_dict.pop("language"), verbose = arg_dict.pop("verbose_output")) basename = audio.split("/")[-1].split(".")[0] + print(f'Saving {basename}.{out_format} to {out_folder}') out.save(os.path.join(out_folder, f"{basename}.{out_format}")) elif task == "diarization": @@ -142,20 +141,18 @@ def cli(): out = model.diarization(audio) basename = audio.split("/")[-1].split(".")[0] path = os.path.join(out_folder, f"{basename}.{out_format}") - if out_format == "txt": - with open(path, "w") as f: - f.write(out) - elif out_format == "json": - with open(path, "w") as f: - json.dump(json.dumps(out, indent= 3), f) - else: - raise ValueError(f"Unsupported output format for diarization{out_format}.") + + print(f'Saving {basename}.{out_format} to {out_folder}') + + with open(path, "w") as f: + json.dump(json.dumps(out, indent= 1), f) + elif task == "transcribe" or task == "translate": for audio in audio_files: - + out = model.transcribe(audio, task = task, - language=arg_dict.pop("language"), + language= arg_dict.pop("language"), verbose = arg_dict.pop("verbose_output")) basename = audio.split("/")[-1].split(".")[0] path = os.path.join(out_folder, f"{basename}.{out_format}") @@ -164,8 +161,9 @@ def cli(): if start_server: # unfinished code - from .gradio_app import gradio_app - gradio_app(model) + from .app.gradio_app import gradio_Interface + gradio_Interface(model).queue().launch(server_port=args.port, server_name=args.server_name) + if __name__ == "__main__": cli() \ No newline at end of file