From 9c0766fc41a3ede97fcc580a817db16ac7779f84 Mon Sep 17 00:00:00 2001 From: Jaikinator Date: Fri, 9 Feb 2024 11:35:38 +0100 Subject: [PATCH] updated dependencies now scraibe works with torch 2 --- scraibe/diarisation.py | 16 ++++++++++++---- scraibe/misc.py | 4 +++- 2 files changed, 15 insertions(+), 5 deletions(-) diff --git a/scraibe/diarisation.py b/scraibe/diarisation.py index f90bcdb..1a33817 100644 --- a/scraibe/diarisation.py +++ b/scraibe/diarisation.py @@ -34,6 +34,8 @@ from typing import TypeVar, Union from pyannote.audio import Pipeline from pyannote.audio.pipelines.speaker_diarization import SpeakerDiarization from torch import Tensor +from torch import device as torch_device +from torch.cuda import is_available, current_device from .misc import PYANNOTE_DEFAULT_PATH, PYANNOTE_DEFAULT_CONFIG Annotation = TypeVar('Annotation') @@ -184,6 +186,7 @@ class Diariser: cache_token: bool = True, cache_dir: Union[Path, str] = PYANNOTE_DEFAULT_PATH, hparams_file: Union[str, Path] = None, + device: str = None, *args, **kwargs ) -> Pipeline: @@ -198,6 +201,7 @@ class Diariser: cache_token: Whether to cache the token locally for future use. cache_dir: Directory for caching models. hparams_file: Path to a YAML file containing hyperparameters. + device: Device to load the model on. args: Additional arguments only to avoid errors. kwargs: Additional keyword arguments only to avoid errors. @@ -205,20 +209,24 @@ class Diariser: Pipeline: A pyannote.audio Pipeline object, encapsulating the loaded model. """ + if cache_token and use_auth_token is not None: cls._save_token(use_auth_token) if not os.path.exists(model) and use_auth_token is None: use_auth_token = cls._get_token() - model = 'pyannote/speaker-diarization' - elif not os.path.exists(model) and use_auth_token is not None: - model = 'pyannote/speaker-diarization' - + _model = Pipeline.from_pretrained(model, use_auth_token = use_auth_token, cache_dir = cache_dir, hparams_file = hparams_file,) + # try to move the model to the device + if device is None: + device = "cuda" if is_available() else "cpu" + + _model = _model.to(torch_device(device)) # torch_device is renamed from torch.device to avoid name conflict + if _model is None: raise ValueError('Unable to load model either from local cache' \ 'or from huggingface.co models. Please check your token' \ diff --git a/scraibe/misc.py b/scraibe/misc.py index b1afeea..c912478 100644 --- a/scraibe/misc.py +++ b/scraibe/misc.py @@ -12,7 +12,9 @@ if CACHE_DIR != PYANNOTE_CACHE_DIR: WHISPER_DEFAULT_PATH = os.path.join(CACHE_DIR, "whisper") PYANNOTE_DEFAULT_PATH = os.path.join(CACHE_DIR, "pyannote") -PYANNOTE_DEFAULT_CONFIG = os.path.join(PYANNOTE_DEFAULT_PATH, "config.yaml") +PYANNOTE_DEFAULT_CONFIG = os.path.join(PYANNOTE_DEFAULT_PATH, "config.yaml") \ + if os.path.exists(os.path.join(PYANNOTE_DEFAULT_PATH, "config.yaml")) \ + else 'pyannote/speaker-diarization-3.1' def config_diarization_yaml(file_path: str, path_to_segmentation: str = None) -> None: """Configure diarization pipeline from a YAML file.