added pytorch audio support
This commit is contained in:
@@ -1,9 +1,13 @@
|
|||||||
from typing import Union
|
from typing import Any, Union
|
||||||
from pydub import AudioSegment
|
from pydub import AudioSegment
|
||||||
|
import torch
|
||||||
|
from torchaudio import load, save
|
||||||
import os
|
import os
|
||||||
|
from warn import warn
|
||||||
|
|
||||||
class AudioProcessor:
|
class AudioProcessor:
|
||||||
def __init__(self, audio_file:str):
|
def __init__(self, audio_file:str):
|
||||||
|
|
||||||
self.audio_file_path = audio_file
|
self.audio_file_path = audio_file
|
||||||
self.audio_file = AudioSegment.from_file(audio_file, format=audio_file.split('.')[-1])
|
self.audio_file = AudioSegment.from_file(audio_file, format=audio_file.split('.')[-1])
|
||||||
|
|
||||||
@@ -12,15 +16,14 @@ class AudioProcessor:
|
|||||||
self.audiofilefolder = os.path.dirname(audio_file)
|
self.audiofilefolder = os.path.dirname(audio_file)
|
||||||
self.audio_file_type = audio_file.split('.')[-1]
|
self.audio_file_type = audio_file.split('.')[-1]
|
||||||
|
|
||||||
|
|
||||||
|
def save(self, path: str, remove_orginal: bool = True , *args, **kwargs) -> None:
|
||||||
def convert_audio(self, savefolder: str = "", savename: str = "", type: str = "wav", remove_orginal: bool = True):
|
|
||||||
"""
|
"""
|
||||||
Convert video file or other audio files to mp3 file, ensures that the audio file is in the correct format for the
|
Convert and saves video file or other audio files to a different file type,
|
||||||
Whisper model
|
Can be used to ensure that the audio file is in the correct format for the Whisper model
|
||||||
:param file: path to audio or video file
|
:param path : path to save file
|
||||||
:param remove_orginal: remove original file
|
:param remove_orginal: remove original file
|
||||||
:return: mp3 file path
|
:return: mp3 file path
|
||||||
"""
|
"""
|
||||||
print(f'Converting {self.audiofilename} to .{type} file')
|
print(f'Converting {self.audiofilename} to .{type} file')
|
||||||
|
|
||||||
@@ -36,16 +39,11 @@ class AudioProcessor:
|
|||||||
|
|
||||||
self.audio_file.export(savepath, format=type)
|
self.audio_file.export(savepath, format=type)
|
||||||
|
|
||||||
print(f'Converted {self.audiofilename} to {type}')
|
|
||||||
|
|
||||||
if remove_orginal:
|
if remove_orginal:
|
||||||
os.remove(self.audio_file_path)
|
os.remove(self.audio_file_path)
|
||||||
print(f'File {self.audio_file_path} removed')
|
print(f'File {self.audio_file_path} removed')
|
||||||
|
|
||||||
self.audio_file_path = savepath
|
|
||||||
self.audio_file = AudioSegment.from_file(savepath, format=type)
|
|
||||||
|
|
||||||
return self
|
|
||||||
|
|
||||||
def to_mp3(self, savefolder: str = "", savename: str = "", remove_orginal: bool = True):
|
def to_mp3(self, savefolder: str = "", savename: str = "", remove_orginal: bool = True):
|
||||||
"""
|
"""
|
||||||
@@ -54,18 +52,29 @@ class AudioProcessor:
|
|||||||
:param remove_orginal: remove original file
|
:param remove_orginal: remove original file
|
||||||
:return: mp3 file path
|
:return: mp3 file path
|
||||||
"""
|
"""
|
||||||
return self.convert_audio(savefolder = savefolder, savename = savename, type="mp3", remove_orginal=remove_orginal)
|
warn(DeprecationWarning, "This function is deprecated, please use convert_audio instead")
|
||||||
|
return self.convert_audio(savefolder = savefolder,
|
||||||
|
savename = savename,
|
||||||
|
type="mp3",
|
||||||
|
remove_orginal=remove_orginal)
|
||||||
|
|
||||||
def to_wav(self, savefolder: str = "", savename: str = "", remove_orginal: bool = True):
|
def to_wav(self, savefolder: str = "",
|
||||||
|
savename: str = "",
|
||||||
|
remove_orginal: bool = True):
|
||||||
"""
|
"""
|
||||||
Convert audio file to wav file
|
Convert audio file to wav file
|
||||||
:param file: audio file
|
:param file: audio file
|
||||||
:param remove_orginal: remove original file
|
:param remove_orginal: remove original file
|
||||||
:return: wav file path
|
:return: wav file path
|
||||||
"""
|
"""
|
||||||
return self.convert_audio(savefolder = savefolder, savename = savename,type="wav", remove_orginal=remove_orginal)
|
warn(DeprecationWarning, "This function is deprecated, please use convert_audio instead")
|
||||||
|
return self.convert_audio(savefolder = savefolder,
|
||||||
|
savename = savename,type="wav",
|
||||||
|
remove_orginal=remove_orginal)
|
||||||
|
|
||||||
def slower_mp3(self, savefolder: str = "", savename: str = "", speed: float = 0.75, type: str = "mp3"):
|
def slower_mp3(self, savefolder: str = "",
|
||||||
|
speed: float = 0.75,
|
||||||
|
type: str = "mp3"):
|
||||||
"""
|
"""
|
||||||
Slow down mp3 file
|
Slow down mp3 file
|
||||||
:param file: mp3 file
|
:param file: mp3 file
|
||||||
@@ -90,4 +99,70 @@ class AudioProcessor:
|
|||||||
|
|
||||||
slow_sound.export(save_path, format=type)
|
slow_sound.export(save_path, format=type)
|
||||||
|
|
||||||
return slow_sound
|
return slow_sound
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
class TorchAudioProcessor:
|
||||||
|
"""
|
||||||
|
Audio Processor using PyTorchaudio instead of PyDub
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(self, waveform: torch.Tensor, sr : torch.Tensor) -> None:
|
||||||
|
"""
|
||||||
|
Initialise audio processor
|
||||||
|
:param waveform: waveform
|
||||||
|
:param sr: sample rate
|
||||||
|
"""
|
||||||
|
self.waveform = waveform
|
||||||
|
self.sr = sr
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def from_file(cls, file: str, *args, **kwargs) -> 'TorchAudioProcessor':
|
||||||
|
"""
|
||||||
|
Load audio file
|
||||||
|
:param file: audio file
|
||||||
|
:return: AudioProcessor
|
||||||
|
"""
|
||||||
|
if not os.path.exists(file):
|
||||||
|
raise FileNotFoundError(f'File {file} not found')
|
||||||
|
|
||||||
|
if "format" not in kwargs:
|
||||||
|
kwargs["format"] = file.split('.')[-1]
|
||||||
|
|
||||||
|
audio, sr = load(file , *args, **kwargs)
|
||||||
|
|
||||||
|
return cls(audio, sr)
|
||||||
|
|
||||||
|
def cut(self, start: float, end: float) -> torch.Tensor:
|
||||||
|
"""
|
||||||
|
Cut audio file
|
||||||
|
:param start: start time in seconds
|
||||||
|
:param end: end time in seconds
|
||||||
|
:return: AudioProcessor
|
||||||
|
"""
|
||||||
|
start = int(start / self.sr)
|
||||||
|
end = torch.ceil(end / self.sr)
|
||||||
|
|
||||||
|
return self.waveform[:, start:end]
|
||||||
|
|
||||||
|
def save(self, path: str, *args, **kwargs) -> None:
|
||||||
|
"""
|
||||||
|
Save audio file
|
||||||
|
:param path: path to save file
|
||||||
|
:return: None
|
||||||
|
"""
|
||||||
|
if "format" not in kwargs:
|
||||||
|
kwargs["format"] = file.split('.')[-1]
|
||||||
|
|
||||||
|
save(file, self.waveform, self.sr, *args, **kwargs)
|
||||||
|
|
||||||
|
def __repr__(self) -> str:
|
||||||
|
return f'TorchAudioProcessor(waveform={len(self.waveform)}, sr={int(self.sr)})'
|
||||||
|
|
||||||
|
def __str__(self) -> str:
|
||||||
|
return f'TorchAudioProcessor(waveform={len(self.waveform)}, sr={int(self.sr)})'
|
||||||
|
|
||||||
Reference in New Issue
Block a user