added SCRAIBE_TORCH_DEVICE to transcriber class
This commit is contained in:
@@ -37,7 +37,7 @@ from inspect import signature
|
|||||||
from abc import abstractmethod
|
from abc import abstractmethod
|
||||||
import warnings
|
import warnings
|
||||||
|
|
||||||
from .misc import WHISPER_DEFAULT_PATH
|
from .misc import WHISPER_DEFAULT_PATH, SCRAIBE_TORCH_DEVICE
|
||||||
whisper = TypeVar('whisper')
|
whisper = TypeVar('whisper')
|
||||||
|
|
||||||
|
|
||||||
@@ -124,7 +124,7 @@ class Transcriber:
|
|||||||
model: str = "medium",
|
model: str = "medium",
|
||||||
whisper_type: str = 'whisper',
|
whisper_type: str = 'whisper',
|
||||||
download_root: str = WHISPER_DEFAULT_PATH,
|
download_root: str = WHISPER_DEFAULT_PATH,
|
||||||
device: Optional[Union[str, device]] = None,
|
device: Optional[Union[str, device]] = SCRAIBE_TORCH_DEVICE,
|
||||||
in_memory: bool = False,
|
in_memory: bool = False,
|
||||||
*args, **kwargs
|
*args, **kwargs
|
||||||
) -> None:
|
) -> None:
|
||||||
@@ -206,7 +206,7 @@ class WhisperTranscriber(Transcriber):
|
|||||||
def load_model(cls,
|
def load_model(cls,
|
||||||
model: str = "medium",
|
model: str = "medium",
|
||||||
download_root: str = WHISPER_DEFAULT_PATH,
|
download_root: str = WHISPER_DEFAULT_PATH,
|
||||||
device: Optional[Union[str, device]] = None,
|
device: Optional[Union[str, device]] = SCRAIBE_TORCH_DEVICE,
|
||||||
in_memory: bool = False,
|
in_memory: bool = False,
|
||||||
*args, **kwargs
|
*args, **kwargs
|
||||||
) -> 'WhisperTranscriber':
|
) -> 'WhisperTranscriber':
|
||||||
@@ -305,7 +305,7 @@ class FasterWhisperTranscriber(Transcriber):
|
|||||||
def load_model(cls,
|
def load_model(cls,
|
||||||
model: str = "medium",
|
model: str = "medium",
|
||||||
download_root: str = WHISPER_DEFAULT_PATH,
|
download_root: str = WHISPER_DEFAULT_PATH,
|
||||||
device: Optional[Union[str, device]] = None,
|
device: Optional[Union[str, device]] = SCRAIBE_TORCH_DEVICE,
|
||||||
*args, **kwargs
|
*args, **kwargs
|
||||||
) -> 'FasterWhisperModel':
|
) -> 'FasterWhisperModel':
|
||||||
"""
|
"""
|
||||||
@@ -330,7 +330,7 @@ class FasterWhisperTranscriber(Transcriber):
|
|||||||
Defaults to WHISPER_DEFAULT_PATH.
|
Defaults to WHISPER_DEFAULT_PATH.
|
||||||
|
|
||||||
device (Optional[Union[str, torch.device]], optional):
|
device (Optional[Union[str, torch.device]], optional):
|
||||||
Device to load model on. Defaults to None.
|
Device to load model on. Defaults to SCRAIBE_TORCH_DEVICE.
|
||||||
in_memory (bool, optional): Whether to load model in memory.
|
in_memory (bool, optional): Whether to load model in memory.
|
||||||
Defaults to False.
|
Defaults to False.
|
||||||
args: Additional arguments only to avoid errors.
|
args: Additional arguments only to avoid errors.
|
||||||
@@ -339,10 +339,10 @@ class FasterWhisperTranscriber(Transcriber):
|
|||||||
Returns:
|
Returns:
|
||||||
Transcriber: A Transcriber object initialized with the specified model.
|
Transcriber: A Transcriber object initialized with the specified model.
|
||||||
"""
|
"""
|
||||||
if device is None:
|
|
||||||
device = "cuda" if cuda_is_available() else "cpu"
|
|
||||||
if not isinstance(device, str):
|
if not isinstance(device, str):
|
||||||
device = str(device)
|
device = str(device)
|
||||||
|
|
||||||
compute_type = kwargs.get('compute_type', 'float16')
|
compute_type = kwargs.get('compute_type', 'float16')
|
||||||
if device == 'cpu' and compute_type == 'float16':
|
if device == 'cpu' and compute_type == 'float16':
|
||||||
warnings.warn(f'Compute type {compute_type} not compatible with '
|
warnings.warn(f'Compute type {compute_type} not compatible with '
|
||||||
@@ -412,7 +412,7 @@ class FasterWhisperTranscriber(Transcriber):
|
|||||||
def load_transcriber(model: str = "medium",
|
def load_transcriber(model: str = "medium",
|
||||||
whisper_type: str = 'whisper',
|
whisper_type: str = 'whisper',
|
||||||
download_root: str = WHISPER_DEFAULT_PATH,
|
download_root: str = WHISPER_DEFAULT_PATH,
|
||||||
device: Optional[Union[str, device]] = None,
|
device: Optional[Union[str, device]] = SCRAIBE_TORCH_DEVICE,
|
||||||
in_memory: bool = False,
|
in_memory: bool = False,
|
||||||
*args, **kwargs
|
*args, **kwargs
|
||||||
) -> Union[WhisperTranscriber, FasterWhisperTranscriber]:
|
) -> Union[WhisperTranscriber, FasterWhisperTranscriber]:
|
||||||
@@ -438,7 +438,7 @@ def load_transcriber(model: str = "medium",
|
|||||||
download_root (str, optional): Path to download the model.
|
download_root (str, optional): Path to download the model.
|
||||||
Defaults to WHISPER_DEFAULT_PATH.
|
Defaults to WHISPER_DEFAULT_PATH.
|
||||||
device (Optional[Union[str, torch.device]], optional):
|
device (Optional[Union[str, torch.device]], optional):
|
||||||
Device to load model on. Defaults to None.
|
Device to load model on. Defaults to SCRAIBE_TORCH_DEVICE.
|
||||||
in_memory (bool, optional): Whether to load model in memory.
|
in_memory (bool, optional): Whether to load model in memory.
|
||||||
Defaults to False.
|
Defaults to False.
|
||||||
args: Additional arguments only to avoid errors.
|
args: Additional arguments only to avoid errors.
|
||||||
|
|||||||
Reference in New Issue
Block a user