From b6bed3ebd8b6bd00f984ef5af4baea9e07a246a8 Mon Sep 17 00:00:00 2001 From: Marko Henning Date: Fri, 17 May 2024 11:23:49 +0200 Subject: [PATCH] Adjusted compute type for tests without gpu. --- scraibe/transcriber.py | 16 +++++++++------- 1 file changed, 9 insertions(+), 7 deletions(-) diff --git a/scraibe/transcriber.py b/scraibe/transcriber.py index c9b9b52..526ff8b 100644 --- a/scraibe/transcriber.py +++ b/scraibe/transcriber.py @@ -32,8 +32,9 @@ from typing import TypeVar, Union, Optional from torch import Tensor, device from torch.cuda import is_available as cuda_is_available from numpy import ndarray -from inspect import getfullargspec +from inspect import signature from abc import abstractmethod +import warnings from .misc import WHISPER_DEFAULT_PATH whisper = TypeVar('whisper') @@ -254,9 +255,7 @@ class WhisperTranscriber(Transcriber): dict: Keyword arguments for whisper model. """ # _possible_kwargs = WhisperModel.transcribe.__code__.co_varnames - _args = getfullargspec(Whisper.transcribe).args - _kwargs = getfullargspec(Whisper.transcribe).kwonlyargs - _possible_kwargs = _args + _kwargs + _possible_kwargs = signature(Whisper.transcribe).parameters.keys() whisper_kwargs = {k: v for k, v in kwargs.items() if k in _possible_kwargs} @@ -343,6 +342,11 @@ class WhisperXTranscriber(Transcriber): device = "cuda" if cuda_is_available() else "cpu" if not isinstance(device, str): device = str(device) + compute_type = kwargs.get('compute_type', 'float16') + if device == 'cpu' and compute_type == 'float16': + warnings.warn(f'Compute type {compute_type} not compatible with ' + f'device {device}! Changing compute type to int8.') + compute_type = 'int8' _model = whisperx_load_model(model, download_root=download_root, device=device) @@ -357,9 +361,7 @@ class WhisperXTranscriber(Transcriber): dict: Keyword arguments for whisper model. """ # _possible_kwargs = WhisperModel.transcribe.__code__.co_varnames - _args = getfullargspec(WhisperModel.transcribe).args - _kwargs = getfullargspec(WhisperModel.transcribe).kwonlyargs - _possible_kwargs = _args + _kwargs + _possible_kwargs = signature(WhisperModel.transcribe).parameters.keys() whisper_kwargs = {k: v for k, v in kwargs.items() if k in _possible_kwargs}