From e7c1a5a2b01263acddb80c781d1c26292fc6210a Mon Sep 17 00:00:00 2001 From: "Schmieder, Jacob" Date: Thu, 10 Oct 2024 09:25:49 +0000 Subject: [PATCH] added SCRAIBE_TORCH_DEVICE to transcriber class --- scraibe/transcriber.py | 18 +++++++++--------- 1 file changed, 9 insertions(+), 9 deletions(-) diff --git a/scraibe/transcriber.py b/scraibe/transcriber.py index abf1ace..9c891f6 100644 --- a/scraibe/transcriber.py +++ b/scraibe/transcriber.py @@ -37,7 +37,7 @@ from inspect import signature from abc import abstractmethod import warnings -from .misc import WHISPER_DEFAULT_PATH +from .misc import WHISPER_DEFAULT_PATH, SCRAIBE_TORCH_DEVICE whisper = TypeVar('whisper') @@ -124,7 +124,7 @@ class Transcriber: model: str = "medium", whisper_type: str = 'whisper', download_root: str = WHISPER_DEFAULT_PATH, - device: Optional[Union[str, device]] = None, + device: Optional[Union[str, device]] = SCRAIBE_TORCH_DEVICE, in_memory: bool = False, *args, **kwargs ) -> None: @@ -206,7 +206,7 @@ class WhisperTranscriber(Transcriber): def load_model(cls, model: str = "medium", download_root: str = WHISPER_DEFAULT_PATH, - device: Optional[Union[str, device]] = None, + device: Optional[Union[str, device]] = SCRAIBE_TORCH_DEVICE, in_memory: bool = False, *args, **kwargs ) -> 'WhisperTranscriber': @@ -305,7 +305,7 @@ class FasterWhisperTranscriber(Transcriber): def load_model(cls, model: str = "medium", download_root: str = WHISPER_DEFAULT_PATH, - device: Optional[Union[str, device]] = None, + device: Optional[Union[str, device]] = SCRAIBE_TORCH_DEVICE, *args, **kwargs ) -> 'FasterWhisperModel': """ @@ -330,7 +330,7 @@ class FasterWhisperTranscriber(Transcriber): Defaults to WHISPER_DEFAULT_PATH. device (Optional[Union[str, torch.device]], optional): - Device to load model on. Defaults to None. + Device to load model on. Defaults to SCRAIBE_TORCH_DEVICE. in_memory (bool, optional): Whether to load model in memory. Defaults to False. args: Additional arguments only to avoid errors. @@ -339,10 +339,10 @@ class FasterWhisperTranscriber(Transcriber): Returns: Transcriber: A Transcriber object initialized with the specified model. """ - if device is None: - device = "cuda" if cuda_is_available() else "cpu" + if not isinstance(device, str): device = str(device) + compute_type = kwargs.get('compute_type', 'float16') if device == 'cpu' and compute_type == 'float16': warnings.warn(f'Compute type {compute_type} not compatible with ' @@ -412,7 +412,7 @@ class FasterWhisperTranscriber(Transcriber): def load_transcriber(model: str = "medium", whisper_type: str = 'whisper', download_root: str = WHISPER_DEFAULT_PATH, - device: Optional[Union[str, device]] = None, + device: Optional[Union[str, device]] = SCRAIBE_TORCH_DEVICE, in_memory: bool = False, *args, **kwargs ) -> Union[WhisperTranscriber, FasterWhisperTranscriber]: @@ -438,7 +438,7 @@ def load_transcriber(model: str = "medium", download_root (str, optional): Path to download the model. Defaults to WHISPER_DEFAULT_PATH. device (Optional[Union[str, torch.device]], optional): - Device to load model on. Defaults to None. + Device to load model on. Defaults to SCRAIBE_TORCH_DEVICE. in_memory (bool, optional): Whether to load model in memory. Defaults to False. args: Additional arguments only to avoid errors.