unifyed documentation
This commit is contained in:
@@ -1,7 +1,32 @@
|
||||
"""
|
||||
Diarisation class.
|
||||
This class is used to diarize an audio file using a pretrained model
|
||||
Diarisation Class
|
||||
=================
|
||||
|
||||
This class serves as the heart of the speaker diarization system, responsible for identifying
|
||||
and segmenting individual speakers from a given audio file. It leverages a pretrained model
|
||||
from pyannote.audio, providing an accessible interface for audio processing tasks such as
|
||||
speaker separation, and timestamping.
|
||||
|
||||
By encapsulating the complexities of the underlying model, it allows for straightforward
|
||||
integration into various applications, ranging from transcription services to voice assistants.
|
||||
|
||||
Available Classes:
|
||||
- Diariser: Main class for performing speaker diarization.
|
||||
Includes methods for loading models, processing audio files,
|
||||
and formatting the diarization output.
|
||||
|
||||
Constants:
|
||||
- TOKEN_PATH (str): Path to the Pyannote token.
|
||||
- PYANNOTE_DEFAULT_PATH (str): Default path to Pyannote models.
|
||||
- PYANNOTE_DEFAULT_CONFIG (str): Default configuration for Pyannote models.
|
||||
|
||||
Usage:
|
||||
from .diarisation import Diariser
|
||||
|
||||
model = Diariser.load_model(model="path/to/model/config.yaml")
|
||||
diarisation_output = model.diarization("path/to/audiofile.wav")
|
||||
"""
|
||||
|
||||
import os
|
||||
from pathlib import Path
|
||||
from typing import TypeVar, Union
|
||||
@@ -10,7 +35,7 @@ from pyannote.audio import Pipeline
|
||||
from pyannote.audio.pipelines.speaker_diarization import SpeakerDiarization
|
||||
from torch import Tensor
|
||||
|
||||
from .misc import PYANNOTE_DEFAULT_CONFIG, PYANNOTE_DEFAULT_PATH
|
||||
from .misc import PYANNOTE_DEFAULT_PATH, PYANNOTE_DEFAULT_CONFIG
|
||||
Annotation = TypeVar('Annotation')
|
||||
|
||||
TOKEN_PATH = os.path.join(os.path.dirname(
|
||||
@@ -18,11 +43,13 @@ TOKEN_PATH = os.path.join(os.path.dirname(
|
||||
|
||||
class Diariser:
|
||||
"""
|
||||
Diarisation class
|
||||
This class is used to diarize an audio file using a pretrained model
|
||||
from pyannote.audio.
|
||||
:param model: model to use for diarization
|
||||
Handles the diarization process of an audio file using a pretrained model
|
||||
from pyannote.audio. Diarization is the task of determining "who spoke when."
|
||||
|
||||
Args:
|
||||
model: The pretrained model to use for diarization.
|
||||
"""
|
||||
|
||||
def __init__(self, model) -> None:
|
||||
|
||||
self.model = model
|
||||
@@ -30,11 +57,20 @@ class Diariser:
|
||||
def diarization(self, audiofile : Union[str, Tensor, dict] ,
|
||||
*args, **kwargs) -> Annotation:
|
||||
"""
|
||||
Diarization of audio file
|
||||
:param audiofile: path to audio file or torch.Tensor
|
||||
:param args: args for diarization model
|
||||
:param kwargs: kwargs for diarization model
|
||||
:return: diarization
|
||||
Perform speaker diarization on the provided audio file,
|
||||
effectively separating different speakers
|
||||
and providing a timestamp for each segment.
|
||||
|
||||
Args:
|
||||
audiofile: The path to the audio file or a torch.Tensor
|
||||
containing the audio data.
|
||||
args: Additional arguments for the diarization model.
|
||||
kwargs: Additional keyword arguments for the diarization model.
|
||||
|
||||
Returns:
|
||||
dict: A dictionary containing speaker names,
|
||||
segments, and other information related
|
||||
to the diarization process.
|
||||
"""
|
||||
kwargs = self._get_diarisation_kwargs(**kwargs)
|
||||
|
||||
@@ -47,10 +83,14 @@ class Diariser:
|
||||
@staticmethod
|
||||
def format_diarization_output(dia : Annotation) -> dict:
|
||||
"""
|
||||
Format diarization output to a list of tuples
|
||||
:param dia: diarization output
|
||||
:return: dict with speaker names as keys and list of tuples
|
||||
as values and list of different speakers
|
||||
Formats the raw diarization output into a more usable structure for this project.
|
||||
|
||||
Args:
|
||||
dia: Raw diarization output.
|
||||
|
||||
Returns:
|
||||
dict: A structured representation of the diarization, with speaker names
|
||||
as keys and a list of tuples representing segments as values.
|
||||
"""
|
||||
|
||||
dia_list = list(dia.itertracks(yield_label=True))
|
||||
@@ -103,10 +143,14 @@ class Diariser:
|
||||
@staticmethod
|
||||
def _get_token():
|
||||
"""
|
||||
Get token from .pyannotetoken.txt
|
||||
:raises ValueError: No token found
|
||||
:return: Huggingface token
|
||||
:rtype: str
|
||||
Retrieves the Huggingface token from a local file. This token is required
|
||||
for accessing certain online resources.
|
||||
|
||||
Raises:
|
||||
ValueError: If the token is not found.
|
||||
|
||||
Returns:
|
||||
str: The Huggingface token.
|
||||
"""
|
||||
|
||||
if os.path.exists(TOKEN_PATH):
|
||||
@@ -121,12 +165,13 @@ class Diariser:
|
||||
@staticmethod
|
||||
def _save_token(token):
|
||||
"""
|
||||
Save token to .pyannotetoken.txt
|
||||
Saves the provided Huggingface token to a local file. This facilitates future
|
||||
access to online resources without needing to repeatedly authenticate.
|
||||
|
||||
:param token: Huggingface token
|
||||
:type token: str
|
||||
Args:
|
||||
token: The Huggingface token to save.
|
||||
"""
|
||||
with open(TOKEN_PATH, 'r', encoding="utf-8") as file:
|
||||
with open(TOKEN_PATH, 'w', encoding="utf-8") as file:
|
||||
file.write(token)
|
||||
|
||||
@classmethod
|
||||
@@ -137,22 +182,21 @@ class Diariser:
|
||||
cache_dir: Union[Path, str] = PYANNOTE_DEFAULT_PATH,
|
||||
hparams_file: Union[str, Path] = None
|
||||
) -> Pipeline:
|
||||
"""
|
||||
Load modules from pyannote
|
||||
|
||||
Parameters
|
||||
----------
|
||||
model : str
|
||||
pyannote model
|
||||
default: /models/pyannote/speaker_diarization/config.yaml
|
||||
token : str
|
||||
HUGGINGFACE_TOKEN
|
||||
local : bool
|
||||
If true, load from local cache
|
||||
|
||||
Returns
|
||||
-------
|
||||
Pipeline Object
|
||||
"""
|
||||
Loads a pretrained model from pyannote.audio,
|
||||
either from a local cache or online repository.
|
||||
|
||||
Args:
|
||||
model: Path or identifier for the pyannote model.
|
||||
default: /models/pyannote/speaker_diarization/config.yaml
|
||||
token: Optional HUGGINGFACE_TOKEN for authenticated access.
|
||||
cache_token: Whether to cache the token locally for future use.
|
||||
cache_dir: Directory for caching models.
|
||||
hparams_file: Path to a YAML file containing hyperparameters.
|
||||
|
||||
Returns:
|
||||
Pipeline: A pyannote.audio Pipeline object, encapsulating the loaded model.
|
||||
"""
|
||||
|
||||
if cache_token and token is not None:
|
||||
@@ -161,38 +205,35 @@ class Diariser:
|
||||
if not os.path.exists(model) and token is None:
|
||||
token = cls._get_token()
|
||||
model = 'pyannote/speaker-diarization'
|
||||
|
||||
|
||||
_model = Pipeline.from_pretrained(model,
|
||||
use_auth_token = token,
|
||||
cache_dir = cache_dir,
|
||||
hparams_file = hparams_file,)
|
||||
|
||||
if model is None:
|
||||
if _model is None:
|
||||
raise ValueError('Unable to load model either from local cache' \
|
||||
'or from huggingface.co models. Please check your token' \
|
||||
'or your local model path')
|
||||
|
||||
return cls(_model)
|
||||
|
||||
@staticmethod
|
||||
def _get_diarisation_kwargs(**kwargs) -> dict:
|
||||
"""
|
||||
Get kwargs for pyannote diarization model
|
||||
Ensure that kwargs are valid
|
||||
:return: kwargs for pyannote diarization model
|
||||
:rtype: dict
|
||||
Validates and extracts the keyword arguments for the pyannote diarization model.
|
||||
|
||||
Ensures that the provided keyword arguments match the expected parameters,
|
||||
filtering out any invalid or unnecessary arguments.
|
||||
|
||||
Returns:
|
||||
dict: A dictionary containing the validated keyword arguments.
|
||||
"""
|
||||
_possible_kwargs = SpeakerDiarization.apply.__code__.co_varnames
|
||||
|
||||
diarisation_kwargs = dict()
|
||||
|
||||
for k in kwargs.keys():
|
||||
if k in _possible_kwargs:
|
||||
diarisation_kwargs[k] = kwargs[k]
|
||||
|
||||
diarisation_kwargs = {k: v for k, v in kwargs.items() if k in _possible_kwargs}
|
||||
|
||||
return diarisation_kwargs
|
||||
|
||||
def __repr__(self):
|
||||
return f"Diarisation(model={self.model})"
|
||||
|
||||
def __str__(self):
|
||||
return f"Diarisation(model={self.model})"
|
||||
|
||||
Reference in New Issue
Block a user