added SCRAIBE_TORCH_DEVICE to transcriber class

This commit is contained in:
Schmieder, Jacob
2024-10-10 09:25:49 +00:00
parent af99a655e5
commit e7c1a5a2b0
+9 -9
View File
@@ -37,7 +37,7 @@ from inspect import signature
from abc import abstractmethod
import warnings
from .misc import WHISPER_DEFAULT_PATH
from .misc import WHISPER_DEFAULT_PATH, SCRAIBE_TORCH_DEVICE
whisper = TypeVar('whisper')
@@ -124,7 +124,7 @@ class Transcriber:
model: str = "medium",
whisper_type: str = 'whisper',
download_root: str = WHISPER_DEFAULT_PATH,
device: Optional[Union[str, device]] = None,
device: Optional[Union[str, device]] = SCRAIBE_TORCH_DEVICE,
in_memory: bool = False,
*args, **kwargs
) -> None:
@@ -206,7 +206,7 @@ class WhisperTranscriber(Transcriber):
def load_model(cls,
model: str = "medium",
download_root: str = WHISPER_DEFAULT_PATH,
device: Optional[Union[str, device]] = None,
device: Optional[Union[str, device]] = SCRAIBE_TORCH_DEVICE,
in_memory: bool = False,
*args, **kwargs
) -> 'WhisperTranscriber':
@@ -305,7 +305,7 @@ class FasterWhisperTranscriber(Transcriber):
def load_model(cls,
model: str = "medium",
download_root: str = WHISPER_DEFAULT_PATH,
device: Optional[Union[str, device]] = None,
device: Optional[Union[str, device]] = SCRAIBE_TORCH_DEVICE,
*args, **kwargs
) -> 'FasterWhisperModel':
"""
@@ -330,7 +330,7 @@ class FasterWhisperTranscriber(Transcriber):
Defaults to WHISPER_DEFAULT_PATH.
device (Optional[Union[str, torch.device]], optional):
Device to load model on. Defaults to None.
Device to load model on. Defaults to SCRAIBE_TORCH_DEVICE.
in_memory (bool, optional): Whether to load model in memory.
Defaults to False.
args: Additional arguments only to avoid errors.
@@ -339,10 +339,10 @@ class FasterWhisperTranscriber(Transcriber):
Returns:
Transcriber: A Transcriber object initialized with the specified model.
"""
if device is None:
device = "cuda" if cuda_is_available() else "cpu"
if not isinstance(device, str):
device = str(device)
compute_type = kwargs.get('compute_type', 'float16')
if device == 'cpu' and compute_type == 'float16':
warnings.warn(f'Compute type {compute_type} not compatible with '
@@ -412,7 +412,7 @@ class FasterWhisperTranscriber(Transcriber):
def load_transcriber(model: str = "medium",
whisper_type: str = 'whisper',
download_root: str = WHISPER_DEFAULT_PATH,
device: Optional[Union[str, device]] = None,
device: Optional[Union[str, device]] = SCRAIBE_TORCH_DEVICE,
in_memory: bool = False,
*args, **kwargs
) -> Union[WhisperTranscriber, FasterWhisperTranscriber]:
@@ -438,7 +438,7 @@ def load_transcriber(model: str = "medium",
download_root (str, optional): Path to download the model.
Defaults to WHISPER_DEFAULT_PATH.
device (Optional[Union[str, torch.device]], optional):
Device to load model on. Defaults to None.
Device to load model on. Defaults to SCRAIBE_TORCH_DEVICE.
in_memory (bool, optional): Whether to load model in memory.
Defaults to False.
args: Additional arguments only to avoid errors.