From 9c0766fc41a3ede97fcc580a817db16ac7779f84 Mon Sep 17 00:00:00 2001 From: Jaikinator Date: Fri, 9 Feb 2024 11:35:38 +0100 Subject: [PATCH 1/2] updated dependencies now scraibe works with torch 2 --- scraibe/diarisation.py | 16 ++++++++++++---- scraibe/misc.py | 4 +++- 2 files changed, 15 insertions(+), 5 deletions(-) diff --git a/scraibe/diarisation.py b/scraibe/diarisation.py index f90bcdb..1a33817 100644 --- a/scraibe/diarisation.py +++ b/scraibe/diarisation.py @@ -34,6 +34,8 @@ from typing import TypeVar, Union from pyannote.audio import Pipeline from pyannote.audio.pipelines.speaker_diarization import SpeakerDiarization from torch import Tensor +from torch import device as torch_device +from torch.cuda import is_available, current_device from .misc import PYANNOTE_DEFAULT_PATH, PYANNOTE_DEFAULT_CONFIG Annotation = TypeVar('Annotation') @@ -184,6 +186,7 @@ class Diariser: cache_token: bool = True, cache_dir: Union[Path, str] = PYANNOTE_DEFAULT_PATH, hparams_file: Union[str, Path] = None, + device: str = None, *args, **kwargs ) -> Pipeline: @@ -198,6 +201,7 @@ class Diariser: cache_token: Whether to cache the token locally for future use. cache_dir: Directory for caching models. hparams_file: Path to a YAML file containing hyperparameters. + device: Device to load the model on. args: Additional arguments only to avoid errors. kwargs: Additional keyword arguments only to avoid errors. @@ -205,20 +209,24 @@ class Diariser: Pipeline: A pyannote.audio Pipeline object, encapsulating the loaded model. """ + if cache_token and use_auth_token is not None: cls._save_token(use_auth_token) if not os.path.exists(model) and use_auth_token is None: use_auth_token = cls._get_token() - model = 'pyannote/speaker-diarization' - elif not os.path.exists(model) and use_auth_token is not None: - model = 'pyannote/speaker-diarization' - + _model = Pipeline.from_pretrained(model, use_auth_token = use_auth_token, cache_dir = cache_dir, hparams_file = hparams_file,) + # try to move the model to the device + if device is None: + device = "cuda" if is_available() else "cpu" + + _model = _model.to(torch_device(device)) # torch_device is renamed from torch.device to avoid name conflict + if _model is None: raise ValueError('Unable to load model either from local cache' \ 'or from huggingface.co models. Please check your token' \ diff --git a/scraibe/misc.py b/scraibe/misc.py index b1afeea..c912478 100644 --- a/scraibe/misc.py +++ b/scraibe/misc.py @@ -12,7 +12,9 @@ if CACHE_DIR != PYANNOTE_CACHE_DIR: WHISPER_DEFAULT_PATH = os.path.join(CACHE_DIR, "whisper") PYANNOTE_DEFAULT_PATH = os.path.join(CACHE_DIR, "pyannote") -PYANNOTE_DEFAULT_CONFIG = os.path.join(PYANNOTE_DEFAULT_PATH, "config.yaml") +PYANNOTE_DEFAULT_CONFIG = os.path.join(PYANNOTE_DEFAULT_PATH, "config.yaml") \ + if os.path.exists(os.path.join(PYANNOTE_DEFAULT_PATH, "config.yaml")) \ + else 'pyannote/speaker-diarization-3.1' def config_diarization_yaml(file_path: str, path_to_segmentation: str = None) -> None: """Configure diarization pipeline from a YAML file. From df79a78a47bc1aabf3f92f9df9703f9b4261d212 Mon Sep 17 00:00:00 2001 From: Jaikinator Date: Fri, 9 Feb 2024 12:17:43 +0100 Subject: [PATCH 2/2] updated dependency list --- requirements.txt | 12 +++++++----- 1 file changed, 7 insertions(+), 5 deletions(-) diff --git a/requirements.txt b/requirements.txt index aed43e8..8cf1782 100644 --- a/requirements.txt +++ b/requirements.txt @@ -1,10 +1,12 @@ -openai-whisper==20230314 +torch~=2.2.0 + +openai-whisper~=20231117 numpy~=1.23.5 -pyannote.audio~=2.1.1 -pyannote.core~=4.5 -pyannote.database~=4.1.3 +pyannote.audio~=3.1.1 +pyannote.core~=5.0.0 +pyannote.database~=5.0.1 pyannote.metrics~=3.2.1 -pyannote.pipeline~=2.3 +pyannote.pipeline~=3.0.1 setuptools~=65.6.3 setuptools-rust~=1.5.2