added SCRAIBE_TORCH_DEVICE to Diariser class
This commit is contained in:
@@ -41,7 +41,7 @@ from torch.cuda import is_available
|
||||
from huggingface_hub import HfApi
|
||||
from huggingface_hub.utils import RepositoryNotFoundError
|
||||
|
||||
from .misc import PYANNOTE_DEFAULT_PATH, PYANNOTE_DEFAULT_CONFIG
|
||||
from .misc import PYANNOTE_DEFAULT_PATH, PYANNOTE_DEFAULT_CONFIG, SCRAIBE_TORCH_DEVICE
|
||||
Annotation = TypeVar('Annotation')
|
||||
|
||||
TOKEN_PATH = os.path.join(os.path.dirname(
|
||||
@@ -190,8 +190,7 @@ class Diariser:
|
||||
cache_token: bool = False,
|
||||
cache_dir: Union[Path, str] = PYANNOTE_DEFAULT_PATH,
|
||||
hparams_file: Union[str, Path] = None,
|
||||
device: str = None,
|
||||
*args, **kwargs
|
||||
device: str = SCRAIBE_TORCH_DEVICE,
|
||||
) -> Pipeline:
|
||||
"""
|
||||
Loads a pretrained model from pyannote.audio,
|
||||
@@ -283,10 +282,6 @@ class Diariser:
|
||||
'or from huggingface.co models. Please check your token'
|
||||
'or your local model path')
|
||||
|
||||
# try to move the model to the device
|
||||
if device is None:
|
||||
device = "cuda" if is_available() else "cpu"
|
||||
|
||||
# torch_device is renamed from torch.device to avoid name conflict
|
||||
_model = _model.to(torch_device(device))
|
||||
|
||||
|
||||
Reference in New Issue
Block a user