added SCRAIBE_TORCH_DEVICE to transcriber class
This commit is contained in:
@@ -37,7 +37,7 @@ from inspect import signature
|
||||
from abc import abstractmethod
|
||||
import warnings
|
||||
|
||||
from .misc import WHISPER_DEFAULT_PATH
|
||||
from .misc import WHISPER_DEFAULT_PATH, SCRAIBE_TORCH_DEVICE
|
||||
whisper = TypeVar('whisper')
|
||||
|
||||
|
||||
@@ -124,7 +124,7 @@ class Transcriber:
|
||||
model: str = "medium",
|
||||
whisper_type: str = 'whisper',
|
||||
download_root: str = WHISPER_DEFAULT_PATH,
|
||||
device: Optional[Union[str, device]] = None,
|
||||
device: Optional[Union[str, device]] = SCRAIBE_TORCH_DEVICE,
|
||||
in_memory: bool = False,
|
||||
*args, **kwargs
|
||||
) -> None:
|
||||
@@ -206,7 +206,7 @@ class WhisperTranscriber(Transcriber):
|
||||
def load_model(cls,
|
||||
model: str = "medium",
|
||||
download_root: str = WHISPER_DEFAULT_PATH,
|
||||
device: Optional[Union[str, device]] = None,
|
||||
device: Optional[Union[str, device]] = SCRAIBE_TORCH_DEVICE,
|
||||
in_memory: bool = False,
|
||||
*args, **kwargs
|
||||
) -> 'WhisperTranscriber':
|
||||
@@ -305,7 +305,7 @@ class FasterWhisperTranscriber(Transcriber):
|
||||
def load_model(cls,
|
||||
model: str = "medium",
|
||||
download_root: str = WHISPER_DEFAULT_PATH,
|
||||
device: Optional[Union[str, device]] = None,
|
||||
device: Optional[Union[str, device]] = SCRAIBE_TORCH_DEVICE,
|
||||
*args, **kwargs
|
||||
) -> 'FasterWhisperModel':
|
||||
"""
|
||||
@@ -330,7 +330,7 @@ class FasterWhisperTranscriber(Transcriber):
|
||||
Defaults to WHISPER_DEFAULT_PATH.
|
||||
|
||||
device (Optional[Union[str, torch.device]], optional):
|
||||
Device to load model on. Defaults to None.
|
||||
Device to load model on. Defaults to SCRAIBE_TORCH_DEVICE.
|
||||
in_memory (bool, optional): Whether to load model in memory.
|
||||
Defaults to False.
|
||||
args: Additional arguments only to avoid errors.
|
||||
@@ -339,10 +339,10 @@ class FasterWhisperTranscriber(Transcriber):
|
||||
Returns:
|
||||
Transcriber: A Transcriber object initialized with the specified model.
|
||||
"""
|
||||
if device is None:
|
||||
device = "cuda" if cuda_is_available() else "cpu"
|
||||
|
||||
if not isinstance(device, str):
|
||||
device = str(device)
|
||||
|
||||
compute_type = kwargs.get('compute_type', 'float16')
|
||||
if device == 'cpu' and compute_type == 'float16':
|
||||
warnings.warn(f'Compute type {compute_type} not compatible with '
|
||||
@@ -412,7 +412,7 @@ class FasterWhisperTranscriber(Transcriber):
|
||||
def load_transcriber(model: str = "medium",
|
||||
whisper_type: str = 'whisper',
|
||||
download_root: str = WHISPER_DEFAULT_PATH,
|
||||
device: Optional[Union[str, device]] = None,
|
||||
device: Optional[Union[str, device]] = SCRAIBE_TORCH_DEVICE,
|
||||
in_memory: bool = False,
|
||||
*args, **kwargs
|
||||
) -> Union[WhisperTranscriber, FasterWhisperTranscriber]:
|
||||
@@ -438,7 +438,7 @@ def load_transcriber(model: str = "medium",
|
||||
download_root (str, optional): Path to download the model.
|
||||
Defaults to WHISPER_DEFAULT_PATH.
|
||||
device (Optional[Union[str, torch.device]], optional):
|
||||
Device to load model on. Defaults to None.
|
||||
Device to load model on. Defaults to SCRAIBE_TORCH_DEVICE.
|
||||
in_memory (bool, optional): Whether to load model in memory.
|
||||
Defaults to False.
|
||||
args: Additional arguments only to avoid errors.
|
||||
|
||||
Reference in New Issue
Block a user