fixed typo

This commit is contained in:
Jaikinator
2024-01-26 15:46:51 +01:00
parent c1ed0547b8
commit ff47d058c8
+22 -22
View File
@@ -220,33 +220,33 @@ class Diariser:
with open(model, 'r') as file:
config = yaml.safe_load(file)
path_to_model = config['pipeline']['params']['segmentation']
path_to_model = config['pipeline']['params']['segmentation']
if not os.path.exists(path_to_model):
warnings.warn(f"Model not found at {path_to_model}. "\
"Trying to find it nearby the config file.")
pwd = model.split("/")[:-1]
path_to_model = os.path.join(pwd, "pytorch_model.bin")
if not os.path.exists(path_to_model):
warnings.warn(f"Model not found at {path_to_model}. "\
"Trying to find it nearby the config file.")
warnings.warn(f"Model not found at {path_to_model}. \
'Trying to find it nearby .bin files instead.")
# list elementes with the ending .bin
bin_files = [f for f in os.listdir(pwd) if f.endswith(".bin")]
if len(bin_files) == 1:
path_to_model = os.path.join(pwd, bin_files[0])
else:
warnings.warn("Found more than one .bin file. "\
"or none. Please specify the path to the model " \
"or setup a huggingface token.")
pwd = file.split("/")[:-1]
path_to_model = os.path.join(pwd, "pytorch_model.bin")
warnings.warn(f"Found model at {path_to_model} overwriting config file.")
if not os.path.exists(path_to_model):
warnings.warn(f"Model not found at {path_to_model}. \
'Trying to find it nearby .bin files instead.")
# list elementes with the ending .bin
bin_files = [f for f in os.listdir(pwd) if f.endswith(".bin")]
if len(bin_files) == 1:
path_to_model = os.path.join(pwd, bin_files[0])
else:
warnings.warn("Found more than one .bin file. "\
"or none. Please specify the path to the model " \
"or setup a huggingface token.")
config['pipeline']['params']['segmentation'] = path_to_model
warnings.warn(f"Found model at {path_to_model} overwriting config file.")
config['pipeline']['params']['segmentation'] = path_to_model
with open(model, 'w') as file:
yaml.dump(config, file)
with open(model, 'w') as file:
yaml.dump(config, file)
_model = Pipeline.from_pretrained(model,
use_auth_token = use_auth_token,