diff --git a/autotranscript/cli.py b/autotranscript/cli.py index 1507f3a..e4c8e45 100644 --- a/autotranscript/cli.py +++ b/autotranscript/cli.py @@ -6,7 +6,7 @@ output formats, and other options necessary for transcription. """ import os from argparse import ArgumentParser, ArgumentDefaultsHelpFormatter -from turtle import st +import json from .transcriber import WHISPER_DEFAULT_PATH from .diarisation import PYANNOTE_DEFAULT_PATH @@ -85,7 +85,7 @@ def cli(): help="Enable or disable progress and debug messages.") parser.add_argument("--task", type=str, default= None, # unifinished code - choices=["autoranscribe", "diarize", "autotranscribe+translate", "translate"], + choices=["autotranscribe", "diarization", "autotranscribe+translate", "translate"], help="Choose to perform transcription, diarization, or translation. \ If set to translate, the language argument must be specified.") @@ -98,14 +98,15 @@ def cli(): arg_dict = vars(args) # configure output - - os.makedirs(arg_dict.pop("output_directory"), exist_ok=True) + out_folder = arg_dict.pop("output_directory") + os.makedirs(out_folder, exist_ok=True) out_format = arg_dict.pop("output_format") # seup server arg: start_server = arg_dict.pop("start_server") + task = arg_dict.pop("task") if args.num_threads > 0: set_num_threads(arg_dict.pop("num_threads")) @@ -115,29 +116,56 @@ def cli(): for k, v in arg_dict.items(): if v is not None: class_kwargs[k] = v - model = AutoTranscribe(**class_kwargs) - # if transcription_task == "transcribe": - # for audio in audio_files: - # out = model.transcribe(audio, language=spoken_language) - # basename = audio.split("/")[-1].split(".")[0] - # spath = f"{output_directory}/{basename}.{output_format}" - # out.save(spath) - - # # ... include other tasks here ... - # elif transcription_task == "diarize": - # # diarize code here - # pass - # elif transcription_task == "wtranscribe": - # # wtranscribe code here - # pass + if arg_dict["audio_files"]: + audio_files = args.pop("audio_files") + + if task == "autotranscribe" or task == "autotranscribe+translate": + for audio in audio_files: + if task == "autotranscribe+translate": + task = "translate" + else: + task = "transcribe" + + out = model.autotranscribe(audio,task = task, language=arg_dict.pop("language"), verbose = arg_dict.pop("verbose_output")) + basename = audio.split("/")[-1].split(".")[0] + out.save(os.path.join(out_folder, f"{basename}.{out_format}")) + + elif task == "diarization": + for audio in audio_files: + if arg_dict.pop("verbose_output"): + print(f"Verbose not implemented for diarization.") + + out = model.diarization(audio) + basename = audio.split("/")[-1].split(".")[0] + path = os.path.join(out_folder, f"{basename}.{out_format}") + if out_format == "txt": + with open(path, "w") as f: + f.write(out) + elif out_format == "json": + with open(path, "w") as f: + json.dump(json.dumps(out, indent= 3), f) + else: + raise ValueError(f"Unsupported output format for diarization{out_format}.") + elif task == "transcribe" or task == "translate": + + for audio in audio_files: + + out = model.transcribe(audio, task = task, + language=arg_dict.pop("language"), + verbose = arg_dict.pop("verbose_output")) + basename = audio.split("/")[-1].split(".")[0] + path = os.path.join(out_folder, f"{basename}.{out_format}") + with open(path, "w") as f: + f.write(out) + - # if start_server: # unfinished code - # from .gradio_app import gradio_app - # gradio_app(model) + if start_server: # unfinished code + from .gradio_app import gradio_app + gradio_app(model) if __name__ == "__main__": cli() \ No newline at end of file