unified docstrings

This commit is contained in:
Jaikinator
2023-08-23 15:39:20 +02:00
parent cab50cba70
commit 18e89fad99
+23 -18
View File
@@ -1,36 +1,41 @@
import os import os
import yaml import yaml
from pyannote.audio.core.model import CACHE_DIR as PYANNOTE_CACHE_DIR
CACHE_DIR = os.getenv( CACHE_DIR = os.getenv(
"AUTOT_CACHE", "AUTOT_CACHE",
os.path.expanduser("~/.cache/torch/models"), os.path.expanduser("~/.cache/torch/models"),
) )
if CACHE_DIR != PYANNOTE_CACHE_DIR:
os.environ["PYANNOTE_CACHE"] = os.path.join(CACHE_DIR, "pyannote")
WHISPER_DEFAULT_PATH = os.path.join(CACHE_DIR, "whisper") WHISPER_DEFAULT_PATH = os.path.join(CACHE_DIR, "whisper")
PYANNOTE_DEFAULT_PATH = os.path.join(CACHE_DIR, "pyannote") PYANNOTE_DEFAULT_PATH = os.path.join(CACHE_DIR, "pyannote")
PYANNOTE_DEFAULT_CONFIG = os.path.join(PYANNOTE_DEFAULT_PATH, "config.yaml") PYANNOTE_DEFAULT_CONFIG = os.path.join(PYANNOTE_DEFAULT_PATH, "config.yaml")
def config_diarization_yaml(file, path_to_segmentation = None):
"""
Configure diarization pipeline from yaml file to use the model offline
and avoid manuel file manipulation.
:param file: yaml file def config_diarization_yaml(file_path: str, path_to_segmentation: str = None) -> None:
:type file: yaml """Configure diarization pipeline from a YAML file.
This function updates the YAML file to use the given segmentation model
offline, and avoids manual file manipulation.
Args:
file_path (str): Path to the YAML file.
path_to_segmentation (str, optional): Optional path to the segmentation model.
Raises:
FileNotFoundError: If the segmentation model file is not found.
""" """
with open(file, "r") as stream: with open(file_path, "r") as stream:
yml = yaml.safe_load(stream) yml = yaml.safe_load(stream)
stream.close()
if path_to_segmentation:
yml["pipeline"]["params"]["segmentation"] = path_to_segmentation
else:
yml["pipeline"]["params"]["segmentation"] = os.path.join(PYANNOTE_DEFAULT_PATH, "pytorch_model.bin")
if not os.path.exists(yml["pipeline"]["params"]["segmentation"]): segmentation_path = path_to_segmentation or os.path.join(PYANNOTE_DEFAULT_PATH, "pytorch_model.bin")
raise FileNotFoundError(f"Segmentation model not found at {yml['pipeline']['params']['segmentation']}") yml["pipeline"]["params"]["segmentation"] = segmentation_path
with open(file, "w") as stream: if not os.path.exists(segmentation_path):
raise FileNotFoundError(f"Segmentation model not found at {segmentation_path}")
with open(file_path, "w") as stream:
yaml.dump(yml, stream) yaml.dump(yml, stream)
stream.close()