diff --git a/autotranscript/audio.py b/autotranscript/audio.py index 35b6f99..ea11fe8 100644 --- a/autotranscript/audio.py +++ b/autotranscript/audio.py @@ -9,13 +9,28 @@ class AudioProcessor: Audio Processor using PyTorchaudio instead of PyDub """ - def __init__(self, waveform: torch.Tensor, sr : torch.Tensor) -> None: + def __init__(self, waveform: torch.Tensor, sr : torch.Tensor, + *args, **kwargs) -> None: """ Initialise audio processor :param waveform: waveform :param sr: sample rate + :param args: additional arguments + :param kwargs: additional keyword arguments + example: + - device: device to use for processing + if cuda is available, cuda is used """ - self.waveform = waveform + + if "device" in kwargs: + device = kwargs["device"] + else: + if torch.cuda.is_available(): + device = "cuda" + else: + device = "cpu" + + self.waveform = waveform.to(device) self.sr = sr if not isinstance(self.sr, int):