From a5e051cbfbc7c6e5bca455778024ec316b1051b4 Mon Sep 17 00:00:00 2001 From: Jaikinator Date: Mon, 19 Jun 2023 15:23:23 +0200 Subject: [PATCH] added cli --- autotranscript/autotranscript.py | 112 +++++++++++++++++++++++++++++-- 1 file changed, 107 insertions(+), 5 deletions(-) diff --git a/autotranscript/autotranscript.py b/autotranscript/autotranscript.py index 9f4100e..0a29528 100644 --- a/autotranscript/autotranscript.py +++ b/autotranscript/autotranscript.py @@ -1,7 +1,7 @@ -from autotranscript.audio import AudioProcessor -from autotranscript.diarisation import Diariser -from autotranscript.transcriber import Transcriber, whisper -from autotranscript.transcript_exporter import Transcript +from .audio import AudioProcessor +from .diarisation import Diariser +from .transcriber import Transcriber, whisper +from .transcript_exporter import Transcript from typing import Union , TypeVar from tqdm import trange import torch @@ -9,6 +9,8 @@ import os from glob import iglob from subprocess import run from warnings import warn +import argparse + diarisation = TypeVar('diarisation') @@ -160,4 +162,104 @@ class AutoTranscribe: if not isinstance(audiofile, AudioProcessor): raise ValueError(f'Audiofile must be of type AudioProcessor,' \ f'not {type(audiofile)}') - return audiofile \ No newline at end of file + return audiofile + + +def cli(): + from whisper import available_models + from whisper.utils import get_writer + from whisper.tokenizer import LANGUAGES , TO_LANGUAGE_CODE + from .transcriber import WHISPER_DEFAULT_PATH + def str2bool(string): + str2val = {"True": True, "False": False} + if string in str2val: + return str2val[string] + else: + raise ValueError(f"Expected one of {set(str2val.keys())}, got {string}") + + + # fmt: off + parser = argparse.ArgumentParser(formatter_class= + argparse.ArgumentDefaultsHelpFormatter) + + parser.add_argument("audio", nargs="+", type=str, + help="audio file(s) to transcribe") + + parser.add_argument("--wmodel", default="medium", + help="name of the Whisper model to use") + parser.add_argument("--wmodel_dir", type=str, default= WHISPER_DEFAULT_PATH, + help="the path to save model files; uses ./models/whisper by default") + + parser.add_argument("--device", + default="cuda" if torch.cuda.is_available() else "cpu", + help="device to use for PyTorch inference") + parser.add_argument("--threads", type=int, default=0, + help="number of threads used by torch for CPU inference; supercedes MKL_NUM_THREADS/OMP_NUM_THREADS") + + parser.add_argument("--output_dir", "-o", type=str, default=".", + help="directory to save the outputs") + parser.add_argument("--output_format", "-f", type=str, default="txt", + choices=["txt", "json", "md", "html"], + help="format of the output file; if not specified, all available formats will be produced") + + parser.add_argument("--verbose", type=str2bool, default=True, + help="whether to print out the progress and debug messages") + + parser.add_argument("--task", type=str, default="transcribe", + choices=["transcribe", "diarize","wtranscribe"], + help="whether to perfrom transcription and diazation or only one of them") + parser.add_argument("--language", type=str, default=None, + choices=sorted(LANGUAGES.keys()) + sorted([k.title() for k in TO_LANGUAGE_CODE.keys()]), + help="language spoken in the audio, specify None to perform language detection") + + # fmt: on + + args = parser.parse_args().__dict__ + model_name: str = args.pop("wmodel") + model_dir: str = args.pop("wmodel_dir") + output_dir: str = args.pop("output_dir") + output_format: str = args.pop("output_format") + task = args.pop("task") + device: str = args.pop("device") + os.makedirs(output_dir, exist_ok=True) + + if (threads := args.pop("threads")) > 0: + torch.set_num_threads(threads) + + wkwargs = {"download_root": model_dir, + "device": device, + "language" : args.pop("language")} + + model = AutoTranscribe(whisper_model= model_name, whisper_kwargs= wkwargs) + + if task == "transcribe": + for audio in args.pop("audio"): + out = model.transcribe(audio) + basename = audio.split("/")[-1].split(".")[0] + spath = f"{output_dir}/{basename}.{output_format}" + out.save(spath) + + elif task == "diarize": + warn("Diarization is still in beta and may not work as expected.", + RuntimeWarning) + for audio in args.pop("audio"): + out = model.diariser.diarization(audio) + basename = audio.split("/")[-1].split(".")[0] + spath = f"{output_dir}/{basename}.json" + + print(f"diairization results saved to {spath}") + + out.save(spath) + + elif task == "wtranscribe": + writer = get_writer(output_format, output_dir) + warn("whisper transcription is poorly supported and may not work as expected." \ + "It is recommendet to use the whisper cli directly", + RuntimeWarning) + for audio in args.pop("audio"): + out = model.transcriber.transcribe(audio, diarisation=True) + basename = audio.split("/")[-1].split(".")[0] + writer(out, audio) + +if __name__ == "__main__": + cli() \ No newline at end of file