diff --git a/scraibe/autotranscript.py b/scraibe/autotranscript.py index 2664e3f..7d54ba8 100644 --- a/scraibe/autotranscript.py +++ b/scraibe/autotranscript.py @@ -95,7 +95,7 @@ class Scraibe: elif isinstance(dia_model, str): self.diariser = Diariser.load_model(dia_model, **kwargs) else: - self.diariser = dia_model + self.diariser : Diariser = dia_model if kwargs.get("verbose"): print("Scraibe initialized all models successfully loaded.") @@ -133,7 +133,7 @@ class Scraibe: if kwargs.get("verbose"): self.verbose = kwargs.get("verbose") # Get audio file as an AudioProcessor object - audio_file = self.get_audio_file(audio_file) + audio_file : AudioProcessor = self.get_audio_file(audio_file) # Prepare waveform and sample rate for diarization dia_audio = { @@ -203,7 +203,7 @@ class Scraibe: """ # Get audio file as an AudioProcessor object - audio_file = self.get_audio_file(audio_file) + audio_file : AudioProcessor = self.get_audio_file(audio_file) # Prepare waveform and sample rate for diarization dia_audio = { @@ -232,9 +232,56 @@ class Scraibe: str: The transcribed text from the audio source. """ - audio_file = self.get_audio_file(audio_file) + audio_file : AudioProcessor = self.get_audio_file(audio_file) return self.transcriber.transcribe(audio_file.waveform, **kwargs) + + def update_transcriber(self, whisper_model : Union[str, whisper], **kwargs) -> None: + """ + Update the transcriber model. + + Args: + whisper_model (Union[str, whisper]): + The new whisper model to use for transcription. + **kwargs: + Additional keyword arguments for the transcriber model. + + Returns: + None + """ + _old_model = self.transcriber.model_name + + if isinstance(whisper_model, str): + self.transcriber = Transcriber.load_model(whisper_model, **kwargs) + elif isinstance(whisper_model, Transcriber): + self.transcriber = whisper_model + else: + warn(f"Invalid model type. Please provide a valid model. Fallback to old {_old_model} Model.", RuntimeWarning) + + return None + + def update_diariser(self, dia_model : Union[str, DiarisationType], **kwargs) -> None: + """ + Update the diariser model. + + Args: + dia_model (Union[str, DiarisationType]): + The new diariser model to use for diarization. + **kwargs: + Additional keyword arguments for the diariser model. + + Returns: + None + """ + if isinstance(dia_model, str): + self.diariser = Diariser.load_model(dia_model, **kwargs) + elif isinstance(dia_model, Diariser): + self.diariser = dia_model + else: + warn(f"Invalid model type. Please provide a valid model. Fallback to old Model.", RuntimeWarning) + + return None + @staticmethod def remove_audio_file(audio_file : str, shred : bool = False) -> None: @@ -269,7 +316,6 @@ class Scraibe: print(f"Audiofile {audio_file} removed.") - @staticmethod def get_audio_file(audio_file : Union[str, torch.Tensor, ndarray], *args, **kwargs) -> AudioProcessor: @@ -298,6 +344,7 @@ class Scraibe: if not isinstance(audio_file, AudioProcessor): raise ValueError(f'Audiofile must be of type AudioProcessor,' \ f'not {type(audio_file)}') + return audio_file def __repr__(self):