removed torch device from AudioProcessor class

This commit is contained in:
Schmieder, Jacob
2024-10-08 12:01:36 +00:00
parent a4b8546033
commit 6fadf3d851
+4 -10
View File
@@ -41,26 +41,20 @@ class AudioProcessor:
The sample rate of the audio.
"""
def __init__(self, waveform: torch.Tensor, sr: int = SAMPLE_RATE,
*args, **kwargs) -> None:
def __init__(self, waveform: torch.Tensor,
sr: int = SAMPLE_RATE) -> None:
"""
Initialize the AudioProcessor object.
Args:
waveform (torch.Tensor): The audio waveform tensor.
sr (int, optional): The sample rate of the audio. Defaults to SAMPLE_RATE.
args: Additional arguments.
kwargs: Additional keyword arguments, e.g., device to use for processing.
If CUDA is available, it defaults to CUDA.
Raises:
ValueError: If the provided sample rate is not of type int.
"""
device = kwargs.get(
"device", "cuda" if torch.cuda.is_available() else "cpu")
self.waveform = waveform.to(device)
self.waveform = waveform
self.sr = sr
if not isinstance(self.sr, int):
@@ -147,6 +141,6 @@ class AudioProcessor:
np.float32) / NORMALIZATION_FACTOR
return out, sr
def __repr__(self) -> str:
return f'TorchAudioProcessor(waveform={len(self.waveform)}, sr={int(self.sr)})'