Files
scribe/autotranscript/diarisation.py
T
2023-07-10 13:27:54 +02:00

199 lines
6.3 KiB
Python

"""
Diarisation class.
This class is used to diarize an audio file using a pretrained model
"""
import os
from pathlib import Path
from typing import TypeVar, Union
from pyannote.audio import Pipeline
from pyannote.audio.pipelines.speaker_diarization import SpeakerDiarization
from torch import Tensor
from .misc import PYANNOTE_DEFAULT_CONFIG, PYANNOTE_DEFAULT_PATH
Annotation = TypeVar('Annotation')
TOKEN_PATH = os.path.join(os.path.dirname(
os.path.realpath(__file__)), '.pyannotetoken')
class Diariser:
"""
Diarisation class
This class is used to diarize an audio file using a pretrained model
from pyannote.audio.
:param model: model to use for diarization
"""
def __init__(self, model) -> None:
self.model = model
def diarization(self, audiofile : Union[str, Tensor, dict] ,
*args, **kwargs) -> Annotation:
"""
Diarization of audio file
:param audiofile: path to audio file or torch.Tensor
:param args: args for diarization model
:param kwargs: kwargs for diarization model
:return: diarization
"""
kwargs = self._get_diarisation_kwargs(**kwargs)
diarization = self.model(audiofile,*args, **kwargs)
out = self.format_diarization_output(diarization)
return out
@staticmethod
def format_diarization_output(dia : Annotation) -> dict:
"""
Format diarization output to a list of tuples
:param dia: diarization output
:return: dict with speaker names as keys and list of tuples
as values and list of different speakers
"""
dia_list = list(dia.itertracks(yield_label=True))
diarization_output = {"speakers": [], "segments": []}
normalized_output = []
index_start_speaker = 0
index_end_speaker = 0
current_speaker = str()
###
# Sometimes two consecutive speakers are the same
# This loop removes these duplicates
###
if len(dia_list) == 1:
normalized_output.append([0, 0, dia_list[0][2]])
else:
for i, (_, _, speaker) in enumerate(dia_list):
if i == 0:
current_speaker = speaker
if speaker != current_speaker:
index_end_speaker = i - 1
normalized_output.append([index_start_speaker,
index_end_speaker,
current_speaker])
index_start_speaker = i
current_speaker = speaker
if i == len(diarization_output["speakers"]) - 1:
index_end_speaker = i
normalized_output.append([index_start_speaker,
index_end_speaker,
current_speaker])
for outp in normalized_output:
start = dia_list[outp[0]][0].start
end = dia_list[outp[1]][0].end
diarization_output["segments"].append([start, end])
diarization_output["speakers"].append(outp[2])
return diarization_output
@staticmethod
def _get_token():
"""
Get token from .pyannotetoken.txt
:raises ValueError: No token found
:return: Huggingface token
:rtype: str
"""
if os.path.exists(TOKEN_PATH):
with open(TOKEN_PATH, 'r', encoding="utf-8") as file:
token = file.read()
else:
raise ValueError('No token found.' \
'Please create a token at https://huggingface.co/settings/token' \
f'and save it in a file called {TOKEN_PATH}')
return token
@staticmethod
def _save_token(token):
"""
Save token to .pyannotetoken.txt
:param token: Huggingface token
:type token: str
"""
with open(TOKEN_PATH, 'r', encoding="utf-8") as file:
file.write(token)
@classmethod
def load_model(cls,
model: str = PYANNOTE_DEFAULT_CONFIG,
token: str = None,
cache_token: bool = False,
cache_dir: Union[Path, str] = PYANNOTE_DEFAULT_PATH,
hparams_file: Union[str, Path] = None
) -> Pipeline:
"""
Load modules from pyannote
Parameters
----------
model : str
pyannote model
default: /models/pyannote/speaker_diarization/config.yaml
token : str
HUGGINGFACE_TOKEN
local : bool
If true, load from local cache
Returns
-------
Pipeline Object
"""
if cache_token and token is not None:
cls._save_token(token)
if not os.path.exists(model) and token is None:
token = cls._get_token()
model = 'pyannote/speaker-diarization'
_model = Pipeline.from_pretrained(model,
use_auth_token = token,
cache_dir = cache_dir,
hparams_file = hparams_file,)
if model is None:
raise ValueError('Unable to load model either from local cache' \
'or from huggingface.co models. Please check your token' \
'or your local model path')
return cls(_model)
@staticmethod
def _get_diarisation_kwargs(**kwargs) -> dict:
"""
Get kwargs for pyannote diarization model
Ensure that kwargs are valid
:return: kwargs for pyannote diarization model
:rtype: dict
"""
_possible_kwargs = SpeakerDiarization.apply.__code__.co_varnames
diarisation_kwargs = dict()
for k in kwargs.keys():
if k in _possible_kwargs:
diarisation_kwargs[k] = kwargs[k]
return diarisation_kwargs
def __repr__(self):
return f"Diarisation(model={self.model})"
def __str__(self):
return f"Diarisation(model={self.model})"