added ndarray datatype to input of transcribe
This commit is contained in:
@@ -10,7 +10,7 @@ from glob import iglob
|
|||||||
from subprocess import run
|
from subprocess import run
|
||||||
from warnings import warn
|
from warnings import warn
|
||||||
import argparse
|
import argparse
|
||||||
|
from numpy import ndarray
|
||||||
|
|
||||||
diarisation = TypeVar('diarisation')
|
diarisation = TypeVar('diarisation')
|
||||||
|
|
||||||
@@ -53,7 +53,7 @@ class AutoTranscribe:
|
|||||||
|
|
||||||
print("AutoTranscribe initialized all models successfully loaded.")
|
print("AutoTranscribe initialized all models successfully loaded.")
|
||||||
|
|
||||||
def transcribe(self, audiofile : Union[str, torch.Tensor],
|
def transcribe(self, audiofile : Union[str, torch.Tensor, ndarray],
|
||||||
remove_original : bool = False,
|
remove_original : bool = False,
|
||||||
*args, **kwargs) -> Transcript:
|
*args, **kwargs) -> Transcript:
|
||||||
"""
|
"""
|
||||||
@@ -140,7 +140,7 @@ class AutoTranscribe:
|
|||||||
|
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
def get_audiofile(audiofile : Union[str, torch.Tensor],
|
def get_audiofile(audiofile : Union[str, torch.Tensor, ndarray],
|
||||||
*args, **kwargs) -> AudioProcessor:
|
*args, **kwargs) -> AudioProcessor:
|
||||||
"""
|
"""
|
||||||
Get audiofile as TorchAudioProcessor
|
Get audiofile as TorchAudioProcessor
|
||||||
@@ -155,8 +155,11 @@ class AutoTranscribe:
|
|||||||
if isinstance(audiofile, str):
|
if isinstance(audiofile, str):
|
||||||
audiofile = AudioProcessor.from_file(audiofile)
|
audiofile = AudioProcessor.from_file(audiofile)
|
||||||
|
|
||||||
if isinstance(audiofile, torch.Tensor):
|
elif isinstance(audiofile, torch.Tensor):
|
||||||
audiofile = AudioProcessor(audiofile[0], audiofile[1])
|
audiofile = AudioProcessor(audiofile[0], audiofile[1])
|
||||||
|
elif isinstance(audiofile, ndarray):
|
||||||
|
audiofile = AudioProcessor(torch.tensor(audiofile[0]),
|
||||||
|
audiofile[1])
|
||||||
|
|
||||||
if not isinstance(audiofile, AudioProcessor):
|
if not isinstance(audiofile, AudioProcessor):
|
||||||
raise ValueError(f'Audiofile must be of type AudioProcessor,' \
|
raise ValueError(f'Audiofile must be of type AudioProcessor,' \
|
||||||
@@ -191,9 +194,10 @@ def cli():
|
|||||||
help="the path to save model files; uses ./models/whisper by default")
|
help="the path to save model files; uses ./models/whisper by default")
|
||||||
|
|
||||||
parser.add_argument("--dia_dir", type=str, default = PYANNOTE_DEFAULT_PATH)
|
parser.add_argument("--dia_dir", type=str, default = PYANNOTE_DEFAULT_PATH)
|
||||||
|
parser.add_argument("--htoken", default="", type=str, help="HuggingFace token for private model download")
|
||||||
parser.add_argument("--allow_download", type= bool, default=True,
|
parser.add_argument("--local", type=str2bool, default=False,
|
||||||
help="whether to allow model download if model is not found locally")
|
help="whether to allow model download if model is not found locally")
|
||||||
|
|
||||||
parser.add_argument("--device",
|
parser.add_argument("--device",
|
||||||
default="cuda" if torch.cuda.is_available() else "cpu",
|
default="cuda" if torch.cuda.is_available() else "cpu",
|
||||||
help="device to use for PyTorch inference")
|
help="device to use for PyTorch inference")
|
||||||
@@ -219,11 +223,12 @@ def cli():
|
|||||||
# fmt: on
|
# fmt: on
|
||||||
|
|
||||||
args = parser.parse_args().__dict__
|
args = parser.parse_args().__dict__
|
||||||
|
|
||||||
model_name: str = args.pop("wmodel")
|
model_name: str = args.pop("wmodel")
|
||||||
model_dir: str = args.pop("wmodel_dir")
|
model_dir: str = args.pop("wmodel_dir")
|
||||||
output_dir: str = args.pop("output_dir")
|
output_dir: str = args.pop("output_dir")
|
||||||
output_format: str = args.pop("output_format")
|
output_format: str = args.pop("output_format")
|
||||||
local :str = args.pop("allow_download")
|
local :str = args.pop("local")
|
||||||
task = args.pop("task")
|
task = args.pop("task")
|
||||||
device: str = args.pop("device")
|
device: str = args.pop("device")
|
||||||
os.makedirs(output_dir, exist_ok=True)
|
os.makedirs(output_dir, exist_ok=True)
|
||||||
@@ -234,7 +239,10 @@ def cli():
|
|||||||
wkwargs = {"download_root": model_dir,
|
wkwargs = {"download_root": model_dir,
|
||||||
"local": local,
|
"local": local,
|
||||||
"device": device}
|
"device": device}
|
||||||
diarisation_kwargs = {"local": local}
|
|
||||||
|
diarisation_kwargs = {"local": local,
|
||||||
|
"token" : args.pop("htoken")}
|
||||||
|
|
||||||
model = AutoTranscribe(whisper_model= model_name,
|
model = AutoTranscribe(whisper_model= model_name,
|
||||||
whisper_kwargs= wkwargs,
|
whisper_kwargs= wkwargs,
|
||||||
dia_model= args.pop("dia_dir"),
|
dia_model= args.pop("dia_dir"),
|
||||||
|
|||||||
Reference in New Issue
Block a user