Merge pull request #33 from JSchmie/tests

Good Job :)
This commit is contained in:
Jacob Schmieder
2024-04-29 23:05:20 +02:00
committed by GitHub
8 changed files with 327 additions and 120 deletions
+43
View File
@@ -0,0 +1,43 @@
name: Run tests
on:
#push:
pull_request:
branches: ['main', 'develop']
workflow_dispatch:
jobs:
pytest:
runs-on: ubuntu-latest
steps:
- name: Checkout
uses: actions/checkout@v3
- name: Setup Python
uses: actions/setup-python@v3
with:
python-version: 3.9
- name: Install Dependencies
run: |
sudo apt update && sudo apt upgrade
python -m pip install --upgrade pip
pip install -r requirements.txt
pip install .
sudo apt-get install libsndfile1-dev
sudo apt-get install ffmpeg
pip install pytest
- name: Run pytest
env:
HF_TOKEN : ${{ secrets.HF_TOKEN }}
run: |
pytest
Binary file not shown.
Binary file not shown.
+127
View File
@@ -0,0 +1,127 @@
import pytest
from scraibe.audio import AudioProcessor
import torch
DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu")
TEST_WAVEFORM = torch.sin(torch.randn(160000)).to(DEVICE)
TEST_SR = 16000
SAMPLE_RATE = 16000
NORMALIZATION_FACTOR = 32768
@pytest.fixture
def probe_audio_processor():
"""Fixture for creating an instance of the AudioProcessor class with test waveform and sample rate.
This fixture is used to create an instance of the AudioProcessor class with a predfined test waveform and sample rate (TEST_SR). It returns the instantiated AudioProcessor , which can bes used as a
dependency in other test functions.
Returns:
AudioProcessor (obj): An instance of the AudioProcessor class with the test waveform and sample rate.
"""
return AudioProcessor(TEST_WAVEFORM, TEST_SR)
def test_AudioProcessor_init(probe_audio_processor):
"""
Test the initialization of the AudioProcessor class.
This test verifies that the AUdioProcessor class is correctly initialized with the provided waveform and sample rate. It checks whether the instantiated AhdioProcessor object has the correct attributes
and whether the waveform and sample rate match the expected values.
Args:
probe_audio_processor (obj): An instance of the AudioProcessor class to be tested.
Returns:
None
"""
assert isinstance(probe_audio_processor, AudioProcessor)
assert probe_audio_processor.waveform.device == TEST_WAVEFORM.device
assert torch.equal(probe_audio_processor.waveform, TEST_WAVEFORM)
assert probe_audio_processor.sr == TEST_SR
def test_cut(probe_audio_processor):
"""Test the cut function of the AudioProcessor class.
This test verifies that the cut function correctly extracts a segment of audio data from
the waveform, given start and end indices. It checks whether the size of the extracted segment matches
the expected size based on the provided start and end indices and the sample rate.
Returns:
None
"""
start = 4
end = 7
trimmed_waveform = probe_audio_processor.cut(start, end)
expected_size = int((end - start) * TEST_SR)
real_size = trimmed_waveform.size(0)
assert real_size == expected_size
#assert AudioProcessor(TEST_WAVEFORM, TEST_SR).cut(start, end).size() == int((end - start) * TEST_SR)
def test_audio_processor_invalid_sr():
"""Test the behavior of AudioProcessor when an invalid smaple rate is provided.
This test verifies that the AudioProcessor constructor raises a ValueError when an invalid sample rate is provided. It uses the pytest.raises context manager to check if the ValueError is raised when initializing an
AudioProcessor object with an invalid sample rate.
Returns:
None
"""
with pytest.raises(ValueError):
AudioProcessor(TEST_WAVEFORM, [44100,48000])
def test_audio_processor_SAMPLE_RATE():
"""Test the default sample rate of the AudioProcessor class.
This test verifies that the default sample rate of the AudioProcessor class matches the expected value defined by the constant SAMPLE_RATE. It instantiates an AudioProcessor object with a test waveform
and checks whether the sample rate attribute (sr) of the AudioProcessor object equals the predefined constant SAMPLE_RATE.
Returns:
None
"""
probe_audio_processor = AudioProcessor(TEST_WAVEFORM)
assert probe_audio_processor.sr == SAMPLE_RATE
+58
View File
@@ -0,0 +1,58 @@
import pytest
from scraibe import Scraibe, Diariser, Transcriber, Transcript
from unittest.mock import MagicMock, patch
import os
@pytest.fixture
def create_scraibe_instance():
if "HF_TOKEN" in os.environ:
return Scraibe(use_auth_token=os.environ["HF_TOKEN"] )
else:
return Scraibe()
def test_scraibe_init(create_scraibe_instance):
model = create_scraibe_instance
assert isinstance(model.transcriber, Transcriber)
assert isinstance(model.diariser, Diariser)
def test_scraibe_autotranscribe(create_scraibe_instance):
model = create_scraibe_instance
transcript = model.autotranscribe('test/audio_test_2.mp4')
assert isinstance(transcript, Transcript)
def test_scraibe_diarization(create_scraibe_instance):
model = create_scraibe_instance
diarisation_result = model.diarization('test/audio_test_2.mp4')
assert isinstance(diarisation_result, dict)
def test_scraibe_transcribe(create_scraibe_instance):
model = create_scraibe_instance
transcription_result = model.transcribe('test/audio_test_2.mp4')
assert isinstance(transcription_result, str)
""" def test_remove_audio_file(create_scraibe_instance):
model = create_scraibe_instance
with pytest.raises(ValueError):
model.remove_audio_file("non_existing_audio_file")
model.remove_audio_file("audio_test_2.mp4")
assert not os.path.exists("audio_test_2.mp4") """
""" def test_get_audio_file(create_scraibe_instance):
model = create_scraibe_instance
audio_file = os.path.exist("audio_test_2.mp4")
assert isinstance(audio_file, AudioProcessor)
assert isinstance(audio_file.waveform, torch.Tensor)
assert isinstance(audio_file.sr, torch.Tensor) """
+47
View File
@@ -0,0 +1,47 @@
import pytest
import os
from unittest import mock
from scraibe import diarisation, Diariser
@pytest.fixture
def diariser_instance():
"""Fixture for creating an instance of the Diariser class with mocked token.
This fixture is used to create an instance of the the Diariser class with a mocked token returned by the _get_token method. It patches the _get_token method of the Diariser class
using unit.test.mock.patch.object, ensuring that it returns a predetrmined value ('personal Hugging-Face token'). The mocked Diariser object is retunrned and can be used as a dependency in otehr tests.
Returns:
Diariser(Obj): An instance of the Diariser class with a mocked token.
"""
#with mock.patch.object(Diariser, '_get_token', return_value = 'HF_TOKEN' ):
return Diariser('pyannote')
def test_Diariser_init(diariser_instance):
"""Test the initialization of the Diariser class.
This test verifies that the Diariser class is correctly initialized with the specified model.
It checks whether the 'model' attribute of the instantiated Diariser object equals 'pyannote'.
Args:
diariser_instance (obj): instance of the Diariser class
Returns:
None
"""
assert diariser_instance.model == 'pyannote'
+52
View File
@@ -0,0 +1,52 @@
import pytest
from unittest.mock import patch
from scraibe import Transcriber
import torch
DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu")
TEST_WAVEFORM = "Hello World"
"""
@pytest.mark.parametrize("audio_file, expected_transcription",[("path_to_test_audiofile", "test_transcription")] )
@patch("scraibe.Transcriber.load_model")
def test_transcriber(mock_load_model, audio_file, expected_transcription):
Args:
mock_load_model (_type_): _description_
audio_file (_type_): _description_
expected_transcription (_type_): _description_
mock_model = mock_load_model.return_value
mock_model.transcribe.return_value ={"text": expected_transcription}
transcriber = Transcriber.load_model(model="medium")
transcription_result = transcriber.transcribe(audio=audio_file)
assert transcription_result == expected_transcription """
@pytest.fixture
def transcriber_instance():
return Transcriber.load_model('medium')
def test_transcriber_initialization(transcriber_instance):
assert isinstance(transcriber_instance, Transcriber)
def test_get_whisper_kwargs():
kwargs = {"arg1": 1, "arg3": 3}
valid_kwargs = Transcriber._get_whisper_kwargs(**kwargs)
assert not valid_kwargs == {"arg1": 1, "arg3": 3}
def test_transcribe(transcriber_instance):
model = transcriber_instance
#mocker.patch.object(transcriber_instance.model, 'transcribe', return_value={'Hello, World !'} )
transcript = model.transcribe('test/audio_test_2.mp4')
assert isinstance(transcript, str)
-120
View File
@@ -1,120 +0,0 @@
import pytest
from scraibe import Transcriber
from unittest.mock import patch, mock_open
import os
def test_load_pyannote_model():
"""
Test load_pyannote_test
"""
from pyannote.audio.pipelines.speaker_diarization import SpeakerDiarization
from pyannote.audio import Pipeline
pipeline = Pipeline.from_pretrained("models/pyannote/speaker_diarization/config.yaml")
assert isinstance(pipeline, SpeakerDiarization)
# Test Transcribtion class
@pytest.fixture
def transcriber():
"""
Prepare Transcriber for testing
Returns: Transcriber Object
"""
return Transcriber.load_model("medium", local=True)
def test_Transcriber_init(transcriber):
"""
Test Transcriber initialization with a whisper model
"""
assert isinstance(transcriber, Transcriber)
def test_transcription(transcriber):
"""
Test transcription
"""
transcript = transcriber.transcribe("tests/test.wav")
assert isinstance(transcript, str)
def test_save_transcript_to_file(transcriber):
"""
Test save_transcript_to_file
"""
transcript = transcriber.transcribe("tests/test.wav")
Transcriber.save_transcript(transcript, "tests/output.txt")
assert os.path.exists("tests/output.txt")
os.remove("tests/output.txt")
# Test Diaraization class
from scraibe import Diariser
@pytest.fixture
def diarisation():
"""
Prepare Diarisation for testing
Returns: Diarisation Object
"""
return Diariser.load_model("models/pyannote/speaker_diarization/config.yaml", local=True)
def test_Diarisation_init(diarisation):
"""
Test Diarisation initialization with a pyannote model
"""
assert isinstance(diarisation, Diariser)
def test_diarisation(diarisation):
"""
Test diarisation
"""
diarisation = diarisation.diarization("tests/test.wav")
assert isinstance(diarisation, dict)
# Test AudioProcessor
from scraibe import AudioProcessor , TorchAudioProcessor
def test_AudioProcessor_init():
"""
Test AudioProcessor initialization
"""
audio = AudioProcessor("tests/test.wav")
assert isinstance(audio, AudioProcessor)
def test_AudioProcessor_convert():
"""
Test AudioProcessor convert
"""
audio = AudioProcessor("tests/test.wav")
audio.convert_audio("tests/test.mp3", format="mp3")
assert os.path.exists("tests/test.mp3")
def test_TorchAudioProcessor_from_file():
"""
Test TorchAudioProcessor initialization
"""
audio = TorchAudioProcessor.from_file("tests/test.wav")
assert isinstance(audio, TorchAudioProcessor)
os.remove("tests/test.mp3")
def test_TorchAudioProcessor_from_ffmpeg():
"""
Test TorchAudioProcessor initialization
"""
audio = TorchAudioProcessor.from_ffmpeg("tests/test.wav")
assert isinstance(audio, TorchAudioProcessor)