added SCRAIBE_TORCH_DEVICE to Diariser class
This commit is contained in:
@@ -41,7 +41,7 @@ from torch.cuda import is_available
|
|||||||
from huggingface_hub import HfApi
|
from huggingface_hub import HfApi
|
||||||
from huggingface_hub.utils import RepositoryNotFoundError
|
from huggingface_hub.utils import RepositoryNotFoundError
|
||||||
|
|
||||||
from .misc import PYANNOTE_DEFAULT_PATH, PYANNOTE_DEFAULT_CONFIG
|
from .misc import PYANNOTE_DEFAULT_PATH, PYANNOTE_DEFAULT_CONFIG, SCRAIBE_TORCH_DEVICE
|
||||||
Annotation = TypeVar('Annotation')
|
Annotation = TypeVar('Annotation')
|
||||||
|
|
||||||
TOKEN_PATH = os.path.join(os.path.dirname(
|
TOKEN_PATH = os.path.join(os.path.dirname(
|
||||||
@@ -190,8 +190,7 @@ class Diariser:
|
|||||||
cache_token: bool = False,
|
cache_token: bool = False,
|
||||||
cache_dir: Union[Path, str] = PYANNOTE_DEFAULT_PATH,
|
cache_dir: Union[Path, str] = PYANNOTE_DEFAULT_PATH,
|
||||||
hparams_file: Union[str, Path] = None,
|
hparams_file: Union[str, Path] = None,
|
||||||
device: str = None,
|
device: str = SCRAIBE_TORCH_DEVICE,
|
||||||
*args, **kwargs
|
|
||||||
) -> Pipeline:
|
) -> Pipeline:
|
||||||
"""
|
"""
|
||||||
Loads a pretrained model from pyannote.audio,
|
Loads a pretrained model from pyannote.audio,
|
||||||
@@ -283,10 +282,6 @@ class Diariser:
|
|||||||
'or from huggingface.co models. Please check your token'
|
'or from huggingface.co models. Please check your token'
|
||||||
'or your local model path')
|
'or your local model path')
|
||||||
|
|
||||||
# try to move the model to the device
|
|
||||||
if device is None:
|
|
||||||
device = "cuda" if is_available() else "cpu"
|
|
||||||
|
|
||||||
# torch_device is renamed from torch.device to avoid name conflict
|
# torch_device is renamed from torch.device to avoid name conflict
|
||||||
_model = _model.to(torch_device(device))
|
_model = _model.to(torch_device(device))
|
||||||
|
|
||||||
|
|||||||
Reference in New Issue
Block a user