Adjusted compute type for tests without gpu.

This commit is contained in:
Marko Henning
2024-05-17 11:23:49 +02:00
parent 07b9939446
commit b6bed3ebd8
+9 -7
View File
@@ -32,8 +32,9 @@ from typing import TypeVar, Union, Optional
from torch import Tensor, device
from torch.cuda import is_available as cuda_is_available
from numpy import ndarray
from inspect import getfullargspec
from inspect import signature
from abc import abstractmethod
import warnings
from .misc import WHISPER_DEFAULT_PATH
whisper = TypeVar('whisper')
@@ -254,9 +255,7 @@ class WhisperTranscriber(Transcriber):
dict: Keyword arguments for whisper model.
"""
# _possible_kwargs = WhisperModel.transcribe.__code__.co_varnames
_args = getfullargspec(Whisper.transcribe).args
_kwargs = getfullargspec(Whisper.transcribe).kwonlyargs
_possible_kwargs = _args + _kwargs
_possible_kwargs = signature(Whisper.transcribe).parameters.keys()
whisper_kwargs = {k: v for k,
v in kwargs.items() if k in _possible_kwargs}
@@ -343,6 +342,11 @@ class WhisperXTranscriber(Transcriber):
device = "cuda" if cuda_is_available() else "cpu"
if not isinstance(device, str):
device = str(device)
compute_type = kwargs.get('compute_type', 'float16')
if device == 'cpu' and compute_type == 'float16':
warnings.warn(f'Compute type {compute_type} not compatible with '
f'device {device}! Changing compute type to int8.')
compute_type = 'int8'
_model = whisperx_load_model(model, download_root=download_root,
device=device)
@@ -357,9 +361,7 @@ class WhisperXTranscriber(Transcriber):
dict: Keyword arguments for whisper model.
"""
# _possible_kwargs = WhisperModel.transcribe.__code__.co_varnames
_args = getfullargspec(WhisperModel.transcribe).args
_kwargs = getfullargspec(WhisperModel.transcribe).kwonlyargs
_possible_kwargs = _args + _kwargs
_possible_kwargs = signature(WhisperModel.transcribe).parameters.keys()
whisper_kwargs = {k: v for k,
v in kwargs.items() if k in _possible_kwargs}