From cd35ad8903b63353c01145223598ae09fad8d0a8 Mon Sep 17 00:00:00 2001 From: Jaikinator Date: Fri, 30 Jun 2023 18:41:43 +0200 Subject: [PATCH] solved path issues --- autotranscript/misc.py | 49 ++++++++++++++++++++---------------------- 1 file changed, 23 insertions(+), 26 deletions(-) diff --git a/autotranscript/misc.py b/autotranscript/misc.py index 1c14198..1eaf34f 100644 --- a/autotranscript/misc.py +++ b/autotranscript/misc.py @@ -1,4 +1,3 @@ - from pyannote.audio import Pipeline from whisper import Whisper, load_model import os @@ -6,15 +5,18 @@ import glob from warnings import warn import yaml -WHISPER_DEFAULT_PATH = os.path.relpath(os.path.join(os.path.dirname(__file__), - "models", "whisper")) +CACHE_DIR = os.getenv( + "AUTOT_CACHE", + os.path.expanduser("~/.cache/torch/models"), +) -PYANNOTE_DEFAULT_PATH = os.path.relpath(os.path.join(os.path.dirname(__file__), - "models", "pyannote", - "speaker_diarization", "config.yaml")) +WHISPER_DEFAULT_PATH = os.path.join(CACHE_DIR, "whisper") +PYANNOTE_DEFAULT_PATH = os.path.join(CACHE_DIR, "pyannote") -def config_diarization_yaml(file, path_to_segmentation = None, path_to_embedding = None): +PYANNOTE_DEFAULT_CONFIG = os.path.join(PYANNOTE_DEFAULT_PATH, "config.yaml") + +def config_diarization_yaml(file, path_to_segmentation = None): """ Configure diarization pipeline from yaml file to use the model offline and avoid manuel file manipulation. @@ -28,30 +30,25 @@ def config_diarization_yaml(file, path_to_segmentation = None, path_to_embedding if path_to_segmentation: yml["pipeline"]["params"]["segmentation"] = path_to_segmentation else: - yml["pipeline"]["params"]["segmentation"] = os.path.relpath(os.path.join( - os.path.dirname(__file__), - "models", "pyannote", - "segmentation", - "pytorch_model.bin")) + yml["pipeline"]["params"]["segmentation"] = os.path.join(PYANNOTE_DEFAULT_PATH, "pytorch_model.bin") - if path_to_embedding: - yml["pipeline"]["params"]["embedding"] = path_to_embedding - else: - yml["pipeline"]["params"]["embedding"] = os.path.relpath( - os.path.join( - os.path.dirname(__file__), - "models", "pyannote", - "speechbrain", - "spkrec-ecapa-voxceleb", - "embedding_model.ckpt")) + # if path_to_embedding: + # yml["pipeline"]["params"]["embedding"] = path_to_embedding + # else: + # yml["pipeline"]["params"]["embedding"] = os.path.relpath( + # os.path.join( + # os.path.dirname(__file__), + # "models", "pyannote", + # "speechbrain", + # "spkrec-ecapa-voxceleb", + # "embedding_model.ckpt")) if not os.path.exists(yml["pipeline"]["params"]["segmentation"]): raise FileNotFoundError(f"Segmentation model not found at {yml['pipeline']['params']['segmentation']}") - if not os.path.exists(yml["pipeline"]["params"]["embedding"]): - raise FileNotFoundError(f"Embedding model not found at {yml['pipeline']['params']['embedding']}") + # if not os.path.exists(yml["pipeline"]["params"]["embedding"]): + # raise FileNotFoundError(f"Embedding model not found at {yml['pipeline']['params']['embedding']}") with open(file, "w") as stream: yaml.dump(yml, stream) - stream.close() - + stream.close()