added SCRAIBE_TORCH_DEVICE to Diariser class

This commit is contained in:
Schmieder, Jacob
2024-10-10 09:22:34 +00:00
parent 44ff678e06
commit af99a655e5
+2 -7
View File
@@ -41,7 +41,7 @@ from torch.cuda import is_available
from huggingface_hub import HfApi from huggingface_hub import HfApi
from huggingface_hub.utils import RepositoryNotFoundError from huggingface_hub.utils import RepositoryNotFoundError
from .misc import PYANNOTE_DEFAULT_PATH, PYANNOTE_DEFAULT_CONFIG from .misc import PYANNOTE_DEFAULT_PATH, PYANNOTE_DEFAULT_CONFIG, SCRAIBE_TORCH_DEVICE
Annotation = TypeVar('Annotation') Annotation = TypeVar('Annotation')
TOKEN_PATH = os.path.join(os.path.dirname( TOKEN_PATH = os.path.join(os.path.dirname(
@@ -190,8 +190,7 @@ class Diariser:
cache_token: bool = False, cache_token: bool = False,
cache_dir: Union[Path, str] = PYANNOTE_DEFAULT_PATH, cache_dir: Union[Path, str] = PYANNOTE_DEFAULT_PATH,
hparams_file: Union[str, Path] = None, hparams_file: Union[str, Path] = None,
device: str = None, device: str = SCRAIBE_TORCH_DEVICE,
*args, **kwargs
) -> Pipeline: ) -> Pipeline:
""" """
Loads a pretrained model from pyannote.audio, Loads a pretrained model from pyannote.audio,
@@ -283,10 +282,6 @@ class Diariser:
'or from huggingface.co models. Please check your token' 'or from huggingface.co models. Please check your token'
'or your local model path') 'or your local model path')
# try to move the model to the device
if device is None:
device = "cuda" if is_available() else "cpu"
# torch_device is renamed from torch.device to avoid name conflict # torch_device is renamed from torch.device to avoid name conflict
_model = _model.to(torch_device(device)) _model = _model.to(torch_device(device))