Merge pull request #134 from JSchmie/fix-audio-torch-device-setting
Improve Torch Device Configuration for Greater User Control
This commit is contained in:
+4
-10
@@ -41,26 +41,20 @@ class AudioProcessor:
|
||||
The sample rate of the audio.
|
||||
"""
|
||||
|
||||
def __init__(self, waveform: torch.Tensor, sr: int = SAMPLE_RATE,
|
||||
*args, **kwargs) -> None:
|
||||
def __init__(self, waveform: torch.Tensor,
|
||||
sr: int = SAMPLE_RATE) -> None:
|
||||
"""
|
||||
Initialize the AudioProcessor object.
|
||||
|
||||
Args:
|
||||
waveform (torch.Tensor): The audio waveform tensor.
|
||||
sr (int, optional): The sample rate of the audio. Defaults to SAMPLE_RATE.
|
||||
args: Additional arguments.
|
||||
kwargs: Additional keyword arguments, e.g., device to use for processing.
|
||||
If CUDA is available, it defaults to CUDA.
|
||||
|
||||
Raises:
|
||||
ValueError: If the provided sample rate is not of type int.
|
||||
"""
|
||||
|
||||
device = kwargs.get(
|
||||
"device", "cuda" if torch.cuda.is_available() else "cpu")
|
||||
|
||||
self.waveform = waveform.to(device)
|
||||
self.waveform = waveform
|
||||
self.sr = sr
|
||||
|
||||
if not isinstance(self.sr, int):
|
||||
@@ -147,6 +141,6 @@ class AudioProcessor:
|
||||
np.float32) / NORMALIZATION_FACTOR
|
||||
|
||||
return out, sr
|
||||
|
||||
|
||||
def __repr__(self) -> str:
|
||||
return f'TorchAudioProcessor(waveform={len(self.waveform)}, sr={int(self.sr)})'
|
||||
|
||||
@@ -40,6 +40,7 @@ from .audio import AudioProcessor
|
||||
from .diarisation import Diariser
|
||||
from .transcriber import Transcriber, load_transcriber, whisper
|
||||
from .transcript_exporter import Transcript
|
||||
from .misc import SCRAIBE_TORCH_DEVICE
|
||||
|
||||
|
||||
DiarisationType = TypeVar('DiarisationType')
|
||||
@@ -115,6 +116,9 @@ class Scraibe:
|
||||
**kwargs)
|
||||
else:
|
||||
self.params = {}
|
||||
|
||||
self.device = kwargs.get(
|
||||
"device", SCRAIBE_TORCH_DEVICE)
|
||||
|
||||
def autotranscribe(self, audio_file: Union[str, torch.Tensor, ndarray],
|
||||
remove_original: bool = False,
|
||||
@@ -141,10 +145,10 @@ class Scraibe:
|
||||
|
||||
# Prepare waveform and sample rate for diarization
|
||||
dia_audio = {
|
||||
"waveform": audio_file.waveform.reshape(1, len(audio_file.waveform)),
|
||||
"waveform": audio_file.waveform.reshape(1, len(audio_file.waveform)).to(self.device),
|
||||
"sample_rate": audio_file.sr
|
||||
}
|
||||
|
||||
|
||||
if self.verbose:
|
||||
print("Starting diarisation.")
|
||||
|
||||
@@ -165,8 +169,6 @@ class Scraibe:
|
||||
if self.verbose:
|
||||
print("Diarisation finished. Starting transcription.")
|
||||
|
||||
audio_file.sr = torch.Tensor([audio_file.sr]).to(
|
||||
audio_file.waveform.device)
|
||||
|
||||
# Transcribe each segment and store the results
|
||||
final_transcript = dict()
|
||||
@@ -213,7 +215,7 @@ class Scraibe:
|
||||
|
||||
# Prepare waveform and sample rate for diarization
|
||||
dia_audio = {
|
||||
"waveform": audio_file.waveform.reshape(1, len(audio_file.waveform)),
|
||||
"waveform": audio_file.waveform.reshape(1, len(audio_file.waveform)).to(self.device),
|
||||
"sample_rate": audio_file.sr
|
||||
}
|
||||
|
||||
@@ -323,8 +325,7 @@ class Scraibe:
|
||||
print(f"Audiofile {audio_file} removed.")
|
||||
|
||||
@staticmethod
|
||||
def get_audio_file(audio_file: Union[str, torch.Tensor, ndarray],
|
||||
*args, **kwargs) -> AudioProcessor:
|
||||
def get_audio_file(audio_file: Union[str, torch.Tensor, ndarray]) -> AudioProcessor:
|
||||
"""Gets an audio file as TorchAudioProcessor.
|
||||
|
||||
Args:
|
||||
|
||||
@@ -37,11 +37,11 @@ from pyannote.audio import Pipeline
|
||||
from pyannote.audio.pipelines.speaker_diarization import SpeakerDiarization
|
||||
from torch import Tensor
|
||||
from torch import device as torch_device
|
||||
from torch.cuda import is_available
|
||||
|
||||
from huggingface_hub import HfApi
|
||||
from huggingface_hub.utils import RepositoryNotFoundError
|
||||
|
||||
from .misc import PYANNOTE_DEFAULT_PATH, PYANNOTE_DEFAULT_CONFIG
|
||||
from .misc import PYANNOTE_DEFAULT_PATH, PYANNOTE_DEFAULT_CONFIG, SCRAIBE_TORCH_DEVICE
|
||||
Annotation = TypeVar('Annotation')
|
||||
|
||||
TOKEN_PATH = os.path.join(os.path.dirname(
|
||||
@@ -190,8 +190,7 @@ class Diariser:
|
||||
cache_token: bool = False,
|
||||
cache_dir: Union[Path, str] = PYANNOTE_DEFAULT_PATH,
|
||||
hparams_file: Union[str, Path] = None,
|
||||
device: str = None,
|
||||
*args, **kwargs
|
||||
device: str = SCRAIBE_TORCH_DEVICE,
|
||||
) -> Pipeline:
|
||||
"""
|
||||
Loads a pretrained model from pyannote.audio,
|
||||
@@ -283,10 +282,6 @@ class Diariser:
|
||||
'or from huggingface.co models. Please check your token'
|
||||
'or your local model path')
|
||||
|
||||
# try to move the model to the device
|
||||
if device is None:
|
||||
device = "cuda" if is_available() else "cpu"
|
||||
|
||||
# torch_device is renamed from torch.device to avoid name conflict
|
||||
_model = _model.to(torch_device(device))
|
||||
|
||||
|
||||
@@ -2,6 +2,7 @@ import os
|
||||
import yaml
|
||||
from argparse import Action
|
||||
from ast import literal_eval
|
||||
from torch.cuda import is_available
|
||||
|
||||
CACHE_DIR = os.getenv(
|
||||
"AUTOT_CACHE",
|
||||
@@ -18,6 +19,7 @@ PYANNOTE_DEFAULT_CONFIG = os.path.join(PYANNOTE_DEFAULT_PATH, "config.yaml") \
|
||||
if os.path.exists(os.path.join(PYANNOTE_DEFAULT_PATH, "config.yaml")) \
|
||||
else ('Jaikinator/ScrAIbe', 'pyannote/speaker-diarization-3.1')
|
||||
|
||||
SCRAIBE_TORCH_DEVICE = os.getenv("SCRAIBE_TORCH_DEVICE", "cuda" if is_available() else "cpu")
|
||||
|
||||
def config_diarization_yaml(file_path: str, path_to_segmentation: str = None) -> None:
|
||||
"""Configure diarization pipeline from a YAML file.
|
||||
|
||||
+9
-10
@@ -31,13 +31,12 @@ from faster_whisper import WhisperModel as FasterWhisperModel
|
||||
from faster_whisper.tokenizer import _LANGUAGE_CODES as FASTER_WHISPER_LANGUAGE_CODES
|
||||
from typing import TypeVar, Union, Optional
|
||||
from torch import Tensor, device
|
||||
from torch.cuda import is_available as cuda_is_available
|
||||
from numpy import ndarray
|
||||
from inspect import signature
|
||||
from abc import abstractmethod
|
||||
import warnings
|
||||
|
||||
from .misc import WHISPER_DEFAULT_PATH
|
||||
from .misc import WHISPER_DEFAULT_PATH, SCRAIBE_TORCH_DEVICE
|
||||
whisper = TypeVar('whisper')
|
||||
|
||||
|
||||
@@ -124,7 +123,7 @@ class Transcriber:
|
||||
model: str = "medium",
|
||||
whisper_type: str = 'whisper',
|
||||
download_root: str = WHISPER_DEFAULT_PATH,
|
||||
device: Optional[Union[str, device]] = None,
|
||||
device: Optional[Union[str, device]] = SCRAIBE_TORCH_DEVICE,
|
||||
in_memory: bool = False,
|
||||
*args, **kwargs
|
||||
) -> None:
|
||||
@@ -206,7 +205,7 @@ class WhisperTranscriber(Transcriber):
|
||||
def load_model(cls,
|
||||
model: str = "medium",
|
||||
download_root: str = WHISPER_DEFAULT_PATH,
|
||||
device: Optional[Union[str, device]] = None,
|
||||
device: Optional[Union[str, device]] = SCRAIBE_TORCH_DEVICE,
|
||||
in_memory: bool = False,
|
||||
*args, **kwargs
|
||||
) -> 'WhisperTranscriber':
|
||||
@@ -305,7 +304,7 @@ class FasterWhisperTranscriber(Transcriber):
|
||||
def load_model(cls,
|
||||
model: str = "medium",
|
||||
download_root: str = WHISPER_DEFAULT_PATH,
|
||||
device: Optional[Union[str, device]] = None,
|
||||
device: Optional[Union[str, device]] = SCRAIBE_TORCH_DEVICE,
|
||||
*args, **kwargs
|
||||
) -> 'FasterWhisperModel':
|
||||
"""
|
||||
@@ -330,7 +329,7 @@ class FasterWhisperTranscriber(Transcriber):
|
||||
Defaults to WHISPER_DEFAULT_PATH.
|
||||
|
||||
device (Optional[Union[str, torch.device]], optional):
|
||||
Device to load model on. Defaults to None.
|
||||
Device to load model on. Defaults to SCRAIBE_TORCH_DEVICE.
|
||||
in_memory (bool, optional): Whether to load model in memory.
|
||||
Defaults to False.
|
||||
args: Additional arguments only to avoid errors.
|
||||
@@ -339,10 +338,10 @@ class FasterWhisperTranscriber(Transcriber):
|
||||
Returns:
|
||||
Transcriber: A Transcriber object initialized with the specified model.
|
||||
"""
|
||||
if device is None:
|
||||
device = "cuda" if cuda_is_available() else "cpu"
|
||||
|
||||
if not isinstance(device, str):
|
||||
device = str(device)
|
||||
|
||||
compute_type = kwargs.get('compute_type', 'float16')
|
||||
if device == 'cpu' and compute_type == 'float16':
|
||||
warnings.warn(f'Compute type {compute_type} not compatible with '
|
||||
@@ -412,7 +411,7 @@ class FasterWhisperTranscriber(Transcriber):
|
||||
def load_transcriber(model: str = "medium",
|
||||
whisper_type: str = 'whisper',
|
||||
download_root: str = WHISPER_DEFAULT_PATH,
|
||||
device: Optional[Union[str, device]] = None,
|
||||
device: Optional[Union[str, device]] = SCRAIBE_TORCH_DEVICE,
|
||||
in_memory: bool = False,
|
||||
*args, **kwargs
|
||||
) -> Union[WhisperTranscriber, FasterWhisperTranscriber]:
|
||||
@@ -438,7 +437,7 @@ def load_transcriber(model: str = "medium",
|
||||
download_root (str, optional): Path to download the model.
|
||||
Defaults to WHISPER_DEFAULT_PATH.
|
||||
device (Optional[Union[str, torch.device]], optional):
|
||||
Device to load model on. Defaults to None.
|
||||
Device to load model on. Defaults to SCRAIBE_TORCH_DEVICE.
|
||||
in_memory (bool, optional): Whether to load model in memory.
|
||||
Defaults to False.
|
||||
args: Additional arguments only to avoid errors.
|
||||
|
||||
Reference in New Issue
Block a user