From ca42d631cdeefc9cef1b37c9de02be9af31230a5 Mon Sep 17 00:00:00 2001 From: Jaikinator Date: Mon, 12 Jun 2023 11:50:20 +0200 Subject: [PATCH] added deprecated warning --- autotranscript/misc.py | 36 +++++++++++++++++------------------- 1 file changed, 17 insertions(+), 19 deletions(-) diff --git a/autotranscript/misc.py b/autotranscript/misc.py index 91008fd..065e45d 100644 --- a/autotranscript/misc.py +++ b/autotranscript/misc.py @@ -3,20 +3,14 @@ from pyannote.audio import Pipeline from whisper import Whisper, load_model import os import glob +from warnings import warn -def get_whisper_default_path() -> str: - """ - Get default path for whisper models +WHISPER_DEFAULT_PATH = os.path.join(os.path.dirname(os.path.dirname(__file__)), + "models", "whisper") - Returns - ------- - str - path - """ - _path = os.path.dirname(os.path.dirname(__file__)) - return os.path.join(_path, "models", "whisper") - -WHISPER_DEFAULT_PATH = get_whisper_default_path() +PYANNOTE_DEFAULT_PATH = os.path.join(os.path.dirname(os.path.dirname(__file__)), + "models", "pyannote", + "speaker_diarization", "config.yaml") def load_whisper_model(model: str ="medium", local : bool = False, download_root: str = WHISPER_DEFAULT_PATH) -> Whisper: """ @@ -52,9 +46,9 @@ def load_whisper_model(model: str ="medium", local : bool = False, download_root ------- Whisper Object """ - + warn("load_whisper_model is deprecated. Use Transcriptor.load_model() instead.", DeprecationWarning) if local: - available_models = [os.path.basename(x) for x in glob.glob(os.path.join(WHISPER_DEFAULT_PATH, "*"))] + available_models = [os.path.basename(x) for x in glob.glob(os.path.join(download_root, "*"))] for i, module in enumerate(available_models): available_models[i] = module.split(".")[0] @@ -62,9 +56,12 @@ def load_whisper_model(model: str ="medium", local : bool = False, download_root if model not in available_models: raise RuntimeError("Model not found. Consider downloading the model first. By deactivating the local flag, the model will be downloaded automatically.") - return load_model(model, download_root=WHISPER_DEFAULT_PATH) + return load_model(model, download_root=download_root) -def load_pyannote_model(model: str, token: str = "", local : bool = True) -> Pipeline: +def load_pyannote_model(model: str = PYANNOTE_DEFAULT_PATH, + token: str = "", + local : bool = True, + *args, **kwargs) -> Pipeline: """ Load modules from pyannote @@ -72,6 +69,7 @@ def load_pyannote_model(model: str, token: str = "", local : bool = True) -> Pip ---------- model : str pyannote model + default: /models/pyannote/speaker_diarization/config.yaml token : str HUGGINGFACE_TOKEN local : bool @@ -81,8 +79,8 @@ def load_pyannote_model(model: str, token: str = "", local : bool = True) -> Pip ------- Pipeline Object """ - + warn("load_pyannote_model is deprecated. Use Diarisation.load_model() instead.", DeprecationWarning) if local: - return Pipeline.from_pretrained(model) + return Pipeline.from_pretrained(model,*args, **kwargs) else: - return Pipeline.from_pretrained(model, use_auth_token = token) + return Pipeline.from_pretrained(model, use_auth_token = token, *args, **kwargs)