Adjusted compute type for tests without gpu.
This commit is contained in:
@@ -32,8 +32,9 @@ from typing import TypeVar, Union, Optional
|
||||
from torch import Tensor, device
|
||||
from torch.cuda import is_available as cuda_is_available
|
||||
from numpy import ndarray
|
||||
from inspect import getfullargspec
|
||||
from inspect import signature
|
||||
from abc import abstractmethod
|
||||
import warnings
|
||||
|
||||
from .misc import WHISPER_DEFAULT_PATH
|
||||
whisper = TypeVar('whisper')
|
||||
@@ -254,9 +255,7 @@ class WhisperTranscriber(Transcriber):
|
||||
dict: Keyword arguments for whisper model.
|
||||
"""
|
||||
# _possible_kwargs = WhisperModel.transcribe.__code__.co_varnames
|
||||
_args = getfullargspec(Whisper.transcribe).args
|
||||
_kwargs = getfullargspec(Whisper.transcribe).kwonlyargs
|
||||
_possible_kwargs = _args + _kwargs
|
||||
_possible_kwargs = signature(Whisper.transcribe).parameters.keys()
|
||||
|
||||
whisper_kwargs = {k: v for k,
|
||||
v in kwargs.items() if k in _possible_kwargs}
|
||||
@@ -343,6 +342,11 @@ class WhisperXTranscriber(Transcriber):
|
||||
device = "cuda" if cuda_is_available() else "cpu"
|
||||
if not isinstance(device, str):
|
||||
device = str(device)
|
||||
compute_type = kwargs.get('compute_type', 'float16')
|
||||
if device == 'cpu' and compute_type == 'float16':
|
||||
warnings.warn(f'Compute type {compute_type} not compatible with '
|
||||
f'device {device}! Changing compute type to int8.')
|
||||
compute_type = 'int8'
|
||||
_model = whisperx_load_model(model, download_root=download_root,
|
||||
device=device)
|
||||
|
||||
@@ -357,9 +361,7 @@ class WhisperXTranscriber(Transcriber):
|
||||
dict: Keyword arguments for whisper model.
|
||||
"""
|
||||
# _possible_kwargs = WhisperModel.transcribe.__code__.co_varnames
|
||||
_args = getfullargspec(WhisperModel.transcribe).args
|
||||
_kwargs = getfullargspec(WhisperModel.transcribe).kwonlyargs
|
||||
_possible_kwargs = _args + _kwargs
|
||||
_possible_kwargs = signature(WhisperModel.transcribe).parameters.keys()
|
||||
|
||||
whisper_kwargs = {k: v for k,
|
||||
v in kwargs.items() if k in _possible_kwargs}
|
||||
|
||||
Reference in New Issue
Block a user