Auto fixes from PEP8, fixes from flake8.
This commit is contained in:
+23
-21
@@ -28,6 +28,7 @@ import torch
|
||||
SAMPLE_RATE = 16000
|
||||
NORMALIZATION_FACTOR = 32768.0
|
||||
|
||||
|
||||
class AudioProcessor:
|
||||
"""
|
||||
Audio Processor class that leverages PyTorchaudio to provide functionalities
|
||||
@@ -39,10 +40,9 @@ class AudioProcessor:
|
||||
sr: int
|
||||
The sample rate of the audio.
|
||||
"""
|
||||
|
||||
def __init__(self, waveform: torch.Tensor, sr : int = SAMPLE_RATE,
|
||||
|
||||
def __init__(self, waveform: torch.Tensor, sr: int = SAMPLE_RATE,
|
||||
*args, **kwargs) -> None:
|
||||
|
||||
"""
|
||||
Initialize the AudioProcessor object.
|
||||
|
||||
@@ -56,16 +56,17 @@ class AudioProcessor:
|
||||
Raises:
|
||||
ValueError: If the provided sample rate is not of type int.
|
||||
"""
|
||||
|
||||
device = kwargs.get("device", "cuda" if torch.cuda.is_available() else "cpu")
|
||||
|
||||
|
||||
device = kwargs.get(
|
||||
"device", "cuda" if torch.cuda.is_available() else "cpu")
|
||||
|
||||
self.waveform = waveform.to(device)
|
||||
self.sr = sr
|
||||
|
||||
|
||||
if not isinstance(self.sr, int):
|
||||
raise ValueError("Sample rate should be a single value of type int," \
|
||||
raise ValueError("Sample rate should be a single value of type int,"
|
||||
f"not {len(self.sr)} and type {type(self.sr)}")
|
||||
|
||||
|
||||
@classmethod
|
||||
def from_file(cls, file: str, *args, **kwargs) -> 'AudioProcessor':
|
||||
"""
|
||||
@@ -77,14 +78,13 @@ class AudioProcessor:
|
||||
Returns:
|
||||
AudioProcessor: An instance of the AudioProcessor class containing the loaded audio.
|
||||
"""
|
||||
|
||||
audio, sr = cls.load_audio(file , *args, **kwargs)
|
||||
|
||||
audio, sr = cls.load_audio(file, *args, **kwargs)
|
||||
|
||||
audio = torch.from_numpy(audio)
|
||||
|
||||
|
||||
return cls(audio, sr)
|
||||
|
||||
|
||||
|
||||
def cut(self, start: float, end: float) -> torch.Tensor:
|
||||
"""
|
||||
Cut a segment from the audio waveform between the specified start and end times.
|
||||
@@ -96,7 +96,7 @@ class AudioProcessor:
|
||||
Returns:
|
||||
torch.Tensor: The cut waveform segment.
|
||||
"""
|
||||
|
||||
|
||||
start = int(start * self.sr)
|
||||
if (isinstance(end, float) or isinstance(end, int)) and isinstance(self.sr, int):
|
||||
end = int(np.ceil(end * self.sr))
|
||||
@@ -140,11 +140,13 @@ class AudioProcessor:
|
||||
try:
|
||||
out = run(cmd, capture_output=True, check=True).stdout
|
||||
except CalledProcessError as e:
|
||||
raise RuntimeError(f"Failed to load audio: {e.stderr.decode()}") from e
|
||||
raise RuntimeError(
|
||||
f"Failed to load audio: {e.stderr.decode()}") from e
|
||||
|
||||
out = np.frombuffer(out, np.int16).flatten().astype(
|
||||
np.float32) / NORMALIZATION_FACTOR
|
||||
|
||||
return out, sr
|
||||
|
||||
out = np.frombuffer(out, np.int16).flatten().astype(np.float32) / NORMALIZATION_FACTOR
|
||||
|
||||
return out , sr
|
||||
|
||||
def __repr__(self) -> str:
|
||||
return f'TorchAudioProcessor(waveform={len(self.waveform)}, sr={int(self.sr)})'
|
||||
return f'TorchAudioProcessor(waveform={len(self.waveform)}, sr={int(self.sr)})'
|
||||
|
||||
Reference in New Issue
Block a user