diff --git a/scraibe/transcriber.py b/scraibe/transcriber.py index 8802cf6..f7765b6 100644 --- a/scraibe/transcriber.py +++ b/scraibe/transcriber.py @@ -30,6 +30,7 @@ from whisperx.asr import WhisperModel from whisperx import load_model as whisperx_load_model from typing import TypeVar, Union, Optional from torch import Tensor, device +from torch.cuda import is_available as cuda_is_available from numpy import ndarray from inspect import getfullargspec from abc import abstractmethod @@ -115,15 +116,14 @@ class Transcriber: print(f'Transcript saved to {save_path}') - @classmethod - def load_model(cls, - model: str = "medium", + @staticmethod + def load_model(model: str = "medium", whisper_type: str = 'whisper', download_root: str = WHISPER_DEFAULT_PATH, device: Optional[Union[str, device]] = None, in_memory: bool = False, *args, **kwargs - ) -> 'Transcriber': + ) -> 'Union[WhisperTranscriber, WhisperXTranscriber]': """ Load whisper model. @@ -278,6 +278,9 @@ class WhisperTranscriber(Transcriber): return whisper_kwargs + def __repr__(self) -> str: + return f"WhisperTranscriber(model_name={self.model_name}, model={self.model})" + class WhisperXTranscriber(Transcriber): def __init__(self, model: whisper, model_name: str) -> None: @@ -345,6 +348,8 @@ class WhisperXTranscriber(Transcriber): Returns: Transcriber: A Transcriber object initialized with the specified model. """ + if device is None: + device = "cuda" if cuda_is_available() else "cpu" if not isinstance(device, str): device = str(device) _model = whisperx_load_model(model, download_root=download_root, @@ -375,3 +380,6 @@ class WhisperXTranscriber(Transcriber): whisper_kwargs["language"] = language return whisper_kwargs + + def __repr__(self) -> str: + return f"WhisperXTranscriber(model_name={self.model_name}, model={self.model})" diff --git a/test/test_transcriber.py b/test/test_transcriber.py index fee3aff..7ecb1be 100644 --- a/test/test_transcriber.py +++ b/test/test_transcriber.py @@ -1,5 +1,5 @@ import pytest -from scraibe import Transcriber +from scraibe import Transcriber, WhisperTranscriber, WhisperXTranscriber import torch @@ -19,7 +19,7 @@ def test_transcriber(mock_load_model, audio_file, expected_transcription): expected_transcription (_type_): _description_ mock_model = mock_load_model.return_value - mock_model.transcribe.return_value ={"text": expected_transcription} + mock_model.transcribe.return_value ={"text": expected_transcription} transcriber = Transcriber.load_model(model="medium") @@ -29,12 +29,34 @@ def test_transcriber(mock_load_model, audio_file, expected_transcription): @pytest.fixture -def transcriber_instance(): - return Transcriber.load_model('medium') +def whisper_instance(): + return Transcriber.load_model('medium', whisper_type='whisper') -def test_transcriber_initialization(transcriber_instance): - assert isinstance(transcriber_instance, Transcriber) +@pytest.fixture +def whisperx_instance(): + return Transcriber.load_model('medium', whisper_type='whisperx') + + +def test_whisper_base_initialization(whisper_instance): + assert isinstance(whisper_instance, Transcriber) + + +def test_whisperx_base_initialization(whisperx_instance): + assert isinstance(whisperx_instance, Transcriber) + + +def test_whisper_transcriber_initialization(whisper_instance): + assert isinstance(whisper_instance, WhisperTranscriber) + + +def test_whisperx_transcriber_initialization(whisperx_instance): + assert isinstance(whisperx_instance, WhisperXTranscriber) + + +def test_wrong_transcriber_initialization(): + with pytest.raises(ValueError): + Transcriber.load_model('medium', whisper_type='wrong_whisper') def test_get_whisper_kwargs(): @@ -43,8 +65,15 @@ def test_get_whisper_kwargs(): assert not valid_kwargs == {"arg1": 1, "arg3": 3} -def test_transcribe(transcriber_instance): - model = transcriber_instance +def test_whisper_transcribe(whisper_instance): + model = whisper_instance + # mocker.patch.object(transcriber_instance.model, 'transcribe', return_value={'Hello, World !'} ) + transcript = model.transcribe('test/audio_test_2.mp4') + assert isinstance(transcript, str) + + +def test_whisperx_transcribe(whisperx_instance): + model = whisperx_instance # mocker.patch.object(transcriber_instance.model, 'transcribe', return_value={'Hello, World !'} ) transcript = model.transcribe('test/audio_test_2.mp4') assert isinstance(transcript, str)