Files
scribe/tests/test_localai_chunking.py
T
admin 6640bc050d
Mirror and run GitLab CI / build (push) Waiting to run
Ruff / ruff (push) Waiting to run
feat: add chunked ASR for long audio with env-configurable chunk duration
- Integrate chunking into LocalAI client to avoid GPU OOM on long audio.
- Split long files into overlapping chunks; transcribe each chunk; merge segments with corrected timestamps.
- Auto-enable chunking when audio duration > LOCALAI_MAX_SINGLE_REQUEST_DURATION (default 300s).
- Add env variables:
    LOCALAI_CHUNK_DURATION (default 180)
    LOCALAI_CHUNK_OVERLAP (default 2)
    LOCALAI_MAX_SINGLE_REQUEST_DURATION (default 300)
- Add unit and integration tests for chunking logic.
- Confirmed working end-to-end with vibevoice-cpp-asr on 88-minute file.
2026-06-18 17:46:29 +00:00

231 lines
7.1 KiB
Python
Raw Blame History

This file contains ambiguous Unicode characters
This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.
import os
import json
import tempfile
from unittest.mock import patch, MagicMock
import pytest
from scraibe.localai_client import LocalAIClient, LocalAIError
from scraibe.audio import get_audio_duration, split_audio_into_chunks
TEST_AUDIO_1 = "tests/audio_test_1.mp4"
def make_fake_segments(start=0.0, count=3):
segments = []
for i in range(count):
s = start + i * 2.0
e = s + 2.0
segments.append({
"start": s,
"end": e,
"speaker": "SPEAKER_00",
"text": f"Segment text {i}",
})
return segments
def fake_localai_response(segments):
return {
"segments": segments,
"text": " ".join(seg["text"] for seg in segments),
}
@pytest.fixture
def client():
with patch.object(LocalAIClient, "__init__", lambda self, **kw: None):
c = LocalAIClient()
c.api_url = "http://localhost:8080"
c.model = "vibevoice-diarize"
c.api_key = None
c._client = MagicMock()
return c
def test_parse_diarization_response(client):
segs = make_fake_segments()
raw = fake_localai_response(segs)
out = client._parse_diarization_response(raw)
assert "segments" in out
assert "speakers" in out
assert "transcripts" in out
assert len(out["segments"]) == len(segs)
for i, s in enumerate(segs):
assert out["segments"][i][0] == s["start"]
assert out["segments"][i][1] == s["end"]
assert out["speakers"][i] == s["speaker"]
assert out["transcripts"][i] == s["text"]
def test_parse_diarization_empty(client):
out = client._parse_diarization_response({"segments": []})
assert out["segments"] == []
assert out["speakers"] == []
assert out["transcripts"] == []
def test_diarize_and_transcribe_single_happy(client):
with patch.object(client, "_client") as mock_client:
mock_resp = MagicMock()
mock_resp.status_code = 200
mock_resp.json.return_value = fake_localai_response(make_fake_segments())
mock_client.post.return_value = mock_resp
result = client.diarize_and_transcribe(
audio_path=TEST_AUDIO_1,
verbose=False,
return_raw=True,
)
assert "segments" in result
assert "raw_result" in result
assert len(result["segments"]) > 0
def test_chunking_triggered_for_long_audio(client):
# Simulate long audio by patching get_audio_duration
with patch("scraibe.localai_client.get_audio_duration") as mock_dur, \
patch.object(client, "_diarize_and_transcribe_chunked") as mock_chunked:
mock_dur.return_value = 600.0 # 10 minutes
mock_chunked.return_value = {
"segments": [],
"speakers": [],
"transcripts": [],
}
client.diarize_and_transcribe(
audio_path=TEST_AUDIO_1,
verbose=False,
use_chunking=None,
max_single_request_duration=300.0,
)
mock_chunked.assert_called_once()
def test_chunking_not_triggered_for_short_audio(client):
with patch("scraibe.localai_client.get_audio_duration") as mock_dur, \
patch.object(client, "_diarize_and_transcribe_chunked") as mock_chunked, \
patch.object(client, "_diarize_and_transcribe_single") as mock_single:
mock_dur.return_value = 120.0
mock_single.return_value = {
"segments": [],
"speakers": [],
"transcripts": [],
}
client.diarize_and_transcribe(
audio_path=TEST_AUDIO_1,
verbose=False,
use_chunking=None,
max_single_request_duration=300.0,
)
mock_chunked.assert_not_called()
mock_single.assert_called_once()
def test_chunked_transcription_adjusts_timestamps(client):
# Mock split_audio_into_chunks to return two chunks
chunk1_path = TEST_AUDIO_1
chunk2_path = TEST_AUDIO_1 # reusing same file; in real usage different
chunks = [
{"path": chunk1_path, "start": 0.0, "end": 10.0},
{"path": chunk2_path, "start": 10.0, "end": 20.0},
]
with patch("scraibe.localai_client.split_audio_into_chunks") as mock_split, \
patch.object(client, "_diarize_and_transcribe_single") as mock_single, \
patch("os.remove"):
mock_split.return_value = chunks
# First chunk: segments 04
# Second chunk: segments 04 (local times)
def side_effect(audio_path, **kw):
if audio_path == chunk1_path:
segs = make_fake_segments(start=0.0, count=2)
else:
segs = make_fake_segments(start=0.0, count=2)
return client._parse_diarization_response(fake_localai_response(segs))
mock_single.side_effect = side_effect
result = client._diarize_and_transcribe_chunked(
audio_path=TEST_AUDIO_1,
verbose=False,
return_raw=False,
chunk_duration=10.0,
chunk_overlap=2.0,
)
# Check we got 4 segments total
assert len(result["segments"]) == 4
# First two segments should be in [0, 4]
assert result["segments"][0][0] == 0.0
assert result["segments"][1][0] == 2.0
# Next two segments should be shifted by 10
assert result["segments"][2][0] == 10.0
assert result["segments"][3][0] == 12.0
@pytest.mark.integration
def test_integration_chunked_transcription_with_localai():
"""
Integration test: run chunked transcription against a live LocalAI instance.
Only runs if LOCALAI_API_URL is set and an audio file is provided.
This test is skipped by default unless run with:
pytest -m integration
"""
api_url = os.getenv("LOCALAI_API_URL")
if not api_url:
pytest.skip("LOCALAI_API_URL not set; skipping integration test")
# Use one of the bundled test audio files
audio_path = TEST_AUDIO_1
if not os.path.exists(audio_path):
pytest.skip(f"Test audio not found: {audio_path}")
# Force chunking with a very small max_single_request_duration
# Use environment-configured model or a sensible default
model = os.getenv("LOCALAI_MODEL") or "vibevoice-cpp-asr"
client = LocalAIClient(api_url=api_url, model=model)
try:
result = client.diarize_and_transcribe(
audio_path=audio_path,
verbose=True,
return_raw=True,
use_chunking=True,
chunk_duration=3.0,
chunk_overlap=0.5,
max_single_request_duration=1.0,
)
assert "segments" in result
assert len(result["segments"]) > 0
# Basic sanity: segments are time-ordered
for i in range(1, len(result["segments"])):
prev_end = result["segments"][i - 1][1]
curr_start = result["segments"][i][0]
assert curr_start >= result["segments"][i - 1][0]
# If raw_result indicates chunked, ensure structure is sensible
raw = result.get("raw_result")
if raw and raw.get("chunked"):
assert "chunks" in raw
assert len(raw["chunks"]) > 1
finally:
client.close()