added update functions for transcriber and diariser + adding some type hints
This commit is contained in:
@@ -95,7 +95,7 @@ class Scraibe:
|
|||||||
elif isinstance(dia_model, str):
|
elif isinstance(dia_model, str):
|
||||||
self.diariser = Diariser.load_model(dia_model, **kwargs)
|
self.diariser = Diariser.load_model(dia_model, **kwargs)
|
||||||
else:
|
else:
|
||||||
self.diariser = dia_model
|
self.diariser : Diariser = dia_model
|
||||||
|
|
||||||
if kwargs.get("verbose"):
|
if kwargs.get("verbose"):
|
||||||
print("Scraibe initialized all models successfully loaded.")
|
print("Scraibe initialized all models successfully loaded.")
|
||||||
@@ -133,7 +133,7 @@ class Scraibe:
|
|||||||
if kwargs.get("verbose"):
|
if kwargs.get("verbose"):
|
||||||
self.verbose = kwargs.get("verbose")
|
self.verbose = kwargs.get("verbose")
|
||||||
# Get audio file as an AudioProcessor object
|
# Get audio file as an AudioProcessor object
|
||||||
audio_file = self.get_audio_file(audio_file)
|
audio_file : AudioProcessor = self.get_audio_file(audio_file)
|
||||||
|
|
||||||
# Prepare waveform and sample rate for diarization
|
# Prepare waveform and sample rate for diarization
|
||||||
dia_audio = {
|
dia_audio = {
|
||||||
@@ -203,7 +203,7 @@ class Scraibe:
|
|||||||
"""
|
"""
|
||||||
|
|
||||||
# Get audio file as an AudioProcessor object
|
# Get audio file as an AudioProcessor object
|
||||||
audio_file = self.get_audio_file(audio_file)
|
audio_file : AudioProcessor = self.get_audio_file(audio_file)
|
||||||
|
|
||||||
# Prepare waveform and sample rate for diarization
|
# Prepare waveform and sample rate for diarization
|
||||||
dia_audio = {
|
dia_audio = {
|
||||||
@@ -232,9 +232,56 @@ class Scraibe:
|
|||||||
str:
|
str:
|
||||||
The transcribed text from the audio source.
|
The transcribed text from the audio source.
|
||||||
"""
|
"""
|
||||||
audio_file = self.get_audio_file(audio_file)
|
audio_file : AudioProcessor = self.get_audio_file(audio_file)
|
||||||
|
|
||||||
return self.transcriber.transcribe(audio_file.waveform, **kwargs)
|
return self.transcriber.transcribe(audio_file.waveform, **kwargs)
|
||||||
|
|
||||||
|
def update_transcriber(self, whisper_model : Union[str, whisper], **kwargs) -> None:
|
||||||
|
"""
|
||||||
|
Update the transcriber model.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
whisper_model (Union[str, whisper]):
|
||||||
|
The new whisper model to use for transcription.
|
||||||
|
**kwargs:
|
||||||
|
Additional keyword arguments for the transcriber model.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
None
|
||||||
|
"""
|
||||||
|
_old_model = self.transcriber.model_name
|
||||||
|
|
||||||
|
if isinstance(whisper_model, str):
|
||||||
|
self.transcriber = Transcriber.load_model(whisper_model, **kwargs)
|
||||||
|
elif isinstance(whisper_model, Transcriber):
|
||||||
|
self.transcriber = whisper_model
|
||||||
|
else:
|
||||||
|
warn(f"Invalid model type. Please provide a valid model. Fallback to old {_old_model} Model.", RuntimeWarning)
|
||||||
|
|
||||||
|
return None
|
||||||
|
|
||||||
|
def update_diariser(self, dia_model : Union[str, DiarisationType], **kwargs) -> None:
|
||||||
|
"""
|
||||||
|
Update the diariser model.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
dia_model (Union[str, DiarisationType]):
|
||||||
|
The new diariser model to use for diarization.
|
||||||
|
**kwargs:
|
||||||
|
Additional keyword arguments for the diariser model.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
None
|
||||||
|
"""
|
||||||
|
if isinstance(dia_model, str):
|
||||||
|
self.diariser = Diariser.load_model(dia_model, **kwargs)
|
||||||
|
elif isinstance(dia_model, Diariser):
|
||||||
|
self.diariser = dia_model
|
||||||
|
else:
|
||||||
|
warn(f"Invalid model type. Please provide a valid model. Fallback to old Model.", RuntimeWarning)
|
||||||
|
|
||||||
|
return None
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
def remove_audio_file(audio_file : str,
|
def remove_audio_file(audio_file : str,
|
||||||
shred : bool = False) -> None:
|
shred : bool = False) -> None:
|
||||||
@@ -269,7 +316,6 @@ class Scraibe:
|
|||||||
print(f"Audiofile {audio_file} removed.")
|
print(f"Audiofile {audio_file} removed.")
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
def get_audio_file(audio_file : Union[str, torch.Tensor, ndarray],
|
def get_audio_file(audio_file : Union[str, torch.Tensor, ndarray],
|
||||||
*args, **kwargs) -> AudioProcessor:
|
*args, **kwargs) -> AudioProcessor:
|
||||||
@@ -298,6 +344,7 @@ class Scraibe:
|
|||||||
if not isinstance(audio_file, AudioProcessor):
|
if not isinstance(audio_file, AudioProcessor):
|
||||||
raise ValueError(f'Audiofile must be of type AudioProcessor,' \
|
raise ValueError(f'Audiofile must be of type AudioProcessor,' \
|
||||||
f'not {type(audio_file)}')
|
f'not {type(audio_file)}')
|
||||||
|
|
||||||
return audio_file
|
return audio_file
|
||||||
|
|
||||||
def __repr__(self):
|
def __repr__(self):
|
||||||
|
|||||||
Reference in New Issue
Block a user