added SCRAIBE_TORCH_DEVICE Variable
This commit is contained in:
@@ -2,6 +2,7 @@ import os
|
|||||||
import yaml
|
import yaml
|
||||||
from argparse import Action
|
from argparse import Action
|
||||||
from ast import literal_eval
|
from ast import literal_eval
|
||||||
|
from torch.cuda import is_available
|
||||||
|
|
||||||
CACHE_DIR = os.getenv(
|
CACHE_DIR = os.getenv(
|
||||||
"AUTOT_CACHE",
|
"AUTOT_CACHE",
|
||||||
@@ -18,6 +19,7 @@ PYANNOTE_DEFAULT_CONFIG = os.path.join(PYANNOTE_DEFAULT_PATH, "config.yaml") \
|
|||||||
if os.path.exists(os.path.join(PYANNOTE_DEFAULT_PATH, "config.yaml")) \
|
if os.path.exists(os.path.join(PYANNOTE_DEFAULT_PATH, "config.yaml")) \
|
||||||
else ('Jaikinator/ScrAIbe', 'pyannote/speaker-diarization-3.1')
|
else ('Jaikinator/ScrAIbe', 'pyannote/speaker-diarization-3.1')
|
||||||
|
|
||||||
|
SCRAIBE_TORCH_DEVICE = os.getenv("SCRAIBE_TORCH_DEVICE", "cuda" if is_available() else "cpu")
|
||||||
|
|
||||||
def config_diarization_yaml(file_path: str, path_to_segmentation: str = None) -> None:
|
def config_diarization_yaml(file_path: str, path_to_segmentation: str = None) -> None:
|
||||||
"""Configure diarization pipeline from a YAML file.
|
"""Configure diarization pipeline from a YAML file.
|
||||||
|
|||||||
Reference in New Issue
Block a user