added deprecated warning

This commit is contained in:
Jaikinator
2023-06-12 11:50:20 +02:00
parent 7aa2ed667f
commit ca42d631cd
+17 -19
View File
@@ -3,20 +3,14 @@ from pyannote.audio import Pipeline
from whisper import Whisper, load_model from whisper import Whisper, load_model
import os import os
import glob import glob
from warnings import warn
def get_whisper_default_path() -> str: WHISPER_DEFAULT_PATH = os.path.join(os.path.dirname(os.path.dirname(__file__)),
""" "models", "whisper")
Get default path for whisper models
Returns PYANNOTE_DEFAULT_PATH = os.path.join(os.path.dirname(os.path.dirname(__file__)),
------- "models", "pyannote",
str "speaker_diarization", "config.yaml")
path
"""
_path = os.path.dirname(os.path.dirname(__file__))
return os.path.join(_path, "models", "whisper")
WHISPER_DEFAULT_PATH = get_whisper_default_path()
def load_whisper_model(model: str ="medium", local : bool = False, download_root: str = WHISPER_DEFAULT_PATH) -> Whisper: def load_whisper_model(model: str ="medium", local : bool = False, download_root: str = WHISPER_DEFAULT_PATH) -> Whisper:
""" """
@@ -52,9 +46,9 @@ def load_whisper_model(model: str ="medium", local : bool = False, download_root
------- -------
Whisper Object Whisper Object
""" """
warn("load_whisper_model is deprecated. Use Transcriptor.load_model() instead.", DeprecationWarning)
if local: if local:
available_models = [os.path.basename(x) for x in glob.glob(os.path.join(WHISPER_DEFAULT_PATH, "*"))] available_models = [os.path.basename(x) for x in glob.glob(os.path.join(download_root, "*"))]
for i, module in enumerate(available_models): for i, module in enumerate(available_models):
available_models[i] = module.split(".")[0] available_models[i] = module.split(".")[0]
@@ -62,9 +56,12 @@ def load_whisper_model(model: str ="medium", local : bool = False, download_root
if model not in available_models: if model not in available_models:
raise RuntimeError("Model not found. Consider downloading the model first. By deactivating the local flag, the model will be downloaded automatically.") raise RuntimeError("Model not found. Consider downloading the model first. By deactivating the local flag, the model will be downloaded automatically.")
return load_model(model, download_root=WHISPER_DEFAULT_PATH) return load_model(model, download_root=download_root)
def load_pyannote_model(model: str, token: str = "", local : bool = True) -> Pipeline: def load_pyannote_model(model: str = PYANNOTE_DEFAULT_PATH,
token: str = "",
local : bool = True,
*args, **kwargs) -> Pipeline:
""" """
Load modules from pyannote Load modules from pyannote
@@ -72,6 +69,7 @@ def load_pyannote_model(model: str, token: str = "", local : bool = True) -> Pip
---------- ----------
model : str model : str
pyannote model pyannote model
default: /models/pyannote/speaker_diarization/config.yaml
token : str token : str
HUGGINGFACE_TOKEN HUGGINGFACE_TOKEN
local : bool local : bool
@@ -81,8 +79,8 @@ def load_pyannote_model(model: str, token: str = "", local : bool = True) -> Pip
------- -------
Pipeline Object Pipeline Object
""" """
warn("load_pyannote_model is deprecated. Use Diarisation.load_model() instead.", DeprecationWarning)
if local: if local:
return Pipeline.from_pretrained(model) return Pipeline.from_pretrained(model,*args, **kwargs)
else: else:
return Pipeline.from_pretrained(model, use_auth_token = token) return Pipeline.from_pretrained(model, use_auth_token = token, *args, **kwargs)