added SCRAIBE_TORCH_DEVICE to Scraibe Class to handle torch device setting
This commit is contained in:
@@ -40,6 +40,7 @@ from .audio import AudioProcessor
|
|||||||
from .diarisation import Diariser
|
from .diarisation import Diariser
|
||||||
from .transcriber import Transcriber, load_transcriber, whisper
|
from .transcriber import Transcriber, load_transcriber, whisper
|
||||||
from .transcript_exporter import Transcript
|
from .transcript_exporter import Transcript
|
||||||
|
from .misc import SCRAIBE_TORCH_DEVICE
|
||||||
|
|
||||||
|
|
||||||
DiarisationType = TypeVar('DiarisationType')
|
DiarisationType = TypeVar('DiarisationType')
|
||||||
@@ -115,6 +116,9 @@ class Scraibe:
|
|||||||
**kwargs)
|
**kwargs)
|
||||||
else:
|
else:
|
||||||
self.params = {}
|
self.params = {}
|
||||||
|
|
||||||
|
self.device = kwargs.get(
|
||||||
|
"device", SCRAIBE_TORCH_DEVICE)
|
||||||
|
|
||||||
def autotranscribe(self, audio_file: Union[str, torch.Tensor, ndarray],
|
def autotranscribe(self, audio_file: Union[str, torch.Tensor, ndarray],
|
||||||
remove_original: bool = False,
|
remove_original: bool = False,
|
||||||
@@ -141,10 +145,10 @@ class Scraibe:
|
|||||||
|
|
||||||
# Prepare waveform and sample rate for diarization
|
# Prepare waveform and sample rate for diarization
|
||||||
dia_audio = {
|
dia_audio = {
|
||||||
"waveform": audio_file.waveform.reshape(1, len(audio_file.waveform)),
|
"waveform": audio_file.waveform.reshape(1, len(audio_file.waveform)).to(self.device),
|
||||||
"sample_rate": audio_file.sr
|
"sample_rate": audio_file.sr
|
||||||
}
|
}
|
||||||
|
|
||||||
if self.verbose:
|
if self.verbose:
|
||||||
print("Starting diarisation.")
|
print("Starting diarisation.")
|
||||||
|
|
||||||
@@ -165,8 +169,6 @@ class Scraibe:
|
|||||||
if self.verbose:
|
if self.verbose:
|
||||||
print("Diarisation finished. Starting transcription.")
|
print("Diarisation finished. Starting transcription.")
|
||||||
|
|
||||||
audio_file.sr = torch.Tensor([audio_file.sr]).to(
|
|
||||||
audio_file.waveform.device)
|
|
||||||
|
|
||||||
# Transcribe each segment and store the results
|
# Transcribe each segment and store the results
|
||||||
final_transcript = dict()
|
final_transcript = dict()
|
||||||
@@ -213,7 +215,7 @@ class Scraibe:
|
|||||||
|
|
||||||
# Prepare waveform and sample rate for diarization
|
# Prepare waveform and sample rate for diarization
|
||||||
dia_audio = {
|
dia_audio = {
|
||||||
"waveform": audio_file.waveform.reshape(1, len(audio_file.waveform)),
|
"waveform": audio_file.waveform.reshape(1, len(audio_file.waveform)).to(self.device),
|
||||||
"sample_rate": audio_file.sr
|
"sample_rate": audio_file.sr
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -323,8 +325,7 @@ class Scraibe:
|
|||||||
print(f"Audiofile {audio_file} removed.")
|
print(f"Audiofile {audio_file} removed.")
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
def get_audio_file(audio_file: Union[str, torch.Tensor, ndarray],
|
def get_audio_file(audio_file: Union[str, torch.Tensor, ndarray]) -> AudioProcessor:
|
||||||
*args, **kwargs) -> AudioProcessor:
|
|
||||||
"""Gets an audio file as TorchAudioProcessor.
|
"""Gets an audio file as TorchAudioProcessor.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
|
|||||||
Reference in New Issue
Block a user