unifyed documentation

This commit is contained in:
Jaikinator
2023-08-23 13:17:13 +02:00
parent a21bc32f7d
commit d2c57866df
2 changed files with 173 additions and 113 deletions
+95 -54
View File
@@ -1,7 +1,32 @@
"""
Diarisation class.
This class is used to diarize an audio file using a pretrained model
Diarisation Class
=================
This class serves as the heart of the speaker diarization system, responsible for identifying
and segmenting individual speakers from a given audio file. It leverages a pretrained model
from pyannote.audio, providing an accessible interface for audio processing tasks such as
speaker separation, and timestamping.
By encapsulating the complexities of the underlying model, it allows for straightforward
integration into various applications, ranging from transcription services to voice assistants.
Available Classes:
- Diariser: Main class for performing speaker diarization.
Includes methods for loading models, processing audio files,
and formatting the diarization output.
Constants:
- TOKEN_PATH (str): Path to the Pyannote token.
- PYANNOTE_DEFAULT_PATH (str): Default path to Pyannote models.
- PYANNOTE_DEFAULT_CONFIG (str): Default configuration for Pyannote models.
Usage:
from .diarisation import Diariser
model = Diariser.load_model(model="path/to/model/config.yaml")
diarisation_output = model.diarization("path/to/audiofile.wav")
"""
import os
from pathlib import Path
from typing import TypeVar, Union
@@ -10,7 +35,7 @@ from pyannote.audio import Pipeline
from pyannote.audio.pipelines.speaker_diarization import SpeakerDiarization
from torch import Tensor
from .misc import PYANNOTE_DEFAULT_CONFIG, PYANNOTE_DEFAULT_PATH
from .misc import PYANNOTE_DEFAULT_PATH, PYANNOTE_DEFAULT_CONFIG
Annotation = TypeVar('Annotation')
TOKEN_PATH = os.path.join(os.path.dirname(
@@ -18,11 +43,13 @@ TOKEN_PATH = os.path.join(os.path.dirname(
class Diariser:
"""
Diarisation class
This class is used to diarize an audio file using a pretrained model
from pyannote.audio.
:param model: model to use for diarization
Handles the diarization process of an audio file using a pretrained model
from pyannote.audio. Diarization is the task of determining "who spoke when."
Args:
model: The pretrained model to use for diarization.
"""
def __init__(self, model) -> None:
self.model = model
@@ -30,11 +57,20 @@ class Diariser:
def diarization(self, audiofile : Union[str, Tensor, dict] ,
*args, **kwargs) -> Annotation:
"""
Diarization of audio file
:param audiofile: path to audio file or torch.Tensor
:param args: args for diarization model
:param kwargs: kwargs for diarization model
:return: diarization
Perform speaker diarization on the provided audio file,
effectively separating different speakers
and providing a timestamp for each segment.
Args:
audiofile: The path to the audio file or a torch.Tensor
containing the audio data.
args: Additional arguments for the diarization model.
kwargs: Additional keyword arguments for the diarization model.
Returns:
dict: A dictionary containing speaker names,
segments, and other information related
to the diarization process.
"""
kwargs = self._get_diarisation_kwargs(**kwargs)
@@ -47,10 +83,14 @@ class Diariser:
@staticmethod
def format_diarization_output(dia : Annotation) -> dict:
"""
Format diarization output to a list of tuples
:param dia: diarization output
:return: dict with speaker names as keys and list of tuples
as values and list of different speakers
Formats the raw diarization output into a more usable structure for this project.
Args:
dia: Raw diarization output.
Returns:
dict: A structured representation of the diarization, with speaker names
as keys and a list of tuples representing segments as values.
"""
dia_list = list(dia.itertracks(yield_label=True))
@@ -103,10 +143,14 @@ class Diariser:
@staticmethod
def _get_token():
"""
Get token from .pyannotetoken.txt
:raises ValueError: No token found
:return: Huggingface token
:rtype: str
Retrieves the Huggingface token from a local file. This token is required
for accessing certain online resources.
Raises:
ValueError: If the token is not found.
Returns:
str: The Huggingface token.
"""
if os.path.exists(TOKEN_PATH):
@@ -121,12 +165,13 @@ class Diariser:
@staticmethod
def _save_token(token):
"""
Save token to .pyannotetoken.txt
Saves the provided Huggingface token to a local file. This facilitates future
access to online resources without needing to repeatedly authenticate.
:param token: Huggingface token
:type token: str
Args:
token: The Huggingface token to save.
"""
with open(TOKEN_PATH, 'r', encoding="utf-8") as file:
with open(TOKEN_PATH, 'w', encoding="utf-8") as file:
file.write(token)
@classmethod
@@ -137,22 +182,21 @@ class Diariser:
cache_dir: Union[Path, str] = PYANNOTE_DEFAULT_PATH,
hparams_file: Union[str, Path] = None
) -> Pipeline:
"""
Load modules from pyannote
Parameters
----------
model : str
pyannote model
default: /models/pyannote/speaker_diarization/config.yaml
token : str
HUGGINGFACE_TOKEN
local : bool
If true, load from local cache
Returns
-------
Pipeline Object
"""
Loads a pretrained model from pyannote.audio,
either from a local cache or online repository.
Args:
model: Path or identifier for the pyannote model.
default: /models/pyannote/speaker_diarization/config.yaml
token: Optional HUGGINGFACE_TOKEN for authenticated access.
cache_token: Whether to cache the token locally for future use.
cache_dir: Directory for caching models.
hparams_file: Path to a YAML file containing hyperparameters.
Returns:
Pipeline: A pyannote.audio Pipeline object, encapsulating the loaded model.
"""
if cache_token and token is not None:
@@ -161,38 +205,35 @@ class Diariser:
if not os.path.exists(model) and token is None:
token = cls._get_token()
model = 'pyannote/speaker-diarization'
_model = Pipeline.from_pretrained(model,
use_auth_token = token,
cache_dir = cache_dir,
hparams_file = hparams_file,)
if model is None:
if _model is None:
raise ValueError('Unable to load model either from local cache' \
'or from huggingface.co models. Please check your token' \
'or your local model path')
return cls(_model)
@staticmethod
def _get_diarisation_kwargs(**kwargs) -> dict:
"""
Get kwargs for pyannote diarization model
Ensure that kwargs are valid
:return: kwargs for pyannote diarization model
:rtype: dict
Validates and extracts the keyword arguments for the pyannote diarization model.
Ensures that the provided keyword arguments match the expected parameters,
filtering out any invalid or unnecessary arguments.
Returns:
dict: A dictionary containing the validated keyword arguments.
"""
_possible_kwargs = SpeakerDiarization.apply.__code__.co_varnames
diarisation_kwargs = dict()
for k in kwargs.keys():
if k in _possible_kwargs:
diarisation_kwargs[k] = kwargs[k]
diarisation_kwargs = {k: v for k, v in kwargs.items() if k in _possible_kwargs}
return diarisation_kwargs
def __repr__(self):
return f"Diarisation(model={self.model})"
def __str__(self):
return f"Diarisation(model={self.model})"