feat: add chunked ASR for long audio with env-configurable chunk duration
- Integrate chunking into LocalAI client to avoid GPU OOM on long audio.
- Split long files into overlapping chunks; transcribe each chunk; merge segments with corrected timestamps.
- Auto-enable chunking when audio duration > LOCALAI_MAX_SINGLE_REQUEST_DURATION (default 300s).
- Add env variables:
LOCALAI_CHUNK_DURATION (default 180)
LOCALAI_CHUNK_OVERLAP (default 2)
LOCALAI_MAX_SINGLE_REQUEST_DURATION (default 300)
- Add unit and integration tests for chunking logic.
- Confirmed working end-to-end with vibevoice-cpp-asr on 88-minute file.
This commit is contained in:
@@ -0,0 +1,230 @@
|
||||
import os
|
||||
import json
|
||||
import tempfile
|
||||
from unittest.mock import patch, MagicMock
|
||||
|
||||
import pytest
|
||||
|
||||
from scraibe.localai_client import LocalAIClient, LocalAIError
|
||||
from scraibe.audio import get_audio_duration, split_audio_into_chunks
|
||||
|
||||
|
||||
TEST_AUDIO_1 = "tests/audio_test_1.mp4"
|
||||
|
||||
|
||||
def make_fake_segments(start=0.0, count=3):
|
||||
segments = []
|
||||
for i in range(count):
|
||||
s = start + i * 2.0
|
||||
e = s + 2.0
|
||||
segments.append({
|
||||
"start": s,
|
||||
"end": e,
|
||||
"speaker": "SPEAKER_00",
|
||||
"text": f"Segment text {i}",
|
||||
})
|
||||
return segments
|
||||
|
||||
|
||||
def fake_localai_response(segments):
|
||||
return {
|
||||
"segments": segments,
|
||||
"text": " ".join(seg["text"] for seg in segments),
|
||||
}
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def client():
|
||||
with patch.object(LocalAIClient, "__init__", lambda self, **kw: None):
|
||||
c = LocalAIClient()
|
||||
c.api_url = "http://localhost:8080"
|
||||
c.model = "vibevoice-diarize"
|
||||
c.api_key = None
|
||||
c._client = MagicMock()
|
||||
return c
|
||||
|
||||
|
||||
def test_parse_diarization_response(client):
|
||||
segs = make_fake_segments()
|
||||
raw = fake_localai_response(segs)
|
||||
|
||||
out = client._parse_diarization_response(raw)
|
||||
|
||||
assert "segments" in out
|
||||
assert "speakers" in out
|
||||
assert "transcripts" in out
|
||||
assert len(out["segments"]) == len(segs)
|
||||
for i, s in enumerate(segs):
|
||||
assert out["segments"][i][0] == s["start"]
|
||||
assert out["segments"][i][1] == s["end"]
|
||||
assert out["speakers"][i] == s["speaker"]
|
||||
assert out["transcripts"][i] == s["text"]
|
||||
|
||||
|
||||
def test_parse_diarization_empty(client):
|
||||
out = client._parse_diarization_response({"segments": []})
|
||||
assert out["segments"] == []
|
||||
assert out["speakers"] == []
|
||||
assert out["transcripts"] == []
|
||||
|
||||
|
||||
def test_diarize_and_transcribe_single_happy(client):
|
||||
with patch.object(client, "_client") as mock_client:
|
||||
mock_resp = MagicMock()
|
||||
mock_resp.status_code = 200
|
||||
mock_resp.json.return_value = fake_localai_response(make_fake_segments())
|
||||
mock_client.post.return_value = mock_resp
|
||||
|
||||
result = client.diarize_and_transcribe(
|
||||
audio_path=TEST_AUDIO_1,
|
||||
verbose=False,
|
||||
return_raw=True,
|
||||
)
|
||||
|
||||
assert "segments" in result
|
||||
assert "raw_result" in result
|
||||
assert len(result["segments"]) > 0
|
||||
|
||||
|
||||
def test_chunking_triggered_for_long_audio(client):
|
||||
# Simulate long audio by patching get_audio_duration
|
||||
with patch("scraibe.localai_client.get_audio_duration") as mock_dur, \
|
||||
patch.object(client, "_diarize_and_transcribe_chunked") as mock_chunked:
|
||||
|
||||
mock_dur.return_value = 600.0 # 10 minutes
|
||||
mock_chunked.return_value = {
|
||||
"segments": [],
|
||||
"speakers": [],
|
||||
"transcripts": [],
|
||||
}
|
||||
|
||||
client.diarize_and_transcribe(
|
||||
audio_path=TEST_AUDIO_1,
|
||||
verbose=False,
|
||||
use_chunking=None,
|
||||
max_single_request_duration=300.0,
|
||||
)
|
||||
|
||||
mock_chunked.assert_called_once()
|
||||
|
||||
|
||||
def test_chunking_not_triggered_for_short_audio(client):
|
||||
with patch("scraibe.localai_client.get_audio_duration") as mock_dur, \
|
||||
patch.object(client, "_diarize_and_transcribe_chunked") as mock_chunked, \
|
||||
patch.object(client, "_diarize_and_transcribe_single") as mock_single:
|
||||
|
||||
mock_dur.return_value = 120.0
|
||||
mock_single.return_value = {
|
||||
"segments": [],
|
||||
"speakers": [],
|
||||
"transcripts": [],
|
||||
}
|
||||
|
||||
client.diarize_and_transcribe(
|
||||
audio_path=TEST_AUDIO_1,
|
||||
verbose=False,
|
||||
use_chunking=None,
|
||||
max_single_request_duration=300.0,
|
||||
)
|
||||
|
||||
mock_chunked.assert_not_called()
|
||||
mock_single.assert_called_once()
|
||||
|
||||
|
||||
def test_chunked_transcription_adjusts_timestamps(client):
|
||||
# Mock split_audio_into_chunks to return two chunks
|
||||
chunk1_path = TEST_AUDIO_1
|
||||
chunk2_path = TEST_AUDIO_1 # reusing same file; in real usage different
|
||||
|
||||
chunks = [
|
||||
{"path": chunk1_path, "start": 0.0, "end": 10.0},
|
||||
{"path": chunk2_path, "start": 10.0, "end": 20.0},
|
||||
]
|
||||
|
||||
with patch("scraibe.localai_client.split_audio_into_chunks") as mock_split, \
|
||||
patch.object(client, "_diarize_and_transcribe_single") as mock_single, \
|
||||
patch("os.remove"):
|
||||
|
||||
mock_split.return_value = chunks
|
||||
|
||||
# First chunk: segments 0–4
|
||||
# Second chunk: segments 0–4 (local times)
|
||||
def side_effect(audio_path, **kw):
|
||||
if audio_path == chunk1_path:
|
||||
segs = make_fake_segments(start=0.0, count=2)
|
||||
else:
|
||||
segs = make_fake_segments(start=0.0, count=2)
|
||||
return client._parse_diarization_response(fake_localai_response(segs))
|
||||
|
||||
mock_single.side_effect = side_effect
|
||||
|
||||
result = client._diarize_and_transcribe_chunked(
|
||||
audio_path=TEST_AUDIO_1,
|
||||
verbose=False,
|
||||
return_raw=False,
|
||||
chunk_duration=10.0,
|
||||
chunk_overlap=2.0,
|
||||
)
|
||||
|
||||
# Check we got 4 segments total
|
||||
assert len(result["segments"]) == 4
|
||||
|
||||
# First two segments should be in [0, 4]
|
||||
assert result["segments"][0][0] == 0.0
|
||||
assert result["segments"][1][0] == 2.0
|
||||
|
||||
# Next two segments should be shifted by 10
|
||||
assert result["segments"][2][0] == 10.0
|
||||
assert result["segments"][3][0] == 12.0
|
||||
|
||||
|
||||
@pytest.mark.integration
|
||||
def test_integration_chunked_transcription_with_localai():
|
||||
"""
|
||||
Integration test: run chunked transcription against a live LocalAI instance.
|
||||
Only runs if LOCALAI_API_URL is set and an audio file is provided.
|
||||
This test is skipped by default unless run with:
|
||||
pytest -m integration
|
||||
"""
|
||||
api_url = os.getenv("LOCALAI_API_URL")
|
||||
if not api_url:
|
||||
pytest.skip("LOCALAI_API_URL not set; skipping integration test")
|
||||
|
||||
# Use one of the bundled test audio files
|
||||
audio_path = TEST_AUDIO_1
|
||||
if not os.path.exists(audio_path):
|
||||
pytest.skip(f"Test audio not found: {audio_path}")
|
||||
|
||||
# Force chunking with a very small max_single_request_duration
|
||||
# Use environment-configured model or a sensible default
|
||||
model = os.getenv("LOCALAI_MODEL") or "vibevoice-cpp-asr"
|
||||
|
||||
client = LocalAIClient(api_url=api_url, model=model)
|
||||
try:
|
||||
result = client.diarize_and_transcribe(
|
||||
audio_path=audio_path,
|
||||
verbose=True,
|
||||
return_raw=True,
|
||||
use_chunking=True,
|
||||
chunk_duration=3.0,
|
||||
chunk_overlap=0.5,
|
||||
max_single_request_duration=1.0,
|
||||
)
|
||||
|
||||
assert "segments" in result
|
||||
assert len(result["segments"]) > 0
|
||||
|
||||
# Basic sanity: segments are time-ordered
|
||||
for i in range(1, len(result["segments"])):
|
||||
prev_end = result["segments"][i - 1][1]
|
||||
curr_start = result["segments"][i][0]
|
||||
assert curr_start >= result["segments"][i - 1][0]
|
||||
|
||||
# If raw_result indicates chunked, ensure structure is sensible
|
||||
raw = result.get("raw_result")
|
||||
if raw and raw.get("chunked"):
|
||||
assert "chunks" in raw
|
||||
assert len(raw["chunks"]) > 1
|
||||
|
||||
finally:
|
||||
client.close()
|
||||
Reference in New Issue
Block a user