diff --git a/scraibe/diarisation.py b/scraibe/diarisation.py index d70df99..6e6d6b9 100644 --- a/scraibe/diarisation.py +++ b/scraibe/diarisation.py @@ -41,7 +41,7 @@ from torch.cuda import is_available from huggingface_hub import HfApi from huggingface_hub.utils import RepositoryNotFoundError -from .misc import PYANNOTE_DEFAULT_PATH, PYANNOTE_DEFAULT_CONFIG +from .misc import PYANNOTE_DEFAULT_PATH, PYANNOTE_DEFAULT_CONFIG, SCRAIBE_TORCH_DEVICE Annotation = TypeVar('Annotation') TOKEN_PATH = os.path.join(os.path.dirname( @@ -190,8 +190,7 @@ class Diariser: cache_token: bool = False, cache_dir: Union[Path, str] = PYANNOTE_DEFAULT_PATH, hparams_file: Union[str, Path] = None, - device: str = None, - *args, **kwargs + device: str = SCRAIBE_TORCH_DEVICE, ) -> Pipeline: """ Loads a pretrained model from pyannote.audio, @@ -283,10 +282,6 @@ class Diariser: 'or from huggingface.co models. Please check your token' 'or your local model path') - # try to move the model to the device - if device is None: - device = "cuda" if is_available() else "cpu" - # torch_device is renamed from torch.device to avoid name conflict _model = _model.to(torch_device(device))