added compability with torchaudio

This commit is contained in:
Jaikinator
2023-06-13 08:25:58 +02:00
parent a5693490df
commit 157851f8fa
+68 -57
View File
@@ -1,109 +1,108 @@
from typing import Any, Union
from pydub import AudioSegment
import torch
from torchaudio import load, save
import os
from warn import warn
from warnings import warn
import torch
from pydub import AudioSegment
from torchaudio import load, save
class AudioProcessor:
def __init__(self, audio_file:str):
self.audio = AudioSegment.from_file(audio_file,
format=audio_file.split('.')[-1])
self.audio_file_path = audio_file
self.audio_file = AudioSegment.from_file(audio_file, format=audio_file.split('.')[-1])
self.waveform = self.pydub_to_tensor[0]
self.sr = self.pydub_to_tensor[1]
self.audiofilename = audio_file.split('/')[-1][:-4]
self.coreaudiofile = audio_file.split('/')[-1][:-4]
self.audiofilefolder = os.path.dirname(audio_file)
self.audio_file_type = audio_file.split('.')[-1]
@property
def pydub_to_tensor(self):
"""
Converts pydub audio segment into np.float32 of shape
[duration_in_seconds*sample_rate, channels],
where each value is in range [-1.0, 1.0].
Returns tuple (audio_np_array, sample_rate).
"""
audio = self.audio
x = torch.Tensor(audio.get_array_of_samples()
).reshape((-1, audio.channels))
y = (1 << (8 * audio.sample_width - 1))
return x / y, audio.frame_rate
def save(self, path: str, remove_orginal: bool = True , *args, **kwargs) -> None:
def convert_audio(self, path: str, remove_orginal: bool = False,
*args, **kwargs) -> None:
"""
Convert and saves video file or other audio files to a different file type,
Can be used to ensure that the audio file is in the correct format for the Whisper model
Can be used to ensure that the audio file is in the correct format
for the Whisper model.
:param path : path to save file
:param remove_orginal: remove original file
:return: mp3 file path
:param args: arguments for pydub.AudioSegment.export
:param kwargs: keyword arguments for pydub.AudioSegment.export
e.g. format
:return: None
"""
print(f'Converting {self.audiofilename} to .{type} file')
if savefolder == "":
savefolder = self.audiofilefolder
if savename == "":
savename = self.coreaudiofile + f'.{type}'
else:
savename = savename + f'.{type}'
savepath = os.path.join(savefolder, savename)
self.audio_file.export(savepath, format=type)
self.audio.export(path, *args, **kwargs)
if remove_orginal:
os.remove(self.audio_file_path)
print(f'File {self.audio_file_path} removed')
self.audio_file_path = path
def to_mp3(self, savefolder: str = "", savename: str = "", remove_orginal: bool = True):
def to_mp3(self, *args, **kwargs) -> None:
"""
Convert audio file to mp3 file
:param file: audio file
:param remove_orginal: remove original file
:return: mp3 file path
"""
warn(DeprecationWarning, "This function is deprecated, please use convert_audio instead")
return self.convert_audio(savefolder = savefolder,
savename = savename,
type="mp3",
remove_orginal=remove_orginal)
def to_wav(self, savefolder: str = "",
savename: str = "",
remove_orginal: bool = True):
warn(DeprecationWarning, "This function is deprecated," \
"please use convert_audio instead")
if "mp3" not in kwargs["format"]:
kwargs["format"] = "mp3"
self.convert_audio(*args, **kwargs)
def to_wav(self,*args, **kwargs) -> None:
"""
Convert audio file to wav file
:param file: audio file
:param remove_orginal: remove original file
:return: wav file path
"""
warn(DeprecationWarning, "This function is deprecated, please use convert_audio instead")
return self.convert_audio(savefolder = savefolder,
savename = savename,type="wav",
remove_orginal=remove_orginal)
warn(DeprecationWarning, "This function is deprecated," \
"please use convert_audio instead")
def slower_mp3(self, savefolder: str = "",
if "wav" not in kwargs["format"]:
kwargs["format"] = "wav"
self.convert_audio(*args, **kwargs)
def slower_mp3(self, path: str,
speed: float = 0.75,
type: str = "mp3"):
type: str = "mp3") -> None:
"""
Slow down mp3 file
:param file: mp3 file
:param speed: speed
:return: None
"""
if savefolder == "":
savefolder = self.audiofilefolder
else:
savefolder = savefolder
sound = self.audio_file
slow_sound = sound._spawn(sound.raw_data, overrides={
"frame_rate": int(sound.frame_rate * speed)
})
speedstr = str(speed).replace('.', '')
file_out = self.coreaudiofile + f'_{speedstr}.{type}'
save_path = os.path.join(savefolder, file_out)
slow_sound.export(save_path, format=type)
slow_sound.export(path, format=type)
return slow_sound
class TorchAudioProcessor:
"""
Audio Processor using PyTorchaudio instead of PyDub
@@ -137,6 +136,19 @@ class TorchAudioProcessor:
return cls(audio, sr)
@classmethod
def from_ffmpeg(cls, file: str, *args, **kwargs) -> 'TorchAudioProcessor':
"""
Initialise audio processor using pydub audio segment.
pydub uses ffmped instead of SoX (which is used by torchaudio)
:param file: audio file
:return: TorchAudioProcessor
"""
audio = AudioProcessor(file)
return cls(audio.waveform, audio.sr)
def cut(self, start: float, end: float) -> torch.Tensor:
"""
Cut audio file
@@ -156,13 +168,12 @@ class TorchAudioProcessor:
:return: None
"""
if "format" not in kwargs:
kwargs["format"] = file.split('.')[-1]
kwargs["format"] = path.split('.')[-1]
save(file, self.waveform, self.sr, *args, **kwargs)
save(path, self.waveform, self.sr, *args, **kwargs)
def __repr__(self) -> str:
return f'TorchAudioProcessor(waveform={len(self.waveform)}, sr={int(self.sr)})'
def __str__(self) -> str:
return f'TorchAudioProcessor(waveform={len(self.waveform)}, sr={int(self.sr)})'