From 328f7d4a0f1b36b3da93777eeafa243085fdcfb5 Mon Sep 17 00:00:00 2001 From: Tryndaron Date: Tue, 9 Apr 2024 10:00:11 +0200 Subject: [PATCH] test transcriber function --- .github/workflows/pytest.yaml | 2 +- test/test_transcriber.py | 22 ++++++++++++++++++++++ 2 files changed, 23 insertions(+), 1 deletion(-) diff --git a/.github/workflows/pytest.yaml b/.github/workflows/pytest.yaml index da3e1a4..7591ae5 100644 --- a/.github/workflows/pytest.yaml +++ b/.github/workflows/pytest.yaml @@ -39,5 +39,5 @@ jobs: env: HF_TOKEN : ${{ secrets.HF_TOKEN }} run: | - pytest test/test_diarisation2.py + pytest test/test_transcriber.py \ No newline at end of file diff --git a/test/test_transcriber.py b/test/test_transcriber.py index ef17951..68ff854 100644 --- a/test/test_transcriber.py +++ b/test/test_transcriber.py @@ -1,9 +1,13 @@ import pytest from unittest.mock import patch from scraibe import Transcriber +import torch +DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu") +TEST_WAVEFORM = "Hello World" + """ @pytest.mark.parametrize("audio_file, expected_transcription",[("path_to_test_audiofile", "test_transcription")] ) @patch("scraibe.Transcriber.load_model") @@ -25,5 +29,23 @@ def test_transcriber(mock_load_model, audio_file, expected_transcription): assert transcription_result == expected_transcription """ +@pytest.fixture +def transcriber_instance(): + return Transcriber('medium') + +def test_transcriber_initialization(transcriber_instance): + assert transcriber_instance.model == 'medium' + +""" def test_get_whisper_kwargs(): + kwargs = {"arg1": 1, "arg3": 3} + valid_kwargs = Transcriber._get_diarisation_kwargs(**kwargs) + assert not valid_kwargs == {"arg1": 1, "arg3": 3} """ + + +""" def test_transcribe(transcriber_instance, TEST_WAVEFORM): + mocker.patch.object(transcriber_instance.model, 'transcribe', return_value={'Hello, World !'} ) + transcript = transcriber_instance.transcribe("Hello, World") + assert isinstance(transcript, str) """ +