removed torch device from AudioProcessor class
This commit is contained in:
+4
-10
@@ -41,26 +41,20 @@ class AudioProcessor:
|
|||||||
The sample rate of the audio.
|
The sample rate of the audio.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
def __init__(self, waveform: torch.Tensor, sr: int = SAMPLE_RATE,
|
def __init__(self, waveform: torch.Tensor,
|
||||||
*args, **kwargs) -> None:
|
sr: int = SAMPLE_RATE) -> None:
|
||||||
"""
|
"""
|
||||||
Initialize the AudioProcessor object.
|
Initialize the AudioProcessor object.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
waveform (torch.Tensor): The audio waveform tensor.
|
waveform (torch.Tensor): The audio waveform tensor.
|
||||||
sr (int, optional): The sample rate of the audio. Defaults to SAMPLE_RATE.
|
sr (int, optional): The sample rate of the audio. Defaults to SAMPLE_RATE.
|
||||||
args: Additional arguments.
|
|
||||||
kwargs: Additional keyword arguments, e.g., device to use for processing.
|
|
||||||
If CUDA is available, it defaults to CUDA.
|
|
||||||
|
|
||||||
Raises:
|
Raises:
|
||||||
ValueError: If the provided sample rate is not of type int.
|
ValueError: If the provided sample rate is not of type int.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
device = kwargs.get(
|
self.waveform = waveform
|
||||||
"device", "cuda" if torch.cuda.is_available() else "cpu")
|
|
||||||
|
|
||||||
self.waveform = waveform.to(device)
|
|
||||||
self.sr = sr
|
self.sr = sr
|
||||||
|
|
||||||
if not isinstance(self.sr, int):
|
if not isinstance(self.sr, int):
|
||||||
@@ -147,6 +141,6 @@ class AudioProcessor:
|
|||||||
np.float32) / NORMALIZATION_FACTOR
|
np.float32) / NORMALIZATION_FACTOR
|
||||||
|
|
||||||
return out, sr
|
return out, sr
|
||||||
|
|
||||||
def __repr__(self) -> str:
|
def __repr__(self) -> str:
|
||||||
return f'TorchAudioProcessor(waveform={len(self.waveform)}, sr={int(self.sr)})'
|
return f'TorchAudioProcessor(waveform={len(self.waveform)}, sr={int(self.sr)})'
|
||||||
|
|||||||
Reference in New Issue
Block a user