removed torch device from AudioProcessor class

This commit is contained in:
Schmieder, Jacob
2024-10-08 12:01:36 +00:00
parent a4b8546033
commit 6fadf3d851
+3 -9
View File
@@ -41,26 +41,20 @@ class AudioProcessor:
The sample rate of the audio. The sample rate of the audio.
""" """
def __init__(self, waveform: torch.Tensor, sr: int = SAMPLE_RATE, def __init__(self, waveform: torch.Tensor,
*args, **kwargs) -> None: sr: int = SAMPLE_RATE) -> None:
""" """
Initialize the AudioProcessor object. Initialize the AudioProcessor object.
Args: Args:
waveform (torch.Tensor): The audio waveform tensor. waveform (torch.Tensor): The audio waveform tensor.
sr (int, optional): The sample rate of the audio. Defaults to SAMPLE_RATE. sr (int, optional): The sample rate of the audio. Defaults to SAMPLE_RATE.
args: Additional arguments.
kwargs: Additional keyword arguments, e.g., device to use for processing.
If CUDA is available, it defaults to CUDA.
Raises: Raises:
ValueError: If the provided sample rate is not of type int. ValueError: If the provided sample rate is not of type int.
""" """
device = kwargs.get( self.waveform = waveform
"device", "cuda" if torch.cuda.is_available() else "cpu")
self.waveform = waveform.to(device)
self.sr = sr self.sr = sr
if not isinstance(self.sr, int): if not isinstance(self.sr, int):