From 6fadf3d851c06ffc130bfd4d6e758d7da5850830 Mon Sep 17 00:00:00 2001 From: "Schmieder, Jacob" Date: Tue, 8 Oct 2024 12:01:36 +0000 Subject: [PATCH] removed torch device from AudioProcessor class --- scraibe/audio.py | 14 ++++---------- 1 file changed, 4 insertions(+), 10 deletions(-) diff --git a/scraibe/audio.py b/scraibe/audio.py index 7fbc6fb..4e5dd0f 100644 --- a/scraibe/audio.py +++ b/scraibe/audio.py @@ -41,26 +41,20 @@ class AudioProcessor: The sample rate of the audio. """ - def __init__(self, waveform: torch.Tensor, sr: int = SAMPLE_RATE, - *args, **kwargs) -> None: + def __init__(self, waveform: torch.Tensor, + sr: int = SAMPLE_RATE) -> None: """ Initialize the AudioProcessor object. Args: waveform (torch.Tensor): The audio waveform tensor. sr (int, optional): The sample rate of the audio. Defaults to SAMPLE_RATE. - args: Additional arguments. - kwargs: Additional keyword arguments, e.g., device to use for processing. - If CUDA is available, it defaults to CUDA. Raises: ValueError: If the provided sample rate is not of type int. """ - device = kwargs.get( - "device", "cuda" if torch.cuda.is_available() else "cpu") - - self.waveform = waveform.to(device) + self.waveform = waveform self.sr = sr if not isinstance(self.sr, int): @@ -147,6 +141,6 @@ class AudioProcessor: np.float32) / NORMALIZATION_FACTOR return out, sr - + def __repr__(self) -> str: return f'TorchAudioProcessor(waveform={len(self.waveform)}, sr={int(self.sr)})'