unified docstrings
This commit is contained in:
+24
-19
@@ -1,36 +1,41 @@
|
|||||||
import os
|
import os
|
||||||
import yaml
|
import yaml
|
||||||
|
from pyannote.audio.core.model import CACHE_DIR as PYANNOTE_CACHE_DIR
|
||||||
|
|
||||||
CACHE_DIR = os.getenv(
|
CACHE_DIR = os.getenv(
|
||||||
"AUTOT_CACHE",
|
"AUTOT_CACHE",
|
||||||
os.path.expanduser("~/.cache/torch/models"),
|
os.path.expanduser("~/.cache/torch/models"),
|
||||||
)
|
)
|
||||||
|
|
||||||
|
if CACHE_DIR != PYANNOTE_CACHE_DIR:
|
||||||
|
os.environ["PYANNOTE_CACHE"] = os.path.join(CACHE_DIR, "pyannote")
|
||||||
|
|
||||||
WHISPER_DEFAULT_PATH = os.path.join(CACHE_DIR, "whisper")
|
WHISPER_DEFAULT_PATH = os.path.join(CACHE_DIR, "whisper")
|
||||||
|
|
||||||
PYANNOTE_DEFAULT_PATH = os.path.join(CACHE_DIR, "pyannote")
|
PYANNOTE_DEFAULT_PATH = os.path.join(CACHE_DIR, "pyannote")
|
||||||
|
|
||||||
PYANNOTE_DEFAULT_CONFIG = os.path.join(PYANNOTE_DEFAULT_PATH, "config.yaml")
|
PYANNOTE_DEFAULT_CONFIG = os.path.join(PYANNOTE_DEFAULT_PATH, "config.yaml")
|
||||||
|
|
||||||
def config_diarization_yaml(file, path_to_segmentation = None):
|
|
||||||
|
def config_diarization_yaml(file_path: str, path_to_segmentation: str = None) -> None:
|
||||||
|
"""Configure diarization pipeline from a YAML file.
|
||||||
|
|
||||||
|
This function updates the YAML file to use the given segmentation model
|
||||||
|
offline, and avoids manual file manipulation.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
file_path (str): Path to the YAML file.
|
||||||
|
path_to_segmentation (str, optional): Optional path to the segmentation model.
|
||||||
|
|
||||||
|
Raises:
|
||||||
|
FileNotFoundError: If the segmentation model file is not found.
|
||||||
"""
|
"""
|
||||||
Configure diarization pipeline from yaml file to use the model offline
|
with open(file_path, "r") as stream:
|
||||||
and avoid manuel file manipulation.
|
yml = yaml.safe_load(stream)
|
||||||
|
|
||||||
:param file: yaml file
|
segmentation_path = path_to_segmentation or os.path.join(PYANNOTE_DEFAULT_PATH, "pytorch_model.bin")
|
||||||
:type file: yaml
|
yml["pipeline"]["params"]["segmentation"] = segmentation_path
|
||||||
"""
|
|
||||||
with open(file, "r") as stream:
|
|
||||||
yml = yaml.safe_load(stream)
|
|
||||||
stream.close()
|
|
||||||
if path_to_segmentation:
|
|
||||||
yml["pipeline"]["params"]["segmentation"] = path_to_segmentation
|
|
||||||
else:
|
|
||||||
yml["pipeline"]["params"]["segmentation"] = os.path.join(PYANNOTE_DEFAULT_PATH, "pytorch_model.bin")
|
|
||||||
|
|
||||||
if not os.path.exists(yml["pipeline"]["params"]["segmentation"]):
|
if not os.path.exists(segmentation_path):
|
||||||
raise FileNotFoundError(f"Segmentation model not found at {yml['pipeline']['params']['segmentation']}")
|
raise FileNotFoundError(f"Segmentation model not found at {segmentation_path}")
|
||||||
|
|
||||||
with open(file, "w") as stream:
|
with open(file_path, "w") as stream:
|
||||||
yaml.dump(yml, stream)
|
yaml.dump(yml, stream)
|
||||||
stream.close()
|
|
||||||
|
|||||||
Reference in New Issue
Block a user