Adding support for setting number of threads to faster-whisper cpu, reading from cli, yaml or env var.
This commit is contained in:
+3
-4
@@ -9,8 +9,8 @@ import json
|
|||||||
from argparse import ArgumentParser, ArgumentDefaultsHelpFormatter
|
from argparse import ArgumentParser, ArgumentDefaultsHelpFormatter
|
||||||
from whisper.tokenizer import LANGUAGES, TO_LANGUAGE_CODE
|
from whisper.tokenizer import LANGUAGES, TO_LANGUAGE_CODE
|
||||||
from torch.cuda import is_available
|
from torch.cuda import is_available
|
||||||
from torch import set_num_threads
|
|
||||||
from .autotranscript import Scraibe
|
from .autotranscript import Scraibe
|
||||||
|
from .misc import set_threads
|
||||||
|
|
||||||
def cli():
|
def cli():
|
||||||
"""
|
"""
|
||||||
@@ -55,7 +55,7 @@ def cli():
|
|||||||
default="cuda" if is_available() else "cpu",
|
default="cuda" if is_available() else "cpu",
|
||||||
help="Device to use for PyTorch inference.")
|
help="Device to use for PyTorch inference.")
|
||||||
|
|
||||||
parser.add_argument("--num-threads", type=int, default=0,
|
parser.add_argument("--num-threads", type=int, default=None,
|
||||||
help="Number of threads used by torch for CPU inference; '\
|
help="Number of threads used by torch for CPU inference; '\
|
||||||
'overrides MKL_NUM_THREADS/OMP_NUM_THREADS.")
|
'overrides MKL_NUM_THREADS/OMP_NUM_THREADS.")
|
||||||
|
|
||||||
@@ -94,8 +94,7 @@ def cli():
|
|||||||
|
|
||||||
task = arg_dict.pop("task")
|
task = arg_dict.pop("task")
|
||||||
|
|
||||||
if args.num_threads > 0:
|
set_threads(arg_dict.pop("num_threads"))
|
||||||
set_num_threads(arg_dict.pop("num_threads"))
|
|
||||||
|
|
||||||
class_kwargs = {'whisper_model': arg_dict.pop("whisper_model_name"),
|
class_kwargs = {'whisper_model': arg_dict.pop("whisper_model_name"),
|
||||||
'whisper_type':arg_dict.pop("whisper_type"),
|
'whisper_type':arg_dict.pop("whisper_type"),
|
||||||
|
|||||||
@@ -3,6 +3,7 @@ import yaml
|
|||||||
from argparse import Action
|
from argparse import Action
|
||||||
from ast import literal_eval
|
from ast import literal_eval
|
||||||
from torch.cuda import is_available
|
from torch.cuda import is_available
|
||||||
|
from torch import get_num_threads, set_num_threads
|
||||||
|
|
||||||
CACHE_DIR = os.getenv(
|
CACHE_DIR = os.getenv(
|
||||||
"AUTOT_CACHE",
|
"AUTOT_CACHE",
|
||||||
@@ -21,6 +22,8 @@ PYANNOTE_DEFAULT_CONFIG = os.path.join(PYANNOTE_DEFAULT_PATH, "config.yaml") \
|
|||||||
|
|
||||||
SCRAIBE_TORCH_DEVICE = os.getenv("SCRAIBE_TORCH_DEVICE", "cuda" if is_available() else "cpu")
|
SCRAIBE_TORCH_DEVICE = os.getenv("SCRAIBE_TORCH_DEVICE", "cuda" if is_available() else "cpu")
|
||||||
|
|
||||||
|
SCRAIBE_NUM_THREADS = os.getenv("SCRAIBE_NUM_THREADS", min(8, get_num_threads()))
|
||||||
|
|
||||||
def config_diarization_yaml(file_path: str, path_to_segmentation: str = None) -> None:
|
def config_diarization_yaml(file_path: str, path_to_segmentation: str = None) -> None:
|
||||||
"""Configure diarization pipeline from a YAML file.
|
"""Configure diarization pipeline from a YAML file.
|
||||||
|
|
||||||
@@ -49,6 +52,28 @@ def config_diarization_yaml(file_path: str, path_to_segmentation: str = None) ->
|
|||||||
yaml.dump(yml, stream)
|
yaml.dump(yml, stream)
|
||||||
|
|
||||||
|
|
||||||
|
def set_threads(parse_threads=None,
|
||||||
|
yaml_threads=None,
|
||||||
|
env_var_threads=None):
|
||||||
|
global SCRAIBE_NUM_THREADS
|
||||||
|
if parse_threads is not None:
|
||||||
|
if not isinstance(parse_threads, int):
|
||||||
|
# probably covered with int type of parser arg
|
||||||
|
raise ValueError(f"Type of --num-threads must be int, but the type is {type(parse_threads)}")
|
||||||
|
elif parse_threads < 1:
|
||||||
|
raise ValueError(f"Number of threads must be a positive integer, {parse_threads} was given")
|
||||||
|
else:
|
||||||
|
set_num_threads(parse_threads)
|
||||||
|
SCRAIBE_NUM_THREADS = parse_threads
|
||||||
|
elif yaml_threads is not None:
|
||||||
|
if not isinstance(yaml_threads, int):
|
||||||
|
raise ValueError(f"Type of num_threads must be int, but the type is {type(yaml_threads)}")
|
||||||
|
elif yaml_threads < 1:
|
||||||
|
raise ValueError(f"Number of threads must be a positive integer, {yaml_threads} was given")
|
||||||
|
else:
|
||||||
|
set_num_threads(yaml_threads)
|
||||||
|
SCRAIBE_NUM_THREADS = yaml_threads
|
||||||
|
|
||||||
class ParseKwargs(Action):
|
class ParseKwargs(Action):
|
||||||
"""
|
"""
|
||||||
Custom argparse action to parse keyword arguments.
|
Custom argparse action to parse keyword arguments.
|
||||||
|
|||||||
@@ -36,7 +36,7 @@ from inspect import signature
|
|||||||
from abc import abstractmethod
|
from abc import abstractmethod
|
||||||
import warnings
|
import warnings
|
||||||
|
|
||||||
from .misc import WHISPER_DEFAULT_PATH, SCRAIBE_TORCH_DEVICE
|
from .misc import WHISPER_DEFAULT_PATH, SCRAIBE_TORCH_DEVICE, SCRAIBE_NUM_THREADS
|
||||||
whisper = TypeVar('whisper')
|
whisper = TypeVar('whisper')
|
||||||
|
|
||||||
|
|
||||||
@@ -348,7 +348,8 @@ class FasterWhisperTranscriber(Transcriber):
|
|||||||
f'device {device}! Changing compute type to int8.')
|
f'device {device}! Changing compute type to int8.')
|
||||||
compute_type = 'int8'
|
compute_type = 'int8'
|
||||||
_model = FasterWhisperModel(model, download_root=download_root,
|
_model = FasterWhisperModel(model, download_root=download_root,
|
||||||
device=device, compute_type=compute_type)
|
device=device, compute_type=compute_type,
|
||||||
|
cpu_threads=SCRAIBE_NUM_THREADS)
|
||||||
|
|
||||||
return cls(_model, model_name=model)
|
return cls(_model, model_name=model)
|
||||||
|
|
||||||
|
|||||||
Reference in New Issue
Block a user