added function to controll paths to pyannote models

This commit is contained in:
Jaikinator
2023-06-28 15:31:52 +02:00
parent 9a767228f7
commit de3a6cd4d1
+37 -5
View File
@@ -6,15 +6,15 @@ import glob
from warnings import warn from warnings import warn
import yaml import yaml
WHISPER_DEFAULT_PATH = os.path.join(os.path.dirname(__file__), WHISPER_DEFAULT_PATH = os.path.relpath(os.path.join(os.path.dirname(__file__),
"models", "whisper") "models", "whisper"))
PYANNOTE_DEFAULT_PATH = os.path.join(os.path.dirname(__file__), PYANNOTE_DEFAULT_PATH = os.path.relpath(os.path.join(os.path.dirname(__file__),
"models", "pyannote", "models", "pyannote",
"speaker_diarization", "config.yaml") "speaker_diarization", "config.yaml"))
def config_diarization_yaml(file): def config_diarization_yaml(file, path_to_segmentation = None, path_to_embedding = None):
""" """
Configure diarization pipeline from yaml file to use the model offline Configure diarization pipeline from yaml file to use the model offline
and avoid manuel file manipulation. and avoid manuel file manipulation.
@@ -22,4 +22,36 @@ def config_diarization_yaml(file):
:param file: yaml file :param file: yaml file
:type file: yaml :type file: yaml
""" """
with open(file, "r") as stream:
yml = yaml.safe_load(stream)
stream.close()
if path_to_segmentation:
yml["pipeline"]["params"]["segmentation"] = path_to_segmentation
else:
yml["pipeline"]["params"]["segmentation"] = os.path.relpath(os.path.join(
os.path.dirname(__file__),
"models", "pyannote",
"segmentation",
"pytorch_model.bin"))
if path_to_embedding:
yml["pipeline"]["params"]["embedding"] = path_to_embedding
else:
yml["pipeline"]["params"]["embedding"] = os.path.relpath(
os.path.join(
os.path.dirname(__file__),
"models", "pyannote",
"speechbrain",
"spkrec-ecapa-voxceleb",
"embedding_model.ckpt"))
if not os.path.exists(yml["pipeline"]["params"]["segmentation"]):
raise FileNotFoundError(f"Segmentation model not found at {yml['pipeline']['params']['segmentation']}")
if not os.path.exists(yml["pipeline"]["params"]["embedding"]):
raise FileNotFoundError(f"Embedding model not found at {yml['pipeline']['params']['embedding']}")
with open(file, "w") as stream:
yaml.dump(yml, stream)
stream.close()