unifyed documentation

This commit is contained in:
Jaikinator
2023-08-23 13:17:13 +02:00
parent a21bc32f7d
commit d2c57866df
2 changed files with 173 additions and 113 deletions
+78 -59
View File
@@ -1,34 +1,63 @@
"""
Audio Processor Module
=======================
This module provides the AudioProcessor class, utilizing PyTorchaudio for handling audio files.
It includes functionalities to load, cut, and manage audio waveforms, offering efficient and
flexible audio processing.
Available Classes:
- AudioProcessor: Processes audio waveforms and provides methods for loading,
cutting, and handling audio.
Usage:
from .audio_import AudioProcessor
processor = AudioProcessor.from_file("path/to/audiofile.wav")
cut_waveform = processor.cut(start=1.0, end=5.0)
Constants:
- SAMPLE_RATE (int): Default sample rate for processing.
- NORMALIZATION_FACTOR (float): Normalization factor for audio waveform.
"""
from subprocess import CalledProcessError, run
import numpy as np import numpy as np
import torch import torch
from subprocess import CalledProcessError, run
from typing import Union
SAMPLE_RATE = 16000 SAMPLE_RATE = 16000
NORMALIZATION_FACTOR = 32768.0
class AudioProcessor: class AudioProcessor:
""" """
Audio Processor using PyTorchaudio instead of PyDub Audio Processor class that leverages PyTorchaudio to provide functionalities
for loading, cutting, and handling audio waveforms.
Attributes:
waveform: torch.Tensor
The audio waveform tensor.
sr: int
The sample rate of the audio.
""" """
def __init__(self, waveform: torch.Tensor, sr : int = SAMPLE_RATE, def __init__(self, waveform: torch.Tensor, sr : int = SAMPLE_RATE,
*args, **kwargs) -> None: *args, **kwargs) -> None:
""" """
Initialise audio processor Initialize the AudioProcessor object.
:param waveform: waveform
:param sr: sample rate Args:
:param args: additional arguments waveform (torch.Tensor): The audio waveform tensor.
:param kwargs: additional keyword arguments sr (int, optional): The sample rate of the audio. Defaults to SAMPLE_RATE.
example: args: Additional arguments.
- device: device to use for processing kwargs: Additional keyword arguments, e.g., device to use for processing.
if cuda is available, cuda is used If CUDA is available, it defaults to CUDA.
Raises:
ValueError: If the provided sample rate is not of type int.
""" """
if "device" in kwargs: device = kwargs.get("device", "cuda" if torch.cuda.is_available() else "cpu")
device = kwargs["device"]
else:
if torch.cuda.is_available():
device = "cuda"
else:
device = "cpu"
self.waveform = waveform.to(device) self.waveform = waveform.to(device)
self.sr = sr self.sr = sr
@@ -40,9 +69,13 @@ class AudioProcessor:
@classmethod @classmethod
def from_file(cls, file: str, *args, **kwargs) -> 'AudioProcessor': def from_file(cls, file: str, *args, **kwargs) -> 'AudioProcessor':
""" """
Load audio file Create an AudioProcessor instance from an audio file.
:param file: audio file
:return: AudioProcessor Args:
file (str): The audio file path.
Returns:
AudioProcessor: An instance of the AudioProcessor class containing the loaded audio.
""" """
audio, sr = cls.load_audio(file , *args, **kwargs) audio, sr = cls.load_audio(file , *args, **kwargs)
@@ -54,42 +87,37 @@ class AudioProcessor:
def cut(self, start: float, end: float) -> torch.Tensor: def cut(self, start: float, end: float) -> torch.Tensor:
""" """
Cut audio file Cut a segment from the audio waveform between the specified start and end times.
:param start: start time in seconds
:param end: end time in seconds Args:
:return: AudioProcessor start (float): Start time in seconds.
end (float): End time in seconds.
Returns:
torch.Tensor: The cut waveform segment.
""" """
if isinstance(start, float): start = int(start * self.sr)
start = torch.Tensor([start]) end = int(torch.ceil(end * self.sr))
if isinstance(end, float): return self.waveform[start:end]
end = torch.Tensor([end])
sr = torch.Tensor([self.sr])
start = int(start * sr)
end = torch.ceil(end * sr)
return self.waveform[start:end.to(int)]
@staticmethod @staticmethod
def load_audio(file: str, sr: int = SAMPLE_RATE): def load_audio(file: str, sr: int = SAMPLE_RATE):
""" """
Open an audio file and read as mono waveform, resampling as necessary Open an audio file and read it as a mono waveform, resampling if necessary.
This method ensures compatibility with pyannote.audio
and requires the ffmpeg CLI in PATH.
Changed from original function at whisper.audio.load_audio to ensure Args:
compatibility with pyannote.audio file (str): The audio file to open.
Parameters sr (int, optional): The desired sample rate. Defaults to SAMPLE_RATE.
----------
file: str
The audio file to open
sr: int Returns:
The sample rate to resample the audio if necessary tuple: A NumPy array containing the audio waveform in float32 dtype
and the sample rate.
Returns Raises:
------- RuntimeError: If failed to load audio.
A NumPy array containing the audio waveform, in float32 dtype.
""" """
# This launches a subprocess to decode audio while down-mixing # This launches a subprocess to decode audio while down-mixing
# and resampling as necessary. Requires the ffmpeg CLI in PATH. # and resampling as necessary. Requires the ffmpeg CLI in PATH.
@@ -111,18 +139,9 @@ class AudioProcessor:
except CalledProcessError as e: except CalledProcessError as e:
raise RuntimeError(f"Failed to load audio: {e.stderr.decode()}") from e raise RuntimeError(f"Failed to load audio: {e.stderr.decode()}") from e
out = np.frombuffer(out, np.int16).flatten().astype(np.float32) / 32768.0 out = np.frombuffer(out, np.int16).flatten().astype(np.float32) / NORMALIZATION_FACTOR
return out , sr return out , sr
def __repr__(self) -> str: def __repr__(self) -> str:
return f'TorchAudioProcessor(waveform={len(self.waveform)}, sr={int(self.sr)})' return f'TorchAudioProcessor(waveform={len(self.waveform)}, sr={int(self.sr)})'
def __str__(self) -> str:
return f'TorchAudioProcessor(waveform={len(self.waveform)}, sr={int(self.sr)})'
if __name__ == "__main__":
print("Testing AudioProcessor")
print(AudioProcessor.from_file("tests/test.wav"))
+95 -54
View File
@@ -1,7 +1,32 @@
""" """
Diarisation class. Diarisation Class
This class is used to diarize an audio file using a pretrained model =================
This class serves as the heart of the speaker diarization system, responsible for identifying
and segmenting individual speakers from a given audio file. It leverages a pretrained model
from pyannote.audio, providing an accessible interface for audio processing tasks such as
speaker separation, and timestamping.
By encapsulating the complexities of the underlying model, it allows for straightforward
integration into various applications, ranging from transcription services to voice assistants.
Available Classes:
- Diariser: Main class for performing speaker diarization.
Includes methods for loading models, processing audio files,
and formatting the diarization output.
Constants:
- TOKEN_PATH (str): Path to the Pyannote token.
- PYANNOTE_DEFAULT_PATH (str): Default path to Pyannote models.
- PYANNOTE_DEFAULT_CONFIG (str): Default configuration for Pyannote models.
Usage:
from .diarisation import Diariser
model = Diariser.load_model(model="path/to/model/config.yaml")
diarisation_output = model.diarization("path/to/audiofile.wav")
""" """
import os import os
from pathlib import Path from pathlib import Path
from typing import TypeVar, Union from typing import TypeVar, Union
@@ -10,7 +35,7 @@ from pyannote.audio import Pipeline
from pyannote.audio.pipelines.speaker_diarization import SpeakerDiarization from pyannote.audio.pipelines.speaker_diarization import SpeakerDiarization
from torch import Tensor from torch import Tensor
from .misc import PYANNOTE_DEFAULT_CONFIG, PYANNOTE_DEFAULT_PATH from .misc import PYANNOTE_DEFAULT_PATH, PYANNOTE_DEFAULT_CONFIG
Annotation = TypeVar('Annotation') Annotation = TypeVar('Annotation')
TOKEN_PATH = os.path.join(os.path.dirname( TOKEN_PATH = os.path.join(os.path.dirname(
@@ -18,11 +43,13 @@ TOKEN_PATH = os.path.join(os.path.dirname(
class Diariser: class Diariser:
""" """
Diarisation class Handles the diarization process of an audio file using a pretrained model
This class is used to diarize an audio file using a pretrained model from pyannote.audio. Diarization is the task of determining "who spoke when."
from pyannote.audio.
:param model: model to use for diarization Args:
model: The pretrained model to use for diarization.
""" """
def __init__(self, model) -> None: def __init__(self, model) -> None:
self.model = model self.model = model
@@ -30,11 +57,20 @@ class Diariser:
def diarization(self, audiofile : Union[str, Tensor, dict] , def diarization(self, audiofile : Union[str, Tensor, dict] ,
*args, **kwargs) -> Annotation: *args, **kwargs) -> Annotation:
""" """
Diarization of audio file Perform speaker diarization on the provided audio file,
:param audiofile: path to audio file or torch.Tensor effectively separating different speakers
:param args: args for diarization model and providing a timestamp for each segment.
:param kwargs: kwargs for diarization model
:return: diarization Args:
audiofile: The path to the audio file or a torch.Tensor
containing the audio data.
args: Additional arguments for the diarization model.
kwargs: Additional keyword arguments for the diarization model.
Returns:
dict: A dictionary containing speaker names,
segments, and other information related
to the diarization process.
""" """
kwargs = self._get_diarisation_kwargs(**kwargs) kwargs = self._get_diarisation_kwargs(**kwargs)
@@ -47,10 +83,14 @@ class Diariser:
@staticmethod @staticmethod
def format_diarization_output(dia : Annotation) -> dict: def format_diarization_output(dia : Annotation) -> dict:
""" """
Format diarization output to a list of tuples Formats the raw diarization output into a more usable structure for this project.
:param dia: diarization output
:return: dict with speaker names as keys and list of tuples Args:
as values and list of different speakers dia: Raw diarization output.
Returns:
dict: A structured representation of the diarization, with speaker names
as keys and a list of tuples representing segments as values.
""" """
dia_list = list(dia.itertracks(yield_label=True)) dia_list = list(dia.itertracks(yield_label=True))
@@ -103,10 +143,14 @@ class Diariser:
@staticmethod @staticmethod
def _get_token(): def _get_token():
""" """
Get token from .pyannotetoken.txt Retrieves the Huggingface token from a local file. This token is required
:raises ValueError: No token found for accessing certain online resources.
:return: Huggingface token
:rtype: str Raises:
ValueError: If the token is not found.
Returns:
str: The Huggingface token.
""" """
if os.path.exists(TOKEN_PATH): if os.path.exists(TOKEN_PATH):
@@ -121,12 +165,13 @@ class Diariser:
@staticmethod @staticmethod
def _save_token(token): def _save_token(token):
""" """
Save token to .pyannotetoken.txt Saves the provided Huggingface token to a local file. This facilitates future
access to online resources without needing to repeatedly authenticate.
:param token: Huggingface token Args:
:type token: str token: The Huggingface token to save.
""" """
with open(TOKEN_PATH, 'r', encoding="utf-8") as file: with open(TOKEN_PATH, 'w', encoding="utf-8") as file:
file.write(token) file.write(token)
@classmethod @classmethod
@@ -137,22 +182,21 @@ class Diariser:
cache_dir: Union[Path, str] = PYANNOTE_DEFAULT_PATH, cache_dir: Union[Path, str] = PYANNOTE_DEFAULT_PATH,
hparams_file: Union[str, Path] = None hparams_file: Union[str, Path] = None
) -> Pipeline: ) -> Pipeline:
"""
Load modules from pyannote
Parameters
----------
model : str
pyannote model
default: /models/pyannote/speaker_diarization/config.yaml
token : str
HUGGINGFACE_TOKEN
local : bool
If true, load from local cache
Returns """
------- Loads a pretrained model from pyannote.audio,
Pipeline Object either from a local cache or online repository.
Args:
model: Path or identifier for the pyannote model.
default: /models/pyannote/speaker_diarization/config.yaml
token: Optional HUGGINGFACE_TOKEN for authenticated access.
cache_token: Whether to cache the token locally for future use.
cache_dir: Directory for caching models.
hparams_file: Path to a YAML file containing hyperparameters.
Returns:
Pipeline: A pyannote.audio Pipeline object, encapsulating the loaded model.
""" """
if cache_token and token is not None: if cache_token and token is not None:
@@ -161,38 +205,35 @@ class Diariser:
if not os.path.exists(model) and token is None: if not os.path.exists(model) and token is None:
token = cls._get_token() token = cls._get_token()
model = 'pyannote/speaker-diarization' model = 'pyannote/speaker-diarization'
_model = Pipeline.from_pretrained(model, _model = Pipeline.from_pretrained(model,
use_auth_token = token, use_auth_token = token,
cache_dir = cache_dir, cache_dir = cache_dir,
hparams_file = hparams_file,) hparams_file = hparams_file,)
if model is None: if _model is None:
raise ValueError('Unable to load model either from local cache' \ raise ValueError('Unable to load model either from local cache' \
'or from huggingface.co models. Please check your token' \ 'or from huggingface.co models. Please check your token' \
'or your local model path') 'or your local model path')
return cls(_model) return cls(_model)
@staticmethod @staticmethod
def _get_diarisation_kwargs(**kwargs) -> dict: def _get_diarisation_kwargs(**kwargs) -> dict:
""" """
Get kwargs for pyannote diarization model Validates and extracts the keyword arguments for the pyannote diarization model.
Ensure that kwargs are valid
:return: kwargs for pyannote diarization model Ensures that the provided keyword arguments match the expected parameters,
:rtype: dict filtering out any invalid or unnecessary arguments.
Returns:
dict: A dictionary containing the validated keyword arguments.
""" """
_possible_kwargs = SpeakerDiarization.apply.__code__.co_varnames _possible_kwargs = SpeakerDiarization.apply.__code__.co_varnames
diarisation_kwargs = dict() diarisation_kwargs = {k: v for k, v in kwargs.items() if k in _possible_kwargs}
for k in kwargs.keys():
if k in _possible_kwargs:
diarisation_kwargs[k] = kwargs[k]
return diarisation_kwargs return diarisation_kwargs
def __repr__(self): def __repr__(self):
return f"Diarisation(model={self.model})" return f"Diarisation(model={self.model})"
def __str__(self):
return f"Diarisation(model={self.model})"