From 713dd3bfd5861e517d6660ff74614019fe2307df Mon Sep 17 00:00:00 2001 From: Jaikinator Date: Fri, 16 Jun 2023 12:10:11 +0200 Subject: [PATCH] added cuda support --- autotranscript/audio.py | 19 +++++++++++++++++-- 1 file changed, 17 insertions(+), 2 deletions(-) diff --git a/autotranscript/audio.py b/autotranscript/audio.py index 35b6f99..ea11fe8 100644 --- a/autotranscript/audio.py +++ b/autotranscript/audio.py @@ -9,13 +9,28 @@ class AudioProcessor: Audio Processor using PyTorchaudio instead of PyDub """ - def __init__(self, waveform: torch.Tensor, sr : torch.Tensor) -> None: + def __init__(self, waveform: torch.Tensor, sr : torch.Tensor, + *args, **kwargs) -> None: """ Initialise audio processor :param waveform: waveform :param sr: sample rate + :param args: additional arguments + :param kwargs: additional keyword arguments + example: + - device: device to use for processing + if cuda is available, cuda is used """ - self.waveform = waveform + + if "device" in kwargs: + device = kwargs["device"] + else: + if torch.cuda.is_available(): + device = "cuda" + else: + device = "cpu" + + self.waveform = waveform.to(device) self.sr = sr if not isinstance(self.sr, int):