fixed typo
This commit is contained in:
+25
-25
@@ -219,34 +219,34 @@ class Diariser:
|
|||||||
# check if model can be found locally nearby the config file
|
# check if model can be found locally nearby the config file
|
||||||
with open(model, 'r') as file:
|
with open(model, 'r') as file:
|
||||||
config = yaml.safe_load(file)
|
config = yaml.safe_load(file)
|
||||||
|
|
||||||
|
path_to_model = config['pipeline']['params']['segmentation']
|
||||||
|
|
||||||
|
if not os.path.exists(path_to_model):
|
||||||
|
warnings.warn(f"Model not found at {path_to_model}. "\
|
||||||
|
"Trying to find it nearby the config file.")
|
||||||
|
|
||||||
path_to_model = config['pipeline']['params']['segmentation']
|
pwd = model.split("/")[:-1]
|
||||||
|
path_to_model = os.path.join(pwd, "pytorch_model.bin")
|
||||||
|
|
||||||
if not os.path.exists(path_to_model):
|
if not os.path.exists(path_to_model):
|
||||||
warnings.warn(f"Model not found at {path_to_model}. "\
|
warnings.warn(f"Model not found at {path_to_model}. \
|
||||||
"Trying to find it nearby the config file.")
|
'Trying to find it nearby .bin files instead.")
|
||||||
|
# list elementes with the ending .bin
|
||||||
pwd = file.split("/")[:-1]
|
bin_files = [f for f in os.listdir(pwd) if f.endswith(".bin")]
|
||||||
path_to_model = os.path.join(pwd, "pytorch_model.bin")
|
if len(bin_files) == 1:
|
||||||
|
path_to_model = os.path.join(pwd, bin_files[0])
|
||||||
if not os.path.exists(path_to_model):
|
else:
|
||||||
warnings.warn(f"Model not found at {path_to_model}. \
|
warnings.warn("Found more than one .bin file. "\
|
||||||
'Trying to find it nearby .bin files instead.")
|
"or none. Please specify the path to the model " \
|
||||||
# list elementes with the ending .bin
|
"or setup a huggingface token.")
|
||||||
bin_files = [f for f in os.listdir(pwd) if f.endswith(".bin")]
|
|
||||||
if len(bin_files) == 1:
|
warnings.warn(f"Found model at {path_to_model} overwriting config file.")
|
||||||
path_to_model = os.path.join(pwd, bin_files[0])
|
|
||||||
else:
|
config['pipeline']['params']['segmentation'] = path_to_model
|
||||||
warnings.warn("Found more than one .bin file. "\
|
|
||||||
"or none. Please specify the path to the model " \
|
with open(model, 'w') as file:
|
||||||
"or setup a huggingface token.")
|
yaml.dump(config, file)
|
||||||
|
|
||||||
warnings.warn(f"Found model at {path_to_model} overwriting config file.")
|
|
||||||
|
|
||||||
config['pipeline']['params']['segmentation'] = path_to_model
|
|
||||||
|
|
||||||
with open(model, 'w') as file:
|
|
||||||
yaml.dump(config, file)
|
|
||||||
|
|
||||||
_model = Pipeline.from_pretrained(model,
|
_model = Pipeline.from_pretrained(model,
|
||||||
use_auth_token = use_auth_token,
|
use_auth_token = use_auth_token,
|
||||||
|
|||||||
Reference in New Issue
Block a user