diff --git a/scraibe/diarisation.py b/scraibe/diarisation.py index 161dae5..8523940 100644 --- a/scraibe/diarisation.py +++ b/scraibe/diarisation.py @@ -19,7 +19,6 @@ Constants: - TOKEN_PATH (str): Path to the Pyannote token. - PYANNOTE_DEFAULT_PATH (str): Default path to Pyannote models. - PYANNOTE_DEFAULT_CONFIG (str): Default configuration for Pyannote models. -- PYANNOTE_FALLBACK_CONFIG (str): Fallback config for Pyannote models if default config does not work. Usage: from .diarisation import Diariser @@ -39,8 +38,10 @@ from pyannote.audio.pipelines.speaker_diarization import SpeakerDiarization from torch import Tensor from torch import device as torch_device from torch.cuda import is_available, current_device +from huggingface_hub import HfApi +from huggingface_hub.utils import RepositoryNotFoundError -from .misc import PYANNOTE_DEFAULT_PATH, PYANNOTE_DEFAULT_CONFIG, PYANNOTE_FALLBACK_CONFIG +from .misc import PYANNOTE_DEFAULT_PATH, PYANNOTE_DEFAULT_CONFIG Annotation = TypeVar('Annotation') TOKEN_PATH = os.path.join(os.path.dirname( @@ -184,7 +185,7 @@ class Diariser: @classmethod def load_model(cls, - model: str = PYANNOTE_FALLBACK_CONFIG, + model: str = PYANNOTE_DEFAULT_CONFIG, use_auth_token: str = None, cache_token: bool = True, cache_dir: Union[Path, str] = PYANNOTE_DEFAULT_PATH, @@ -211,64 +212,65 @@ class Diariser: Returns: Pipeline: A pyannote.audio Pipeline object, encapsulating the loaded model. """ - try: - hf_model = PYANNOTE_DEFAULT_CONFIG - # if not use_auth_token: - # use_auth_token = cls._get_token() - _model = Pipeline.from_pretrained( - hf_model, use_auth_token=use_auth_token, - cache_dir=cache_dir, hparams_file=hparams_file - ) - except: - print(f'Trying fallback to config file.. ') - _model = None - if _model is None: - - if cache_token and use_auth_token is not None: - cls._save_token(use_auth_token) - - if not os.path.exists(model) and use_auth_token is None: - use_auth_token = cls._get_token() - - elif os.path.exists(model) and not use_auth_token: - # check if model can be found locally nearby the config file - with open(model, 'r') as file: - config = yaml.safe_load(file) - - path_to_model = config['pipeline']['params']['segmentation'] + if isinstance(model, str) and os.path.exists(model): + # check if model can be found locally nearby the config file + with open(model, 'r') as file: + config = yaml.safe_load(file) + + path_to_model = config['pipeline']['params']['segmentation'] + + if not os.path.exists(path_to_model): + warnings.warn(f"Model not found at {path_to_model}. " + "Trying to find it nearby the config file.") + + pwd = model.split("/")[:-1] + pwd = "/".join(pwd) + + path_to_model = os.path.join(pwd, "pytorch_model.bin") if not os.path.exists(path_to_model): - warnings.warn(f"Model not found at {path_to_model}. "\ - "Trying to find it nearby the config file.") - - pwd = model.split("/")[:-1] - pwd = "/".join(pwd) - - path_to_model = os.path.join(pwd, "pytorch_model.bin") + warnings.warn(f"Model not found at {path_to_model}. \ + 'Trying to find it nearby .bin files instead.") + # list elementes with the ending .bin + bin_files = [f for f in os.listdir(pwd) if f.endswith(".bin")] + if len(bin_files) == 1: + path_to_model = os.path.join(pwd, bin_files[0]) + else: + warnings.warn("Found more than one .bin file. "\ + "or none. Please specify the path to the model " \ + "or setup a huggingface token.") + raise FileNotFoundError - if not os.path.exists(path_to_model): - warnings.warn(f"Model not found at {path_to_model}. \ - 'Trying to find it nearby .bin files instead.") - # list elementes with the ending .bin - bin_files = [f for f in os.listdir(pwd) if f.endswith(".bin")] - if len(bin_files) == 1: - path_to_model = os.path.join(pwd, bin_files[0]) - else: - warnings.warn("Found more than one .bin file. "\ - "or none. Please specify the path to the model " \ - "or setup a huggingface token.") - - warnings.warn(f"Found model at {path_to_model} overwriting config file.") - - config['pipeline']['params']['segmentation'] = path_to_model - - with open(model, 'w') as file: - yaml.dump(config, file) - - _model = Pipeline.from_pretrained(model, - use_auth_token = use_auth_token, - cache_dir = cache_dir, - hparams_file = hparams_file,) + warnings.warn(f"Found model at {path_to_model} overwriting config file.") + + config['pipeline']['params']['segmentation'] = path_to_model + + with open(model, 'w') as file: + yaml.dump(config, file) + elif isinstance(model, tuple): + try: + _model = model[0] + HfApi().model_info(_model) + model = _model + use_auth_token = None + except RepositoryNotFoundError: + print(f'{model[0]} not found on Huggingface, \ + trying {model[1]}') + _model = model[1] + HfApi().model_info(_model) + model = _model + if cache_token and use_auth_token is not None: + cls._save_token(use_auth_token) + + if not os.path.exists(model) and use_auth_token is None: + use_auth_token = cls._get_token() + else: + raise FileNotFoundError(f'No local model or directory found at {model}.') + + _model = Pipeline.from_pretrained(model, + use_auth_token=use_auth_token, + cache_dir=cache_dir, + hparams_file=hparams_file,) # try to move the model to the device if device is None: diff --git a/scraibe/misc.py b/scraibe/misc.py index 549ee67..c1d5484 100644 --- a/scraibe/misc.py +++ b/scraibe/misc.py @@ -13,10 +13,9 @@ if CACHE_DIR != PYANNOTE_CACHE_DIR: WHISPER_DEFAULT_PATH = os.path.join(CACHE_DIR, "whisper") PYANNOTE_DEFAULT_PATH = os.path.join(CACHE_DIR, "pyannote") -PYANNOTE_DEFAULT_CONFIG = 'jaikinator/scraibe' -PYANNOTE_FALLBACK_CONFIG = os.path.join(PYANNOTE_DEFAULT_PATH, "config.yaml") \ +PYANNOTE_DEFAULT_CONFIG = os.path.join(PYANNOTE_DEFAULT_PATH, "config.yaml") \ if os.path.exists(os.path.join(PYANNOTE_DEFAULT_PATH, "config.yaml")) \ - else 'pyannote/speaker-diarization-3.1' + else ('jaikinator/scraibe', 'pyannote/speaker-diarization-3.1') def config_diarization_yaml(file_path: str, path_to_segmentation: str = None) -> None: """Configure diarization pipeline from a YAML file.