Auto fixes from PEP8, fixes from flake8.

This commit is contained in:
Marko Henning
2024-05-15 15:18:17 +02:00
parent 9f526a8f3b
commit 4bcd28d0ea
15 changed files with 391 additions and 417 deletions
+15 -46
View File
@@ -3,7 +3,6 @@ from scraibe.audio import AudioProcessor
import torch
DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu")
TEST_WAVEFORM = torch.sin(torch.randn(160000)).to(DEVICE)
TEST_SR = 16000
@@ -14,21 +13,17 @@ NORMALIZATION_FACTOR = 32768
@pytest.fixture
def probe_audio_processor():
"""Fixture for creating an instance of the AudioProcessor class with test waveform and sample rate.
This fixture is used to create an instance of the AudioProcessor class with a predfined test waveform and sample rate (TEST_SR). It returns the instantiated AudioProcessor , which can bes used as a
dependency in other test functions.
Returns:
AudioProcessor (obj): An instance of the AudioProcessor class with the test waveform and sample rate.
"""
"""
return AudioProcessor(TEST_WAVEFORM, TEST_SR)
def test_AudioProcessor_init(probe_audio_processor):
"""
Test the initialization of the AudioProcessor class.
@@ -43,20 +38,19 @@ def test_AudioProcessor_init(probe_audio_processor):
Returns:
None
"""
"""
assert isinstance(probe_audio_processor, AudioProcessor)
assert probe_audio_processor.waveform.device == TEST_WAVEFORM.device
assert torch.equal(probe_audio_processor.waveform, TEST_WAVEFORM)
assert probe_audio_processor.sr == TEST_SR
def test_cut(probe_audio_processor):
"""Test the cut function of the AudioProcessor class.
This test verifies that the cut function correctly extracts a segment of audio data from
the waveform, given start and end indices. It checks whether the size of the extracted segment matches
the expected size based on the provided start and end indices and the sample rate.
@@ -65,63 +59,38 @@ def test_cut(probe_audio_processor):
None
"""
"""
start = 4
end = 7
trimmed_waveform = probe_audio_processor.cut(start, end)
expected_size = int((end - start) * TEST_SR)
real_size = trimmed_waveform.size(0)
assert real_size == expected_size
#assert AudioProcessor(TEST_WAVEFORM, TEST_SR).cut(start, end).size() == int((end - start) * TEST_SR)
# assert AudioProcessor(TEST_WAVEFORM, TEST_SR).cut(start, end).size() == int((end - start) * TEST_SR)
def test_audio_processor_invalid_sr():
"""Test the behavior of AudioProcessor when an invalid smaple rate is provided.
This test verifies that the AudioProcessor constructor raises a ValueError when an invalid sample rate is provided. It uses the pytest.raises context manager to check if the ValueError is raised when initializing an
AudioProcessor object with an invalid sample rate.
Returns:
None
"""
"""
with pytest.raises(ValueError):
AudioProcessor(TEST_WAVEFORM, [44100,48000])
AudioProcessor(TEST_WAVEFORM, [44100, 48000])
def test_audio_processor_SAMPLE_RATE():
"""Test the default sample rate of the AudioProcessor class.
This test verifies that the default sample rate of the AudioProcessor class matches the expected value defined by the constant SAMPLE_RATE. It instantiates an AudioProcessor object with a test waveform
and checks whether the sample rate attribute (sr) of the AudioProcessor object equals the predefined constant SAMPLE_RATE.
Returns:
None
"""
"""
probe_audio_processor = AudioProcessor(TEST_WAVEFORM)
assert probe_audio_processor.sr == SAMPLE_RATE
assert probe_audio_processor.sr == SAMPLE_RATE
+2 -8
View File
@@ -1,20 +1,14 @@
import pytest
from scraibe import Scraibe, Diariser, Transcriber, Transcript
from unittest.mock import MagicMock, patch
import os
@pytest.fixture
def create_scraibe_instance():
if "HF_TOKEN" in os.environ:
return Scraibe(use_auth_token=os.environ["HF_TOKEN"] )
return Scraibe(use_auth_token=os.environ["HF_TOKEN"])
else:
return Scraibe()
def test_scraibe_init(create_scraibe_instance):
@@ -47,7 +41,7 @@ def test_scraibe_transcribe(create_scraibe_instance):
model.remove_audio_file("non_existing_audio_file")
model.remove_audio_file("audio_test_2.mp4")
assert not os.path.exists("audio_test_2.mp4") """
assert not os.path.exists("audio_test_2.mp4") """
""" def test_get_audio_file(create_scraibe_instance):
+4 -19
View File
@@ -1,8 +1,5 @@
import pytest
import os
from unittest import mock
from scraibe import diarisation, Diariser
from scraibe import Diariser
@pytest.fixture
@@ -15,11 +12,10 @@ def diariser_instance():
Returns:
Diariser(Obj): An instance of the Diariser class with a mocked token.
"""
#with mock.patch.object(Diariser, '_get_token', return_value = 'HF_TOKEN' ):
# with mock.patch.object(Diariser, '_get_token', return_value = 'HF_TOKEN' ):
return Diariser('pyannote')
def test_Diariser_init(diariser_instance):
"""Test the initialization of the Diariser class.
@@ -30,18 +26,7 @@ def test_Diariser_init(diariser_instance):
Args:
diariser_instance (obj): instance of the Diariser class
Returns:
Returns:
None
"""
"""
assert diariser_instance.model == 'pyannote'
+9 -11
View File
@@ -1,25 +1,23 @@
import pytest
from unittest.mock import patch
from scraibe import Transcriber
import torch
DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu")
TEST_WAVEFORM = "Hello World"
"""
"""
@pytest.mark.parametrize("audio_file, expected_transcription",[("path_to_test_audiofile", "test_transcription")] )
@patch("scraibe.Transcriber.load_model")
def test_transcriber(mock_load_model, audio_file, expected_transcription):
Args:
mock_load_model (_type_): _description_
audio_file (_type_): _description_
expected_transcription (_type_): _description_
mock_model = mock_load_model.return_value
mock_model.transcribe.return_value ={"text": expected_transcription}
@@ -29,24 +27,24 @@ def test_transcriber(mock_load_model, audio_file, expected_transcription):
assert transcription_result == expected_transcription """
@pytest.fixture
def transcriber_instance():
return Transcriber.load_model('medium')
def test_transcriber_initialization(transcriber_instance):
assert isinstance(transcriber_instance, Transcriber)
def test_get_whisper_kwargs():
kwargs = {"arg1": 1, "arg3": 3}
kwargs = {"arg1": 1, "arg3": 3}
valid_kwargs = Transcriber._get_whisper_kwargs(**kwargs)
assert not valid_kwargs == {"arg1": 1, "arg3": 3}
assert not valid_kwargs == {"arg1": 1, "arg3": 3}
def test_transcribe(transcriber_instance):
model = transcriber_instance
#mocker.patch.object(transcriber_instance.model, 'transcribe', return_value={'Hello, World !'} )
# mocker.patch.object(transcriber_instance.model, 'transcribe', return_value={'Hello, World !'} )
transcript = model.transcribe('test/audio_test_2.mp4')
assert isinstance(transcript, str)