Adjusted compute type for tests without gpu.
This commit is contained in:
@@ -32,8 +32,9 @@ from typing import TypeVar, Union, Optional
|
|||||||
from torch import Tensor, device
|
from torch import Tensor, device
|
||||||
from torch.cuda import is_available as cuda_is_available
|
from torch.cuda import is_available as cuda_is_available
|
||||||
from numpy import ndarray
|
from numpy import ndarray
|
||||||
from inspect import getfullargspec
|
from inspect import signature
|
||||||
from abc import abstractmethod
|
from abc import abstractmethod
|
||||||
|
import warnings
|
||||||
|
|
||||||
from .misc import WHISPER_DEFAULT_PATH
|
from .misc import WHISPER_DEFAULT_PATH
|
||||||
whisper = TypeVar('whisper')
|
whisper = TypeVar('whisper')
|
||||||
@@ -254,9 +255,7 @@ class WhisperTranscriber(Transcriber):
|
|||||||
dict: Keyword arguments for whisper model.
|
dict: Keyword arguments for whisper model.
|
||||||
"""
|
"""
|
||||||
# _possible_kwargs = WhisperModel.transcribe.__code__.co_varnames
|
# _possible_kwargs = WhisperModel.transcribe.__code__.co_varnames
|
||||||
_args = getfullargspec(Whisper.transcribe).args
|
_possible_kwargs = signature(Whisper.transcribe).parameters.keys()
|
||||||
_kwargs = getfullargspec(Whisper.transcribe).kwonlyargs
|
|
||||||
_possible_kwargs = _args + _kwargs
|
|
||||||
|
|
||||||
whisper_kwargs = {k: v for k,
|
whisper_kwargs = {k: v for k,
|
||||||
v in kwargs.items() if k in _possible_kwargs}
|
v in kwargs.items() if k in _possible_kwargs}
|
||||||
@@ -343,6 +342,11 @@ class WhisperXTranscriber(Transcriber):
|
|||||||
device = "cuda" if cuda_is_available() else "cpu"
|
device = "cuda" if cuda_is_available() else "cpu"
|
||||||
if not isinstance(device, str):
|
if not isinstance(device, str):
|
||||||
device = str(device)
|
device = str(device)
|
||||||
|
compute_type = kwargs.get('compute_type', 'float16')
|
||||||
|
if device == 'cpu' and compute_type == 'float16':
|
||||||
|
warnings.warn(f'Compute type {compute_type} not compatible with '
|
||||||
|
f'device {device}! Changing compute type to int8.')
|
||||||
|
compute_type = 'int8'
|
||||||
_model = whisperx_load_model(model, download_root=download_root,
|
_model = whisperx_load_model(model, download_root=download_root,
|
||||||
device=device)
|
device=device)
|
||||||
|
|
||||||
@@ -357,9 +361,7 @@ class WhisperXTranscriber(Transcriber):
|
|||||||
dict: Keyword arguments for whisper model.
|
dict: Keyword arguments for whisper model.
|
||||||
"""
|
"""
|
||||||
# _possible_kwargs = WhisperModel.transcribe.__code__.co_varnames
|
# _possible_kwargs = WhisperModel.transcribe.__code__.co_varnames
|
||||||
_args = getfullargspec(WhisperModel.transcribe).args
|
_possible_kwargs = signature(WhisperModel.transcribe).parameters.keys()
|
||||||
_kwargs = getfullargspec(WhisperModel.transcribe).kwonlyargs
|
|
||||||
_possible_kwargs = _args + _kwargs
|
|
||||||
|
|
||||||
whisper_kwargs = {k: v for k,
|
whisper_kwargs = {k: v for k,
|
||||||
v in kwargs.items() if k in _possible_kwargs}
|
v in kwargs.items() if k in _possible_kwargs}
|
||||||
|
|||||||
Reference in New Issue
Block a user