Auto fixes from PEP8, fixes from flake8.

This commit is contained in:
Marko Henning
2024-05-15 15:18:17 +02:00
parent 9f526a8f3b
commit 4bcd28d0ea
15 changed files with 391 additions and 417 deletions
+23 -21
View File
@@ -28,6 +28,7 @@ import torch
SAMPLE_RATE = 16000
NORMALIZATION_FACTOR = 32768.0
class AudioProcessor:
"""
Audio Processor class that leverages PyTorchaudio to provide functionalities
@@ -39,10 +40,9 @@ class AudioProcessor:
sr: int
The sample rate of the audio.
"""
def __init__(self, waveform: torch.Tensor, sr : int = SAMPLE_RATE,
def __init__(self, waveform: torch.Tensor, sr: int = SAMPLE_RATE,
*args, **kwargs) -> None:
"""
Initialize the AudioProcessor object.
@@ -56,16 +56,17 @@ class AudioProcessor:
Raises:
ValueError: If the provided sample rate is not of type int.
"""
device = kwargs.get("device", "cuda" if torch.cuda.is_available() else "cpu")
device = kwargs.get(
"device", "cuda" if torch.cuda.is_available() else "cpu")
self.waveform = waveform.to(device)
self.sr = sr
if not isinstance(self.sr, int):
raise ValueError("Sample rate should be a single value of type int," \
raise ValueError("Sample rate should be a single value of type int,"
f"not {len(self.sr)} and type {type(self.sr)}")
@classmethod
def from_file(cls, file: str, *args, **kwargs) -> 'AudioProcessor':
"""
@@ -77,14 +78,13 @@ class AudioProcessor:
Returns:
AudioProcessor: An instance of the AudioProcessor class containing the loaded audio.
"""
audio, sr = cls.load_audio(file , *args, **kwargs)
audio, sr = cls.load_audio(file, *args, **kwargs)
audio = torch.from_numpy(audio)
return cls(audio, sr)
def cut(self, start: float, end: float) -> torch.Tensor:
"""
Cut a segment from the audio waveform between the specified start and end times.
@@ -96,7 +96,7 @@ class AudioProcessor:
Returns:
torch.Tensor: The cut waveform segment.
"""
start = int(start * self.sr)
if (isinstance(end, float) or isinstance(end, int)) and isinstance(self.sr, int):
end = int(np.ceil(end * self.sr))
@@ -140,11 +140,13 @@ class AudioProcessor:
try:
out = run(cmd, capture_output=True, check=True).stdout
except CalledProcessError as e:
raise RuntimeError(f"Failed to load audio: {e.stderr.decode()}") from e
raise RuntimeError(
f"Failed to load audio: {e.stderr.decode()}") from e
out = np.frombuffer(out, np.int16).flatten().astype(
np.float32) / NORMALIZATION_FACTOR
return out, sr
out = np.frombuffer(out, np.int16).flatten().astype(np.float32) / NORMALIZATION_FACTOR
return out , sr
def __repr__(self) -> str:
return f'TorchAudioProcessor(waveform={len(self.waveform)}, sr={int(self.sr)})'
return f'TorchAudioProcessor(waveform={len(self.waveform)}, sr={int(self.sr)})'