added compability with torchaudio

This commit is contained in:
Jaikinator
2023-06-13 08:25:58 +02:00
parent a5693490df
commit 157851f8fa
+73 -62
View File
@@ -1,108 +1,107 @@
from typing import Any, Union
from pydub import AudioSegment
import torch
from torchaudio import load, save
import os import os
from warn import warn from warnings import warn
import torch
from pydub import AudioSegment
from torchaudio import load, save
class AudioProcessor: class AudioProcessor:
def __init__(self, audio_file:str): def __init__(self, audio_file:str):
self.audio_file_path = audio_file self.audio = AudioSegment.from_file(audio_file,
self.audio_file = AudioSegment.from_file(audio_file, format=audio_file.split('.')[-1]) format=audio_file.split('.')[-1])
self.audio_file_path = audio_file
self.audiofilename = audio_file.split('/')[-1][:-4] self.waveform = self.pydub_to_tensor[0]
self.coreaudiofile = audio_file.split('/')[-1][:-4] self.sr = self.pydub_to_tensor[1]
self.audiofilefolder = os.path.dirname(audio_file)
self.audio_file_type = audio_file.split('.')[-1] @property
def pydub_to_tensor(self):
"""
def save(self, path: str, remove_orginal: bool = True , *args, **kwargs) -> None: Converts pydub audio segment into np.float32 of shape
[duration_in_seconds*sample_rate, channels],
where each value is in range [-1.0, 1.0].
Returns tuple (audio_np_array, sample_rate).
"""
audio = self.audio
x = torch.Tensor(audio.get_array_of_samples()
).reshape((-1, audio.channels))
y = (1 << (8 * audio.sample_width - 1))
return x / y, audio.frame_rate
def convert_audio(self, path: str, remove_orginal: bool = False,
*args, **kwargs) -> None:
""" """
Convert and saves video file or other audio files to a different file type, Convert and saves video file or other audio files to a different file type,
Can be used to ensure that the audio file is in the correct format for the Whisper model Can be used to ensure that the audio file is in the correct format
for the Whisper model.
:param path : path to save file :param path : path to save file
:param remove_orginal: remove original file :param remove_orginal: remove original file
:return: mp3 file path :param args: arguments for pydub.AudioSegment.export
:param kwargs: keyword arguments for pydub.AudioSegment.export
e.g. format
:return: None
""" """
print(f'Converting {self.audiofilename} to .{type} file')
if savefolder == "": self.audio.export(path, *args, **kwargs)
savefolder = self.audiofilefolder
if savename == "":
savename = self.coreaudiofile + f'.{type}'
else:
savename = savename + f'.{type}'
savepath = os.path.join(savefolder, savename)
self.audio_file.export(savepath, format=type)
if remove_orginal: if remove_orginal:
os.remove(self.audio_file_path) os.remove(self.audio_file_path)
print(f'File {self.audio_file_path} removed') print(f'File {self.audio_file_path} removed')
self.audio_file_path = path
def to_mp3(self, *args, **kwargs) -> None:
def to_mp3(self, savefolder: str = "", savename: str = "", remove_orginal: bool = True):
""" """
Convert audio file to mp3 file Convert audio file to mp3 file
:param file: audio file :param file: audio file
:param remove_orginal: remove original file :param remove_orginal: remove original file
:return: mp3 file path :return: mp3 file path
""" """
warn(DeprecationWarning, "This function is deprecated, please use convert_audio instead")
return self.convert_audio(savefolder = savefolder, warn(DeprecationWarning, "This function is deprecated," \
savename = savename, "please use convert_audio instead")
type="mp3",
remove_orginal=remove_orginal) if "mp3" not in kwargs["format"]:
kwargs["format"] = "mp3"
self.convert_audio(*args, **kwargs)
def to_wav(self, savefolder: str = "", def to_wav(self,*args, **kwargs) -> None:
savename: str = "",
remove_orginal: bool = True):
""" """
Convert audio file to wav file Convert audio file to wav file
:param file: audio file :param file: audio file
:param remove_orginal: remove original file :param remove_orginal: remove original file
:return: wav file path :return: wav file path
""" """
warn(DeprecationWarning, "This function is deprecated, please use convert_audio instead") warn(DeprecationWarning, "This function is deprecated," \
return self.convert_audio(savefolder = savefolder, "please use convert_audio instead")
savename = savename,type="wav",
remove_orginal=remove_orginal) if "wav" not in kwargs["format"]:
kwargs["format"] = "wav"
self.convert_audio(*args, **kwargs)
def slower_mp3(self, savefolder: str = "", def slower_mp3(self, path: str,
speed: float = 0.75, speed: float = 0.75,
type: str = "mp3"): type: str = "mp3") -> None:
""" """
Slow down mp3 file Slow down mp3 file
:param file: mp3 file :param file: mp3 file
:param speed: speed :param speed: speed
:return: None :return: None
""" """
if savefolder == "":
savefolder = self.audiofilefolder
else:
savefolder = savefolder
sound = self.audio_file sound = self.audio_file
slow_sound = sound._spawn(sound.raw_data, overrides={ slow_sound = sound._spawn(sound.raw_data, overrides={
"frame_rate": int(sound.frame_rate * speed) "frame_rate": int(sound.frame_rate * speed)
}) })
speedstr = str(speed).replace('.', '') slow_sound.export(path, format=type)
file_out = self.coreaudiofile + f'_{speedstr}.{type}'
save_path = os.path.join(savefolder, file_out)
slow_sound.export(save_path, format=type)
return slow_sound return slow_sound
class TorchAudioProcessor: class TorchAudioProcessor:
""" """
@@ -136,6 +135,19 @@ class TorchAudioProcessor:
audio, sr = load(file , *args, **kwargs) audio, sr = load(file , *args, **kwargs)
return cls(audio, sr) return cls(audio, sr)
@classmethod
def from_ffmpeg(cls, file: str, *args, **kwargs) -> 'TorchAudioProcessor':
"""
Initialise audio processor using pydub audio segment.
pydub uses ffmped instead of SoX (which is used by torchaudio)
:param file: audio file
:return: TorchAudioProcessor
"""
audio = AudioProcessor(file)
return cls(audio.waveform, audio.sr)
def cut(self, start: float, end: float) -> torch.Tensor: def cut(self, start: float, end: float) -> torch.Tensor:
""" """
@@ -156,13 +168,12 @@ class TorchAudioProcessor:
:return: None :return: None
""" """
if "format" not in kwargs: if "format" not in kwargs:
kwargs["format"] = file.split('.')[-1] kwargs["format"] = path.split('.')[-1]
save(file, self.waveform, self.sr, *args, **kwargs) save(path, self.waveform, self.sr, *args, **kwargs)
def __repr__(self) -> str: def __repr__(self) -> str:
return f'TorchAudioProcessor(waveform={len(self.waveform)}, sr={int(self.sr)})' return f'TorchAudioProcessor(waveform={len(self.waveform)}, sr={int(self.sr)})'
def __str__(self) -> str: def __str__(self) -> str:
return f'TorchAudioProcessor(waveform={len(self.waveform)}, sr={int(self.sr)})' return f'TorchAudioProcessor(waveform={len(self.waveform)}, sr={int(self.sr)})'