updated diarisation file to better handle tokens

This commit is contained in:
Jaikinator
2023-07-10 13:27:54 +02:00
parent abd733b2ae
commit a71475c3eb
2 changed files with 50 additions and 29 deletions
+45 -24
View File
@@ -1,13 +1,21 @@
from pyannote.audio import Pipeline """
from pyannote.audio.pipelines.speaker_diarization import SpeakerDiarization Diarisation class.
from torch import Tensor This class is used to diarize an audio file using a pretrained model
"""
import os import os
from pathlib import Path from pathlib import Path
from typing import TypeVar, Union from typing import TypeVar, Union
import json
from pyannote.audio import Pipeline
from pyannote.audio.pipelines.speaker_diarization import SpeakerDiarization
from torch import Tensor
from .misc import PYANNOTE_DEFAULT_CONFIG, PYANNOTE_DEFAULT_PATH from .misc import PYANNOTE_DEFAULT_CONFIG, PYANNOTE_DEFAULT_PATH
Annotation = TypeVar('Annotation') Annotation = TypeVar('Annotation')
TOKEN_PATH = os.path.join(os.path.dirname(
os.path.realpath(__file__)), '.pyannotetoken')
class Diariser: class Diariser:
""" """
Diarisation class Diarisation class
@@ -15,7 +23,7 @@ class Diariser:
from pyannote.audio. from pyannote.audio.
:param model: model to use for diarization :param model: model to use for diarization
""" """
def __init__(self, model,*args,**kwargs) -> None: def __init__(self, model) -> None:
self.model = model self.model = model
@@ -92,36 +100,40 @@ class Diariser:
diarization_output["speakers"].append(outp[2]) diarization_output["speakers"].append(outp[2])
return diarization_output return diarization_output
def save(self, path : str, *args, **kwargs) -> None:
"""
Save diarization output to a file
:param path: path to save file
:type path: str
"""
with open(path, "w") as f:
json.dump(self.transcript, f, *args, **kwargs)
@staticmethod @staticmethod
def _get_token(): def _get_token():
# check ig .pyannotetoken.txt exists """
path = os.path.join(os.path.dirname( Get token from .pyannotetoken.txt
os.path.realpath(__file__)), '.pyannotetoken') :raises ValueError: No token found
if os.path.exists(path): :return: Huggingface token
with open(path, 'r') as f: :rtype: str
token = f.read() """
if os.path.exists(TOKEN_PATH):
with open(TOKEN_PATH, 'r', encoding="utf-8") as file:
token = file.read()
else: else:
raise ValueError('No token found.' \ raise ValueError('No token found.' \
'Please create a token at https://huggingface.co/settings/token' \ 'Please create a token at https://huggingface.co/settings/token' \
'and save it in a file called .pyannotetoken.txt') f'and save it in a file called {TOKEN_PATH}')
return token return token
@staticmethod
def _save_token(token):
"""
Save token to .pyannotetoken.txt
:param token: Huggingface token
:type token: str
"""
with open(TOKEN_PATH, 'r', encoding="utf-8") as file:
file.write(token)
@classmethod @classmethod
def load_model(cls, def load_model(cls,
model: str = PYANNOTE_DEFAULT_CONFIG, model: str = PYANNOTE_DEFAULT_CONFIG,
token: str = None, token: str = None,
cache_token: bool = False,
cache_dir: Union[Path, str] = PYANNOTE_DEFAULT_PATH, cache_dir: Union[Path, str] = PYANNOTE_DEFAULT_PATH,
hparams_file: Union[str, Path] = None hparams_file: Union[str, Path] = None
) -> Pipeline: ) -> Pipeline:
@@ -142,14 +154,23 @@ class Diariser:
------- -------
Pipeline Object Pipeline Object
""" """
if cache_token and token is not None:
cls._save_token(token)
if not os.path.exists(model) and token is None: if not os.path.exists(model) and token is None:
token = cls._get_token() token = cls._get_token()
model = 'pyannote/speaker-diarization'
_model = Pipeline.from_pretrained(model, _model = Pipeline.from_pretrained(model,
use_auth_token = token, use_auth_token = token,
cache_dir = cache_dir, cache_dir = cache_dir,
hparams_file = hparams_file,) hparams_file = hparams_file,)
if model is None:
raise ValueError('Unable to load model either from local cache' \
'or from huggingface.co models. Please check your token' \
'or your local model path')
return cls(_model) return cls(_model)
@staticmethod @staticmethod