Auto fixes from PEP8, fixes from flake8.
This commit is contained in:
+56
-52
@@ -37,15 +37,16 @@ from pyannote.audio import Pipeline
|
||||
from pyannote.audio.pipelines.speaker_diarization import SpeakerDiarization
|
||||
from torch import Tensor
|
||||
from torch import device as torch_device
|
||||
from torch.cuda import is_available, current_device
|
||||
from torch.cuda import is_available
|
||||
from huggingface_hub import HfApi
|
||||
from huggingface_hub.utils import RepositoryNotFoundError
|
||||
|
||||
from .misc import PYANNOTE_DEFAULT_PATH, PYANNOTE_DEFAULT_CONFIG
|
||||
Annotation = TypeVar('Annotation')
|
||||
Annotation = TypeVar('Annotation')
|
||||
|
||||
TOKEN_PATH = os.path.join(os.path.dirname(
|
||||
os.path.realpath(__file__)), '.pyannotetoken')
|
||||
os.path.realpath(__file__)), '.pyannotetoken')
|
||||
|
||||
|
||||
class Diariser:
|
||||
"""
|
||||
@@ -55,12 +56,12 @@ class Diariser:
|
||||
Args:
|
||||
model: The pretrained model to use for diarization.
|
||||
"""
|
||||
|
||||
|
||||
def __init__(self, model) -> None:
|
||||
|
||||
self.model = model
|
||||
|
||||
def diarization(self, audiofile : Union[str, Tensor, dict] ,
|
||||
def diarization(self, audiofile: Union[str, Tensor, dict],
|
||||
*args, **kwargs) -> Annotation:
|
||||
"""
|
||||
Perform speaker diarization on the provided audio file,
|
||||
@@ -79,15 +80,15 @@ class Diariser:
|
||||
to the diarization process.
|
||||
"""
|
||||
kwargs = self._get_diarisation_kwargs(**kwargs)
|
||||
|
||||
diarization = self.model(audiofile,*args, **kwargs)
|
||||
|
||||
diarization = self.model(audiofile, *args, **kwargs)
|
||||
|
||||
out = self.format_diarization_output(diarization)
|
||||
|
||||
return out
|
||||
|
||||
@staticmethod
|
||||
def format_diarization_output(dia : Annotation) -> dict:
|
||||
def format_diarization_output(dia: Annotation) -> dict:
|
||||
"""
|
||||
Formats the raw diarization output into a more usable structure for this project.
|
||||
|
||||
@@ -99,14 +100,14 @@ class Diariser:
|
||||
as keys and a list of tuples representing segments as values.
|
||||
"""
|
||||
|
||||
dia_list = list(dia.itertracks(yield_label=True))
|
||||
dia_list = list(dia.itertracks(yield_label=True))
|
||||
diarization_output = {"speakers": [], "segments": []}
|
||||
|
||||
normalized_output = []
|
||||
index_start_speaker = 0
|
||||
index_end_speaker = 0
|
||||
current_speaker = str()
|
||||
|
||||
|
||||
###
|
||||
# Sometimes two consecutive speakers are the same
|
||||
# This loop removes these duplicates
|
||||
@@ -115,40 +116,39 @@ class Diariser:
|
||||
if len(dia_list) == 1:
|
||||
normalized_output.append([0, 0, dia_list[0][2]])
|
||||
else:
|
||||
|
||||
|
||||
for i, (_, _, speaker) in enumerate(dia_list):
|
||||
|
||||
|
||||
if i == 0:
|
||||
current_speaker = speaker
|
||||
|
||||
|
||||
if speaker != current_speaker:
|
||||
|
||||
index_end_speaker = i - 1
|
||||
|
||||
normalized_output.append([index_start_speaker,
|
||||
index_end_speaker,
|
||||
current_speaker])
|
||||
index_end_speaker,
|
||||
current_speaker])
|
||||
|
||||
index_start_speaker = i
|
||||
current_speaker = speaker
|
||||
|
||||
|
||||
if i == len(dia_list) - 1:
|
||||
|
||||
index_end_speaker = i
|
||||
|
||||
normalized_output.append([index_start_speaker,
|
||||
index_end_speaker,
|
||||
current_speaker])
|
||||
|
||||
|
||||
normalized_output.append([index_start_speaker,
|
||||
index_end_speaker,
|
||||
current_speaker])
|
||||
|
||||
for outp in normalized_output:
|
||||
start = dia_list[outp[0]][0].start
|
||||
end = dia_list[outp[1]][0].end
|
||||
start = dia_list[outp[0]][0].start
|
||||
end = dia_list[outp[1]][0].end
|
||||
|
||||
diarization_output["segments"].append([start, end])
|
||||
diarization_output["speakers"].append(outp[2])
|
||||
return diarization_output
|
||||
|
||||
|
||||
@staticmethod
|
||||
def _get_token():
|
||||
"""
|
||||
@@ -161,14 +161,14 @@ class Diariser:
|
||||
Returns:
|
||||
str: The Huggingface token.
|
||||
"""
|
||||
|
||||
|
||||
if os.path.exists(TOKEN_PATH):
|
||||
with open(TOKEN_PATH, 'r', encoding="utf-8") as file:
|
||||
token = file.read()
|
||||
else:
|
||||
raise ValueError('No token found.' \
|
||||
'Please create a token at https://huggingface.co/settings/token' \
|
||||
f'and save it in a file called {TOKEN_PATH}')
|
||||
raise ValueError('No token found.'
|
||||
'Please create a token at https://huggingface.co/settings/token'
|
||||
f'and save it in a file called {TOKEN_PATH}')
|
||||
return token
|
||||
|
||||
@staticmethod
|
||||
@@ -182,18 +182,17 @@ class Diariser:
|
||||
"""
|
||||
with open(TOKEN_PATH, 'w', encoding="utf-8") as file:
|
||||
file.write(token)
|
||||
|
||||
|
||||
@classmethod
|
||||
def load_model(cls,
|
||||
model: str = PYANNOTE_DEFAULT_CONFIG,
|
||||
use_auth_token: str = None,
|
||||
cache_token: bool = False,
|
||||
cache_dir: Union[Path, str] = PYANNOTE_DEFAULT_PATH,
|
||||
hparams_file: Union[str, Path] = None,
|
||||
device: str = None,
|
||||
*args, **kwargs
|
||||
) -> Pipeline:
|
||||
|
||||
def load_model(cls,
|
||||
model: str = PYANNOTE_DEFAULT_CONFIG,
|
||||
use_auth_token: str = None,
|
||||
cache_token: bool = False,
|
||||
cache_dir: Union[Path, str] = PYANNOTE_DEFAULT_PATH,
|
||||
hparams_file: Union[str, Path] = None,
|
||||
device: str = None,
|
||||
*args, **kwargs
|
||||
) -> Pipeline:
|
||||
"""
|
||||
Loads a pretrained model from pyannote.audio,
|
||||
either from a local cache or some online repository.
|
||||
@@ -237,16 +236,18 @@ class Diariser:
|
||||
'deprecated and will be removed in future versions.',
|
||||
category=DeprecationWarning)
|
||||
# list elementes with the ending .bin
|
||||
bin_files = [f for f in os.listdir(pwd) if f.endswith(".bin")]
|
||||
bin_files = [f for f in os.listdir(
|
||||
pwd) if f.endswith(".bin")]
|
||||
if len(bin_files) == 1:
|
||||
path_to_model = os.path.join(pwd, bin_files[0])
|
||||
else:
|
||||
warnings.warn("Found more than one .bin file. "\
|
||||
"or none. Please specify the path to the model " \
|
||||
"or setup a huggingface token.")
|
||||
warnings.warn("Found more than one .bin file. "
|
||||
"or none. Please specify the path to the model "
|
||||
"or setup a huggingface token.")
|
||||
raise FileNotFoundError
|
||||
|
||||
warnings.warn(f"Found model at {path_to_model} overwriting config file.")
|
||||
warnings.warn(
|
||||
f"Found model at {path_to_model} overwriting config file.")
|
||||
|
||||
config['pipeline']['params']['segmentation'] = path_to_model
|
||||
|
||||
@@ -270,22 +271,24 @@ class Diariser:
|
||||
if use_auth_token is None:
|
||||
use_auth_token = cls._get_token()
|
||||
else:
|
||||
raise FileNotFoundError(f'No local model or directory found at {model}.')
|
||||
raise FileNotFoundError(
|
||||
f'No local model or directory found at {model}.')
|
||||
|
||||
_model = Pipeline.from_pretrained(model,
|
||||
use_auth_token=use_auth_token,
|
||||
cache_dir=cache_dir,
|
||||
hparams_file=hparams_file,)
|
||||
if _model is None:
|
||||
raise ValueError('Unable to load model either from local cache' \
|
||||
'or from huggingface.co models. Please check your token' \
|
||||
'or your local model path')
|
||||
raise ValueError('Unable to load model either from local cache'
|
||||
'or from huggingface.co models. Please check your token'
|
||||
'or your local model path')
|
||||
|
||||
# try to move the model to the device
|
||||
if device is None:
|
||||
device = "cuda" if is_available() else "cpu"
|
||||
|
||||
_model = _model.to(torch_device(device)) # torch_device is renamed from torch.device to avoid name conflict
|
||||
# torch_device is renamed from torch.device to avoid name conflict
|
||||
_model = _model.to(torch_device(device))
|
||||
|
||||
return cls(_model)
|
||||
|
||||
@@ -302,9 +305,10 @@ class Diariser:
|
||||
"""
|
||||
_possible_kwargs = SpeakerDiarization.apply.__code__.co_varnames
|
||||
|
||||
diarisation_kwargs = {k: v for k, v in kwargs.items() if k in _possible_kwargs}
|
||||
|
||||
diarisation_kwargs = {k: v for k,
|
||||
v in kwargs.items() if k in _possible_kwargs}
|
||||
|
||||
return diarisation_kwargs
|
||||
|
||||
|
||||
def __repr__(self):
|
||||
return f"Diarisation(model={self.model})"
|
||||
|
||||
Reference in New Issue
Block a user