65 lines
2.0 KiB
Python
65 lines
2.0 KiB
Python
import os
|
|
from argparse import Action
|
|
from ast import literal_eval
|
|
|
|
CACHE_DIR = os.getenv(
|
|
"AUTOT_CACHE",
|
|
os.path.expanduser("~/.cache/torch/models"),
|
|
)
|
|
|
|
# Legacy paths kept for backward compatibility (ignored by LocalAI client)
|
|
WHISPER_DEFAULT_PATH = os.path.join(CACHE_DIR, "whisper")
|
|
PYANNOTE_DEFAULT_PATH = os.path.join(CACHE_DIR, "pyannote")
|
|
PYANNOTE_DEFAULT_CONFIG = os.path.join(PYANNOTE_DEFAULT_PATH, "config.yaml")
|
|
|
|
|
|
def set_threads(parse_threads=None, yaml_threads=None):
|
|
"""
|
|
Configure number of threads.
|
|
|
|
In LocalAI mode, this is mainly kept for backward compatibility.
|
|
"""
|
|
chosen = None
|
|
if parse_threads is not None:
|
|
if not isinstance(parse_threads, int):
|
|
raise ValueError(
|
|
f"Type of --num-threads must be int, but the type is {type(parse_threads)}"
|
|
)
|
|
elif parse_threads < 1:
|
|
raise ValueError(
|
|
f"Number of threads must be a positive integer, {parse_threads} was given"
|
|
)
|
|
else:
|
|
chosen = parse_threads
|
|
elif yaml_threads is not None:
|
|
if not isinstance(yaml_threads, int):
|
|
raise ValueError(
|
|
f"Type of num_threads must be int, but the type is {type(yaml_threads)}"
|
|
)
|
|
elif yaml_threads < 1:
|
|
raise ValueError(
|
|
f"Number of threads must be a positive integer, {yaml_threads} was given"
|
|
)
|
|
else:
|
|
chosen = yaml_threads
|
|
|
|
if chosen is not None:
|
|
os.environ["OMP_NUM_THREADS"] = str(chosen)
|
|
os.environ["MKL_NUM_THREADS"] = str(chosen)
|
|
|
|
|
|
class ParseKwargs(Action):
|
|
"""
|
|
Custom argparse action to parse keyword arguments.
|
|
"""
|
|
|
|
def __call__(self, parser, namespace, values, option_string=None):
|
|
setattr(namespace, self.dest, dict())
|
|
for value in values:
|
|
key, value = value.split("=")
|
|
try:
|
|
value = literal_eval(value)
|
|
except:
|
|
pass
|
|
getattr(namespace, self.dest)[key] = value
|