diff --git a/scraibe/audio.py b/scraibe/audio.py index 7fbc6fb..4e5dd0f 100644 --- a/scraibe/audio.py +++ b/scraibe/audio.py @@ -41,26 +41,20 @@ class AudioProcessor: The sample rate of the audio. """ - def __init__(self, waveform: torch.Tensor, sr: int = SAMPLE_RATE, - *args, **kwargs) -> None: + def __init__(self, waveform: torch.Tensor, + sr: int = SAMPLE_RATE) -> None: """ Initialize the AudioProcessor object. Args: waveform (torch.Tensor): The audio waveform tensor. sr (int, optional): The sample rate of the audio. Defaults to SAMPLE_RATE. - args: Additional arguments. - kwargs: Additional keyword arguments, e.g., device to use for processing. - If CUDA is available, it defaults to CUDA. Raises: ValueError: If the provided sample rate is not of type int. """ - device = kwargs.get( - "device", "cuda" if torch.cuda.is_available() else "cpu") - - self.waveform = waveform.to(device) + self.waveform = waveform self.sr = sr if not isinstance(self.sr, int): @@ -147,6 +141,6 @@ class AudioProcessor: np.float32) / NORMALIZATION_FACTOR return out, sr - + def __repr__(self) -> str: return f'TorchAudioProcessor(waveform={len(self.waveform)}, sr={int(self.sr)})' diff --git a/scraibe/autotranscript.py b/scraibe/autotranscript.py index 43dedc2..9023107 100644 --- a/scraibe/autotranscript.py +++ b/scraibe/autotranscript.py @@ -40,6 +40,7 @@ from .audio import AudioProcessor from .diarisation import Diariser from .transcriber import Transcriber, load_transcriber, whisper from .transcript_exporter import Transcript +from .misc import SCRAIBE_TORCH_DEVICE DiarisationType = TypeVar('DiarisationType') @@ -115,6 +116,9 @@ class Scraibe: **kwargs) else: self.params = {} + + self.device = kwargs.get( + "device", SCRAIBE_TORCH_DEVICE) def autotranscribe(self, audio_file: Union[str, torch.Tensor, ndarray], remove_original: bool = False, @@ -141,10 +145,10 @@ class Scraibe: # Prepare waveform and sample rate for diarization dia_audio = { - "waveform": audio_file.waveform.reshape(1, len(audio_file.waveform)), + "waveform": audio_file.waveform.reshape(1, len(audio_file.waveform)).to(self.device), "sample_rate": audio_file.sr } - + if self.verbose: print("Starting diarisation.") @@ -165,8 +169,6 @@ class Scraibe: if self.verbose: print("Diarisation finished. Starting transcription.") - audio_file.sr = torch.Tensor([audio_file.sr]).to( - audio_file.waveform.device) # Transcribe each segment and store the results final_transcript = dict() @@ -213,7 +215,7 @@ class Scraibe: # Prepare waveform and sample rate for diarization dia_audio = { - "waveform": audio_file.waveform.reshape(1, len(audio_file.waveform)), + "waveform": audio_file.waveform.reshape(1, len(audio_file.waveform)).to(self.device), "sample_rate": audio_file.sr } @@ -323,8 +325,7 @@ class Scraibe: print(f"Audiofile {audio_file} removed.") @staticmethod - def get_audio_file(audio_file: Union[str, torch.Tensor, ndarray], - *args, **kwargs) -> AudioProcessor: + def get_audio_file(audio_file: Union[str, torch.Tensor, ndarray]) -> AudioProcessor: """Gets an audio file as TorchAudioProcessor. Args: diff --git a/scraibe/diarisation.py b/scraibe/diarisation.py index d70df99..eeef135 100644 --- a/scraibe/diarisation.py +++ b/scraibe/diarisation.py @@ -37,11 +37,11 @@ from pyannote.audio import Pipeline from pyannote.audio.pipelines.speaker_diarization import SpeakerDiarization from torch import Tensor from torch import device as torch_device -from torch.cuda import is_available + from huggingface_hub import HfApi from huggingface_hub.utils import RepositoryNotFoundError -from .misc import PYANNOTE_DEFAULT_PATH, PYANNOTE_DEFAULT_CONFIG +from .misc import PYANNOTE_DEFAULT_PATH, PYANNOTE_DEFAULT_CONFIG, SCRAIBE_TORCH_DEVICE Annotation = TypeVar('Annotation') TOKEN_PATH = os.path.join(os.path.dirname( @@ -190,8 +190,7 @@ class Diariser: cache_token: bool = False, cache_dir: Union[Path, str] = PYANNOTE_DEFAULT_PATH, hparams_file: Union[str, Path] = None, - device: str = None, - *args, **kwargs + device: str = SCRAIBE_TORCH_DEVICE, ) -> Pipeline: """ Loads a pretrained model from pyannote.audio, @@ -283,10 +282,6 @@ class Diariser: 'or from huggingface.co models. Please check your token' 'or your local model path') - # try to move the model to the device - if device is None: - device = "cuda" if is_available() else "cpu" - # torch_device is renamed from torch.device to avoid name conflict _model = _model.to(torch_device(device)) diff --git a/scraibe/misc.py b/scraibe/misc.py index e865f52..4f5ab1a 100644 --- a/scraibe/misc.py +++ b/scraibe/misc.py @@ -2,6 +2,7 @@ import os import yaml from argparse import Action from ast import literal_eval +from torch.cuda import is_available CACHE_DIR = os.getenv( "AUTOT_CACHE", @@ -18,6 +19,7 @@ PYANNOTE_DEFAULT_CONFIG = os.path.join(PYANNOTE_DEFAULT_PATH, "config.yaml") \ if os.path.exists(os.path.join(PYANNOTE_DEFAULT_PATH, "config.yaml")) \ else ('Jaikinator/ScrAIbe', 'pyannote/speaker-diarization-3.1') +SCRAIBE_TORCH_DEVICE = os.getenv("SCRAIBE_TORCH_DEVICE", "cuda" if is_available() else "cpu") def config_diarization_yaml(file_path: str, path_to_segmentation: str = None) -> None: """Configure diarization pipeline from a YAML file. diff --git a/scraibe/transcriber.py b/scraibe/transcriber.py index abf1ace..040b79d 100644 --- a/scraibe/transcriber.py +++ b/scraibe/transcriber.py @@ -31,13 +31,12 @@ from faster_whisper import WhisperModel as FasterWhisperModel from faster_whisper.tokenizer import _LANGUAGE_CODES as FASTER_WHISPER_LANGUAGE_CODES from typing import TypeVar, Union, Optional from torch import Tensor, device -from torch.cuda import is_available as cuda_is_available from numpy import ndarray from inspect import signature from abc import abstractmethod import warnings -from .misc import WHISPER_DEFAULT_PATH +from .misc import WHISPER_DEFAULT_PATH, SCRAIBE_TORCH_DEVICE whisper = TypeVar('whisper') @@ -124,7 +123,7 @@ class Transcriber: model: str = "medium", whisper_type: str = 'whisper', download_root: str = WHISPER_DEFAULT_PATH, - device: Optional[Union[str, device]] = None, + device: Optional[Union[str, device]] = SCRAIBE_TORCH_DEVICE, in_memory: bool = False, *args, **kwargs ) -> None: @@ -206,7 +205,7 @@ class WhisperTranscriber(Transcriber): def load_model(cls, model: str = "medium", download_root: str = WHISPER_DEFAULT_PATH, - device: Optional[Union[str, device]] = None, + device: Optional[Union[str, device]] = SCRAIBE_TORCH_DEVICE, in_memory: bool = False, *args, **kwargs ) -> 'WhisperTranscriber': @@ -305,7 +304,7 @@ class FasterWhisperTranscriber(Transcriber): def load_model(cls, model: str = "medium", download_root: str = WHISPER_DEFAULT_PATH, - device: Optional[Union[str, device]] = None, + device: Optional[Union[str, device]] = SCRAIBE_TORCH_DEVICE, *args, **kwargs ) -> 'FasterWhisperModel': """ @@ -330,7 +329,7 @@ class FasterWhisperTranscriber(Transcriber): Defaults to WHISPER_DEFAULT_PATH. device (Optional[Union[str, torch.device]], optional): - Device to load model on. Defaults to None. + Device to load model on. Defaults to SCRAIBE_TORCH_DEVICE. in_memory (bool, optional): Whether to load model in memory. Defaults to False. args: Additional arguments only to avoid errors. @@ -339,10 +338,10 @@ class FasterWhisperTranscriber(Transcriber): Returns: Transcriber: A Transcriber object initialized with the specified model. """ - if device is None: - device = "cuda" if cuda_is_available() else "cpu" + if not isinstance(device, str): device = str(device) + compute_type = kwargs.get('compute_type', 'float16') if device == 'cpu' and compute_type == 'float16': warnings.warn(f'Compute type {compute_type} not compatible with ' @@ -412,7 +411,7 @@ class FasterWhisperTranscriber(Transcriber): def load_transcriber(model: str = "medium", whisper_type: str = 'whisper', download_root: str = WHISPER_DEFAULT_PATH, - device: Optional[Union[str, device]] = None, + device: Optional[Union[str, device]] = SCRAIBE_TORCH_DEVICE, in_memory: bool = False, *args, **kwargs ) -> Union[WhisperTranscriber, FasterWhisperTranscriber]: @@ -438,7 +437,7 @@ def load_transcriber(model: str = "medium", download_root (str, optional): Path to download the model. Defaults to WHISPER_DEFAULT_PATH. device (Optional[Union[str, torch.device]], optional): - Device to load model on. Defaults to None. + Device to load model on. Defaults to SCRAIBE_TORCH_DEVICE. in_memory (bool, optional): Whether to load model in memory. Defaults to False. args: Additional arguments only to avoid errors.