added function to controll paths to pyannote models
This commit is contained in:
+38
-6
@@ -6,15 +6,15 @@ import glob
|
|||||||
from warnings import warn
|
from warnings import warn
|
||||||
import yaml
|
import yaml
|
||||||
|
|
||||||
WHISPER_DEFAULT_PATH = os.path.join(os.path.dirname(__file__),
|
WHISPER_DEFAULT_PATH = os.path.relpath(os.path.join(os.path.dirname(__file__),
|
||||||
"models", "whisper")
|
"models", "whisper"))
|
||||||
|
|
||||||
PYANNOTE_DEFAULT_PATH = os.path.join(os.path.dirname(__file__),
|
PYANNOTE_DEFAULT_PATH = os.path.relpath(os.path.join(os.path.dirname(__file__),
|
||||||
"models", "pyannote",
|
"models", "pyannote",
|
||||||
"speaker_diarization", "config.yaml")
|
"speaker_diarization", "config.yaml"))
|
||||||
|
|
||||||
|
|
||||||
def config_diarization_yaml(file):
|
def config_diarization_yaml(file, path_to_segmentation = None, path_to_embedding = None):
|
||||||
"""
|
"""
|
||||||
Configure diarization pipeline from yaml file to use the model offline
|
Configure diarization pipeline from yaml file to use the model offline
|
||||||
and avoid manuel file manipulation.
|
and avoid manuel file manipulation.
|
||||||
@@ -22,4 +22,36 @@ def config_diarization_yaml(file):
|
|||||||
:param file: yaml file
|
:param file: yaml file
|
||||||
:type file: yaml
|
:type file: yaml
|
||||||
"""
|
"""
|
||||||
|
with open(file, "r") as stream:
|
||||||
|
yml = yaml.safe_load(stream)
|
||||||
|
stream.close()
|
||||||
|
if path_to_segmentation:
|
||||||
|
yml["pipeline"]["params"]["segmentation"] = path_to_segmentation
|
||||||
|
else:
|
||||||
|
yml["pipeline"]["params"]["segmentation"] = os.path.relpath(os.path.join(
|
||||||
|
os.path.dirname(__file__),
|
||||||
|
"models", "pyannote",
|
||||||
|
"segmentation",
|
||||||
|
"pytorch_model.bin"))
|
||||||
|
|
||||||
|
if path_to_embedding:
|
||||||
|
yml["pipeline"]["params"]["embedding"] = path_to_embedding
|
||||||
|
else:
|
||||||
|
yml["pipeline"]["params"]["embedding"] = os.path.relpath(
|
||||||
|
os.path.join(
|
||||||
|
os.path.dirname(__file__),
|
||||||
|
"models", "pyannote",
|
||||||
|
"speechbrain",
|
||||||
|
"spkrec-ecapa-voxceleb",
|
||||||
|
"embedding_model.ckpt"))
|
||||||
|
|
||||||
|
if not os.path.exists(yml["pipeline"]["params"]["segmentation"]):
|
||||||
|
raise FileNotFoundError(f"Segmentation model not found at {yml['pipeline']['params']['segmentation']}")
|
||||||
|
|
||||||
|
if not os.path.exists(yml["pipeline"]["params"]["embedding"]):
|
||||||
|
raise FileNotFoundError(f"Embedding model not found at {yml['pipeline']['params']['embedding']}")
|
||||||
|
|
||||||
|
with open(file, "w") as stream:
|
||||||
|
yaml.dump(yml, stream)
|
||||||
|
stream.close()
|
||||||
|
|
||||||
|
|||||||
Reference in New Issue
Block a user