diff --git a/autotranscript/__init__.py b/autotranscript/__init__.py index e6b02f3..20bcc93 100644 --- a/autotranscript/__init__.py +++ b/autotranscript/__init__.py @@ -6,5 +6,5 @@ from .transcript_exporter import * from .diarisation import * from .version import get_version as _get_version from .misc import * - + __version__ = _get_version() diff --git a/autotranscript/diarisation.py b/autotranscript/diarisation.py index bb364e9..5359e3e 100644 --- a/autotranscript/diarisation.py +++ b/autotranscript/diarisation.py @@ -1,13 +1,21 @@ -from pyannote.audio import Pipeline -from pyannote.audio.pipelines.speaker_diarization import SpeakerDiarization -from torch import Tensor +""" +Diarisation class. +This class is used to diarize an audio file using a pretrained model +""" import os from pathlib import Path from typing import TypeVar, Union -import json + +from pyannote.audio import Pipeline +from pyannote.audio.pipelines.speaker_diarization import SpeakerDiarization +from torch import Tensor + from .misc import PYANNOTE_DEFAULT_CONFIG, PYANNOTE_DEFAULT_PATH Annotation = TypeVar('Annotation') +TOKEN_PATH = os.path.join(os.path.dirname( + os.path.realpath(__file__)), '.pyannotetoken') + class Diariser: """ Diarisation class @@ -15,7 +23,7 @@ class Diariser: from pyannote.audio. :param model: model to use for diarization """ - def __init__(self, model,*args,**kwargs) -> None: + def __init__(self, model) -> None: self.model = model @@ -29,7 +37,7 @@ class Diariser: :return: diarization """ kwargs = self._get_diarisation_kwargs(**kwargs) - + diarization = self.model(audiofile,*args, **kwargs) out = self.format_diarization_output(diarization) @@ -52,7 +60,7 @@ class Diariser: index_start_speaker = 0 index_end_speaker = 0 current_speaker = str() - + ### # Sometimes two consecutive speakers are the same # This loop removes these duplicates @@ -91,37 +99,41 @@ class Diariser: diarization_output["segments"].append([start, end]) diarization_output["speakers"].append(outp[2]) return diarization_output - - def save(self, path : str, *args, **kwargs) -> None: - """ - Save diarization output to a file - - :param path: path to save file - :type path: str - """ - with open(path, "w") as f: - json.dump(self.transcript, f, *args, **kwargs) - - @staticmethod def _get_token(): - # check ig .pyannotetoken.txt exists - path = os.path.join(os.path.dirname( - os.path.realpath(__file__)), '.pyannotetoken') - if os.path.exists(path): - with open(path, 'r') as f: - token = f.read() + """ + Get token from .pyannotetoken.txt + :raises ValueError: No token found + :return: Huggingface token + :rtype: str + """ + + if os.path.exists(TOKEN_PATH): + with open(TOKEN_PATH, 'r', encoding="utf-8") as file: + token = file.read() else: raise ValueError('No token found.' \ 'Please create a token at https://huggingface.co/settings/token' \ - 'and save it in a file called .pyannotetoken.txt') + f'and save it in a file called {TOKEN_PATH}') return token + + @staticmethod + def _save_token(token): + """ + Save token to .pyannotetoken.txt + + :param token: Huggingface token + :type token: str + """ + with open(TOKEN_PATH, 'r', encoding="utf-8") as file: + file.write(token) @classmethod def load_model(cls, model: str = PYANNOTE_DEFAULT_CONFIG, token: str = None, + cache_token: bool = False, cache_dir: Union[Path, str] = PYANNOTE_DEFAULT_PATH, hparams_file: Union[str, Path] = None ) -> Pipeline: @@ -142,14 +154,23 @@ class Diariser: ------- Pipeline Object """ + + if cache_token and token is not None: + cls._save_token(token) + if not os.path.exists(model) and token is None: token = cls._get_token() - + model = 'pyannote/speaker-diarization' + _model = Pipeline.from_pretrained(model, use_auth_token = token, cache_dir = cache_dir, hparams_file = hparams_file,) - + + if model is None: + raise ValueError('Unable to load model either from local cache' \ + 'or from huggingface.co models. Please check your token' \ + 'or your local model path') return cls(_model) @staticmethod