added pytorch audio support

This commit is contained in:
Jaikinator
2023-06-12 16:38:19 +02:00
parent 6870d03f6b
commit edbe7ebb1d
+90 -15
View File
@@ -1,9 +1,13 @@
from typing import Union from typing import Any, Union
from pydub import AudioSegment from pydub import AudioSegment
import torch
from torchaudio import load, save
import os import os
from warn import warn
class AudioProcessor: class AudioProcessor:
def __init__(self, audio_file:str): def __init__(self, audio_file:str):
self.audio_file_path = audio_file self.audio_file_path = audio_file
self.audio_file = AudioSegment.from_file(audio_file, format=audio_file.split('.')[-1]) self.audio_file = AudioSegment.from_file(audio_file, format=audio_file.split('.')[-1])
@@ -13,12 +17,11 @@ class AudioProcessor:
self.audio_file_type = audio_file.split('.')[-1] self.audio_file_type = audio_file.split('.')[-1]
def save(self, path: str, remove_orginal: bool = True , *args, **kwargs) -> None:
def convert_audio(self, savefolder: str = "", savename: str = "", type: str = "wav", remove_orginal: bool = True):
""" """
Convert video file or other audio files to mp3 file, ensures that the audio file is in the correct format for the Convert and saves video file or other audio files to a different file type,
Whisper model Can be used to ensure that the audio file is in the correct format for the Whisper model
:param file: path to audio or video file :param path : path to save file
:param remove_orginal: remove original file :param remove_orginal: remove original file
:return: mp3 file path :return: mp3 file path
""" """
@@ -36,16 +39,11 @@ class AudioProcessor:
self.audio_file.export(savepath, format=type) self.audio_file.export(savepath, format=type)
print(f'Converted {self.audiofilename} to {type}')
if remove_orginal: if remove_orginal:
os.remove(self.audio_file_path) os.remove(self.audio_file_path)
print(f'File {self.audio_file_path} removed') print(f'File {self.audio_file_path} removed')
self.audio_file_path = savepath
self.audio_file = AudioSegment.from_file(savepath, format=type)
return self
def to_mp3(self, savefolder: str = "", savename: str = "", remove_orginal: bool = True): def to_mp3(self, savefolder: str = "", savename: str = "", remove_orginal: bool = True):
""" """
@@ -54,18 +52,29 @@ class AudioProcessor:
:param remove_orginal: remove original file :param remove_orginal: remove original file
:return: mp3 file path :return: mp3 file path
""" """
return self.convert_audio(savefolder = savefolder, savename = savename, type="mp3", remove_orginal=remove_orginal) warn(DeprecationWarning, "This function is deprecated, please use convert_audio instead")
return self.convert_audio(savefolder = savefolder,
savename = savename,
type="mp3",
remove_orginal=remove_orginal)
def to_wav(self, savefolder: str = "", savename: str = "", remove_orginal: bool = True): def to_wav(self, savefolder: str = "",
savename: str = "",
remove_orginal: bool = True):
""" """
Convert audio file to wav file Convert audio file to wav file
:param file: audio file :param file: audio file
:param remove_orginal: remove original file :param remove_orginal: remove original file
:return: wav file path :return: wav file path
""" """
return self.convert_audio(savefolder = savefolder, savename = savename,type="wav", remove_orginal=remove_orginal) warn(DeprecationWarning, "This function is deprecated, please use convert_audio instead")
return self.convert_audio(savefolder = savefolder,
savename = savename,type="wav",
remove_orginal=remove_orginal)
def slower_mp3(self, savefolder: str = "", savename: str = "", speed: float = 0.75, type: str = "mp3"): def slower_mp3(self, savefolder: str = "",
speed: float = 0.75,
type: str = "mp3"):
""" """
Slow down mp3 file Slow down mp3 file
:param file: mp3 file :param file: mp3 file
@@ -91,3 +100,69 @@ class AudioProcessor:
slow_sound.export(save_path, format=type) slow_sound.export(save_path, format=type)
return slow_sound return slow_sound
class TorchAudioProcessor:
"""
Audio Processor using PyTorchaudio instead of PyDub
"""
def __init__(self, waveform: torch.Tensor, sr : torch.Tensor) -> None:
"""
Initialise audio processor
:param waveform: waveform
:param sr: sample rate
"""
self.waveform = waveform
self.sr = sr
@classmethod
def from_file(cls, file: str, *args, **kwargs) -> 'TorchAudioProcessor':
"""
Load audio file
:param file: audio file
:return: AudioProcessor
"""
if not os.path.exists(file):
raise FileNotFoundError(f'File {file} not found')
if "format" not in kwargs:
kwargs["format"] = file.split('.')[-1]
audio, sr = load(file , *args, **kwargs)
return cls(audio, sr)
def cut(self, start: float, end: float) -> torch.Tensor:
"""
Cut audio file
:param start: start time in seconds
:param end: end time in seconds
:return: AudioProcessor
"""
start = int(start / self.sr)
end = torch.ceil(end / self.sr)
return self.waveform[:, start:end]
def save(self, path: str, *args, **kwargs) -> None:
"""
Save audio file
:param path: path to save file
:return: None
"""
if "format" not in kwargs:
kwargs["format"] = file.split('.')[-1]
save(file, self.waveform, self.sr, *args, **kwargs)
def __repr__(self) -> str:
return f'TorchAudioProcessor(waveform={len(self.waveform)}, sr={int(self.sr)})'
def __str__(self) -> str:
return f'TorchAudioProcessor(waveform={len(self.waveform)}, sr={int(self.sr)})'