Changed loading of transcriber objects to function. Adjusted tests.
This commit is contained in:
@@ -1,5 +1,6 @@
|
||||
import pytest
|
||||
from scraibe import Transcriber, WhisperTranscriber, WhisperXTranscriber
|
||||
from scraibe import (Transcriber, WhisperTranscriber,
|
||||
WhisperXTranscriber, load_transcriber)
|
||||
import torch
|
||||
|
||||
|
||||
@@ -30,12 +31,12 @@ def test_transcriber(mock_load_model, audio_file, expected_transcription):
|
||||
|
||||
@pytest.fixture
|
||||
def whisper_instance():
|
||||
return Transcriber.load_model('medium', whisper_type='whisper')
|
||||
return load_transcriber('medium', whisper_type='whisper')
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def whisperx_instance():
|
||||
return Transcriber.load_model('medium', whisper_type='whisperx')
|
||||
return load_transcriber('medium', whisper_type='whisperx')
|
||||
|
||||
|
||||
def test_whisper_base_initialization(whisper_instance):
|
||||
@@ -56,7 +57,7 @@ def test_whisperx_transcriber_initialization(whisperx_instance):
|
||||
|
||||
def test_wrong_transcriber_initialization():
|
||||
with pytest.raises(ValueError):
|
||||
Transcriber.load_model('medium', whisper_type='wrong_whisper')
|
||||
load_transcriber('medium', whisper_type='wrong_whisper')
|
||||
|
||||
|
||||
def test_get_whisper_kwargs():
|
||||
|
||||
Reference in New Issue
Block a user