feat: add chunked ASR for long audio with env-configurable chunk duration
- Integrate chunking into LocalAI client to avoid GPU OOM on long audio.
- Split long files into overlapping chunks; transcribe each chunk; merge segments with corrected timestamps.
- Auto-enable chunking when audio duration > LOCALAI_MAX_SINGLE_REQUEST_DURATION (default 300s).
- Add env variables:
LOCALAI_CHUNK_DURATION (default 180)
LOCALAI_CHUNK_OVERLAP (default 2)
LOCALAI_MAX_SINGLE_REQUEST_DURATION (default 300)
- Add unit and integration tests for chunking logic.
- Confirmed working end-to-end with vibevoice-cpp-asr on 88-minute file.
This commit is contained in:
@@ -0,0 +1,86 @@
|
||||
import os
|
||||
import subprocess
|
||||
import tempfile
|
||||
import pytest
|
||||
|
||||
from scraibe.audio import (
|
||||
get_audio_duration,
|
||||
split_audio_into_chunks,
|
||||
)
|
||||
|
||||
TEST_AUDIO_1 = "tests/audio_test_1.mp4"
|
||||
TEST_AUDIO_2 = "tests/audio_test_2.mp4"
|
||||
|
||||
|
||||
@pytest.fixture(params=[TEST_AUDIO_1, TEST_AUDIO_2])
|
||||
def test_audio_path(request):
|
||||
return request.param
|
||||
|
||||
|
||||
def test_get_audio_duration(test_audio_path):
|
||||
dur = get_audio_duration(test_audio_path)
|
||||
assert isinstance(dur, float)
|
||||
assert dur > 0
|
||||
|
||||
|
||||
def test_split_audio_into_chunks_no_split_short(test_audio_path):
|
||||
# For short files, should return the same file with no extra chunks
|
||||
chunks = split_audio_into_chunks(
|
||||
input_path=test_audio_path,
|
||||
max_duration=600.0, # larger than both test files
|
||||
overlap=2.0,
|
||||
)
|
||||
assert len(chunks) == 1
|
||||
assert chunks[0]["path"] == test_audio_path
|
||||
assert chunks[0]["start"] == 0.0
|
||||
dur = get_audio_duration(test_audio_path)
|
||||
assert abs(chunks[0]["end"] - dur) < 0.05
|
||||
|
||||
|
||||
def test_split_audio_into_chunks_creates_chunks(tmp_path):
|
||||
# Use a small chunk duration to force splitting
|
||||
chunks = split_audio_into_chunks(
|
||||
input_path=TEST_AUDIO_1,
|
||||
max_duration=2.0,
|
||||
overlap=0.5,
|
||||
)
|
||||
assert len(chunks) > 1
|
||||
|
||||
# Check that each chunk file exists and is non-empty
|
||||
for c in chunks:
|
||||
assert os.path.exists(c["path"])
|
||||
assert os.path.getsize(c["path"]) > 0
|
||||
|
||||
# Check time ordering and overlap
|
||||
for i in range(1, len(chunks)):
|
||||
prev = chunks[i - 1]
|
||||
curr = chunks[i]
|
||||
assert curr["start"] >= prev["start"]
|
||||
assert curr["start"] < prev["end"] # overlap
|
||||
|
||||
# Cleanup
|
||||
for c in chunks:
|
||||
if os.path.exists(c["path"]):
|
||||
os.remove(c["path"])
|
||||
|
||||
|
||||
def test_split_audio_into_chunks_total_coverage(test_audio_path):
|
||||
dur = get_audio_duration(test_audio_path)
|
||||
|
||||
# Use small chunks to ensure coverage
|
||||
chunks = split_audio_into_chunks(
|
||||
input_path=test_audio_path,
|
||||
max_duration=2.0,
|
||||
overlap=0.5,
|
||||
)
|
||||
|
||||
# First chunk starts at 0
|
||||
assert chunks[0]["start"] == 0.0
|
||||
|
||||
# Last chunk end should cover the duration
|
||||
assert chunks[-1]["end"] >= dur - 0.05
|
||||
|
||||
# Cleanup
|
||||
for c in chunks:
|
||||
if os.path.exists(c["path"]):
|
||||
os.remove(c["path"])
|
||||
@@ -0,0 +1,230 @@
|
||||
import os
|
||||
import json
|
||||
import tempfile
|
||||
from unittest.mock import patch, MagicMock
|
||||
|
||||
import pytest
|
||||
|
||||
from scraibe.localai_client import LocalAIClient, LocalAIError
|
||||
from scraibe.audio import get_audio_duration, split_audio_into_chunks
|
||||
|
||||
|
||||
TEST_AUDIO_1 = "tests/audio_test_1.mp4"
|
||||
|
||||
|
||||
def make_fake_segments(start=0.0, count=3):
|
||||
segments = []
|
||||
for i in range(count):
|
||||
s = start + i * 2.0
|
||||
e = s + 2.0
|
||||
segments.append({
|
||||
"start": s,
|
||||
"end": e,
|
||||
"speaker": "SPEAKER_00",
|
||||
"text": f"Segment text {i}",
|
||||
})
|
||||
return segments
|
||||
|
||||
|
||||
def fake_localai_response(segments):
|
||||
return {
|
||||
"segments": segments,
|
||||
"text": " ".join(seg["text"] for seg in segments),
|
||||
}
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def client():
|
||||
with patch.object(LocalAIClient, "__init__", lambda self, **kw: None):
|
||||
c = LocalAIClient()
|
||||
c.api_url = "http://localhost:8080"
|
||||
c.model = "vibevoice-diarize"
|
||||
c.api_key = None
|
||||
c._client = MagicMock()
|
||||
return c
|
||||
|
||||
|
||||
def test_parse_diarization_response(client):
|
||||
segs = make_fake_segments()
|
||||
raw = fake_localai_response(segs)
|
||||
|
||||
out = client._parse_diarization_response(raw)
|
||||
|
||||
assert "segments" in out
|
||||
assert "speakers" in out
|
||||
assert "transcripts" in out
|
||||
assert len(out["segments"]) == len(segs)
|
||||
for i, s in enumerate(segs):
|
||||
assert out["segments"][i][0] == s["start"]
|
||||
assert out["segments"][i][1] == s["end"]
|
||||
assert out["speakers"][i] == s["speaker"]
|
||||
assert out["transcripts"][i] == s["text"]
|
||||
|
||||
|
||||
def test_parse_diarization_empty(client):
|
||||
out = client._parse_diarization_response({"segments": []})
|
||||
assert out["segments"] == []
|
||||
assert out["speakers"] == []
|
||||
assert out["transcripts"] == []
|
||||
|
||||
|
||||
def test_diarize_and_transcribe_single_happy(client):
|
||||
with patch.object(client, "_client") as mock_client:
|
||||
mock_resp = MagicMock()
|
||||
mock_resp.status_code = 200
|
||||
mock_resp.json.return_value = fake_localai_response(make_fake_segments())
|
||||
mock_client.post.return_value = mock_resp
|
||||
|
||||
result = client.diarize_and_transcribe(
|
||||
audio_path=TEST_AUDIO_1,
|
||||
verbose=False,
|
||||
return_raw=True,
|
||||
)
|
||||
|
||||
assert "segments" in result
|
||||
assert "raw_result" in result
|
||||
assert len(result["segments"]) > 0
|
||||
|
||||
|
||||
def test_chunking_triggered_for_long_audio(client):
|
||||
# Simulate long audio by patching get_audio_duration
|
||||
with patch("scraibe.localai_client.get_audio_duration") as mock_dur, \
|
||||
patch.object(client, "_diarize_and_transcribe_chunked") as mock_chunked:
|
||||
|
||||
mock_dur.return_value = 600.0 # 10 minutes
|
||||
mock_chunked.return_value = {
|
||||
"segments": [],
|
||||
"speakers": [],
|
||||
"transcripts": [],
|
||||
}
|
||||
|
||||
client.diarize_and_transcribe(
|
||||
audio_path=TEST_AUDIO_1,
|
||||
verbose=False,
|
||||
use_chunking=None,
|
||||
max_single_request_duration=300.0,
|
||||
)
|
||||
|
||||
mock_chunked.assert_called_once()
|
||||
|
||||
|
||||
def test_chunking_not_triggered_for_short_audio(client):
|
||||
with patch("scraibe.localai_client.get_audio_duration") as mock_dur, \
|
||||
patch.object(client, "_diarize_and_transcribe_chunked") as mock_chunked, \
|
||||
patch.object(client, "_diarize_and_transcribe_single") as mock_single:
|
||||
|
||||
mock_dur.return_value = 120.0
|
||||
mock_single.return_value = {
|
||||
"segments": [],
|
||||
"speakers": [],
|
||||
"transcripts": [],
|
||||
}
|
||||
|
||||
client.diarize_and_transcribe(
|
||||
audio_path=TEST_AUDIO_1,
|
||||
verbose=False,
|
||||
use_chunking=None,
|
||||
max_single_request_duration=300.0,
|
||||
)
|
||||
|
||||
mock_chunked.assert_not_called()
|
||||
mock_single.assert_called_once()
|
||||
|
||||
|
||||
def test_chunked_transcription_adjusts_timestamps(client):
|
||||
# Mock split_audio_into_chunks to return two chunks
|
||||
chunk1_path = TEST_AUDIO_1
|
||||
chunk2_path = TEST_AUDIO_1 # reusing same file; in real usage different
|
||||
|
||||
chunks = [
|
||||
{"path": chunk1_path, "start": 0.0, "end": 10.0},
|
||||
{"path": chunk2_path, "start": 10.0, "end": 20.0},
|
||||
]
|
||||
|
||||
with patch("scraibe.localai_client.split_audio_into_chunks") as mock_split, \
|
||||
patch.object(client, "_diarize_and_transcribe_single") as mock_single, \
|
||||
patch("os.remove"):
|
||||
|
||||
mock_split.return_value = chunks
|
||||
|
||||
# First chunk: segments 0–4
|
||||
# Second chunk: segments 0–4 (local times)
|
||||
def side_effect(audio_path, **kw):
|
||||
if audio_path == chunk1_path:
|
||||
segs = make_fake_segments(start=0.0, count=2)
|
||||
else:
|
||||
segs = make_fake_segments(start=0.0, count=2)
|
||||
return client._parse_diarization_response(fake_localai_response(segs))
|
||||
|
||||
mock_single.side_effect = side_effect
|
||||
|
||||
result = client._diarize_and_transcribe_chunked(
|
||||
audio_path=TEST_AUDIO_1,
|
||||
verbose=False,
|
||||
return_raw=False,
|
||||
chunk_duration=10.0,
|
||||
chunk_overlap=2.0,
|
||||
)
|
||||
|
||||
# Check we got 4 segments total
|
||||
assert len(result["segments"]) == 4
|
||||
|
||||
# First two segments should be in [0, 4]
|
||||
assert result["segments"][0][0] == 0.0
|
||||
assert result["segments"][1][0] == 2.0
|
||||
|
||||
# Next two segments should be shifted by 10
|
||||
assert result["segments"][2][0] == 10.0
|
||||
assert result["segments"][3][0] == 12.0
|
||||
|
||||
|
||||
@pytest.mark.integration
|
||||
def test_integration_chunked_transcription_with_localai():
|
||||
"""
|
||||
Integration test: run chunked transcription against a live LocalAI instance.
|
||||
Only runs if LOCALAI_API_URL is set and an audio file is provided.
|
||||
This test is skipped by default unless run with:
|
||||
pytest -m integration
|
||||
"""
|
||||
api_url = os.getenv("LOCALAI_API_URL")
|
||||
if not api_url:
|
||||
pytest.skip("LOCALAI_API_URL not set; skipping integration test")
|
||||
|
||||
# Use one of the bundled test audio files
|
||||
audio_path = TEST_AUDIO_1
|
||||
if not os.path.exists(audio_path):
|
||||
pytest.skip(f"Test audio not found: {audio_path}")
|
||||
|
||||
# Force chunking with a very small max_single_request_duration
|
||||
# Use environment-configured model or a sensible default
|
||||
model = os.getenv("LOCALAI_MODEL") or "vibevoice-cpp-asr"
|
||||
|
||||
client = LocalAIClient(api_url=api_url, model=model)
|
||||
try:
|
||||
result = client.diarize_and_transcribe(
|
||||
audio_path=audio_path,
|
||||
verbose=True,
|
||||
return_raw=True,
|
||||
use_chunking=True,
|
||||
chunk_duration=3.0,
|
||||
chunk_overlap=0.5,
|
||||
max_single_request_duration=1.0,
|
||||
)
|
||||
|
||||
assert "segments" in result
|
||||
assert len(result["segments"]) > 0
|
||||
|
||||
# Basic sanity: segments are time-ordered
|
||||
for i in range(1, len(result["segments"])):
|
||||
prev_end = result["segments"][i - 1][1]
|
||||
curr_start = result["segments"][i][0]
|
||||
assert curr_start >= result["segments"][i - 1][0]
|
||||
|
||||
# If raw_result indicates chunked, ensure structure is sensible
|
||||
raw = result.get("raw_result")
|
||||
if raw and raw.get("chunked"):
|
||||
assert "chunks" in raw
|
||||
assert len(raw["chunks"]) > 1
|
||||
|
||||
finally:
|
||||
client.close()
|
||||
Reference in New Issue
Block a user