diff --git a/scraibe/diarisation.py b/scraibe/diarisation.py index 3e6047c..c62bda0 100644 --- a/scraibe/diarisation.py +++ b/scraibe/diarisation.py @@ -219,34 +219,34 @@ class Diariser: # check if model can be found locally nearby the config file with open(model, 'r') as file: config = yaml.safe_load(file) + + path_to_model = config['pipeline']['params']['segmentation'] + + if not os.path.exists(path_to_model): + warnings.warn(f"Model not found at {path_to_model}. "\ + "Trying to find it nearby the config file.") - path_to_model = config['pipeline']['params']['segmentation'] + pwd = model.split("/")[:-1] + path_to_model = os.path.join(pwd, "pytorch_model.bin") if not os.path.exists(path_to_model): - warnings.warn(f"Model not found at {path_to_model}. "\ - "Trying to find it nearby the config file.") - - pwd = file.split("/")[:-1] - path_to_model = os.path.join(pwd, "pytorch_model.bin") - - if not os.path.exists(path_to_model): - warnings.warn(f"Model not found at {path_to_model}. \ - 'Trying to find it nearby .bin files instead.") - # list elementes with the ending .bin - bin_files = [f for f in os.listdir(pwd) if f.endswith(".bin")] - if len(bin_files) == 1: - path_to_model = os.path.join(pwd, bin_files[0]) - else: - warnings.warn("Found more than one .bin file. "\ - "or none. Please specify the path to the model " \ - "or setup a huggingface token.") - - warnings.warn(f"Found model at {path_to_model} overwriting config file.") - - config['pipeline']['params']['segmentation'] = path_to_model - - with open(model, 'w') as file: - yaml.dump(config, file) + warnings.warn(f"Model not found at {path_to_model}. \ + 'Trying to find it nearby .bin files instead.") + # list elementes with the ending .bin + bin_files = [f for f in os.listdir(pwd) if f.endswith(".bin")] + if len(bin_files) == 1: + path_to_model = os.path.join(pwd, bin_files[0]) + else: + warnings.warn("Found more than one .bin file. "\ + "or none. Please specify the path to the model " \ + "or setup a huggingface token.") + + warnings.warn(f"Found model at {path_to_model} overwriting config file.") + + config['pipeline']['params']['segmentation'] = path_to_model + + with open(model, 'w') as file: + yaml.dump(config, file) _model = Pipeline.from_pretrained(model, use_auth_token = use_auth_token,