From f7927fd524bd6a6d7527d18dd1ac5013c0412f01 Mon Sep 17 00:00:00 2001 From: Marko Henning Date: Fri, 19 Apr 2024 17:36:34 +0200 Subject: [PATCH] Add default path to pyannote model with fallback option. --- scraibe/diarisation.py | 100 +++++++++++++++++++++++------------------ scraibe/misc.py | 3 +- 2 files changed, 58 insertions(+), 45 deletions(-) diff --git a/scraibe/diarisation.py b/scraibe/diarisation.py index 0f0e14a..161dae5 100644 --- a/scraibe/diarisation.py +++ b/scraibe/diarisation.py @@ -19,6 +19,7 @@ Constants: - TOKEN_PATH (str): Path to the Pyannote token. - PYANNOTE_DEFAULT_PATH (str): Default path to Pyannote models. - PYANNOTE_DEFAULT_CONFIG (str): Default configuration for Pyannote models. +- PYANNOTE_FALLBACK_CONFIG (str): Fallback config for Pyannote models if default config does not work. Usage: from .diarisation import Diariser @@ -39,7 +40,7 @@ from torch import Tensor from torch import device as torch_device from torch.cuda import is_available, current_device -from .misc import PYANNOTE_DEFAULT_PATH, PYANNOTE_DEFAULT_CONFIG +from .misc import PYANNOTE_DEFAULT_PATH, PYANNOTE_DEFAULT_CONFIG, PYANNOTE_FALLBACK_CONFIG Annotation = TypeVar('Annotation') TOKEN_PATH = os.path.join(os.path.dirname( @@ -183,7 +184,7 @@ class Diariser: @classmethod def load_model(cls, - model: str = PYANNOTE_DEFAULT_CONFIG, + model: str = PYANNOTE_FALLBACK_CONFIG, use_auth_token: str = None, cache_token: bool = True, cache_dir: Union[Path, str] = PYANNOTE_DEFAULT_PATH, @@ -210,53 +211,64 @@ class Diariser: Returns: Pipeline: A pyannote.audio Pipeline object, encapsulating the loaded model. """ + try: + hf_model = PYANNOTE_DEFAULT_CONFIG + # if not use_auth_token: + # use_auth_token = cls._get_token() + _model = Pipeline.from_pretrained( + hf_model, use_auth_token=use_auth_token, + cache_dir=cache_dir, hparams_file=hparams_file + ) + except: + print(f'Trying fallback to config file.. ') + _model = None + if _model is None: - - if cache_token and use_auth_token is not None: - cls._save_token(use_auth_token) - - if not os.path.exists(model) and use_auth_token is None: - use_auth_token = cls._get_token() - - elif os.path.exists(model) and not use_auth_token: - # check if model can be found locally nearby the config file - with open(model, 'r') as file: - config = yaml.safe_load(file) - - path_to_model = config['pipeline']['params']['segmentation'] - - if not os.path.exists(path_to_model): - warnings.warn(f"Model not found at {path_to_model}. "\ - "Trying to find it nearby the config file.") + if cache_token and use_auth_token is not None: + cls._save_token(use_auth_token) - pwd = model.split("/")[:-1] - pwd = "/".join(pwd) + if not os.path.exists(model) and use_auth_token is None: + use_auth_token = cls._get_token() - path_to_model = os.path.join(pwd, "pytorch_model.bin") + elif os.path.exists(model) and not use_auth_token: + # check if model can be found locally nearby the config file + with open(model, 'r') as file: + config = yaml.safe_load(file) + + path_to_model = config['pipeline']['params']['segmentation'] if not os.path.exists(path_to_model): - warnings.warn(f"Model not found at {path_to_model}. \ - 'Trying to find it nearby .bin files instead.") - # list elementes with the ending .bin - bin_files = [f for f in os.listdir(pwd) if f.endswith(".bin")] - if len(bin_files) == 1: - path_to_model = os.path.join(pwd, bin_files[0]) - else: - warnings.warn("Found more than one .bin file. "\ - "or none. Please specify the path to the model " \ - "or setup a huggingface token.") - - warnings.warn(f"Found model at {path_to_model} overwriting config file.") - - config['pipeline']['params']['segmentation'] = path_to_model - - with open(model, 'w') as file: - yaml.dump(config, file) - - _model = Pipeline.from_pretrained(model, - use_auth_token = use_auth_token, - cache_dir = cache_dir, - hparams_file = hparams_file,) + warnings.warn(f"Model not found at {path_to_model}. "\ + "Trying to find it nearby the config file.") + + pwd = model.split("/")[:-1] + pwd = "/".join(pwd) + + path_to_model = os.path.join(pwd, "pytorch_model.bin") + + if not os.path.exists(path_to_model): + warnings.warn(f"Model not found at {path_to_model}. \ + 'Trying to find it nearby .bin files instead.") + # list elementes with the ending .bin + bin_files = [f for f in os.listdir(pwd) if f.endswith(".bin")] + if len(bin_files) == 1: + path_to_model = os.path.join(pwd, bin_files[0]) + else: + warnings.warn("Found more than one .bin file. "\ + "or none. Please specify the path to the model " \ + "or setup a huggingface token.") + + warnings.warn(f"Found model at {path_to_model} overwriting config file.") + + config['pipeline']['params']['segmentation'] = path_to_model + + with open(model, 'w') as file: + yaml.dump(config, file) + + _model = Pipeline.from_pretrained(model, + use_auth_token = use_auth_token, + cache_dir = cache_dir, + hparams_file = hparams_file,) # try to move the model to the device if device is None: diff --git a/scraibe/misc.py b/scraibe/misc.py index 992e40c..549ee67 100644 --- a/scraibe/misc.py +++ b/scraibe/misc.py @@ -13,7 +13,8 @@ if CACHE_DIR != PYANNOTE_CACHE_DIR: WHISPER_DEFAULT_PATH = os.path.join(CACHE_DIR, "whisper") PYANNOTE_DEFAULT_PATH = os.path.join(CACHE_DIR, "pyannote") -PYANNOTE_DEFAULT_CONFIG = os.path.join(PYANNOTE_DEFAULT_PATH, "config.yaml") \ +PYANNOTE_DEFAULT_CONFIG = 'jaikinator/scraibe' +PYANNOTE_FALLBACK_CONFIG = os.path.join(PYANNOTE_DEFAULT_PATH, "config.yaml") \ if os.path.exists(os.path.join(PYANNOTE_DEFAULT_PATH, "config.yaml")) \ else 'pyannote/speaker-diarization-3.1'