added SCRAIBE_TORCH_DEVICE to Scraibe Class to handle torch device setting

This commit is contained in:
Schmieder, Jacob
2024-10-08 12:02:08 +00:00
parent 6fadf3d851
commit 8813662d4d
+8 -7
View File
@@ -40,6 +40,7 @@ from .audio import AudioProcessor
from .diarisation import Diariser
from .transcriber import Transcriber, load_transcriber, whisper
from .transcript_exporter import Transcript
from .misc import SCRAIBE_TORCH_DEVICE
DiarisationType = TypeVar('DiarisationType')
@@ -115,6 +116,9 @@ class Scraibe:
**kwargs)
else:
self.params = {}
self.device = kwargs.get(
"device", SCRAIBE_TORCH_DEVICE)
def autotranscribe(self, audio_file: Union[str, torch.Tensor, ndarray],
remove_original: bool = False,
@@ -141,10 +145,10 @@ class Scraibe:
# Prepare waveform and sample rate for diarization
dia_audio = {
"waveform": audio_file.waveform.reshape(1, len(audio_file.waveform)),
"waveform": audio_file.waveform.reshape(1, len(audio_file.waveform)).to(self.device),
"sample_rate": audio_file.sr
}
if self.verbose:
print("Starting diarisation.")
@@ -165,8 +169,6 @@ class Scraibe:
if self.verbose:
print("Diarisation finished. Starting transcription.")
audio_file.sr = torch.Tensor([audio_file.sr]).to(
audio_file.waveform.device)
# Transcribe each segment and store the results
final_transcript = dict()
@@ -213,7 +215,7 @@ class Scraibe:
# Prepare waveform and sample rate for diarization
dia_audio = {
"waveform": audio_file.waveform.reshape(1, len(audio_file.waveform)),
"waveform": audio_file.waveform.reshape(1, len(audio_file.waveform)).to(self.device),
"sample_rate": audio_file.sr
}
@@ -323,8 +325,7 @@ class Scraibe:
print(f"Audiofile {audio_file} removed.")
@staticmethod
def get_audio_file(audio_file: Union[str, torch.Tensor, ndarray],
*args, **kwargs) -> AudioProcessor:
def get_audio_file(audio_file: Union[str, torch.Tensor, ndarray]) -> AudioProcessor:
"""Gets an audio file as TorchAudioProcessor.
Args: