diff --git a/.github/workflows/pytest.yaml b/.github/workflows/pytest.yaml new file mode 100644 index 0000000..8959789 --- /dev/null +++ b/.github/workflows/pytest.yaml @@ -0,0 +1,43 @@ +name: Run tests + +on: + #push: + + pull_request: + branches: ['main', 'develop'] + + workflow_dispatch: + +jobs: + pytest: + runs-on: ubuntu-latest + + steps: + - name: Checkout + uses: actions/checkout@v3 + + - name: Setup Python + uses: actions/setup-python@v3 + with: + python-version: 3.9 + + - name: Install Dependencies + run: | + + sudo apt update && sudo apt upgrade + python -m pip install --upgrade pip + pip install -r requirements.txt + pip install . + sudo apt-get install libsndfile1-dev + sudo apt-get install ffmpeg + pip install pytest + + + + + - name: Run pytest + env: + HF_TOKEN : ${{ secrets.HF_TOKEN }} + run: | + pytest + \ No newline at end of file diff --git a/test/audio_test_1.mp4 b/test/audio_test_1.mp4 new file mode 100644 index 0000000..d7b0440 Binary files /dev/null and b/test/audio_test_1.mp4 differ diff --git a/test/audio_test_2.mp4 b/test/audio_test_2.mp4 new file mode 100644 index 0000000..c1307e5 Binary files /dev/null and b/test/audio_test_2.mp4 differ diff --git a/test/test_audio.py b/test/test_audio.py new file mode 100644 index 0000000..311a472 --- /dev/null +++ b/test/test_audio.py @@ -0,0 +1,127 @@ +import pytest +from scraibe.audio import AudioProcessor +import torch + + + +DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu") +TEST_WAVEFORM = torch.sin(torch.randn(160000)).to(DEVICE) +TEST_SR = 16000 +SAMPLE_RATE = 16000 +NORMALIZATION_FACTOR = 32768 + + +@pytest.fixture +def probe_audio_processor(): + """Fixture for creating an instance of the AudioProcessor class with test waveform and sample rate. + + This fixture is used to create an instance of the AudioProcessor class with a predfined test waveform and sample rate (TEST_SR). It returns the instantiated AudioProcessor , which can bes used as a + dependency in other test functions. + + + Returns: + AudioProcessor (obj): An instance of the AudioProcessor class with the test waveform and sample rate. + """ + return AudioProcessor(TEST_WAVEFORM, TEST_SR) + + + + + + +def test_AudioProcessor_init(probe_audio_processor): + """ + Test the initialization of the AudioProcessor class. + + This test verifies that the AUdioProcessor class is correctly initialized with the provided waveform and sample rate. It checks whether the instantiated AhdioProcessor object has the correct attributes + and whether the waveform and sample rate match the expected values. + + Args: + probe_audio_processor (obj): An instance of the AudioProcessor class to be tested. + + + Returns: + None + + + + + """ + assert isinstance(probe_audio_processor, AudioProcessor) + assert probe_audio_processor.waveform.device == TEST_WAVEFORM.device + assert torch.equal(probe_audio_processor.waveform, TEST_WAVEFORM) + assert probe_audio_processor.sr == TEST_SR + + + +def test_cut(probe_audio_processor): + """Test the cut function of the AudioProcessor class. + + This test verifies that the cut function correctly extracts a segment of audio data from + the waveform, given start and end indices. It checks whether the size of the extracted segment matches + the expected size based on the provided start and end indices and the sample rate. + + Returns: + None + + + """ + + start = 4 + end = 7 + trimmed_waveform = probe_audio_processor.cut(start, end) + expected_size = int((end - start) * TEST_SR) + real_size = trimmed_waveform.size(0) + assert real_size == expected_size + #assert AudioProcessor(TEST_WAVEFORM, TEST_SR).cut(start, end).size() == int((end - start) * TEST_SR) + + + + + + + + + + +def test_audio_processor_invalid_sr(): + """Test the behavior of AudioProcessor when an invalid smaple rate is provided. + + This test verifies that the AudioProcessor constructor raises a ValueError when an invalid sample rate is provided. It uses the pytest.raises context manager to check if the ValueError is raised when initializing an + AudioProcessor object with an invalid sample rate. + + Returns: + None + """ + with pytest.raises(ValueError): + AudioProcessor(TEST_WAVEFORM, [44100,48000]) + + +def test_audio_processor_SAMPLE_RATE(): + """Test the default sample rate of the AudioProcessor class. + + This test verifies that the default sample rate of the AudioProcessor class matches the expected value defined by the constant SAMPLE_RATE. It instantiates an AudioProcessor object with a test waveform + and checks whether the sample rate attribute (sr) of the AudioProcessor object equals the predefined constant SAMPLE_RATE. + + Returns: + None + """ + probe_audio_processor = AudioProcessor(TEST_WAVEFORM) + assert probe_audio_processor.sr == SAMPLE_RATE + + + + + + + + + + + + + + + + + diff --git a/test/test_autotranscript.py b/test/test_autotranscript.py new file mode 100644 index 0000000..edbe0f7 --- /dev/null +++ b/test/test_autotranscript.py @@ -0,0 +1,58 @@ +import pytest +from scraibe import Scraibe, Diariser, Transcriber, Transcript +from unittest.mock import MagicMock, patch +import os + + + + + +@pytest.fixture +def create_scraibe_instance(): + if "HF_TOKEN" in os.environ: + return Scraibe(use_auth_token=os.environ["HF_TOKEN"] ) + else: + return Scraibe() + + + + +def test_scraibe_init(create_scraibe_instance): + model = create_scraibe_instance + assert isinstance(model.transcriber, Transcriber) + assert isinstance(model.diariser, Diariser) + + +def test_scraibe_autotranscribe(create_scraibe_instance): + model = create_scraibe_instance + transcript = model.autotranscribe('test/audio_test_2.mp4') + assert isinstance(transcript, Transcript) + + +def test_scraibe_diarization(create_scraibe_instance): + model = create_scraibe_instance + diarisation_result = model.diarization('test/audio_test_2.mp4') + assert isinstance(diarisation_result, dict) + + +def test_scraibe_transcribe(create_scraibe_instance): + model = create_scraibe_instance + transcription_result = model.transcribe('test/audio_test_2.mp4') + assert isinstance(transcription_result, str) + + +""" def test_remove_audio_file(create_scraibe_instance): + model = create_scraibe_instance + with pytest.raises(ValueError): + model.remove_audio_file("non_existing_audio_file") + + model.remove_audio_file("audio_test_2.mp4") + assert not os.path.exists("audio_test_2.mp4") """ + + +""" def test_get_audio_file(create_scraibe_instance): + model = create_scraibe_instance + audio_file = os.path.exist("audio_test_2.mp4") + assert isinstance(audio_file, AudioProcessor) + assert isinstance(audio_file.waveform, torch.Tensor) + assert isinstance(audio_file.sr, torch.Tensor) """ diff --git a/test/test_diarisation.py b/test/test_diarisation.py new file mode 100644 index 0000000..d1d26f3 --- /dev/null +++ b/test/test_diarisation.py @@ -0,0 +1,47 @@ +import pytest +import os +from unittest import mock +from scraibe import diarisation, Diariser + + + +@pytest.fixture +def diariser_instance(): + """Fixture for creating an instance of the Diariser class with mocked token. + + This fixture is used to create an instance of the the Diariser class with a mocked token returned by the _get_token method. It patches the _get_token method of the Diariser class + using unit.test.mock.patch.object, ensuring that it returns a predetrmined value ('personal Hugging-Face token'). The mocked Diariser object is retunrned and can be used as a dependency in otehr tests. + + Returns: + Diariser(Obj): An instance of the Diariser class with a mocked token. + """ + #with mock.patch.object(Diariser, '_get_token', return_value = 'HF_TOKEN' ): + return Diariser('pyannote') + + + +def test_Diariser_init(diariser_instance): + """Test the initialization of the Diariser class. + + This test verifies that the Diariser class is correctly initialized with the specified model. + It checks whether the 'model' attribute of the instantiated Diariser object equals 'pyannote'. + + + Args: + diariser_instance (obj): instance of the Diariser class + + Returns: + None + """ + assert diariser_instance.model == 'pyannote' + + + + + + + + + + + diff --git a/test/test_transcriber.py b/test/test_transcriber.py new file mode 100644 index 0000000..3a4a0dc --- /dev/null +++ b/test/test_transcriber.py @@ -0,0 +1,52 @@ +import pytest +from unittest.mock import patch +from scraibe import Transcriber +import torch + + + +DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu") +TEST_WAVEFORM = "Hello World" + +""" +@pytest.mark.parametrize("audio_file, expected_transcription",[("path_to_test_audiofile", "test_transcription")] ) +@patch("scraibe.Transcriber.load_model") + +def test_transcriber(mock_load_model, audio_file, expected_transcription): + + + Args: + mock_load_model (_type_): _description_ + audio_file (_type_): _description_ + expected_transcription (_type_): _description_ + + mock_model = mock_load_model.return_value + mock_model.transcribe.return_value ={"text": expected_transcription} + + transcriber = Transcriber.load_model(model="medium") + + transcription_result = transcriber.transcribe(audio=audio_file) + + assert transcription_result == expected_transcription """ + +@pytest.fixture +def transcriber_instance(): + return Transcriber.load_model('medium') + +def test_transcriber_initialization(transcriber_instance): + assert isinstance(transcriber_instance, Transcriber) + +def test_get_whisper_kwargs(): + kwargs = {"arg1": 1, "arg3": 3} + valid_kwargs = Transcriber._get_whisper_kwargs(**kwargs) + assert not valid_kwargs == {"arg1": 1, "arg3": 3} + + +def test_transcribe(transcriber_instance): + model = transcriber_instance + #mocker.patch.object(transcriber_instance.model, 'transcribe', return_value={'Hello, World !'} ) + transcript = model.transcribe('test/audio_test_2.mp4') + assert isinstance(transcript, str) + + + diff --git a/tests/test_autotranscript.py b/tests/test_autotranscript.py deleted file mode 100644 index 475f4de..0000000 --- a/tests/test_autotranscript.py +++ /dev/null @@ -1,120 +0,0 @@ -import pytest -from scraibe import Transcriber -from unittest.mock import patch, mock_open -import os - -def test_load_pyannote_model(): - """ - Test load_pyannote_test - """ - from pyannote.audio.pipelines.speaker_diarization import SpeakerDiarization - from pyannote.audio import Pipeline - - pipeline = Pipeline.from_pretrained("models/pyannote/speaker_diarization/config.yaml") - assert isinstance(pipeline, SpeakerDiarization) - -# Test Transcribtion class - - -@pytest.fixture -def transcriber(): - """ - Prepare Transcriber for testing - Returns: Transcriber Object - """ - - return Transcriber.load_model("medium", local=True) - - -def test_Transcriber_init(transcriber): - """ - Test Transcriber initialization with a whisper model - """ - - assert isinstance(transcriber, Transcriber) - -def test_transcription(transcriber): - """ - Test transcription - """ - - transcript = transcriber.transcribe("tests/test.wav") - assert isinstance(transcript, str) - -def test_save_transcript_to_file(transcriber): - """ - Test save_transcript_to_file - """ - transcript = transcriber.transcribe("tests/test.wav") - - Transcriber.save_transcript(transcript, "tests/output.txt") - - assert os.path.exists("tests/output.txt") - - os.remove("tests/output.txt") - -# Test Diaraization class - -from scraibe import Diariser - -@pytest.fixture -def diarisation(): - """ - Prepare Diarisation for testing - Returns: Diarisation Object - """ - - return Diariser.load_model("models/pyannote/speaker_diarization/config.yaml", local=True) - -def test_Diarisation_init(diarisation): - """ - Test Diarisation initialization with a pyannote model - """ - - assert isinstance(diarisation, Diariser) - -def test_diarisation(diarisation): - """ - Test diarisation - """ - - diarisation = diarisation.diarization("tests/test.wav") - assert isinstance(diarisation, dict) - -# Test AudioProcessor - -from scraibe import AudioProcessor , TorchAudioProcessor - - -def test_AudioProcessor_init(): - """ - Test AudioProcessor initialization - """ - audio = AudioProcessor("tests/test.wav") - assert isinstance(audio, AudioProcessor) - -def test_AudioProcessor_convert(): - """ - Test AudioProcessor convert - """ - audio = AudioProcessor("tests/test.wav") - audio.convert_audio("tests/test.mp3", format="mp3") - assert os.path.exists("tests/test.mp3") - -def test_TorchAudioProcessor_from_file(): - """ - Test TorchAudioProcessor initialization - """ - audio = TorchAudioProcessor.from_file("tests/test.wav") - - assert isinstance(audio, TorchAudioProcessor) - - os.remove("tests/test.mp3") - - -def test_TorchAudioProcessor_from_ffmpeg(): - """ - Test TorchAudioProcessor initialization - """ - audio = TorchAudioProcessor.from_ffmpeg("tests/test.wav") - assert isinstance(audio, TorchAudioProcessor)