diff --git a/scraibe/autotranscript.py b/scraibe/autotranscript.py index 14d2451..7391f1a 100644 --- a/scraibe/autotranscript.py +++ b/scraibe/autotranscript.py @@ -38,7 +38,7 @@ from tqdm import trange # Application-Specific Imports from .audio import AudioProcessor from .diarisation import Diariser -from .transcriber import Transcriber, whisper +from .transcriber import Transcriber, load_transcriber, whisper from .transcript_exporter import Transcript @@ -87,10 +87,10 @@ class Scraibe: """ if whisper_model is None: - self.transcriber = Transcriber.load_model( + self.transcriber = load_transcriber( "medium", whisper_type, **kwargs) elif isinstance(whisper_model, str): - self.transcriber = Transcriber.load_model( + self.transcriber = load_transcriber( whisper_model, whisper_type, **kwargs) else: self.transcriber = whisper_model @@ -258,7 +258,7 @@ class Scraibe: _old_model = self.transcriber.model_name if isinstance(whisper_model, str): - self.transcriber = Transcriber.load_model(whisper_model, **kwargs) + self.transcriber = load_transcriber(whisper_model, **kwargs) elif isinstance(whisper_model, Transcriber): self.transcriber = whisper_model else: diff --git a/scraibe/transcriber.py b/scraibe/transcriber.py index f7765b6..c9b9b52 100644 --- a/scraibe/transcriber.py +++ b/scraibe/transcriber.py @@ -116,14 +116,16 @@ class Transcriber: print(f'Transcript saved to {save_path}') - @staticmethod - def load_model(model: str = "medium", + @classmethod + @abstractmethod + def load_model(cls, + model: str = "medium", whisper_type: str = 'whisper', download_root: str = WHISPER_DEFAULT_PATH, device: Optional[Union[str, device]] = None, in_memory: bool = False, *args, **kwargs - ) -> 'Union[WhisperTranscriber, WhisperXTranscriber]': + ) -> None: """ Load whisper model. @@ -153,19 +155,8 @@ class Transcriber: kwargs: Additional keyword arguments only to avoid errors. Returns: - Transcriber: A Transcriber object initialized with the specified model. + None: abscract method. """ - if whisper_type.lower() == 'whisper': - _model = WhisperTranscriber.load_model( - model, download_root, device, in_memory, *args, **kwargs) - return _model - elif whisper_type.lower() == 'whisperx': - _model = WhisperXTranscriber.load_model( - model, download_root, device, *args, **kwargs) - return _model - else: - raise ValueError(f'Model type not recognized, exptected "whisper" ' - f'or "whisperx", got {whisper_type}.') pass @staticmethod @@ -216,7 +207,7 @@ class WhisperTranscriber(Transcriber): device: Optional[Union[str, device]] = None, in_memory: bool = False, *args, **kwargs - ) -> 'Transcriber': + ) -> 'WhisperTranscriber': """ Load whisper model. @@ -316,7 +307,7 @@ class WhisperXTranscriber(Transcriber): download_root: str = WHISPER_DEFAULT_PATH, device: Optional[Union[str, device]] = None, *args, **kwargs - ) -> 'Transcriber': + ) -> 'WhisperXTranscriber': """ Load whisper model. @@ -383,3 +374,55 @@ class WhisperXTranscriber(Transcriber): def __repr__(self) -> str: return f"WhisperXTranscriber(model_name={self.model_name}, model={self.model})" + + +def load_transcriber(model: str = "medium", + whisper_type: str = 'whisper', + download_root: str = WHISPER_DEFAULT_PATH, + device: Optional[Union[str, device]] = None, + in_memory: bool = False, + *args, **kwargs + ) -> Union[WhisperTranscriber, WhisperXTranscriber]: + """ + Load whisper model. + + Args: + model (str): Whisper model. Available models include: + - 'tiny.en' + - 'tiny' + - 'base.en' + - 'base' + - 'small.en' + - 'small' + - 'medium.en' + - 'medium' + - 'large-v1' + - 'large-v2' + - 'large-v3' + - 'large' + whisper_type (str): + Type of whisper model to load. "whisper" or "whisperx". + download_root (str, optional): Path to download the model. + Defaults to WHISPER_DEFAULT_PATH. + device (Optional[Union[str, torch.device]], optional): + Device to load model on. Defaults to None. + in_memory (bool, optional): Whether to load model in memory. + Defaults to False. + args: Additional arguments only to avoid errors. + kwargs: Additional keyword arguments only to avoid errors. + + Returns: + Union[WhisperTranscriber, WhisperXTranscriber]: + One of the Whisper variants as Transcrbier object initialized with the specified model. + """ + if whisper_type.lower() == 'whisper': + _model = WhisperTranscriber.load_model( + model, download_root, device, in_memory, *args, **kwargs) + return _model + elif whisper_type.lower() == 'whisperx': + _model = WhisperXTranscriber.load_model( + model, download_root, device, *args, **kwargs) + return _model + else: + raise ValueError(f'Model type not recognized, exptected "whisper" ' + f'or "whisperx", got {whisper_type}.') diff --git a/test/test_transcriber.py b/test/test_transcriber.py index 7ecb1be..31765f6 100644 --- a/test/test_transcriber.py +++ b/test/test_transcriber.py @@ -1,5 +1,6 @@ import pytest -from scraibe import Transcriber, WhisperTranscriber, WhisperXTranscriber +from scraibe import (Transcriber, WhisperTranscriber, + WhisperXTranscriber, load_transcriber) import torch @@ -30,12 +31,12 @@ def test_transcriber(mock_load_model, audio_file, expected_transcription): @pytest.fixture def whisper_instance(): - return Transcriber.load_model('medium', whisper_type='whisper') + return load_transcriber('medium', whisper_type='whisper') @pytest.fixture def whisperx_instance(): - return Transcriber.load_model('medium', whisper_type='whisperx') + return load_transcriber('medium', whisper_type='whisperx') def test_whisper_base_initialization(whisper_instance): @@ -56,7 +57,7 @@ def test_whisperx_transcriber_initialization(whisperx_instance): def test_wrong_transcriber_initialization(): with pytest.raises(ValueError): - Transcriber.load_model('medium', whisper_type='wrong_whisper') + load_transcriber('medium', whisper_type='wrong_whisper') def test_get_whisper_kwargs():