updated diarisation file to better handle tokens

This commit is contained in:
Jaikinator
2023-07-10 13:27:54 +02:00
parent abd733b2ae
commit a71475c3eb
2 changed files with 50 additions and 29 deletions
+1 -1
View File
@@ -6,5 +6,5 @@ from .transcript_exporter import *
from .diarisation import *
from .version import get_version as _get_version
from .misc import *
__version__ = _get_version()
+49 -28
View File
@@ -1,13 +1,21 @@
from pyannote.audio import Pipeline
from pyannote.audio.pipelines.speaker_diarization import SpeakerDiarization
from torch import Tensor
"""
Diarisation class.
This class is used to diarize an audio file using a pretrained model
"""
import os
from pathlib import Path
from typing import TypeVar, Union
import json
from pyannote.audio import Pipeline
from pyannote.audio.pipelines.speaker_diarization import SpeakerDiarization
from torch import Tensor
from .misc import PYANNOTE_DEFAULT_CONFIG, PYANNOTE_DEFAULT_PATH
Annotation = TypeVar('Annotation')
TOKEN_PATH = os.path.join(os.path.dirname(
os.path.realpath(__file__)), '.pyannotetoken')
class Diariser:
"""
Diarisation class
@@ -15,7 +23,7 @@ class Diariser:
from pyannote.audio.
:param model: model to use for diarization
"""
def __init__(self, model,*args,**kwargs) -> None:
def __init__(self, model) -> None:
self.model = model
@@ -29,7 +37,7 @@ class Diariser:
:return: diarization
"""
kwargs = self._get_diarisation_kwargs(**kwargs)
diarization = self.model(audiofile,*args, **kwargs)
out = self.format_diarization_output(diarization)
@@ -52,7 +60,7 @@ class Diariser:
index_start_speaker = 0
index_end_speaker = 0
current_speaker = str()
###
# Sometimes two consecutive speakers are the same
# This loop removes these duplicates
@@ -91,37 +99,41 @@ class Diariser:
diarization_output["segments"].append([start, end])
diarization_output["speakers"].append(outp[2])
return diarization_output
def save(self, path : str, *args, **kwargs) -> None:
"""
Save diarization output to a file
:param path: path to save file
:type path: str
"""
with open(path, "w") as f:
json.dump(self.transcript, f, *args, **kwargs)
@staticmethod
def _get_token():
# check ig .pyannotetoken.txt exists
path = os.path.join(os.path.dirname(
os.path.realpath(__file__)), '.pyannotetoken')
if os.path.exists(path):
with open(path, 'r') as f:
token = f.read()
"""
Get token from .pyannotetoken.txt
:raises ValueError: No token found
:return: Huggingface token
:rtype: str
"""
if os.path.exists(TOKEN_PATH):
with open(TOKEN_PATH, 'r', encoding="utf-8") as file:
token = file.read()
else:
raise ValueError('No token found.' \
'Please create a token at https://huggingface.co/settings/token' \
'and save it in a file called .pyannotetoken.txt')
f'and save it in a file called {TOKEN_PATH}')
return token
@staticmethod
def _save_token(token):
"""
Save token to .pyannotetoken.txt
:param token: Huggingface token
:type token: str
"""
with open(TOKEN_PATH, 'r', encoding="utf-8") as file:
file.write(token)
@classmethod
def load_model(cls,
model: str = PYANNOTE_DEFAULT_CONFIG,
token: str = None,
cache_token: bool = False,
cache_dir: Union[Path, str] = PYANNOTE_DEFAULT_PATH,
hparams_file: Union[str, Path] = None
) -> Pipeline:
@@ -142,14 +154,23 @@ class Diariser:
-------
Pipeline Object
"""
if cache_token and token is not None:
cls._save_token(token)
if not os.path.exists(model) and token is None:
token = cls._get_token()
model = 'pyannote/speaker-diarization'
_model = Pipeline.from_pretrained(model,
use_auth_token = token,
cache_dir = cache_dir,
hparams_file = hparams_file,)
if model is None:
raise ValueError('Unable to load model either from local cache' \
'or from huggingface.co models. Please check your token' \
'or your local model path')
return cls(_model)
@staticmethod