diff --git a/scraibe/diarisation.py b/scraibe/diarisation.py index f90bcdb..3e6047c 100644 --- a/scraibe/diarisation.py +++ b/scraibe/diarisation.py @@ -27,7 +27,9 @@ Usage: diarisation_output = model.diarization("path/to/audiofile.wav") """ +import warnings import os +import yaml from pathlib import Path from typing import TypeVar, Union @@ -213,7 +215,39 @@ class Diariser: model = 'pyannote/speaker-diarization' elif not os.path.exists(model) and use_auth_token is not None: model = 'pyannote/speaker-diarization' - + elif os.path.exists(model) and not use_auth_token: + # check if model can be found locally nearby the config file + with open(model, 'r') as file: + config = yaml.safe_load(file) + + path_to_model = config['pipeline']['params']['segmentation'] + + if not os.path.exists(path_to_model): + warnings.warn(f"Model not found at {path_to_model}. "\ + "Trying to find it nearby the config file.") + + pwd = file.split("/")[:-1] + path_to_model = os.path.join(pwd, "pytorch_model.bin") + + if not os.path.exists(path_to_model): + warnings.warn(f"Model not found at {path_to_model}. \ + 'Trying to find it nearby .bin files instead.") + # list elementes with the ending .bin + bin_files = [f for f in os.listdir(pwd) if f.endswith(".bin")] + if len(bin_files) == 1: + path_to_model = os.path.join(pwd, bin_files[0]) + else: + warnings.warn("Found more than one .bin file. "\ + "or none. Please specify the path to the model " \ + "or setup a huggingface token.") + + warnings.warn(f"Found model at {path_to_model} overwriting config file.") + + config['pipeline']['params']['segmentation'] = path_to_model + + with open(model, 'w') as file: + yaml.dump(config, file) + _model = Pipeline.from_pretrained(model, use_auth_token = use_auth_token, cache_dir = cache_dir,