diff --git a/scraibe/cli.py b/scraibe/cli.py index df73d1b..e4eeaad 100644 --- a/scraibe/cli.py +++ b/scraibe/cli.py @@ -9,8 +9,8 @@ import json from argparse import ArgumentParser, ArgumentDefaultsHelpFormatter from whisper.tokenizer import LANGUAGES, TO_LANGUAGE_CODE from torch.cuda import is_available -from torch import set_num_threads from .autotranscript import Scraibe +from .misc import set_threads def cli(): """ @@ -55,7 +55,7 @@ def cli(): default="cuda" if is_available() else "cpu", help="Device to use for PyTorch inference.") - parser.add_argument("--num-threads", type=int, default=0, + parser.add_argument("--num-threads", type=int, default=None, help="Number of threads used by torch for CPU inference; '\ 'overrides MKL_NUM_THREADS/OMP_NUM_THREADS.") @@ -94,8 +94,7 @@ def cli(): task = arg_dict.pop("task") - if args.num_threads > 0: - set_num_threads(arg_dict.pop("num_threads")) + set_threads(arg_dict.pop("num_threads")) class_kwargs = {'whisper_model': arg_dict.pop("whisper_model_name"), 'whisper_type':arg_dict.pop("whisper_type"), diff --git a/scraibe/misc.py b/scraibe/misc.py index 4f5ab1a..f5d2bfe 100644 --- a/scraibe/misc.py +++ b/scraibe/misc.py @@ -3,6 +3,7 @@ import yaml from argparse import Action from ast import literal_eval from torch.cuda import is_available +from torch import get_num_threads, set_num_threads CACHE_DIR = os.getenv( "AUTOT_CACHE", @@ -21,6 +22,8 @@ PYANNOTE_DEFAULT_CONFIG = os.path.join(PYANNOTE_DEFAULT_PATH, "config.yaml") \ SCRAIBE_TORCH_DEVICE = os.getenv("SCRAIBE_TORCH_DEVICE", "cuda" if is_available() else "cpu") +SCRAIBE_NUM_THREADS = os.getenv("SCRAIBE_NUM_THREADS", min(8, get_num_threads())) + def config_diarization_yaml(file_path: str, path_to_segmentation: str = None) -> None: """Configure diarization pipeline from a YAML file. @@ -49,6 +52,27 @@ def config_diarization_yaml(file_path: str, path_to_segmentation: str = None) -> yaml.dump(yml, stream) +def set_threads(parse_threads=None, + yaml_threads=None): + global SCRAIBE_NUM_THREADS + if parse_threads is not None: + if not isinstance(parse_threads, int): + # probably covered with int type of parser arg + raise ValueError(f"Type of --num-threads must be int, but the type is {type(parse_threads)}") + elif parse_threads < 1: + raise ValueError(f"Number of threads must be a positive integer, {parse_threads} was given") + else: + set_num_threads(parse_threads) + SCRAIBE_NUM_THREADS = parse_threads + elif yaml_threads is not None: + if not isinstance(yaml_threads, int): + raise ValueError(f"Type of num_threads must be int, but the type is {type(yaml_threads)}") + elif yaml_threads < 1: + raise ValueError(f"Number of threads must be a positive integer, {yaml_threads} was given") + else: + set_num_threads(yaml_threads) + SCRAIBE_NUM_THREADS = yaml_threads + class ParseKwargs(Action): """ Custom argparse action to parse keyword arguments. diff --git a/scraibe/transcriber.py b/scraibe/transcriber.py index 040b79d..bc341dc 100644 --- a/scraibe/transcriber.py +++ b/scraibe/transcriber.py @@ -36,7 +36,7 @@ from inspect import signature from abc import abstractmethod import warnings -from .misc import WHISPER_DEFAULT_PATH, SCRAIBE_TORCH_DEVICE +from .misc import WHISPER_DEFAULT_PATH, SCRAIBE_TORCH_DEVICE, SCRAIBE_NUM_THREADS whisper = TypeVar('whisper') @@ -348,7 +348,8 @@ class FasterWhisperTranscriber(Transcriber): f'device {device}! Changing compute type to int8.') compute_type = 'int8' _model = FasterWhisperModel(model, download_root=download_root, - device=device, compute_type=compute_type) + device=device, compute_type=compute_type, + cpu_threads=SCRAIBE_NUM_THREADS) return cls(_model, model_name=model)