From 6326d0f15677e00909cfb76a737cec8efeec9e22 Mon Sep 17 00:00:00 2001 From: "Schmieder, Jacob" Date: Mon, 30 Sep 2024 16:03:27 +0000 Subject: [PATCH] added tests folder --- {test => tests}/audio_test_1.mp4 | Bin {test => tests}/audio_test_2.mp4 | Bin tests/test_audio.py | 96 +++++++++++++++++++++++++ {test => tests}/test_autotranscript.py | 6 +- tests/test_diarisation.py | 32 +++++++++ tests/test_transcriber.py | 80 +++++++++++++++++++++ 6 files changed, 211 insertions(+), 3 deletions(-) rename {test => tests}/audio_test_1.mp4 (100%) rename {test => tests}/audio_test_2.mp4 (100%) create mode 100644 tests/test_audio.py rename {test => tests}/test_autotranscript.py (88%) create mode 100644 tests/test_diarisation.py create mode 100644 tests/test_transcriber.py diff --git a/test/audio_test_1.mp4 b/tests/audio_test_1.mp4 similarity index 100% rename from test/audio_test_1.mp4 rename to tests/audio_test_1.mp4 diff --git a/test/audio_test_2.mp4 b/tests/audio_test_2.mp4 similarity index 100% rename from test/audio_test_2.mp4 rename to tests/audio_test_2.mp4 diff --git a/tests/test_audio.py b/tests/test_audio.py new file mode 100644 index 0000000..aee6cb3 --- /dev/null +++ b/tests/test_audio.py @@ -0,0 +1,96 @@ +import pytest +from scraibe.audio import AudioProcessor +import torch + + +DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu") +TEST_WAVEFORM = torch.sin(torch.randn(160000)).to(DEVICE) +TEST_SR = 16000 +SAMPLE_RATE = 16000 +NORMALIZATION_FACTOR = 32768 + + +@pytest.fixture +def probe_audio_processor(): + """Fixture for creating an instance of the AudioProcessor class with test waveform and sample rate. + + This fixture is used to create an instance of the AudioProcessor class with a predfined test waveform and sample rate (TEST_SR). It returns the instantiated AudioProcessor , which can bes used as a + dependency in other test functions. + + + Returns: + AudioProcessor (obj): An instance of the AudioProcessor class with the test waveform and sample rate. + """ + return AudioProcessor(TEST_WAVEFORM, TEST_SR) + + +def test_AudioProcessor_init(probe_audio_processor): + """ + Test the initialization of the AudioProcessor class. + + This test verifies that the AUdioProcessor class is correctly initialized with the provided waveform and sample rate. It checks whether the instantiated AhdioProcessor object has the correct attributes + and whether the waveform and sample rate match the expected values. + + Args: + probe_audio_processor (obj): An instance of the AudioProcessor class to be tested. + + + Returns: + None + + + + + """ + assert isinstance(probe_audio_processor, AudioProcessor) + assert probe_audio_processor.waveform.device == TEST_WAVEFORM.device + assert torch.equal(probe_audio_processor.waveform, TEST_WAVEFORM) + assert probe_audio_processor.sr == TEST_SR + + +def test_cut(probe_audio_processor): + """Test the cut function of the AudioProcessor class. + + This test verifies that the cut function correctly extracts a segment of audio data from + the waveform, given start and end indices. It checks whether the size of the extracted segment matches + the expected size based on the provided start and end indices and the sample rate. + + Returns: + None + + + """ + + start = 4 + end = 7 + trimmed_waveform = probe_audio_processor.cut(start, end) + expected_size = int((end - start) * TEST_SR) + real_size = trimmed_waveform.size(0) + assert real_size == expected_size + # assert AudioProcessor(TEST_WAVEFORM, TEST_SR).cut(start, end).size() == int((end - start) * TEST_SR) + + +def test_audio_processor_invalid_sr(): + """Test the behavior of AudioProcessor when an invalid smaple rate is provided. + + This test verifies that the AudioProcessor constructor raises a ValueError when an invalid sample rate is provided. It uses the pytest.raises context manager to check if the ValueError is raised when initializing an + AudioProcessor object with an invalid sample rate. + + Returns: + None + """ + with pytest.raises(ValueError): + AudioProcessor(TEST_WAVEFORM, [44100, 48000]) + + +def test_audio_processor_SAMPLE_RATE(): + """Test the default sample rate of the AudioProcessor class. + + This test verifies that the default sample rate of the AudioProcessor class matches the expected value defined by the constant SAMPLE_RATE. It instantiates an AudioProcessor object with a test waveform + and checks whether the sample rate attribute (sr) of the AudioProcessor object equals the predefined constant SAMPLE_RATE. + + Returns: + None + """ + probe_audio_processor = AudioProcessor(TEST_WAVEFORM) + assert probe_audio_processor.sr == SAMPLE_RATE diff --git a/test/test_autotranscript.py b/tests/test_autotranscript.py similarity index 88% rename from test/test_autotranscript.py rename to tests/test_autotranscript.py index 865f507..fbf18ab 100644 --- a/test/test_autotranscript.py +++ b/tests/test_autotranscript.py @@ -19,19 +19,19 @@ def test_scraibe_init(create_scraibe_instance): def test_scraibe_autotranscribe(create_scraibe_instance): model = create_scraibe_instance - transcript = model.autotranscribe('./audio_test_2.mp4') + transcript = model.autotranscribe('tests/audio_test_2.mp4') assert isinstance(transcript, Transcript) def test_scraibe_diarization(create_scraibe_instance): model = create_scraibe_instance - diarisation_result = model.diarization('./audio_test_2.mp4') + diarisation_result = model.diarization('tests/audio_test_2.mp4') assert isinstance(diarisation_result, dict) def test_scraibe_transcribe(create_scraibe_instance): model = create_scraibe_instance - transcription_result = model.transcribe('./audio_test_2.mp4') + transcription_result = model.transcribe('tests/audio_test_2.mp4') assert isinstance(transcription_result, str) diff --git a/tests/test_diarisation.py b/tests/test_diarisation.py new file mode 100644 index 0000000..01431be --- /dev/null +++ b/tests/test_diarisation.py @@ -0,0 +1,32 @@ +import pytest +from scraibe import Diariser + + +@pytest.fixture +def diariser_instance(): + """Fixture for creating an instance of the Diariser class with mocked token. + + This fixture is used to create an instance of the the Diariser class with a mocked token returned by the _get_token method. It patches the _get_token method of the Diariser class + using unit.test.mock.patch.object, ensuring that it returns a predetrmined value ('personal Hugging-Face token'). The mocked Diariser object is retunrned and can be used as a dependency in otehr tests. + + Returns: + Diariser(Obj): An instance of the Diariser class with a mocked token. + """ + # with mock.patch.object(Diariser, '_get_token', return_value = 'HF_TOKEN' ): + return Diariser('pyannote') + + +def test_Diariser_init(diariser_instance): + """Test the initialization of the Diariser class. + + This test verifies that the Diariser class is correctly initialized with the specified model. + It checks whether the 'model' attribute of the instantiated Diariser object equals 'pyannote'. + + + Args: + diariser_instance (obj): instance of the Diariser class + + Returns: + None + """ + assert diariser_instance.model == 'pyannote' diff --git a/tests/test_transcriber.py b/tests/test_transcriber.py new file mode 100644 index 0000000..ba9d99a --- /dev/null +++ b/tests/test_transcriber.py @@ -0,0 +1,80 @@ +import pytest +from scraibe import (Transcriber, WhisperTranscriber, + FasterWhisperTranscriber, load_transcriber) +import torch + + +DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu") +TEST_WAVEFORM = "Hello World" + +""" +@pytest.mark.parametrize("audio_file, expected_transcription",[("path_to_test_audiofile", "test_transcription")] ) +@patch("scraibe.Transcriber.load_model") + +def test_transcriber(mock_load_model, audio_file, expected_transcription): + + + Args: + mock_load_model (_type_): _description_ + audio_file (_type_): _description_ + expected_transcription (_type_): _description_ + + mock_model = mock_load_model.return_value + mock_model.transcribe.return_value ={"text": expected_transcription} + + transcriber = Transcriber.load_model(model="medium") + + transcription_result = transcriber.transcribe(audio=audio_file) + + assert transcription_result == expected_transcription """ + + +@pytest.fixture +def whisper_instance(): + return load_transcriber('tiny', whisper_type='whisper') + + +@pytest.fixture +def faster_whisper_instance(): + return load_transcriber('tiny', whisper_type='faster-whisper') + + +def test_whisper_base_initialization(whisper_instance): + assert isinstance(whisper_instance, Transcriber) + + +def test_faster_whisper_base_initialization(faster_whisper_instance): + assert isinstance(faster_whisper_instance, Transcriber) + + +def test_whisper_transcriber_initialization(whisper_instance): + assert isinstance(whisper_instance, WhisperTranscriber) + + +def test_faster_whisper_transcriber_initialization(faster_whisper_instance): + assert isinstance(faster_whisper_instance, FasterWhisperTranscriber) + + +def test_wrong_transcriber_initialization(): + with pytest.raises(ValueError): + load_transcriber('tiny', whisper_type='wrong_whisper') + + +def test_get_whisper_kwargs(): + kwargs = {"arg1": 1, "arg3": 3} + valid_kwargs = Transcriber._get_whisper_kwargs(**kwargs) + assert not valid_kwargs == {"arg1": 1, "arg3": 3} + + +def test_whisper_transcribe(whisper_instance): + model = whisper_instance + # mocker.patch.object(transcriber_instance.model, 'transcribe', return_value={'Hello, World !'} ) + transcript = model.transcribe('tests/audio_test_2.mp4') + assert isinstance(transcript, str) + + +def test_faster_whisper_transcribe(faster_whisper_instance): + model = faster_whisper_instance + # mocker.patch.object(transcriber_instance.model, 'transcribe', return_value={'Hello, World !'} ) + transcript = model.transcribe('tests/audio_test_2.mp4') + assert isinstance(transcript, str)