updated dependencies now scraibe works with torch 2
This commit is contained in:
+11
-3
@@ -34,6 +34,8 @@ from typing import TypeVar, Union
|
|||||||
from pyannote.audio import Pipeline
|
from pyannote.audio import Pipeline
|
||||||
from pyannote.audio.pipelines.speaker_diarization import SpeakerDiarization
|
from pyannote.audio.pipelines.speaker_diarization import SpeakerDiarization
|
||||||
from torch import Tensor
|
from torch import Tensor
|
||||||
|
from torch import device as torch_device
|
||||||
|
from torch.cuda import is_available, current_device
|
||||||
|
|
||||||
from .misc import PYANNOTE_DEFAULT_PATH, PYANNOTE_DEFAULT_CONFIG
|
from .misc import PYANNOTE_DEFAULT_PATH, PYANNOTE_DEFAULT_CONFIG
|
||||||
Annotation = TypeVar('Annotation')
|
Annotation = TypeVar('Annotation')
|
||||||
@@ -184,6 +186,7 @@ class Diariser:
|
|||||||
cache_token: bool = True,
|
cache_token: bool = True,
|
||||||
cache_dir: Union[Path, str] = PYANNOTE_DEFAULT_PATH,
|
cache_dir: Union[Path, str] = PYANNOTE_DEFAULT_PATH,
|
||||||
hparams_file: Union[str, Path] = None,
|
hparams_file: Union[str, Path] = None,
|
||||||
|
device: str = None,
|
||||||
*args, **kwargs
|
*args, **kwargs
|
||||||
) -> Pipeline:
|
) -> Pipeline:
|
||||||
|
|
||||||
@@ -198,6 +201,7 @@ class Diariser:
|
|||||||
cache_token: Whether to cache the token locally for future use.
|
cache_token: Whether to cache the token locally for future use.
|
||||||
cache_dir: Directory for caching models.
|
cache_dir: Directory for caching models.
|
||||||
hparams_file: Path to a YAML file containing hyperparameters.
|
hparams_file: Path to a YAML file containing hyperparameters.
|
||||||
|
device: Device to load the model on.
|
||||||
args: Additional arguments only to avoid errors.
|
args: Additional arguments only to avoid errors.
|
||||||
kwargs: Additional keyword arguments only to avoid errors.
|
kwargs: Additional keyword arguments only to avoid errors.
|
||||||
|
|
||||||
@@ -205,20 +209,24 @@ class Diariser:
|
|||||||
Pipeline: A pyannote.audio Pipeline object, encapsulating the loaded model.
|
Pipeline: A pyannote.audio Pipeline object, encapsulating the loaded model.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
|
|
||||||
if cache_token and use_auth_token is not None:
|
if cache_token and use_auth_token is not None:
|
||||||
cls._save_token(use_auth_token)
|
cls._save_token(use_auth_token)
|
||||||
|
|
||||||
if not os.path.exists(model) and use_auth_token is None:
|
if not os.path.exists(model) and use_auth_token is None:
|
||||||
use_auth_token = cls._get_token()
|
use_auth_token = cls._get_token()
|
||||||
model = 'pyannote/speaker-diarization'
|
|
||||||
elif not os.path.exists(model) and use_auth_token is not None:
|
|
||||||
model = 'pyannote/speaker-diarization'
|
|
||||||
|
|
||||||
_model = Pipeline.from_pretrained(model,
|
_model = Pipeline.from_pretrained(model,
|
||||||
use_auth_token = use_auth_token,
|
use_auth_token = use_auth_token,
|
||||||
cache_dir = cache_dir,
|
cache_dir = cache_dir,
|
||||||
hparams_file = hparams_file,)
|
hparams_file = hparams_file,)
|
||||||
|
|
||||||
|
# try to move the model to the device
|
||||||
|
if device is None:
|
||||||
|
device = "cuda" if is_available() else "cpu"
|
||||||
|
|
||||||
|
_model = _model.to(torch_device(device)) # torch_device is renamed from torch.device to avoid name conflict
|
||||||
|
|
||||||
if _model is None:
|
if _model is None:
|
||||||
raise ValueError('Unable to load model either from local cache' \
|
raise ValueError('Unable to load model either from local cache' \
|
||||||
'or from huggingface.co models. Please check your token' \
|
'or from huggingface.co models. Please check your token' \
|
||||||
|
|||||||
+3
-1
@@ -12,7 +12,9 @@ if CACHE_DIR != PYANNOTE_CACHE_DIR:
|
|||||||
|
|
||||||
WHISPER_DEFAULT_PATH = os.path.join(CACHE_DIR, "whisper")
|
WHISPER_DEFAULT_PATH = os.path.join(CACHE_DIR, "whisper")
|
||||||
PYANNOTE_DEFAULT_PATH = os.path.join(CACHE_DIR, "pyannote")
|
PYANNOTE_DEFAULT_PATH = os.path.join(CACHE_DIR, "pyannote")
|
||||||
PYANNOTE_DEFAULT_CONFIG = os.path.join(PYANNOTE_DEFAULT_PATH, "config.yaml")
|
PYANNOTE_DEFAULT_CONFIG = os.path.join(PYANNOTE_DEFAULT_PATH, "config.yaml") \
|
||||||
|
if os.path.exists(os.path.join(PYANNOTE_DEFAULT_PATH, "config.yaml")) \
|
||||||
|
else 'pyannote/speaker-diarization-3.1'
|
||||||
|
|
||||||
def config_diarization_yaml(file_path: str, path_to_segmentation: str = None) -> None:
|
def config_diarization_yaml(file_path: str, path_to_segmentation: str = None) -> None:
|
||||||
"""Configure diarization pipeline from a YAML file.
|
"""Configure diarization pipeline from a YAML file.
|
||||||
|
|||||||
Reference in New Issue
Block a user