added function to controll paths to pyannote models

This commit is contained in:
Jaikinator
2023-06-28 15:31:52 +02:00
parent 9a767228f7
commit de3a6cd4d1
+37 -5
View File
@@ -6,15 +6,15 @@ import glob
from warnings import warn
import yaml
WHISPER_DEFAULT_PATH = os.path.join(os.path.dirname(__file__),
"models", "whisper")
WHISPER_DEFAULT_PATH = os.path.relpath(os.path.join(os.path.dirname(__file__),
"models", "whisper"))
PYANNOTE_DEFAULT_PATH = os.path.join(os.path.dirname(__file__),
PYANNOTE_DEFAULT_PATH = os.path.relpath(os.path.join(os.path.dirname(__file__),
"models", "pyannote",
"speaker_diarization", "config.yaml")
"speaker_diarization", "config.yaml"))
def config_diarization_yaml(file):
def config_diarization_yaml(file, path_to_segmentation = None, path_to_embedding = None):
"""
Configure diarization pipeline from yaml file to use the model offline
and avoid manuel file manipulation.
@@ -22,4 +22,36 @@ def config_diarization_yaml(file):
:param file: yaml file
:type file: yaml
"""
with open(file, "r") as stream:
yml = yaml.safe_load(stream)
stream.close()
if path_to_segmentation:
yml["pipeline"]["params"]["segmentation"] = path_to_segmentation
else:
yml["pipeline"]["params"]["segmentation"] = os.path.relpath(os.path.join(
os.path.dirname(__file__),
"models", "pyannote",
"segmentation",
"pytorch_model.bin"))
if path_to_embedding:
yml["pipeline"]["params"]["embedding"] = path_to_embedding
else:
yml["pipeline"]["params"]["embedding"] = os.path.relpath(
os.path.join(
os.path.dirname(__file__),
"models", "pyannote",
"speechbrain",
"spkrec-ecapa-voxceleb",
"embedding_model.ckpt"))
if not os.path.exists(yml["pipeline"]["params"]["segmentation"]):
raise FileNotFoundError(f"Segmentation model not found at {yml['pipeline']['params']['segmentation']}")
if not os.path.exists(yml["pipeline"]["params"]["embedding"]):
raise FileNotFoundError(f"Embedding model not found at {yml['pipeline']['params']['embedding']}")
with open(file, "w") as stream:
yaml.dump(yml, stream)
stream.close()