added ndarray datatype to input of transcribe

This commit is contained in:
Jaikinator
2023-06-27 10:21:21 +02:00
parent 2308a9337c
commit d882d80d1d
+17 -9
View File
@@ -10,7 +10,7 @@ from glob import iglob
from subprocess import run from subprocess import run
from warnings import warn from warnings import warn
import argparse import argparse
from numpy import ndarray
diarisation = TypeVar('diarisation') diarisation = TypeVar('diarisation')
@@ -53,7 +53,7 @@ class AutoTranscribe:
print("AutoTranscribe initialized all models successfully loaded.") print("AutoTranscribe initialized all models successfully loaded.")
def transcribe(self, audiofile : Union[str, torch.Tensor], def transcribe(self, audiofile : Union[str, torch.Tensor, ndarray],
remove_original : bool = False, remove_original : bool = False,
*args, **kwargs) -> Transcript: *args, **kwargs) -> Transcript:
""" """
@@ -140,7 +140,7 @@ class AutoTranscribe:
@staticmethod @staticmethod
def get_audiofile(audiofile : Union[str, torch.Tensor], def get_audiofile(audiofile : Union[str, torch.Tensor, ndarray],
*args, **kwargs) -> AudioProcessor: *args, **kwargs) -> AudioProcessor:
""" """
Get audiofile as TorchAudioProcessor Get audiofile as TorchAudioProcessor
@@ -155,9 +155,12 @@ class AutoTranscribe:
if isinstance(audiofile, str): if isinstance(audiofile, str):
audiofile = AudioProcessor.from_file(audiofile) audiofile = AudioProcessor.from_file(audiofile)
if isinstance(audiofile, torch.Tensor): elif isinstance(audiofile, torch.Tensor):
audiofile = AudioProcessor(audiofile[0], audiofile[1]) audiofile = AudioProcessor(audiofile[0], audiofile[1])
elif isinstance(audiofile, ndarray):
audiofile = AudioProcessor(torch.tensor(audiofile[0]),
audiofile[1])
if not isinstance(audiofile, AudioProcessor): if not isinstance(audiofile, AudioProcessor):
raise ValueError(f'Audiofile must be of type AudioProcessor,' \ raise ValueError(f'Audiofile must be of type AudioProcessor,' \
f'not {type(audiofile)}') f'not {type(audiofile)}')
@@ -191,9 +194,10 @@ def cli():
help="the path to save model files; uses ./models/whisper by default") help="the path to save model files; uses ./models/whisper by default")
parser.add_argument("--dia_dir", type=str, default = PYANNOTE_DEFAULT_PATH) parser.add_argument("--dia_dir", type=str, default = PYANNOTE_DEFAULT_PATH)
parser.add_argument("--htoken", default="", type=str, help="HuggingFace token for private model download")
parser.add_argument("--allow_download", type= bool, default=True, parser.add_argument("--local", type=str2bool, default=False,
help="whether to allow model download if model is not found locally") help="whether to allow model download if model is not found locally")
parser.add_argument("--device", parser.add_argument("--device",
default="cuda" if torch.cuda.is_available() else "cpu", default="cuda" if torch.cuda.is_available() else "cpu",
help="device to use for PyTorch inference") help="device to use for PyTorch inference")
@@ -219,11 +223,12 @@ def cli():
# fmt: on # fmt: on
args = parser.parse_args().__dict__ args = parser.parse_args().__dict__
model_name: str = args.pop("wmodel") model_name: str = args.pop("wmodel")
model_dir: str = args.pop("wmodel_dir") model_dir: str = args.pop("wmodel_dir")
output_dir: str = args.pop("output_dir") output_dir: str = args.pop("output_dir")
output_format: str = args.pop("output_format") output_format: str = args.pop("output_format")
local :str = args.pop("allow_download") local :str = args.pop("local")
task = args.pop("task") task = args.pop("task")
device: str = args.pop("device") device: str = args.pop("device")
os.makedirs(output_dir, exist_ok=True) os.makedirs(output_dir, exist_ok=True)
@@ -234,7 +239,10 @@ def cli():
wkwargs = {"download_root": model_dir, wkwargs = {"download_root": model_dir,
"local": local, "local": local,
"device": device} "device": device}
diarisation_kwargs = {"local": local}
diarisation_kwargs = {"local": local,
"token" : args.pop("htoken")}
model = AutoTranscribe(whisper_model= model_name, model = AutoTranscribe(whisper_model= model_name,
whisper_kwargs= wkwargs, whisper_kwargs= wkwargs,
dia_model= args.pop("dia_dir"), dia_model= args.pop("dia_dir"),