unifyed documentation
This commit is contained in:
+78
-59
@@ -1,34 +1,63 @@
|
|||||||
|
"""
|
||||||
|
Audio Processor Module
|
||||||
|
=======================
|
||||||
|
|
||||||
|
This module provides the AudioProcessor class, utilizing PyTorchaudio for handling audio files.
|
||||||
|
It includes functionalities to load, cut, and manage audio waveforms, offering efficient and
|
||||||
|
flexible audio processing.
|
||||||
|
|
||||||
|
Available Classes:
|
||||||
|
- AudioProcessor: Processes audio waveforms and provides methods for loading,
|
||||||
|
cutting, and handling audio.
|
||||||
|
|
||||||
|
Usage:
|
||||||
|
from .audio_import AudioProcessor
|
||||||
|
|
||||||
|
processor = AudioProcessor.from_file("path/to/audiofile.wav")
|
||||||
|
cut_waveform = processor.cut(start=1.0, end=5.0)
|
||||||
|
|
||||||
|
Constants:
|
||||||
|
- SAMPLE_RATE (int): Default sample rate for processing.
|
||||||
|
- NORMALIZATION_FACTOR (float): Normalization factor for audio waveform.
|
||||||
|
"""
|
||||||
|
|
||||||
|
from subprocess import CalledProcessError, run
|
||||||
import numpy as np
|
import numpy as np
|
||||||
import torch
|
import torch
|
||||||
from subprocess import CalledProcessError, run
|
|
||||||
from typing import Union
|
|
||||||
SAMPLE_RATE = 16000
|
SAMPLE_RATE = 16000
|
||||||
|
NORMALIZATION_FACTOR = 32768.0
|
||||||
|
|
||||||
class AudioProcessor:
|
class AudioProcessor:
|
||||||
"""
|
"""
|
||||||
Audio Processor using PyTorchaudio instead of PyDub
|
Audio Processor class that leverages PyTorchaudio to provide functionalities
|
||||||
|
for loading, cutting, and handling audio waveforms.
|
||||||
|
|
||||||
|
Attributes:
|
||||||
|
waveform: torch.Tensor
|
||||||
|
The audio waveform tensor.
|
||||||
|
sr: int
|
||||||
|
The sample rate of the audio.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
def __init__(self, waveform: torch.Tensor, sr : int = SAMPLE_RATE,
|
def __init__(self, waveform: torch.Tensor, sr : int = SAMPLE_RATE,
|
||||||
*args, **kwargs) -> None:
|
*args, **kwargs) -> None:
|
||||||
|
|
||||||
"""
|
"""
|
||||||
Initialise audio processor
|
Initialize the AudioProcessor object.
|
||||||
:param waveform: waveform
|
|
||||||
:param sr: sample rate
|
Args:
|
||||||
:param args: additional arguments
|
waveform (torch.Tensor): The audio waveform tensor.
|
||||||
:param kwargs: additional keyword arguments
|
sr (int, optional): The sample rate of the audio. Defaults to SAMPLE_RATE.
|
||||||
example:
|
args: Additional arguments.
|
||||||
- device: device to use for processing
|
kwargs: Additional keyword arguments, e.g., device to use for processing.
|
||||||
if cuda is available, cuda is used
|
If CUDA is available, it defaults to CUDA.
|
||||||
|
|
||||||
|
Raises:
|
||||||
|
ValueError: If the provided sample rate is not of type int.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
if "device" in kwargs:
|
device = kwargs.get("device", "cuda" if torch.cuda.is_available() else "cpu")
|
||||||
device = kwargs["device"]
|
|
||||||
else:
|
|
||||||
if torch.cuda.is_available():
|
|
||||||
device = "cuda"
|
|
||||||
else:
|
|
||||||
device = "cpu"
|
|
||||||
|
|
||||||
self.waveform = waveform.to(device)
|
self.waveform = waveform.to(device)
|
||||||
self.sr = sr
|
self.sr = sr
|
||||||
@@ -40,9 +69,13 @@ class AudioProcessor:
|
|||||||
@classmethod
|
@classmethod
|
||||||
def from_file(cls, file: str, *args, **kwargs) -> 'AudioProcessor':
|
def from_file(cls, file: str, *args, **kwargs) -> 'AudioProcessor':
|
||||||
"""
|
"""
|
||||||
Load audio file
|
Create an AudioProcessor instance from an audio file.
|
||||||
:param file: audio file
|
|
||||||
:return: AudioProcessor
|
Args:
|
||||||
|
file (str): The audio file path.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
AudioProcessor: An instance of the AudioProcessor class containing the loaded audio.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
audio, sr = cls.load_audio(file , *args, **kwargs)
|
audio, sr = cls.load_audio(file , *args, **kwargs)
|
||||||
@@ -54,42 +87,37 @@ class AudioProcessor:
|
|||||||
|
|
||||||
def cut(self, start: float, end: float) -> torch.Tensor:
|
def cut(self, start: float, end: float) -> torch.Tensor:
|
||||||
"""
|
"""
|
||||||
Cut audio file
|
Cut a segment from the audio waveform between the specified start and end times.
|
||||||
:param start: start time in seconds
|
|
||||||
:param end: end time in seconds
|
Args:
|
||||||
:return: AudioProcessor
|
start (float): Start time in seconds.
|
||||||
|
end (float): End time in seconds.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
torch.Tensor: The cut waveform segment.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
if isinstance(start, float):
|
start = int(start * self.sr)
|
||||||
start = torch.Tensor([start])
|
end = int(torch.ceil(end * self.sr))
|
||||||
if isinstance(end, float):
|
return self.waveform[start:end]
|
||||||
end = torch.Tensor([end])
|
|
||||||
|
|
||||||
sr = torch.Tensor([self.sr])
|
|
||||||
|
|
||||||
start = int(start * sr)
|
|
||||||
end = torch.ceil(end * sr)
|
|
||||||
|
|
||||||
return self.waveform[start:end.to(int)]
|
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
def load_audio(file: str, sr: int = SAMPLE_RATE):
|
def load_audio(file: str, sr: int = SAMPLE_RATE):
|
||||||
"""
|
"""
|
||||||
Open an audio file and read as mono waveform, resampling as necessary
|
Open an audio file and read it as a mono waveform, resampling if necessary.
|
||||||
|
This method ensures compatibility with pyannote.audio
|
||||||
|
and requires the ffmpeg CLI in PATH.
|
||||||
|
|
||||||
Changed from original function at whisper.audio.load_audio to ensure
|
Args:
|
||||||
compatibility with pyannote.audio
|
file (str): The audio file to open.
|
||||||
Parameters
|
sr (int, optional): The desired sample rate. Defaults to SAMPLE_RATE.
|
||||||
----------
|
|
||||||
file: str
|
|
||||||
The audio file to open
|
|
||||||
|
|
||||||
sr: int
|
Returns:
|
||||||
The sample rate to resample the audio if necessary
|
tuple: A NumPy array containing the audio waveform in float32 dtype
|
||||||
|
and the sample rate.
|
||||||
|
|
||||||
Returns
|
Raises:
|
||||||
-------
|
RuntimeError: If failed to load audio.
|
||||||
A NumPy array containing the audio waveform, in float32 dtype.
|
|
||||||
"""
|
"""
|
||||||
# This launches a subprocess to decode audio while down-mixing
|
# This launches a subprocess to decode audio while down-mixing
|
||||||
# and resampling as necessary. Requires the ffmpeg CLI in PATH.
|
# and resampling as necessary. Requires the ffmpeg CLI in PATH.
|
||||||
@@ -111,18 +139,9 @@ class AudioProcessor:
|
|||||||
except CalledProcessError as e:
|
except CalledProcessError as e:
|
||||||
raise RuntimeError(f"Failed to load audio: {e.stderr.decode()}") from e
|
raise RuntimeError(f"Failed to load audio: {e.stderr.decode()}") from e
|
||||||
|
|
||||||
out = np.frombuffer(out, np.int16).flatten().astype(np.float32) / 32768.0
|
out = np.frombuffer(out, np.int16).flatten().astype(np.float32) / NORMALIZATION_FACTOR
|
||||||
|
|
||||||
return out , sr
|
return out , sr
|
||||||
|
|
||||||
def __repr__(self) -> str:
|
def __repr__(self) -> str:
|
||||||
return f'TorchAudioProcessor(waveform={len(self.waveform)}, sr={int(self.sr)})'
|
return f'TorchAudioProcessor(waveform={len(self.waveform)}, sr={int(self.sr)})'
|
||||||
|
|
||||||
def __str__(self) -> str:
|
|
||||||
return f'TorchAudioProcessor(waveform={len(self.waveform)}, sr={int(self.sr)})'
|
|
||||||
|
|
||||||
|
|
||||||
if __name__ == "__main__":
|
|
||||||
|
|
||||||
print("Testing AudioProcessor")
|
|
||||||
print(AudioProcessor.from_file("tests/test.wav"))
|
|
||||||
@@ -1,7 +1,32 @@
|
|||||||
"""
|
"""
|
||||||
Diarisation class.
|
Diarisation Class
|
||||||
This class is used to diarize an audio file using a pretrained model
|
=================
|
||||||
|
|
||||||
|
This class serves as the heart of the speaker diarization system, responsible for identifying
|
||||||
|
and segmenting individual speakers from a given audio file. It leverages a pretrained model
|
||||||
|
from pyannote.audio, providing an accessible interface for audio processing tasks such as
|
||||||
|
speaker separation, and timestamping.
|
||||||
|
|
||||||
|
By encapsulating the complexities of the underlying model, it allows for straightforward
|
||||||
|
integration into various applications, ranging from transcription services to voice assistants.
|
||||||
|
|
||||||
|
Available Classes:
|
||||||
|
- Diariser: Main class for performing speaker diarization.
|
||||||
|
Includes methods for loading models, processing audio files,
|
||||||
|
and formatting the diarization output.
|
||||||
|
|
||||||
|
Constants:
|
||||||
|
- TOKEN_PATH (str): Path to the Pyannote token.
|
||||||
|
- PYANNOTE_DEFAULT_PATH (str): Default path to Pyannote models.
|
||||||
|
- PYANNOTE_DEFAULT_CONFIG (str): Default configuration for Pyannote models.
|
||||||
|
|
||||||
|
Usage:
|
||||||
|
from .diarisation import Diariser
|
||||||
|
|
||||||
|
model = Diariser.load_model(model="path/to/model/config.yaml")
|
||||||
|
diarisation_output = model.diarization("path/to/audiofile.wav")
|
||||||
"""
|
"""
|
||||||
|
|
||||||
import os
|
import os
|
||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
from typing import TypeVar, Union
|
from typing import TypeVar, Union
|
||||||
@@ -10,7 +35,7 @@ from pyannote.audio import Pipeline
|
|||||||
from pyannote.audio.pipelines.speaker_diarization import SpeakerDiarization
|
from pyannote.audio.pipelines.speaker_diarization import SpeakerDiarization
|
||||||
from torch import Tensor
|
from torch import Tensor
|
||||||
|
|
||||||
from .misc import PYANNOTE_DEFAULT_CONFIG, PYANNOTE_DEFAULT_PATH
|
from .misc import PYANNOTE_DEFAULT_PATH, PYANNOTE_DEFAULT_CONFIG
|
||||||
Annotation = TypeVar('Annotation')
|
Annotation = TypeVar('Annotation')
|
||||||
|
|
||||||
TOKEN_PATH = os.path.join(os.path.dirname(
|
TOKEN_PATH = os.path.join(os.path.dirname(
|
||||||
@@ -18,11 +43,13 @@ TOKEN_PATH = os.path.join(os.path.dirname(
|
|||||||
|
|
||||||
class Diariser:
|
class Diariser:
|
||||||
"""
|
"""
|
||||||
Diarisation class
|
Handles the diarization process of an audio file using a pretrained model
|
||||||
This class is used to diarize an audio file using a pretrained model
|
from pyannote.audio. Diarization is the task of determining "who spoke when."
|
||||||
from pyannote.audio.
|
|
||||||
:param model: model to use for diarization
|
Args:
|
||||||
|
model: The pretrained model to use for diarization.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
def __init__(self, model) -> None:
|
def __init__(self, model) -> None:
|
||||||
|
|
||||||
self.model = model
|
self.model = model
|
||||||
@@ -30,11 +57,20 @@ class Diariser:
|
|||||||
def diarization(self, audiofile : Union[str, Tensor, dict] ,
|
def diarization(self, audiofile : Union[str, Tensor, dict] ,
|
||||||
*args, **kwargs) -> Annotation:
|
*args, **kwargs) -> Annotation:
|
||||||
"""
|
"""
|
||||||
Diarization of audio file
|
Perform speaker diarization on the provided audio file,
|
||||||
:param audiofile: path to audio file or torch.Tensor
|
effectively separating different speakers
|
||||||
:param args: args for diarization model
|
and providing a timestamp for each segment.
|
||||||
:param kwargs: kwargs for diarization model
|
|
||||||
:return: diarization
|
Args:
|
||||||
|
audiofile: The path to the audio file or a torch.Tensor
|
||||||
|
containing the audio data.
|
||||||
|
args: Additional arguments for the diarization model.
|
||||||
|
kwargs: Additional keyword arguments for the diarization model.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
dict: A dictionary containing speaker names,
|
||||||
|
segments, and other information related
|
||||||
|
to the diarization process.
|
||||||
"""
|
"""
|
||||||
kwargs = self._get_diarisation_kwargs(**kwargs)
|
kwargs = self._get_diarisation_kwargs(**kwargs)
|
||||||
|
|
||||||
@@ -47,10 +83,14 @@ class Diariser:
|
|||||||
@staticmethod
|
@staticmethod
|
||||||
def format_diarization_output(dia : Annotation) -> dict:
|
def format_diarization_output(dia : Annotation) -> dict:
|
||||||
"""
|
"""
|
||||||
Format diarization output to a list of tuples
|
Formats the raw diarization output into a more usable structure for this project.
|
||||||
:param dia: diarization output
|
|
||||||
:return: dict with speaker names as keys and list of tuples
|
Args:
|
||||||
as values and list of different speakers
|
dia: Raw diarization output.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
dict: A structured representation of the diarization, with speaker names
|
||||||
|
as keys and a list of tuples representing segments as values.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
dia_list = list(dia.itertracks(yield_label=True))
|
dia_list = list(dia.itertracks(yield_label=True))
|
||||||
@@ -103,10 +143,14 @@ class Diariser:
|
|||||||
@staticmethod
|
@staticmethod
|
||||||
def _get_token():
|
def _get_token():
|
||||||
"""
|
"""
|
||||||
Get token from .pyannotetoken.txt
|
Retrieves the Huggingface token from a local file. This token is required
|
||||||
:raises ValueError: No token found
|
for accessing certain online resources.
|
||||||
:return: Huggingface token
|
|
||||||
:rtype: str
|
Raises:
|
||||||
|
ValueError: If the token is not found.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
str: The Huggingface token.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
if os.path.exists(TOKEN_PATH):
|
if os.path.exists(TOKEN_PATH):
|
||||||
@@ -121,12 +165,13 @@ class Diariser:
|
|||||||
@staticmethod
|
@staticmethod
|
||||||
def _save_token(token):
|
def _save_token(token):
|
||||||
"""
|
"""
|
||||||
Save token to .pyannotetoken.txt
|
Saves the provided Huggingface token to a local file. This facilitates future
|
||||||
|
access to online resources without needing to repeatedly authenticate.
|
||||||
|
|
||||||
:param token: Huggingface token
|
Args:
|
||||||
:type token: str
|
token: The Huggingface token to save.
|
||||||
"""
|
"""
|
||||||
with open(TOKEN_PATH, 'r', encoding="utf-8") as file:
|
with open(TOKEN_PATH, 'w', encoding="utf-8") as file:
|
||||||
file.write(token)
|
file.write(token)
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
@@ -137,22 +182,21 @@ class Diariser:
|
|||||||
cache_dir: Union[Path, str] = PYANNOTE_DEFAULT_PATH,
|
cache_dir: Union[Path, str] = PYANNOTE_DEFAULT_PATH,
|
||||||
hparams_file: Union[str, Path] = None
|
hparams_file: Union[str, Path] = None
|
||||||
) -> Pipeline:
|
) -> Pipeline:
|
||||||
"""
|
|
||||||
Load modules from pyannote
|
|
||||||
|
|
||||||
Parameters
|
|
||||||
----------
|
|
||||||
model : str
|
|
||||||
pyannote model
|
|
||||||
default: /models/pyannote/speaker_diarization/config.yaml
|
|
||||||
token : str
|
|
||||||
HUGGINGFACE_TOKEN
|
|
||||||
local : bool
|
|
||||||
If true, load from local cache
|
|
||||||
|
|
||||||
Returns
|
"""
|
||||||
-------
|
Loads a pretrained model from pyannote.audio,
|
||||||
Pipeline Object
|
either from a local cache or online repository.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
model: Path or identifier for the pyannote model.
|
||||||
|
default: /models/pyannote/speaker_diarization/config.yaml
|
||||||
|
token: Optional HUGGINGFACE_TOKEN for authenticated access.
|
||||||
|
cache_token: Whether to cache the token locally for future use.
|
||||||
|
cache_dir: Directory for caching models.
|
||||||
|
hparams_file: Path to a YAML file containing hyperparameters.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Pipeline: A pyannote.audio Pipeline object, encapsulating the loaded model.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
if cache_token and token is not None:
|
if cache_token and token is not None:
|
||||||
@@ -161,38 +205,35 @@ class Diariser:
|
|||||||
if not os.path.exists(model) and token is None:
|
if not os.path.exists(model) and token is None:
|
||||||
token = cls._get_token()
|
token = cls._get_token()
|
||||||
model = 'pyannote/speaker-diarization'
|
model = 'pyannote/speaker-diarization'
|
||||||
|
|
||||||
_model = Pipeline.from_pretrained(model,
|
_model = Pipeline.from_pretrained(model,
|
||||||
use_auth_token = token,
|
use_auth_token = token,
|
||||||
cache_dir = cache_dir,
|
cache_dir = cache_dir,
|
||||||
hparams_file = hparams_file,)
|
hparams_file = hparams_file,)
|
||||||
|
|
||||||
if model is None:
|
if _model is None:
|
||||||
raise ValueError('Unable to load model either from local cache' \
|
raise ValueError('Unable to load model either from local cache' \
|
||||||
'or from huggingface.co models. Please check your token' \
|
'or from huggingface.co models. Please check your token' \
|
||||||
'or your local model path')
|
'or your local model path')
|
||||||
|
|
||||||
return cls(_model)
|
return cls(_model)
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
def _get_diarisation_kwargs(**kwargs) -> dict:
|
def _get_diarisation_kwargs(**kwargs) -> dict:
|
||||||
"""
|
"""
|
||||||
Get kwargs for pyannote diarization model
|
Validates and extracts the keyword arguments for the pyannote diarization model.
|
||||||
Ensure that kwargs are valid
|
|
||||||
:return: kwargs for pyannote diarization model
|
Ensures that the provided keyword arguments match the expected parameters,
|
||||||
:rtype: dict
|
filtering out any invalid or unnecessary arguments.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
dict: A dictionary containing the validated keyword arguments.
|
||||||
"""
|
"""
|
||||||
_possible_kwargs = SpeakerDiarization.apply.__code__.co_varnames
|
_possible_kwargs = SpeakerDiarization.apply.__code__.co_varnames
|
||||||
|
|
||||||
diarisation_kwargs = dict()
|
diarisation_kwargs = {k: v for k, v in kwargs.items() if k in _possible_kwargs}
|
||||||
|
|
||||||
for k in kwargs.keys():
|
|
||||||
if k in _possible_kwargs:
|
|
||||||
diarisation_kwargs[k] = kwargs[k]
|
|
||||||
|
|
||||||
return diarisation_kwargs
|
return diarisation_kwargs
|
||||||
|
|
||||||
def __repr__(self):
|
def __repr__(self):
|
||||||
return f"Diarisation(model={self.model})"
|
return f"Diarisation(model={self.model})"
|
||||||
|
|
||||||
def __str__(self):
|
|
||||||
return f"Diarisation(model={self.model})"
|
|
||||||
|
|||||||
Reference in New Issue
Block a user