diff --git a/scraibe/autotranscript.py b/scraibe/autotranscript.py index 43dedc2..9023107 100644 --- a/scraibe/autotranscript.py +++ b/scraibe/autotranscript.py @@ -40,6 +40,7 @@ from .audio import AudioProcessor from .diarisation import Diariser from .transcriber import Transcriber, load_transcriber, whisper from .transcript_exporter import Transcript +from .misc import SCRAIBE_TORCH_DEVICE DiarisationType = TypeVar('DiarisationType') @@ -115,6 +116,9 @@ class Scraibe: **kwargs) else: self.params = {} + + self.device = kwargs.get( + "device", SCRAIBE_TORCH_DEVICE) def autotranscribe(self, audio_file: Union[str, torch.Tensor, ndarray], remove_original: bool = False, @@ -141,10 +145,10 @@ class Scraibe: # Prepare waveform and sample rate for diarization dia_audio = { - "waveform": audio_file.waveform.reshape(1, len(audio_file.waveform)), + "waveform": audio_file.waveform.reshape(1, len(audio_file.waveform)).to(self.device), "sample_rate": audio_file.sr } - + if self.verbose: print("Starting diarisation.") @@ -165,8 +169,6 @@ class Scraibe: if self.verbose: print("Diarisation finished. Starting transcription.") - audio_file.sr = torch.Tensor([audio_file.sr]).to( - audio_file.waveform.device) # Transcribe each segment and store the results final_transcript = dict() @@ -213,7 +215,7 @@ class Scraibe: # Prepare waveform and sample rate for diarization dia_audio = { - "waveform": audio_file.waveform.reshape(1, len(audio_file.waveform)), + "waveform": audio_file.waveform.reshape(1, len(audio_file.waveform)).to(self.device), "sample_rate": audio_file.sr } @@ -323,8 +325,7 @@ class Scraibe: print(f"Audiofile {audio_file} removed.") @staticmethod - def get_audio_file(audio_file: Union[str, torch.Tensor, ndarray], - *args, **kwargs) -> AudioProcessor: + def get_audio_file(audio_file: Union[str, torch.Tensor, ndarray]) -> AudioProcessor: """Gets an audio file as TorchAudioProcessor. Args: