added cuda support
This commit is contained in:
+17
-2
@@ -9,13 +9,28 @@ class AudioProcessor:
|
|||||||
Audio Processor using PyTorchaudio instead of PyDub
|
Audio Processor using PyTorchaudio instead of PyDub
|
||||||
"""
|
"""
|
||||||
|
|
||||||
def __init__(self, waveform: torch.Tensor, sr : torch.Tensor) -> None:
|
def __init__(self, waveform: torch.Tensor, sr : torch.Tensor,
|
||||||
|
*args, **kwargs) -> None:
|
||||||
"""
|
"""
|
||||||
Initialise audio processor
|
Initialise audio processor
|
||||||
:param waveform: waveform
|
:param waveform: waveform
|
||||||
:param sr: sample rate
|
:param sr: sample rate
|
||||||
|
:param args: additional arguments
|
||||||
|
:param kwargs: additional keyword arguments
|
||||||
|
example:
|
||||||
|
- device: device to use for processing
|
||||||
|
if cuda is available, cuda is used
|
||||||
"""
|
"""
|
||||||
self.waveform = waveform
|
|
||||||
|
if "device" in kwargs:
|
||||||
|
device = kwargs["device"]
|
||||||
|
else:
|
||||||
|
if torch.cuda.is_available():
|
||||||
|
device = "cuda"
|
||||||
|
else:
|
||||||
|
device = "cpu"
|
||||||
|
|
||||||
|
self.waveform = waveform.to(device)
|
||||||
self.sr = sr
|
self.sr = sr
|
||||||
|
|
||||||
if not isinstance(self.sr, int):
|
if not isinstance(self.sr, int):
|
||||||
|
|||||||
Reference in New Issue
Block a user