diff --git a/scraibe/transcriber.py b/scraibe/transcriber.py index c9b9b52..526ff8b 100644 --- a/scraibe/transcriber.py +++ b/scraibe/transcriber.py @@ -32,8 +32,9 @@ from typing import TypeVar, Union, Optional from torch import Tensor, device from torch.cuda import is_available as cuda_is_available from numpy import ndarray -from inspect import getfullargspec +from inspect import signature from abc import abstractmethod +import warnings from .misc import WHISPER_DEFAULT_PATH whisper = TypeVar('whisper') @@ -254,9 +255,7 @@ class WhisperTranscriber(Transcriber): dict: Keyword arguments for whisper model. """ # _possible_kwargs = WhisperModel.transcribe.__code__.co_varnames - _args = getfullargspec(Whisper.transcribe).args - _kwargs = getfullargspec(Whisper.transcribe).kwonlyargs - _possible_kwargs = _args + _kwargs + _possible_kwargs = signature(Whisper.transcribe).parameters.keys() whisper_kwargs = {k: v for k, v in kwargs.items() if k in _possible_kwargs} @@ -343,6 +342,11 @@ class WhisperXTranscriber(Transcriber): device = "cuda" if cuda_is_available() else "cpu" if not isinstance(device, str): device = str(device) + compute_type = kwargs.get('compute_type', 'float16') + if device == 'cpu' and compute_type == 'float16': + warnings.warn(f'Compute type {compute_type} not compatible with ' + f'device {device}! Changing compute type to int8.') + compute_type = 'int8' _model = whisperx_load_model(model, download_root=download_root, device=device) @@ -357,9 +361,7 @@ class WhisperXTranscriber(Transcriber): dict: Keyword arguments for whisper model. """ # _possible_kwargs = WhisperModel.transcribe.__code__.co_varnames - _args = getfullargspec(WhisperModel.transcribe).args - _kwargs = getfullargspec(WhisperModel.transcribe).kwonlyargs - _possible_kwargs = _args + _kwargs + _possible_kwargs = signature(WhisperModel.transcribe).parameters.keys() whisper_kwargs = {k: v for k, v in kwargs.items() if k in _possible_kwargs}