From de3a6cd4d17a7a9261706ad514a10abaa2d60758 Mon Sep 17 00:00:00 2001 From: Jaikinator Date: Wed, 28 Jun 2023 15:31:52 +0200 Subject: [PATCH] added function to controll paths to pyannote models --- autotranscript/misc.py | 44 ++++++++++++++++++++++++++++++++++++------ 1 file changed, 38 insertions(+), 6 deletions(-) diff --git a/autotranscript/misc.py b/autotranscript/misc.py index 716852e..1c14198 100644 --- a/autotranscript/misc.py +++ b/autotranscript/misc.py @@ -6,15 +6,15 @@ import glob from warnings import warn import yaml -WHISPER_DEFAULT_PATH = os.path.join(os.path.dirname(__file__), - "models", "whisper") +WHISPER_DEFAULT_PATH = os.path.relpath(os.path.join(os.path.dirname(__file__), + "models", "whisper")) -PYANNOTE_DEFAULT_PATH = os.path.join(os.path.dirname(__file__), +PYANNOTE_DEFAULT_PATH = os.path.relpath(os.path.join(os.path.dirname(__file__), "models", "pyannote", - "speaker_diarization", "config.yaml") + "speaker_diarization", "config.yaml")) -def config_diarization_yaml(file): +def config_diarization_yaml(file, path_to_segmentation = None, path_to_embedding = None): """ Configure diarization pipeline from yaml file to use the model offline and avoid manuel file manipulation. @@ -22,4 +22,36 @@ def config_diarization_yaml(file): :param file: yaml file :type file: yaml """ - \ No newline at end of file + with open(file, "r") as stream: + yml = yaml.safe_load(stream) + stream.close() + if path_to_segmentation: + yml["pipeline"]["params"]["segmentation"] = path_to_segmentation + else: + yml["pipeline"]["params"]["segmentation"] = os.path.relpath(os.path.join( + os.path.dirname(__file__), + "models", "pyannote", + "segmentation", + "pytorch_model.bin")) + + if path_to_embedding: + yml["pipeline"]["params"]["embedding"] = path_to_embedding + else: + yml["pipeline"]["params"]["embedding"] = os.path.relpath( + os.path.join( + os.path.dirname(__file__), + "models", "pyannote", + "speechbrain", + "spkrec-ecapa-voxceleb", + "embedding_model.ckpt")) + + if not os.path.exists(yml["pipeline"]["params"]["segmentation"]): + raise FileNotFoundError(f"Segmentation model not found at {yml['pipeline']['params']['segmentation']}") + + if not os.path.exists(yml["pipeline"]["params"]["embedding"]): + raise FileNotFoundError(f"Embedding model not found at {yml['pipeline']['params']['embedding']}") + + with open(file, "w") as stream: + yaml.dump(yml, stream) + stream.close() +