Adjusted compute type for tests without gpu.

This commit is contained in:
Marko Henning
2024-05-17 11:23:49 +02:00
parent 07b9939446
commit b6bed3ebd8
+9 -7
View File
@@ -32,8 +32,9 @@ from typing import TypeVar, Union, Optional
from torch import Tensor, device from torch import Tensor, device
from torch.cuda import is_available as cuda_is_available from torch.cuda import is_available as cuda_is_available
from numpy import ndarray from numpy import ndarray
from inspect import getfullargspec from inspect import signature
from abc import abstractmethod from abc import abstractmethod
import warnings
from .misc import WHISPER_DEFAULT_PATH from .misc import WHISPER_DEFAULT_PATH
whisper = TypeVar('whisper') whisper = TypeVar('whisper')
@@ -254,9 +255,7 @@ class WhisperTranscriber(Transcriber):
dict: Keyword arguments for whisper model. dict: Keyword arguments for whisper model.
""" """
# _possible_kwargs = WhisperModel.transcribe.__code__.co_varnames # _possible_kwargs = WhisperModel.transcribe.__code__.co_varnames
_args = getfullargspec(Whisper.transcribe).args _possible_kwargs = signature(Whisper.transcribe).parameters.keys()
_kwargs = getfullargspec(Whisper.transcribe).kwonlyargs
_possible_kwargs = _args + _kwargs
whisper_kwargs = {k: v for k, whisper_kwargs = {k: v for k,
v in kwargs.items() if k in _possible_kwargs} v in kwargs.items() if k in _possible_kwargs}
@@ -343,6 +342,11 @@ class WhisperXTranscriber(Transcriber):
device = "cuda" if cuda_is_available() else "cpu" device = "cuda" if cuda_is_available() else "cpu"
if not isinstance(device, str): if not isinstance(device, str):
device = str(device) device = str(device)
compute_type = kwargs.get('compute_type', 'float16')
if device == 'cpu' and compute_type == 'float16':
warnings.warn(f'Compute type {compute_type} not compatible with '
f'device {device}! Changing compute type to int8.')
compute_type = 'int8'
_model = whisperx_load_model(model, download_root=download_root, _model = whisperx_load_model(model, download_root=download_root,
device=device) device=device)
@@ -357,9 +361,7 @@ class WhisperXTranscriber(Transcriber):
dict: Keyword arguments for whisper model. dict: Keyword arguments for whisper model.
""" """
# _possible_kwargs = WhisperModel.transcribe.__code__.co_varnames # _possible_kwargs = WhisperModel.transcribe.__code__.co_varnames
_args = getfullargspec(WhisperModel.transcribe).args _possible_kwargs = signature(WhisperModel.transcribe).parameters.keys()
_kwargs = getfullargspec(WhisperModel.transcribe).kwonlyargs
_possible_kwargs = _args + _kwargs
whisper_kwargs = {k: v for k, whisper_kwargs = {k: v for k,
v in kwargs.items() if k in _possible_kwargs} v in kwargs.items() if k in _possible_kwargs}