From 6fadf3d851c06ffc130bfd4d6e758d7da5850830 Mon Sep 17 00:00:00 2001 From: "Schmieder, Jacob" Date: Tue, 8 Oct 2024 12:01:36 +0000 Subject: [PATCH 1/6] removed torch device from AudioProcessor class --- scraibe/audio.py | 14 ++++---------- 1 file changed, 4 insertions(+), 10 deletions(-) diff --git a/scraibe/audio.py b/scraibe/audio.py index 7fbc6fb..4e5dd0f 100644 --- a/scraibe/audio.py +++ b/scraibe/audio.py @@ -41,26 +41,20 @@ class AudioProcessor: The sample rate of the audio. """ - def __init__(self, waveform: torch.Tensor, sr: int = SAMPLE_RATE, - *args, **kwargs) -> None: + def __init__(self, waveform: torch.Tensor, + sr: int = SAMPLE_RATE) -> None: """ Initialize the AudioProcessor object. Args: waveform (torch.Tensor): The audio waveform tensor. sr (int, optional): The sample rate of the audio. Defaults to SAMPLE_RATE. - args: Additional arguments. - kwargs: Additional keyword arguments, e.g., device to use for processing. - If CUDA is available, it defaults to CUDA. Raises: ValueError: If the provided sample rate is not of type int. """ - device = kwargs.get( - "device", "cuda" if torch.cuda.is_available() else "cpu") - - self.waveform = waveform.to(device) + self.waveform = waveform self.sr = sr if not isinstance(self.sr, int): @@ -147,6 +141,6 @@ class AudioProcessor: np.float32) / NORMALIZATION_FACTOR return out, sr - + def __repr__(self) -> str: return f'TorchAudioProcessor(waveform={len(self.waveform)}, sr={int(self.sr)})' From 8813662d4df9cbea51940a82530c2782c8f22f28 Mon Sep 17 00:00:00 2001 From: "Schmieder, Jacob" Date: Tue, 8 Oct 2024 12:02:08 +0000 Subject: [PATCH 2/6] added SCRAIBE_TORCH_DEVICE to Scraibe Class to handle torch device setting --- scraibe/autotranscript.py | 15 ++++++++------- 1 file changed, 8 insertions(+), 7 deletions(-) diff --git a/scraibe/autotranscript.py b/scraibe/autotranscript.py index 43dedc2..9023107 100644 --- a/scraibe/autotranscript.py +++ b/scraibe/autotranscript.py @@ -40,6 +40,7 @@ from .audio import AudioProcessor from .diarisation import Diariser from .transcriber import Transcriber, load_transcriber, whisper from .transcript_exporter import Transcript +from .misc import SCRAIBE_TORCH_DEVICE DiarisationType = TypeVar('DiarisationType') @@ -115,6 +116,9 @@ class Scraibe: **kwargs) else: self.params = {} + + self.device = kwargs.get( + "device", SCRAIBE_TORCH_DEVICE) def autotranscribe(self, audio_file: Union[str, torch.Tensor, ndarray], remove_original: bool = False, @@ -141,10 +145,10 @@ class Scraibe: # Prepare waveform and sample rate for diarization dia_audio = { - "waveform": audio_file.waveform.reshape(1, len(audio_file.waveform)), + "waveform": audio_file.waveform.reshape(1, len(audio_file.waveform)).to(self.device), "sample_rate": audio_file.sr } - + if self.verbose: print("Starting diarisation.") @@ -165,8 +169,6 @@ class Scraibe: if self.verbose: print("Diarisation finished. Starting transcription.") - audio_file.sr = torch.Tensor([audio_file.sr]).to( - audio_file.waveform.device) # Transcribe each segment and store the results final_transcript = dict() @@ -213,7 +215,7 @@ class Scraibe: # Prepare waveform and sample rate for diarization dia_audio = { - "waveform": audio_file.waveform.reshape(1, len(audio_file.waveform)), + "waveform": audio_file.waveform.reshape(1, len(audio_file.waveform)).to(self.device), "sample_rate": audio_file.sr } @@ -323,8 +325,7 @@ class Scraibe: print(f"Audiofile {audio_file} removed.") @staticmethod - def get_audio_file(audio_file: Union[str, torch.Tensor, ndarray], - *args, **kwargs) -> AudioProcessor: + def get_audio_file(audio_file: Union[str, torch.Tensor, ndarray]) -> AudioProcessor: """Gets an audio file as TorchAudioProcessor. Args: From 44ff678e06aa99b0fdced7dd2b5675ec2165e495 Mon Sep 17 00:00:00 2001 From: "Schmieder, Jacob" Date: Tue, 8 Oct 2024 12:02:30 +0000 Subject: [PATCH 3/6] added SCRAIBE_TORCH_DEVICE Variable --- scraibe/misc.py | 2 ++ 1 file changed, 2 insertions(+) diff --git a/scraibe/misc.py b/scraibe/misc.py index 106b9e1..4a3de57 100644 --- a/scraibe/misc.py +++ b/scraibe/misc.py @@ -2,6 +2,7 @@ import os import yaml from argparse import Action from ast import literal_eval +from torch.cuda import is_available CACHE_DIR = os.getenv( "AUTOT_CACHE", @@ -18,6 +19,7 @@ PYANNOTE_DEFAULT_CONFIG = os.path.join(PYANNOTE_DEFAULT_PATH, "config.yaml") \ if os.path.exists(os.path.join(PYANNOTE_DEFAULT_PATH, "config.yaml")) \ else ('Jaikinator/ScrAIbe', 'pyannote/speaker-diarization-3.1') +SCRAIBE_TORCH_DEVICE = os.getenv("SCRAIBE_TORCH_DEVICE", "cuda" if is_available() else "cpu") def config_diarization_yaml(file_path: str, path_to_segmentation: str = None) -> None: """Configure diarization pipeline from a YAML file. From af99a655e593093494600eb25353b82f4a44dcd6 Mon Sep 17 00:00:00 2001 From: "Schmieder, Jacob" Date: Thu, 10 Oct 2024 09:22:34 +0000 Subject: [PATCH 4/6] added SCRAIBE_TORCH_DEVICE to Diariser class --- scraibe/diarisation.py | 9 ++------- 1 file changed, 2 insertions(+), 7 deletions(-) diff --git a/scraibe/diarisation.py b/scraibe/diarisation.py index d70df99..6e6d6b9 100644 --- a/scraibe/diarisation.py +++ b/scraibe/diarisation.py @@ -41,7 +41,7 @@ from torch.cuda import is_available from huggingface_hub import HfApi from huggingface_hub.utils import RepositoryNotFoundError -from .misc import PYANNOTE_DEFAULT_PATH, PYANNOTE_DEFAULT_CONFIG +from .misc import PYANNOTE_DEFAULT_PATH, PYANNOTE_DEFAULT_CONFIG, SCRAIBE_TORCH_DEVICE Annotation = TypeVar('Annotation') TOKEN_PATH = os.path.join(os.path.dirname( @@ -190,8 +190,7 @@ class Diariser: cache_token: bool = False, cache_dir: Union[Path, str] = PYANNOTE_DEFAULT_PATH, hparams_file: Union[str, Path] = None, - device: str = None, - *args, **kwargs + device: str = SCRAIBE_TORCH_DEVICE, ) -> Pipeline: """ Loads a pretrained model from pyannote.audio, @@ -283,10 +282,6 @@ class Diariser: 'or from huggingface.co models. Please check your token' 'or your local model path') - # try to move the model to the device - if device is None: - device = "cuda" if is_available() else "cpu" - # torch_device is renamed from torch.device to avoid name conflict _model = _model.to(torch_device(device)) From e7c1a5a2b01263acddb80c781d1c26292fc6210a Mon Sep 17 00:00:00 2001 From: "Schmieder, Jacob" Date: Thu, 10 Oct 2024 09:25:49 +0000 Subject: [PATCH 5/6] added SCRAIBE_TORCH_DEVICE to transcriber class --- scraibe/transcriber.py | 18 +++++++++--------- 1 file changed, 9 insertions(+), 9 deletions(-) diff --git a/scraibe/transcriber.py b/scraibe/transcriber.py index abf1ace..9c891f6 100644 --- a/scraibe/transcriber.py +++ b/scraibe/transcriber.py @@ -37,7 +37,7 @@ from inspect import signature from abc import abstractmethod import warnings -from .misc import WHISPER_DEFAULT_PATH +from .misc import WHISPER_DEFAULT_PATH, SCRAIBE_TORCH_DEVICE whisper = TypeVar('whisper') @@ -124,7 +124,7 @@ class Transcriber: model: str = "medium", whisper_type: str = 'whisper', download_root: str = WHISPER_DEFAULT_PATH, - device: Optional[Union[str, device]] = None, + device: Optional[Union[str, device]] = SCRAIBE_TORCH_DEVICE, in_memory: bool = False, *args, **kwargs ) -> None: @@ -206,7 +206,7 @@ class WhisperTranscriber(Transcriber): def load_model(cls, model: str = "medium", download_root: str = WHISPER_DEFAULT_PATH, - device: Optional[Union[str, device]] = None, + device: Optional[Union[str, device]] = SCRAIBE_TORCH_DEVICE, in_memory: bool = False, *args, **kwargs ) -> 'WhisperTranscriber': @@ -305,7 +305,7 @@ class FasterWhisperTranscriber(Transcriber): def load_model(cls, model: str = "medium", download_root: str = WHISPER_DEFAULT_PATH, - device: Optional[Union[str, device]] = None, + device: Optional[Union[str, device]] = SCRAIBE_TORCH_DEVICE, *args, **kwargs ) -> 'FasterWhisperModel': """ @@ -330,7 +330,7 @@ class FasterWhisperTranscriber(Transcriber): Defaults to WHISPER_DEFAULT_PATH. device (Optional[Union[str, torch.device]], optional): - Device to load model on. Defaults to None. + Device to load model on. Defaults to SCRAIBE_TORCH_DEVICE. in_memory (bool, optional): Whether to load model in memory. Defaults to False. args: Additional arguments only to avoid errors. @@ -339,10 +339,10 @@ class FasterWhisperTranscriber(Transcriber): Returns: Transcriber: A Transcriber object initialized with the specified model. """ - if device is None: - device = "cuda" if cuda_is_available() else "cpu" + if not isinstance(device, str): device = str(device) + compute_type = kwargs.get('compute_type', 'float16') if device == 'cpu' and compute_type == 'float16': warnings.warn(f'Compute type {compute_type} not compatible with ' @@ -412,7 +412,7 @@ class FasterWhisperTranscriber(Transcriber): def load_transcriber(model: str = "medium", whisper_type: str = 'whisper', download_root: str = WHISPER_DEFAULT_PATH, - device: Optional[Union[str, device]] = None, + device: Optional[Union[str, device]] = SCRAIBE_TORCH_DEVICE, in_memory: bool = False, *args, **kwargs ) -> Union[WhisperTranscriber, FasterWhisperTranscriber]: @@ -438,7 +438,7 @@ def load_transcriber(model: str = "medium", download_root (str, optional): Path to download the model. Defaults to WHISPER_DEFAULT_PATH. device (Optional[Union[str, torch.device]], optional): - Device to load model on. Defaults to None. + Device to load model on. Defaults to SCRAIBE_TORCH_DEVICE. in_memory (bool, optional): Whether to load model in memory. Defaults to False. args: Additional arguments only to avoid errors. From 101e913f849ce450dea400a4681095d6c39d455f Mon Sep 17 00:00:00 2001 From: "Schmieder, Jacob" Date: Thu, 10 Oct 2024 09:29:48 +0000 Subject: [PATCH 6/6] make ruff happy --- scraibe/diarisation.py | 2 +- scraibe/transcriber.py | 1 - 2 files changed, 1 insertion(+), 2 deletions(-) diff --git a/scraibe/diarisation.py b/scraibe/diarisation.py index 6e6d6b9..eeef135 100644 --- a/scraibe/diarisation.py +++ b/scraibe/diarisation.py @@ -37,7 +37,7 @@ from pyannote.audio import Pipeline from pyannote.audio.pipelines.speaker_diarization import SpeakerDiarization from torch import Tensor from torch import device as torch_device -from torch.cuda import is_available + from huggingface_hub import HfApi from huggingface_hub.utils import RepositoryNotFoundError diff --git a/scraibe/transcriber.py b/scraibe/transcriber.py index 9c891f6..040b79d 100644 --- a/scraibe/transcriber.py +++ b/scraibe/transcriber.py @@ -31,7 +31,6 @@ from faster_whisper import WhisperModel as FasterWhisperModel from faster_whisper.tokenizer import _LANGUAGE_CODES as FASTER_WHISPER_LANGUAGE_CODES from typing import TypeVar, Union, Optional from torch import Tensor, device -from torch.cuda import is_available as cuda_is_available from numpy import ndarray from inspect import signature from abc import abstractmethod