Merge pull request #45 from JSchmie/update-ml-deps

updated dependencies now scraibe works with torch 2
This commit is contained in:
Jacob Schmieder
2024-02-09 12:24:50 +01:00
committed by GitHub
3 changed files with 22 additions and 10 deletions
+7 -5
View File
@@ -1,10 +1,12 @@
openai-whisper==20230314 torch~=2.2.0
openai-whisper~=20231117
numpy~=1.23.5 numpy~=1.23.5
pyannote.audio~=2.1.1 pyannote.audio~=3.1.1
pyannote.core~=4.5 pyannote.core~=5.0.0
pyannote.database~=4.1.3 pyannote.database~=5.0.1
pyannote.metrics~=3.2.1 pyannote.metrics~=3.2.1
pyannote.pipeline~=2.3 pyannote.pipeline~=3.0.1
setuptools~=65.6.3 setuptools~=65.6.3
setuptools-rust~=1.5.2 setuptools-rust~=1.5.2
+11 -3
View File
@@ -34,6 +34,8 @@ from typing import TypeVar, Union
from pyannote.audio import Pipeline from pyannote.audio import Pipeline
from pyannote.audio.pipelines.speaker_diarization import SpeakerDiarization from pyannote.audio.pipelines.speaker_diarization import SpeakerDiarization
from torch import Tensor from torch import Tensor
from torch import device as torch_device
from torch.cuda import is_available, current_device
from .misc import PYANNOTE_DEFAULT_PATH, PYANNOTE_DEFAULT_CONFIG from .misc import PYANNOTE_DEFAULT_PATH, PYANNOTE_DEFAULT_CONFIG
Annotation = TypeVar('Annotation') Annotation = TypeVar('Annotation')
@@ -184,6 +186,7 @@ class Diariser:
cache_token: bool = True, cache_token: bool = True,
cache_dir: Union[Path, str] = PYANNOTE_DEFAULT_PATH, cache_dir: Union[Path, str] = PYANNOTE_DEFAULT_PATH,
hparams_file: Union[str, Path] = None, hparams_file: Union[str, Path] = None,
device: str = None,
*args, **kwargs *args, **kwargs
) -> Pipeline: ) -> Pipeline:
@@ -198,6 +201,7 @@ class Diariser:
cache_token: Whether to cache the token locally for future use. cache_token: Whether to cache the token locally for future use.
cache_dir: Directory for caching models. cache_dir: Directory for caching models.
hparams_file: Path to a YAML file containing hyperparameters. hparams_file: Path to a YAML file containing hyperparameters.
device: Device to load the model on.
args: Additional arguments only to avoid errors. args: Additional arguments only to avoid errors.
kwargs: Additional keyword arguments only to avoid errors. kwargs: Additional keyword arguments only to avoid errors.
@@ -205,20 +209,24 @@ class Diariser:
Pipeline: A pyannote.audio Pipeline object, encapsulating the loaded model. Pipeline: A pyannote.audio Pipeline object, encapsulating the loaded model.
""" """
if cache_token and use_auth_token is not None: if cache_token and use_auth_token is not None:
cls._save_token(use_auth_token) cls._save_token(use_auth_token)
if not os.path.exists(model) and use_auth_token is None: if not os.path.exists(model) and use_auth_token is None:
use_auth_token = cls._get_token() use_auth_token = cls._get_token()
model = 'pyannote/speaker-diarization'
elif not os.path.exists(model) and use_auth_token is not None:
model = 'pyannote/speaker-diarization'
_model = Pipeline.from_pretrained(model, _model = Pipeline.from_pretrained(model,
use_auth_token = use_auth_token, use_auth_token = use_auth_token,
cache_dir = cache_dir, cache_dir = cache_dir,
hparams_file = hparams_file,) hparams_file = hparams_file,)
# try to move the model to the device
if device is None:
device = "cuda" if is_available() else "cpu"
_model = _model.to(torch_device(device)) # torch_device is renamed from torch.device to avoid name conflict
if _model is None: if _model is None:
raise ValueError('Unable to load model either from local cache' \ raise ValueError('Unable to load model either from local cache' \
'or from huggingface.co models. Please check your token' \ 'or from huggingface.co models. Please check your token' \
+3 -1
View File
@@ -12,7 +12,9 @@ if CACHE_DIR != PYANNOTE_CACHE_DIR:
WHISPER_DEFAULT_PATH = os.path.join(CACHE_DIR, "whisper") WHISPER_DEFAULT_PATH = os.path.join(CACHE_DIR, "whisper")
PYANNOTE_DEFAULT_PATH = os.path.join(CACHE_DIR, "pyannote") PYANNOTE_DEFAULT_PATH = os.path.join(CACHE_DIR, "pyannote")
PYANNOTE_DEFAULT_CONFIG = os.path.join(PYANNOTE_DEFAULT_PATH, "config.yaml") PYANNOTE_DEFAULT_CONFIG = os.path.join(PYANNOTE_DEFAULT_PATH, "config.yaml") \
if os.path.exists(os.path.join(PYANNOTE_DEFAULT_PATH, "config.yaml")) \
else 'pyannote/speaker-diarization-3.1'
def config_diarization_yaml(file_path: str, path_to_segmentation: str = None) -> None: def config_diarization_yaml(file_path: str, path_to_segmentation: str = None) -> None:
"""Configure diarization pipeline from a YAML file. """Configure diarization pipeline from a YAML file.