Merge branch 'develop' into develop_gradio_app

This commit is contained in:
Jacob Schmieder
2024-02-12 12:42:56 +01:00
committed by GitHub
3 changed files with 41 additions and 5 deletions
+11
View File
@@ -36,6 +36,8 @@ from typing import TypeVar, Union
from pyannote.audio import Pipeline
from pyannote.audio.pipelines.speaker_diarization import SpeakerDiarization
from torch import Tensor
from torch import device as torch_device
from torch.cuda import is_available, current_device
from .misc import PYANNOTE_DEFAULT_PATH, PYANNOTE_DEFAULT_CONFIG
Annotation = TypeVar('Annotation')
@@ -186,6 +188,7 @@ class Diariser:
cache_token: bool = True,
cache_dir: Union[Path, str] = PYANNOTE_DEFAULT_PATH,
hparams_file: Union[str, Path] = None,
device: str = None,
*args, **kwargs
) -> Pipeline:
@@ -200,6 +203,7 @@ class Diariser:
cache_token: Whether to cache the token locally for future use.
cache_dir: Directory for caching models.
hparams_file: Path to a YAML file containing hyperparameters.
device: Device to load the model on.
args: Additional arguments only to avoid errors.
kwargs: Additional keyword arguments only to avoid errors.
@@ -207,6 +211,7 @@ class Diariser:
Pipeline: A pyannote.audio Pipeline object, encapsulating the loaded model.
"""
if cache_token and use_auth_token is not None:
cls._save_token(use_auth_token)
@@ -253,6 +258,12 @@ class Diariser:
cache_dir = cache_dir,
hparams_file = hparams_file,)
# try to move the model to the device
if device is None:
device = "cuda" if is_available() else "cpu"
_model = _model.to(torch_device(device)) # torch_device is renamed from torch.device to avoid name conflict
if _model is None:
raise ValueError('Unable to load model either from local cache' \
'or from huggingface.co models. Please check your token' \