Fixed cache default value, moved ValuError t othe right place, added to docstring.
This commit is contained in:
+13
-13
@@ -187,7 +187,7 @@ class Diariser:
|
|||||||
def load_model(cls,
|
def load_model(cls,
|
||||||
model: str = PYANNOTE_DEFAULT_CONFIG,
|
model: str = PYANNOTE_DEFAULT_CONFIG,
|
||||||
use_auth_token: str = None,
|
use_auth_token: str = None,
|
||||||
cache_token: bool = True,
|
cache_token: bool = False,
|
||||||
cache_dir: Union[Path, str] = PYANNOTE_DEFAULT_PATH,
|
cache_dir: Union[Path, str] = PYANNOTE_DEFAULT_PATH,
|
||||||
hparams_file: Union[str, Path] = None,
|
hparams_file: Union[str, Path] = None,
|
||||||
device: str = None,
|
device: str = None,
|
||||||
@@ -196,11 +196,12 @@ class Diariser:
|
|||||||
|
|
||||||
"""
|
"""
|
||||||
Loads a pretrained model from pyannote.audio,
|
Loads a pretrained model from pyannote.audio,
|
||||||
either from a local cache or online repository.
|
either from a local cache or some online repository.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
model: Path or identifier for the pyannote model.
|
model: Path or identifier for the pyannote model.
|
||||||
default: /models/pyannote/speaker_diarization/config.yaml
|
default: '/home/[user]/.cache/torch/models/pyannote/config.yaml'
|
||||||
|
or one of 'jaikinator/scraibe', 'pyannote/speaker-diarization-3.1'
|
||||||
token: Optional HUGGINGFACE_TOKEN for authenticated access.
|
token: Optional HUGGINGFACE_TOKEN for authenticated access.
|
||||||
cache_token: Whether to cache the token locally for future use.
|
cache_token: Whether to cache the token locally for future use.
|
||||||
cache_dir: Directory for caching models.
|
cache_dir: Directory for caching models.
|
||||||
@@ -261,8 +262,8 @@ class Diariser:
|
|||||||
model = _model
|
model = _model
|
||||||
if cache_token and use_auth_token is not None:
|
if cache_token and use_auth_token is not None:
|
||||||
cls._save_token(use_auth_token)
|
cls._save_token(use_auth_token)
|
||||||
|
|
||||||
if not os.path.exists(model) and use_auth_token is None:
|
if use_auth_token is None:
|
||||||
use_auth_token = cls._get_token()
|
use_auth_token = cls._get_token()
|
||||||
else:
|
else:
|
||||||
raise FileNotFoundError(f'No local model or directory found at {model}.')
|
raise FileNotFoundError(f'No local model or directory found at {model}.')
|
||||||
@@ -271,18 +272,17 @@ class Diariser:
|
|||||||
use_auth_token=use_auth_token,
|
use_auth_token=use_auth_token,
|
||||||
cache_dir=cache_dir,
|
cache_dir=cache_dir,
|
||||||
hparams_file=hparams_file,)
|
hparams_file=hparams_file,)
|
||||||
|
|
||||||
# try to move the model to the device
|
|
||||||
if device is None:
|
|
||||||
device = "cuda" if is_available() else "cpu"
|
|
||||||
|
|
||||||
_model = _model.to(torch_device(device)) # torch_device is renamed from torch.device to avoid name conflict
|
|
||||||
|
|
||||||
if _model is None:
|
if _model is None:
|
||||||
raise ValueError('Unable to load model either from local cache' \
|
raise ValueError('Unable to load model either from local cache' \
|
||||||
'or from huggingface.co models. Please check your token' \
|
'or from huggingface.co models. Please check your token' \
|
||||||
'or your local model path')
|
'or your local model path')
|
||||||
|
|
||||||
|
# try to move the model to the device
|
||||||
|
if device is None:
|
||||||
|
device = "cuda" if is_available() else "cpu"
|
||||||
|
|
||||||
|
_model = _model.to(torch_device(device)) # torch_device is renamed from torch.device to avoid name conflict
|
||||||
|
|
||||||
return cls(_model)
|
return cls(_model)
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
|
|||||||
Reference in New Issue
Block a user