From d882d80d1d381a2d19882b6d2c93145c15ac0220 Mon Sep 17 00:00:00 2001 From: Jaikinator Date: Tue, 27 Jun 2023 10:21:21 +0200 Subject: [PATCH] added ndarray datatype to input of transcribe --- autotranscript/autotranscript.py | 26 +++++++++++++++++--------- 1 file changed, 17 insertions(+), 9 deletions(-) diff --git a/autotranscript/autotranscript.py b/autotranscript/autotranscript.py index d79b392..6f00888 100644 --- a/autotranscript/autotranscript.py +++ b/autotranscript/autotranscript.py @@ -10,7 +10,7 @@ from glob import iglob from subprocess import run from warnings import warn import argparse - +from numpy import ndarray diarisation = TypeVar('diarisation') @@ -53,7 +53,7 @@ class AutoTranscribe: print("AutoTranscribe initialized all models successfully loaded.") - def transcribe(self, audiofile : Union[str, torch.Tensor], + def transcribe(self, audiofile : Union[str, torch.Tensor, ndarray], remove_original : bool = False, *args, **kwargs) -> Transcript: """ @@ -140,7 +140,7 @@ class AutoTranscribe: @staticmethod - def get_audiofile(audiofile : Union[str, torch.Tensor], + def get_audiofile(audiofile : Union[str, torch.Tensor, ndarray], *args, **kwargs) -> AudioProcessor: """ Get audiofile as TorchAudioProcessor @@ -155,9 +155,12 @@ class AutoTranscribe: if isinstance(audiofile, str): audiofile = AudioProcessor.from_file(audiofile) - if isinstance(audiofile, torch.Tensor): + elif isinstance(audiofile, torch.Tensor): audiofile = AudioProcessor(audiofile[0], audiofile[1]) - + elif isinstance(audiofile, ndarray): + audiofile = AudioProcessor(torch.tensor(audiofile[0]), + audiofile[1]) + if not isinstance(audiofile, AudioProcessor): raise ValueError(f'Audiofile must be of type AudioProcessor,' \ f'not {type(audiofile)}') @@ -191,9 +194,10 @@ def cli(): help="the path to save model files; uses ./models/whisper by default") parser.add_argument("--dia_dir", type=str, default = PYANNOTE_DEFAULT_PATH) - - parser.add_argument("--allow_download", type= bool, default=True, + parser.add_argument("--htoken", default="", type=str, help="HuggingFace token for private model download") + parser.add_argument("--local", type=str2bool, default=False, help="whether to allow model download if model is not found locally") + parser.add_argument("--device", default="cuda" if torch.cuda.is_available() else "cpu", help="device to use for PyTorch inference") @@ -219,11 +223,12 @@ def cli(): # fmt: on args = parser.parse_args().__dict__ + model_name: str = args.pop("wmodel") model_dir: str = args.pop("wmodel_dir") output_dir: str = args.pop("output_dir") output_format: str = args.pop("output_format") - local :str = args.pop("allow_download") + local :str = args.pop("local") task = args.pop("task") device: str = args.pop("device") os.makedirs(output_dir, exist_ok=True) @@ -234,7 +239,10 @@ def cli(): wkwargs = {"download_root": model_dir, "local": local, "device": device} - diarisation_kwargs = {"local": local} + + diarisation_kwargs = {"local": local, + "token" : args.pop("htoken")} + model = AutoTranscribe(whisper_model= model_name, whisper_kwargs= wkwargs, dia_model= args.pop("dia_dir"),