From 44ff678e06aa99b0fdced7dd2b5675ec2165e495 Mon Sep 17 00:00:00 2001 From: "Schmieder, Jacob" Date: Tue, 8 Oct 2024 12:02:30 +0000 Subject: [PATCH] added SCRAIBE_TORCH_DEVICE Variable --- scraibe/misc.py | 2 ++ 1 file changed, 2 insertions(+) diff --git a/scraibe/misc.py b/scraibe/misc.py index 106b9e1..4a3de57 100644 --- a/scraibe/misc.py +++ b/scraibe/misc.py @@ -2,6 +2,7 @@ import os import yaml from argparse import Action from ast import literal_eval +from torch.cuda import is_available CACHE_DIR = os.getenv( "AUTOT_CACHE", @@ -18,6 +19,7 @@ PYANNOTE_DEFAULT_CONFIG = os.path.join(PYANNOTE_DEFAULT_PATH, "config.yaml") \ if os.path.exists(os.path.join(PYANNOTE_DEFAULT_PATH, "config.yaml")) \ else ('Jaikinator/ScrAIbe', 'pyannote/speaker-diarization-3.1') +SCRAIBE_TORCH_DEVICE = os.getenv("SCRAIBE_TORCH_DEVICE", "cuda" if is_available() else "cpu") def config_diarization_yaml(file_path: str, path_to_segmentation: str = None) -> None: """Configure diarization pipeline from a YAML file.