Merge pull request #71 from JSchmie/develop_hf_wrapper
Add default path to pyannote model with fallback option.
This commit is contained in:
+36
-18
@@ -38,6 +38,8 @@ from pyannote.audio.pipelines.speaker_diarization import SpeakerDiarization
|
|||||||
from torch import Tensor
|
from torch import Tensor
|
||||||
from torch import device as torch_device
|
from torch import device as torch_device
|
||||||
from torch.cuda import is_available, current_device
|
from torch.cuda import is_available, current_device
|
||||||
|
from huggingface_hub import HfApi
|
||||||
|
from huggingface_hub.utils import RepositoryNotFoundError
|
||||||
|
|
||||||
from .misc import PYANNOTE_DEFAULT_PATH, PYANNOTE_DEFAULT_CONFIG
|
from .misc import PYANNOTE_DEFAULT_PATH, PYANNOTE_DEFAULT_CONFIG
|
||||||
Annotation = TypeVar('Annotation')
|
Annotation = TypeVar('Annotation')
|
||||||
@@ -185,7 +187,7 @@ class Diariser:
|
|||||||
def load_model(cls,
|
def load_model(cls,
|
||||||
model: str = PYANNOTE_DEFAULT_CONFIG,
|
model: str = PYANNOTE_DEFAULT_CONFIG,
|
||||||
use_auth_token: str = None,
|
use_auth_token: str = None,
|
||||||
cache_token: bool = True,
|
cache_token: bool = False,
|
||||||
cache_dir: Union[Path, str] = PYANNOTE_DEFAULT_PATH,
|
cache_dir: Union[Path, str] = PYANNOTE_DEFAULT_PATH,
|
||||||
hparams_file: Union[str, Path] = None,
|
hparams_file: Union[str, Path] = None,
|
||||||
device: str = None,
|
device: str = None,
|
||||||
@@ -194,11 +196,12 @@ class Diariser:
|
|||||||
|
|
||||||
"""
|
"""
|
||||||
Loads a pretrained model from pyannote.audio,
|
Loads a pretrained model from pyannote.audio,
|
||||||
either from a local cache or online repository.
|
either from a local cache or some online repository.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
model: Path or identifier for the pyannote model.
|
model: Path or identifier for the pyannote model.
|
||||||
default: /models/pyannote/speaker_diarization/config.yaml
|
default: '/home/[user]/.cache/torch/models/pyannote/config.yaml'
|
||||||
|
or one of 'jaikinator/scraibe', 'pyannote/speaker-diarization-3.1'
|
||||||
token: Optional HUGGINGFACE_TOKEN for authenticated access.
|
token: Optional HUGGINGFACE_TOKEN for authenticated access.
|
||||||
cache_token: Whether to cache the token locally for future use.
|
cache_token: Whether to cache the token locally for future use.
|
||||||
cache_dir: Directory for caching models.
|
cache_dir: Directory for caching models.
|
||||||
@@ -210,15 +213,7 @@ class Diariser:
|
|||||||
Returns:
|
Returns:
|
||||||
Pipeline: A pyannote.audio Pipeline object, encapsulating the loaded model.
|
Pipeline: A pyannote.audio Pipeline object, encapsulating the loaded model.
|
||||||
"""
|
"""
|
||||||
|
if isinstance(model, str) and os.path.exists(model):
|
||||||
|
|
||||||
if cache_token and use_auth_token is not None:
|
|
||||||
cls._save_token(use_auth_token)
|
|
||||||
|
|
||||||
if not os.path.exists(model) and use_auth_token is None:
|
|
||||||
use_auth_token = cls._get_token()
|
|
||||||
|
|
||||||
elif os.path.exists(model) and not use_auth_token:
|
|
||||||
# check if model can be found locally nearby the config file
|
# check if model can be found locally nearby the config file
|
||||||
with open(model, 'r') as file:
|
with open(model, 'r') as file:
|
||||||
config = yaml.safe_load(file)
|
config = yaml.safe_load(file)
|
||||||
@@ -226,7 +221,7 @@ class Diariser:
|
|||||||
path_to_model = config['pipeline']['params']['segmentation']
|
path_to_model = config['pipeline']['params']['segmentation']
|
||||||
|
|
||||||
if not os.path.exists(path_to_model):
|
if not os.path.exists(path_to_model):
|
||||||
warnings.warn(f"Model not found at {path_to_model}. "\
|
warnings.warn(f"Model not found at {path_to_model}. "
|
||||||
"Trying to find it nearby the config file.")
|
"Trying to find it nearby the config file.")
|
||||||
|
|
||||||
pwd = model.split("/")[:-1]
|
pwd = model.split("/")[:-1]
|
||||||
@@ -237,6 +232,10 @@ class Diariser:
|
|||||||
if not os.path.exists(path_to_model):
|
if not os.path.exists(path_to_model):
|
||||||
warnings.warn(f"Model not found at {path_to_model}. \
|
warnings.warn(f"Model not found at {path_to_model}. \
|
||||||
'Trying to find it nearby .bin files instead.")
|
'Trying to find it nearby .bin files instead.")
|
||||||
|
warnings.warn(
|
||||||
|
'Searching for nearby files in a folder path is '
|
||||||
|
'deprecated and will be removed in future versions.',
|
||||||
|
category=DeprecationWarning)
|
||||||
# list elementes with the ending .bin
|
# list elementes with the ending .bin
|
||||||
bin_files = [f for f in os.listdir(pwd) if f.endswith(".bin")]
|
bin_files = [f for f in os.listdir(pwd) if f.endswith(".bin")]
|
||||||
if len(bin_files) == 1:
|
if len(bin_files) == 1:
|
||||||
@@ -245,6 +244,7 @@ class Diariser:
|
|||||||
warnings.warn("Found more than one .bin file. "\
|
warnings.warn("Found more than one .bin file. "\
|
||||||
"or none. Please specify the path to the model " \
|
"or none. Please specify the path to the model " \
|
||||||
"or setup a huggingface token.")
|
"or setup a huggingface token.")
|
||||||
|
raise FileNotFoundError
|
||||||
|
|
||||||
warnings.warn(f"Found model at {path_to_model} overwriting config file.")
|
warnings.warn(f"Found model at {path_to_model} overwriting config file.")
|
||||||
|
|
||||||
@@ -252,11 +252,34 @@ class Diariser:
|
|||||||
|
|
||||||
with open(model, 'w') as file:
|
with open(model, 'w') as file:
|
||||||
yaml.dump(config, file)
|
yaml.dump(config, file)
|
||||||
|
elif isinstance(model, tuple):
|
||||||
|
try:
|
||||||
|
_model = model[0]
|
||||||
|
HfApi().model_info(_model)
|
||||||
|
model = _model
|
||||||
|
use_auth_token = None
|
||||||
|
except RepositoryNotFoundError:
|
||||||
|
print(f'{model[0]} not found on Huggingface, \
|
||||||
|
trying {model[1]}')
|
||||||
|
_model = model[1]
|
||||||
|
HfApi().model_info(_model)
|
||||||
|
model = _model
|
||||||
|
if cache_token and use_auth_token is not None:
|
||||||
|
cls._save_token(use_auth_token)
|
||||||
|
|
||||||
|
if use_auth_token is None:
|
||||||
|
use_auth_token = cls._get_token()
|
||||||
|
else:
|
||||||
|
raise FileNotFoundError(f'No local model or directory found at {model}.')
|
||||||
|
|
||||||
_model = Pipeline.from_pretrained(model,
|
_model = Pipeline.from_pretrained(model,
|
||||||
use_auth_token=use_auth_token,
|
use_auth_token=use_auth_token,
|
||||||
cache_dir=cache_dir,
|
cache_dir=cache_dir,
|
||||||
hparams_file=hparams_file,)
|
hparams_file=hparams_file,)
|
||||||
|
if _model is None:
|
||||||
|
raise ValueError('Unable to load model either from local cache' \
|
||||||
|
'or from huggingface.co models. Please check your token' \
|
||||||
|
'or your local model path')
|
||||||
|
|
||||||
# try to move the model to the device
|
# try to move the model to the device
|
||||||
if device is None:
|
if device is None:
|
||||||
@@ -264,11 +287,6 @@ class Diariser:
|
|||||||
|
|
||||||
_model = _model.to(torch_device(device)) # torch_device is renamed from torch.device to avoid name conflict
|
_model = _model.to(torch_device(device)) # torch_device is renamed from torch.device to avoid name conflict
|
||||||
|
|
||||||
if _model is None:
|
|
||||||
raise ValueError('Unable to load model either from local cache' \
|
|
||||||
'or from huggingface.co models. Please check your token' \
|
|
||||||
'or your local model path')
|
|
||||||
|
|
||||||
return cls(_model)
|
return cls(_model)
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
|
|||||||
+1
-1
@@ -15,7 +15,7 @@ WHISPER_DEFAULT_PATH = os.path.join(CACHE_DIR, "whisper")
|
|||||||
PYANNOTE_DEFAULT_PATH = os.path.join(CACHE_DIR, "pyannote")
|
PYANNOTE_DEFAULT_PATH = os.path.join(CACHE_DIR, "pyannote")
|
||||||
PYANNOTE_DEFAULT_CONFIG = os.path.join(PYANNOTE_DEFAULT_PATH, "config.yaml") \
|
PYANNOTE_DEFAULT_CONFIG = os.path.join(PYANNOTE_DEFAULT_PATH, "config.yaml") \
|
||||||
if os.path.exists(os.path.join(PYANNOTE_DEFAULT_PATH, "config.yaml")) \
|
if os.path.exists(os.path.join(PYANNOTE_DEFAULT_PATH, "config.yaml")) \
|
||||||
else 'pyannote/speaker-diarization-3.1'
|
else ('jaikinator/scraibe', 'pyannote/speaker-diarization-3.1')
|
||||||
|
|
||||||
def config_diarization_yaml(file_path: str, path_to_segmentation: str = None) -> None:
|
def config_diarization_yaml(file_path: str, path_to_segmentation: str = None) -> None:
|
||||||
"""Configure diarization pipeline from a YAML file.
|
"""Configure diarization pipeline from a YAML file.
|
||||||
|
|||||||
Reference in New Issue
Block a user