added SCRAIBE_TORCH_DEVICE Variable

This commit is contained in:
Schmieder, Jacob
2024-10-08 12:02:30 +00:00
parent 8813662d4d
commit 44ff678e06
+2
View File
@@ -2,6 +2,7 @@ import os
import yaml import yaml
from argparse import Action from argparse import Action
from ast import literal_eval from ast import literal_eval
from torch.cuda import is_available
CACHE_DIR = os.getenv( CACHE_DIR = os.getenv(
"AUTOT_CACHE", "AUTOT_CACHE",
@@ -18,6 +19,7 @@ PYANNOTE_DEFAULT_CONFIG = os.path.join(PYANNOTE_DEFAULT_PATH, "config.yaml") \
if os.path.exists(os.path.join(PYANNOTE_DEFAULT_PATH, "config.yaml")) \ if os.path.exists(os.path.join(PYANNOTE_DEFAULT_PATH, "config.yaml")) \
else ('Jaikinator/ScrAIbe', 'pyannote/speaker-diarization-3.1') else ('Jaikinator/ScrAIbe', 'pyannote/speaker-diarization-3.1')
SCRAIBE_TORCH_DEVICE = os.getenv("SCRAIBE_TORCH_DEVICE", "cuda" if is_available() else "cpu")
def config_diarization_yaml(file_path: str, path_to_segmentation: str = None) -> None: def config_diarization_yaml(file_path: str, path_to_segmentation: str = None) -> None:
"""Configure diarization pipeline from a YAML file. """Configure diarization pipeline from a YAML file.