Merge pull request #134 from JSchmie/fix-audio-torch-device-setting

Improve Torch Device Configuration for Greater User Control
This commit is contained in:
Marko Henning
2024-10-24 09:17:05 +02:00
committed by GitHub
5 changed files with 26 additions and 35 deletions
+4 -10
View File
@@ -41,26 +41,20 @@ class AudioProcessor:
The sample rate of the audio. The sample rate of the audio.
""" """
def __init__(self, waveform: torch.Tensor, sr: int = SAMPLE_RATE, def __init__(self, waveform: torch.Tensor,
*args, **kwargs) -> None: sr: int = SAMPLE_RATE) -> None:
""" """
Initialize the AudioProcessor object. Initialize the AudioProcessor object.
Args: Args:
waveform (torch.Tensor): The audio waveform tensor. waveform (torch.Tensor): The audio waveform tensor.
sr (int, optional): The sample rate of the audio. Defaults to SAMPLE_RATE. sr (int, optional): The sample rate of the audio. Defaults to SAMPLE_RATE.
args: Additional arguments.
kwargs: Additional keyword arguments, e.g., device to use for processing.
If CUDA is available, it defaults to CUDA.
Raises: Raises:
ValueError: If the provided sample rate is not of type int. ValueError: If the provided sample rate is not of type int.
""" """
device = kwargs.get( self.waveform = waveform
"device", "cuda" if torch.cuda.is_available() else "cpu")
self.waveform = waveform.to(device)
self.sr = sr self.sr = sr
if not isinstance(self.sr, int): if not isinstance(self.sr, int):
@@ -147,6 +141,6 @@ class AudioProcessor:
np.float32) / NORMALIZATION_FACTOR np.float32) / NORMALIZATION_FACTOR
return out, sr return out, sr
def __repr__(self) -> str: def __repr__(self) -> str:
return f'TorchAudioProcessor(waveform={len(self.waveform)}, sr={int(self.sr)})' return f'TorchAudioProcessor(waveform={len(self.waveform)}, sr={int(self.sr)})'
+8 -7
View File
@@ -40,6 +40,7 @@ from .audio import AudioProcessor
from .diarisation import Diariser from .diarisation import Diariser
from .transcriber import Transcriber, load_transcriber, whisper from .transcriber import Transcriber, load_transcriber, whisper
from .transcript_exporter import Transcript from .transcript_exporter import Transcript
from .misc import SCRAIBE_TORCH_DEVICE
DiarisationType = TypeVar('DiarisationType') DiarisationType = TypeVar('DiarisationType')
@@ -115,6 +116,9 @@ class Scraibe:
**kwargs) **kwargs)
else: else:
self.params = {} self.params = {}
self.device = kwargs.get(
"device", SCRAIBE_TORCH_DEVICE)
def autotranscribe(self, audio_file: Union[str, torch.Tensor, ndarray], def autotranscribe(self, audio_file: Union[str, torch.Tensor, ndarray],
remove_original: bool = False, remove_original: bool = False,
@@ -141,10 +145,10 @@ class Scraibe:
# Prepare waveform and sample rate for diarization # Prepare waveform and sample rate for diarization
dia_audio = { dia_audio = {
"waveform": audio_file.waveform.reshape(1, len(audio_file.waveform)), "waveform": audio_file.waveform.reshape(1, len(audio_file.waveform)).to(self.device),
"sample_rate": audio_file.sr "sample_rate": audio_file.sr
} }
if self.verbose: if self.verbose:
print("Starting diarisation.") print("Starting diarisation.")
@@ -165,8 +169,6 @@ class Scraibe:
if self.verbose: if self.verbose:
print("Diarisation finished. Starting transcription.") print("Diarisation finished. Starting transcription.")
audio_file.sr = torch.Tensor([audio_file.sr]).to(
audio_file.waveform.device)
# Transcribe each segment and store the results # Transcribe each segment and store the results
final_transcript = dict() final_transcript = dict()
@@ -213,7 +215,7 @@ class Scraibe:
# Prepare waveform and sample rate for diarization # Prepare waveform and sample rate for diarization
dia_audio = { dia_audio = {
"waveform": audio_file.waveform.reshape(1, len(audio_file.waveform)), "waveform": audio_file.waveform.reshape(1, len(audio_file.waveform)).to(self.device),
"sample_rate": audio_file.sr "sample_rate": audio_file.sr
} }
@@ -323,8 +325,7 @@ class Scraibe:
print(f"Audiofile {audio_file} removed.") print(f"Audiofile {audio_file} removed.")
@staticmethod @staticmethod
def get_audio_file(audio_file: Union[str, torch.Tensor, ndarray], def get_audio_file(audio_file: Union[str, torch.Tensor, ndarray]) -> AudioProcessor:
*args, **kwargs) -> AudioProcessor:
"""Gets an audio file as TorchAudioProcessor. """Gets an audio file as TorchAudioProcessor.
Args: Args:
+3 -8
View File
@@ -37,11 +37,11 @@ from pyannote.audio import Pipeline
from pyannote.audio.pipelines.speaker_diarization import SpeakerDiarization from pyannote.audio.pipelines.speaker_diarization import SpeakerDiarization
from torch import Tensor from torch import Tensor
from torch import device as torch_device from torch import device as torch_device
from torch.cuda import is_available
from huggingface_hub import HfApi from huggingface_hub import HfApi
from huggingface_hub.utils import RepositoryNotFoundError from huggingface_hub.utils import RepositoryNotFoundError
from .misc import PYANNOTE_DEFAULT_PATH, PYANNOTE_DEFAULT_CONFIG from .misc import PYANNOTE_DEFAULT_PATH, PYANNOTE_DEFAULT_CONFIG, SCRAIBE_TORCH_DEVICE
Annotation = TypeVar('Annotation') Annotation = TypeVar('Annotation')
TOKEN_PATH = os.path.join(os.path.dirname( TOKEN_PATH = os.path.join(os.path.dirname(
@@ -190,8 +190,7 @@ class Diariser:
cache_token: bool = False, cache_token: bool = False,
cache_dir: Union[Path, str] = PYANNOTE_DEFAULT_PATH, cache_dir: Union[Path, str] = PYANNOTE_DEFAULT_PATH,
hparams_file: Union[str, Path] = None, hparams_file: Union[str, Path] = None,
device: str = None, device: str = SCRAIBE_TORCH_DEVICE,
*args, **kwargs
) -> Pipeline: ) -> Pipeline:
""" """
Loads a pretrained model from pyannote.audio, Loads a pretrained model from pyannote.audio,
@@ -283,10 +282,6 @@ class Diariser:
'or from huggingface.co models. Please check your token' 'or from huggingface.co models. Please check your token'
'or your local model path') 'or your local model path')
# try to move the model to the device
if device is None:
device = "cuda" if is_available() else "cpu"
# torch_device is renamed from torch.device to avoid name conflict # torch_device is renamed from torch.device to avoid name conflict
_model = _model.to(torch_device(device)) _model = _model.to(torch_device(device))
+2
View File
@@ -2,6 +2,7 @@ import os
import yaml import yaml
from argparse import Action from argparse import Action
from ast import literal_eval from ast import literal_eval
from torch.cuda import is_available
CACHE_DIR = os.getenv( CACHE_DIR = os.getenv(
"AUTOT_CACHE", "AUTOT_CACHE",
@@ -18,6 +19,7 @@ PYANNOTE_DEFAULT_CONFIG = os.path.join(PYANNOTE_DEFAULT_PATH, "config.yaml") \
if os.path.exists(os.path.join(PYANNOTE_DEFAULT_PATH, "config.yaml")) \ if os.path.exists(os.path.join(PYANNOTE_DEFAULT_PATH, "config.yaml")) \
else ('Jaikinator/ScrAIbe', 'pyannote/speaker-diarization-3.1') else ('Jaikinator/ScrAIbe', 'pyannote/speaker-diarization-3.1')
SCRAIBE_TORCH_DEVICE = os.getenv("SCRAIBE_TORCH_DEVICE", "cuda" if is_available() else "cpu")
def config_diarization_yaml(file_path: str, path_to_segmentation: str = None) -> None: def config_diarization_yaml(file_path: str, path_to_segmentation: str = None) -> None:
"""Configure diarization pipeline from a YAML file. """Configure diarization pipeline from a YAML file.
+9 -10
View File
@@ -31,13 +31,12 @@ from faster_whisper import WhisperModel as FasterWhisperModel
from faster_whisper.tokenizer import _LANGUAGE_CODES as FASTER_WHISPER_LANGUAGE_CODES from faster_whisper.tokenizer import _LANGUAGE_CODES as FASTER_WHISPER_LANGUAGE_CODES
from typing import TypeVar, Union, Optional from typing import TypeVar, Union, Optional
from torch import Tensor, device from torch import Tensor, device
from torch.cuda import is_available as cuda_is_available
from numpy import ndarray from numpy import ndarray
from inspect import signature from inspect import signature
from abc import abstractmethod from abc import abstractmethod
import warnings import warnings
from .misc import WHISPER_DEFAULT_PATH from .misc import WHISPER_DEFAULT_PATH, SCRAIBE_TORCH_DEVICE
whisper = TypeVar('whisper') whisper = TypeVar('whisper')
@@ -124,7 +123,7 @@ class Transcriber:
model: str = "medium", model: str = "medium",
whisper_type: str = 'whisper', whisper_type: str = 'whisper',
download_root: str = WHISPER_DEFAULT_PATH, download_root: str = WHISPER_DEFAULT_PATH,
device: Optional[Union[str, device]] = None, device: Optional[Union[str, device]] = SCRAIBE_TORCH_DEVICE,
in_memory: bool = False, in_memory: bool = False,
*args, **kwargs *args, **kwargs
) -> None: ) -> None:
@@ -206,7 +205,7 @@ class WhisperTranscriber(Transcriber):
def load_model(cls, def load_model(cls,
model: str = "medium", model: str = "medium",
download_root: str = WHISPER_DEFAULT_PATH, download_root: str = WHISPER_DEFAULT_PATH,
device: Optional[Union[str, device]] = None, device: Optional[Union[str, device]] = SCRAIBE_TORCH_DEVICE,
in_memory: bool = False, in_memory: bool = False,
*args, **kwargs *args, **kwargs
) -> 'WhisperTranscriber': ) -> 'WhisperTranscriber':
@@ -305,7 +304,7 @@ class FasterWhisperTranscriber(Transcriber):
def load_model(cls, def load_model(cls,
model: str = "medium", model: str = "medium",
download_root: str = WHISPER_DEFAULT_PATH, download_root: str = WHISPER_DEFAULT_PATH,
device: Optional[Union[str, device]] = None, device: Optional[Union[str, device]] = SCRAIBE_TORCH_DEVICE,
*args, **kwargs *args, **kwargs
) -> 'FasterWhisperModel': ) -> 'FasterWhisperModel':
""" """
@@ -330,7 +329,7 @@ class FasterWhisperTranscriber(Transcriber):
Defaults to WHISPER_DEFAULT_PATH. Defaults to WHISPER_DEFAULT_PATH.
device (Optional[Union[str, torch.device]], optional): device (Optional[Union[str, torch.device]], optional):
Device to load model on. Defaults to None. Device to load model on. Defaults to SCRAIBE_TORCH_DEVICE.
in_memory (bool, optional): Whether to load model in memory. in_memory (bool, optional): Whether to load model in memory.
Defaults to False. Defaults to False.
args: Additional arguments only to avoid errors. args: Additional arguments only to avoid errors.
@@ -339,10 +338,10 @@ class FasterWhisperTranscriber(Transcriber):
Returns: Returns:
Transcriber: A Transcriber object initialized with the specified model. Transcriber: A Transcriber object initialized with the specified model.
""" """
if device is None:
device = "cuda" if cuda_is_available() else "cpu"
if not isinstance(device, str): if not isinstance(device, str):
device = str(device) device = str(device)
compute_type = kwargs.get('compute_type', 'float16') compute_type = kwargs.get('compute_type', 'float16')
if device == 'cpu' and compute_type == 'float16': if device == 'cpu' and compute_type == 'float16':
warnings.warn(f'Compute type {compute_type} not compatible with ' warnings.warn(f'Compute type {compute_type} not compatible with '
@@ -412,7 +411,7 @@ class FasterWhisperTranscriber(Transcriber):
def load_transcriber(model: str = "medium", def load_transcriber(model: str = "medium",
whisper_type: str = 'whisper', whisper_type: str = 'whisper',
download_root: str = WHISPER_DEFAULT_PATH, download_root: str = WHISPER_DEFAULT_PATH,
device: Optional[Union[str, device]] = None, device: Optional[Union[str, device]] = SCRAIBE_TORCH_DEVICE,
in_memory: bool = False, in_memory: bool = False,
*args, **kwargs *args, **kwargs
) -> Union[WhisperTranscriber, FasterWhisperTranscriber]: ) -> Union[WhisperTranscriber, FasterWhisperTranscriber]:
@@ -438,7 +437,7 @@ def load_transcriber(model: str = "medium",
download_root (str, optional): Path to download the model. download_root (str, optional): Path to download the model.
Defaults to WHISPER_DEFAULT_PATH. Defaults to WHISPER_DEFAULT_PATH.
device (Optional[Union[str, torch.device]], optional): device (Optional[Union[str, torch.device]], optional):
Device to load model on. Defaults to None. Device to load model on. Defaults to SCRAIBE_TORCH_DEVICE.
in_memory (bool, optional): Whether to load model in memory. in_memory (bool, optional): Whether to load model in memory.
Defaults to False. Defaults to False.
args: Additional arguments only to avoid errors. args: Additional arguments only to avoid errors.