From d2c57866df503a7aae4d4c5004caae223443bb74 Mon Sep 17 00:00:00 2001 From: Jaikinator Date: Wed, 23 Aug 2023 13:17:13 +0200 Subject: [PATCH] unifyed documentation --- autotranscript/audio.py | 137 +++++++++++++++++-------------- autotranscript/diarisation.py | 149 ++++++++++++++++++++++------------ 2 files changed, 173 insertions(+), 113 deletions(-) diff --git a/autotranscript/audio.py b/autotranscript/audio.py index 7944a73..04feb1d 100644 --- a/autotranscript/audio.py +++ b/autotranscript/audio.py @@ -1,34 +1,63 @@ +""" +Audio Processor Module +======================= + +This module provides the AudioProcessor class, utilizing PyTorchaudio for handling audio files. +It includes functionalities to load, cut, and manage audio waveforms, offering efficient and +flexible audio processing. + +Available Classes: +- AudioProcessor: Processes audio waveforms and provides methods for loading, + cutting, and handling audio. + +Usage: + from .audio_import AudioProcessor + + processor = AudioProcessor.from_file("path/to/audiofile.wav") + cut_waveform = processor.cut(start=1.0, end=5.0) + +Constants: +- SAMPLE_RATE (int): Default sample rate for processing. +- NORMALIZATION_FACTOR (float): Normalization factor for audio waveform. +""" + +from subprocess import CalledProcessError, run import numpy as np import torch -from subprocess import CalledProcessError, run -from typing import Union + SAMPLE_RATE = 16000 +NORMALIZATION_FACTOR = 32768.0 class AudioProcessor: """ - Audio Processor using PyTorchaudio instead of PyDub + Audio Processor class that leverages PyTorchaudio to provide functionalities + for loading, cutting, and handling audio waveforms. + + Attributes: + waveform: torch.Tensor + The audio waveform tensor. + sr: int + The sample rate of the audio. """ def __init__(self, waveform: torch.Tensor, sr : int = SAMPLE_RATE, *args, **kwargs) -> None: + """ - Initialise audio processor - :param waveform: waveform - :param sr: sample rate - :param args: additional arguments - :param kwargs: additional keyword arguments - example: - - device: device to use for processing - if cuda is available, cuda is used + Initialize the AudioProcessor object. + + Args: + waveform (torch.Tensor): The audio waveform tensor. + sr (int, optional): The sample rate of the audio. Defaults to SAMPLE_RATE. + args: Additional arguments. + kwargs: Additional keyword arguments, e.g., device to use for processing. + If CUDA is available, it defaults to CUDA. + + Raises: + ValueError: If the provided sample rate is not of type int. """ - if "device" in kwargs: - device = kwargs["device"] - else: - if torch.cuda.is_available(): - device = "cuda" - else: - device = "cpu" + device = kwargs.get("device", "cuda" if torch.cuda.is_available() else "cpu") self.waveform = waveform.to(device) self.sr = sr @@ -40,9 +69,13 @@ class AudioProcessor: @classmethod def from_file(cls, file: str, *args, **kwargs) -> 'AudioProcessor': """ - Load audio file - :param file: audio file - :return: AudioProcessor + Create an AudioProcessor instance from an audio file. + + Args: + file (str): The audio file path. + + Returns: + AudioProcessor: An instance of the AudioProcessor class containing the loaded audio. """ audio, sr = cls.load_audio(file , *args, **kwargs) @@ -54,42 +87,37 @@ class AudioProcessor: def cut(self, start: float, end: float) -> torch.Tensor: """ - Cut audio file - :param start: start time in seconds - :param end: end time in seconds - :return: AudioProcessor + Cut a segment from the audio waveform between the specified start and end times. + + Args: + start (float): Start time in seconds. + end (float): End time in seconds. + + Returns: + torch.Tensor: The cut waveform segment. """ - if isinstance(start, float): - start = torch.Tensor([start]) - if isinstance(end, float): - end = torch.Tensor([end]) - - sr = torch.Tensor([self.sr]) - - start = int(start * sr) - end = torch.ceil(end * sr) - - return self.waveform[start:end.to(int)] + start = int(start * self.sr) + end = int(torch.ceil(end * self.sr)) + return self.waveform[start:end] @staticmethod def load_audio(file: str, sr: int = SAMPLE_RATE): """ - Open an audio file and read as mono waveform, resampling as necessary + Open an audio file and read it as a mono waveform, resampling if necessary. + This method ensures compatibility with pyannote.audio + and requires the ffmpeg CLI in PATH. - Changed from original function at whisper.audio.load_audio to ensure - compatibility with pyannote.audio - Parameters - ---------- - file: str - The audio file to open + Args: + file (str): The audio file to open. + sr (int, optional): The desired sample rate. Defaults to SAMPLE_RATE. - sr: int - The sample rate to resample the audio if necessary + Returns: + tuple: A NumPy array containing the audio waveform in float32 dtype + and the sample rate. - Returns - ------- - A NumPy array containing the audio waveform, in float32 dtype. + Raises: + RuntimeError: If failed to load audio. """ # This launches a subprocess to decode audio while down-mixing # and resampling as necessary. Requires the ffmpeg CLI in PATH. @@ -111,18 +139,9 @@ class AudioProcessor: except CalledProcessError as e: raise RuntimeError(f"Failed to load audio: {e.stderr.decode()}") from e - out = np.frombuffer(out, np.int16).flatten().astype(np.float32) / 32768.0 + out = np.frombuffer(out, np.int16).flatten().astype(np.float32) / NORMALIZATION_FACTOR return out , sr def __repr__(self) -> str: - return f'TorchAudioProcessor(waveform={len(self.waveform)}, sr={int(self.sr)})' - - def __str__(self) -> str: - return f'TorchAudioProcessor(waveform={len(self.waveform)}, sr={int(self.sr)})' - - -if __name__ == "__main__": - - print("Testing AudioProcessor") - print(AudioProcessor.from_file("tests/test.wav")) \ No newline at end of file + return f'TorchAudioProcessor(waveform={len(self.waveform)}, sr={int(self.sr)})' \ No newline at end of file diff --git a/autotranscript/diarisation.py b/autotranscript/diarisation.py index 5359e3e..0770ea9 100644 --- a/autotranscript/diarisation.py +++ b/autotranscript/diarisation.py @@ -1,7 +1,32 @@ """ -Diarisation class. -This class is used to diarize an audio file using a pretrained model +Diarisation Class +================= + +This class serves as the heart of the speaker diarization system, responsible for identifying +and segmenting individual speakers from a given audio file. It leverages a pretrained model +from pyannote.audio, providing an accessible interface for audio processing tasks such as +speaker separation, and timestamping. + +By encapsulating the complexities of the underlying model, it allows for straightforward +integration into various applications, ranging from transcription services to voice assistants. + +Available Classes: +- Diariser: Main class for performing speaker diarization. + Includes methods for loading models, processing audio files, + and formatting the diarization output. + +Constants: +- TOKEN_PATH (str): Path to the Pyannote token. +- PYANNOTE_DEFAULT_PATH (str): Default path to Pyannote models. +- PYANNOTE_DEFAULT_CONFIG (str): Default configuration for Pyannote models. + +Usage: + from .diarisation import Diariser + + model = Diariser.load_model(model="path/to/model/config.yaml") + diarisation_output = model.diarization("path/to/audiofile.wav") """ + import os from pathlib import Path from typing import TypeVar, Union @@ -10,7 +35,7 @@ from pyannote.audio import Pipeline from pyannote.audio.pipelines.speaker_diarization import SpeakerDiarization from torch import Tensor -from .misc import PYANNOTE_DEFAULT_CONFIG, PYANNOTE_DEFAULT_PATH +from .misc import PYANNOTE_DEFAULT_PATH, PYANNOTE_DEFAULT_CONFIG Annotation = TypeVar('Annotation') TOKEN_PATH = os.path.join(os.path.dirname( @@ -18,11 +43,13 @@ TOKEN_PATH = os.path.join(os.path.dirname( class Diariser: """ - Diarisation class - This class is used to diarize an audio file using a pretrained model - from pyannote.audio. - :param model: model to use for diarization + Handles the diarization process of an audio file using a pretrained model + from pyannote.audio. Diarization is the task of determining "who spoke when." + + Args: + model: The pretrained model to use for diarization. """ + def __init__(self, model) -> None: self.model = model @@ -30,11 +57,20 @@ class Diariser: def diarization(self, audiofile : Union[str, Tensor, dict] , *args, **kwargs) -> Annotation: """ - Diarization of audio file - :param audiofile: path to audio file or torch.Tensor - :param args: args for diarization model - :param kwargs: kwargs for diarization model - :return: diarization + Perform speaker diarization on the provided audio file, + effectively separating different speakers + and providing a timestamp for each segment. + + Args: + audiofile: The path to the audio file or a torch.Tensor + containing the audio data. + args: Additional arguments for the diarization model. + kwargs: Additional keyword arguments for the diarization model. + + Returns: + dict: A dictionary containing speaker names, + segments, and other information related + to the diarization process. """ kwargs = self._get_diarisation_kwargs(**kwargs) @@ -47,10 +83,14 @@ class Diariser: @staticmethod def format_diarization_output(dia : Annotation) -> dict: """ - Format diarization output to a list of tuples - :param dia: diarization output - :return: dict with speaker names as keys and list of tuples - as values and list of different speakers + Formats the raw diarization output into a more usable structure for this project. + + Args: + dia: Raw diarization output. + + Returns: + dict: A structured representation of the diarization, with speaker names + as keys and a list of tuples representing segments as values. """ dia_list = list(dia.itertracks(yield_label=True)) @@ -103,10 +143,14 @@ class Diariser: @staticmethod def _get_token(): """ - Get token from .pyannotetoken.txt - :raises ValueError: No token found - :return: Huggingface token - :rtype: str + Retrieves the Huggingface token from a local file. This token is required + for accessing certain online resources. + + Raises: + ValueError: If the token is not found. + + Returns: + str: The Huggingface token. """ if os.path.exists(TOKEN_PATH): @@ -121,12 +165,13 @@ class Diariser: @staticmethod def _save_token(token): """ - Save token to .pyannotetoken.txt + Saves the provided Huggingface token to a local file. This facilitates future + access to online resources without needing to repeatedly authenticate. - :param token: Huggingface token - :type token: str + Args: + token: The Huggingface token to save. """ - with open(TOKEN_PATH, 'r', encoding="utf-8") as file: + with open(TOKEN_PATH, 'w', encoding="utf-8") as file: file.write(token) @classmethod @@ -137,22 +182,21 @@ class Diariser: cache_dir: Union[Path, str] = PYANNOTE_DEFAULT_PATH, hparams_file: Union[str, Path] = None ) -> Pipeline: - """ - Load modules from pyannote - - Parameters - ---------- - model : str - pyannote model - default: /models/pyannote/speaker_diarization/config.yaml - token : str - HUGGINGFACE_TOKEN - local : bool - If true, load from local cache - Returns - ------- - Pipeline Object + """ + Loads a pretrained model from pyannote.audio, + either from a local cache or online repository. + + Args: + model: Path or identifier for the pyannote model. + default: /models/pyannote/speaker_diarization/config.yaml + token: Optional HUGGINGFACE_TOKEN for authenticated access. + cache_token: Whether to cache the token locally for future use. + cache_dir: Directory for caching models. + hparams_file: Path to a YAML file containing hyperparameters. + + Returns: + Pipeline: A pyannote.audio Pipeline object, encapsulating the loaded model. """ if cache_token and token is not None: @@ -161,38 +205,35 @@ class Diariser: if not os.path.exists(model) and token is None: token = cls._get_token() model = 'pyannote/speaker-diarization' - + _model = Pipeline.from_pretrained(model, use_auth_token = token, cache_dir = cache_dir, hparams_file = hparams_file,) - if model is None: + if _model is None: raise ValueError('Unable to load model either from local cache' \ 'or from huggingface.co models. Please check your token' \ 'or your local model path') + return cls(_model) @staticmethod def _get_diarisation_kwargs(**kwargs) -> dict: """ - Get kwargs for pyannote diarization model - Ensure that kwargs are valid - :return: kwargs for pyannote diarization model - :rtype: dict + Validates and extracts the keyword arguments for the pyannote diarization model. + + Ensures that the provided keyword arguments match the expected parameters, + filtering out any invalid or unnecessary arguments. + + Returns: + dict: A dictionary containing the validated keyword arguments. """ _possible_kwargs = SpeakerDiarization.apply.__code__.co_varnames - - diarisation_kwargs = dict() - - for k in kwargs.keys(): - if k in _possible_kwargs: - diarisation_kwargs[k] = kwargs[k] + + diarisation_kwargs = {k: v for k, v in kwargs.items() if k in _possible_kwargs} return diarisation_kwargs def __repr__(self): return f"Diarisation(model={self.model})" - - def __str__(self): - return f"Diarisation(model={self.model})"