feat: add chunked ASR for long audio with env-configurable chunk duration
- Integrate chunking into LocalAI client to avoid GPU OOM on long audio.
- Split long files into overlapping chunks; transcribe each chunk; merge segments with corrected timestamps.
- Auto-enable chunking when audio duration > LOCALAI_MAX_SINGLE_REQUEST_DURATION (default 300s).
- Add env variables:
LOCALAI_CHUNK_DURATION (default 180)
LOCALAI_CHUNK_OVERLAP (default 2)
LOCALAI_MAX_SINGLE_REQUEST_DURATION (default 300)
- Add unit and integration tests for chunking logic.
- Confirmed working end-to-end with vibevoice-cpp-asr on 88-minute file.
This commit is contained in:
@@ -7,13 +7,21 @@ Simplified audio processor for ScrAIbe.
|
||||
Previously this used torch and pyannote-style processing. In the LocalAI-backed
|
||||
version, we primarily pass files to the API, but we keep a lightweight helper
|
||||
for backward compatibility.
|
||||
|
||||
Now also includes utilities for chunking long audio into smaller segments
|
||||
to avoid GPU memory limits when using vibevoice-cpp on LocalAI.
|
||||
"""
|
||||
|
||||
import json
|
||||
import os
|
||||
import tempfile
|
||||
from subprocess import CalledProcessError, run
|
||||
import numpy as np
|
||||
|
||||
SAMPLE_RATE = 16000
|
||||
NORMALIZATION_FACTOR = 32768.0
|
||||
DEFAULT_CHUNK_DURATION = 180.0 # seconds
|
||||
DEFAULT_CHUNK_OVERLAP = 2.0 # seconds
|
||||
|
||||
|
||||
class AudioProcessor:
|
||||
@@ -106,3 +114,109 @@ class AudioProcessor:
|
||||
|
||||
def __repr__(self) -> str:
|
||||
return f"AudioProcessor(waveform_len={len(self.waveform)}, sr={self.sr})"
|
||||
|
||||
|
||||
def get_audio_duration(file_path: str) -> float:
|
||||
"""
|
||||
Get the duration of an audio file in seconds using ffprobe.
|
||||
|
||||
Args:
|
||||
file_path: Path to the audio file.
|
||||
|
||||
Returns:
|
||||
Duration in seconds as a float.
|
||||
|
||||
Raises:
|
||||
RuntimeError: If ffprobe fails.
|
||||
"""
|
||||
cmd = [
|
||||
"ffprobe",
|
||||
"-v", "error",
|
||||
"-show_entries", "format=duration",
|
||||
"-of", "json",
|
||||
file_path,
|
||||
]
|
||||
try:
|
||||
result = run(cmd, capture_output=True, text=True, check=True)
|
||||
data = json.loads(result.stdout)
|
||||
return float(data["format"]["duration"])
|
||||
except (CalledProcessError, json.JSONDecodeError, KeyError) as e:
|
||||
raise RuntimeError(f"Failed to get audio duration for {file_path}: {e}")
|
||||
|
||||
|
||||
def split_audio_into_chunks(
|
||||
input_path: str,
|
||||
max_duration: float = DEFAULT_CHUNK_DURATION,
|
||||
overlap: float = DEFAULT_CHUNK_OVERLAP,
|
||||
output_format: str = "wav",
|
||||
sample_rate: int = 24000,
|
||||
) -> list:
|
||||
"""
|
||||
Split a long audio file into overlapping chunks using ffmpeg.
|
||||
|
||||
Args:
|
||||
input_path: Path to the input audio file.
|
||||
max_duration: Maximum duration of each chunk in seconds.
|
||||
overlap: Overlap duration in seconds between consecutive chunks.
|
||||
output_format: Output format (e.g., 'wav').
|
||||
sample_rate: Sample rate for output chunks.
|
||||
|
||||
Returns:
|
||||
List of dicts:
|
||||
[{"path": "chunk.wav", "start": 0.0, "end": 180.0}, ...]
|
||||
Files must be cleaned up by the caller.
|
||||
"""
|
||||
duration = get_audio_duration(input_path)
|
||||
|
||||
# If file is shorter than max_duration, no need to split
|
||||
if duration <= max_duration:
|
||||
return [{"path": input_path, "start": 0.0, "end": duration}]
|
||||
|
||||
chunks = []
|
||||
start = 0.0
|
||||
chunk_id = 0
|
||||
|
||||
while start < duration:
|
||||
chunk_end = min(start + max_duration, duration)
|
||||
chunk_duration = chunk_end - start
|
||||
|
||||
tmp = tempfile.NamedTemporaryFile(
|
||||
delete=False,
|
||||
suffix=f".{output_format}",
|
||||
prefix="scraibe_chunk_",
|
||||
)
|
||||
chunk_path = tmp.name
|
||||
tmp.close()
|
||||
|
||||
cmd = [
|
||||
"ffmpeg",
|
||||
"-y",
|
||||
"-nostdin",
|
||||
"-ss", str(start),
|
||||
"-i", input_path,
|
||||
"-t", str(chunk_duration),
|
||||
"-ar", str(sample_rate),
|
||||
"-ac", "1",
|
||||
"-c:a", "pcm_s16le",
|
||||
chunk_path,
|
||||
]
|
||||
try:
|
||||
run(cmd, capture_output=True, check=True)
|
||||
except CalledProcessError as e:
|
||||
# Clean up on error
|
||||
if os.path.exists(chunk_path):
|
||||
os.remove(chunk_path)
|
||||
raise RuntimeError(
|
||||
f"Failed to create audio chunk {chunk_id} for {input_path}: {e.stderr.decode()}"
|
||||
)
|
||||
|
||||
chunks.append({
|
||||
"path": chunk_path,
|
||||
"start": start,
|
||||
"end": chunk_end,
|
||||
})
|
||||
|
||||
start += max_duration - overlap
|
||||
chunk_id += 1
|
||||
|
||||
return chunks
|
||||
|
||||
@@ -9,11 +9,21 @@ It replaces the previous local Whisper + Pyannote pipeline by sending
|
||||
audio files to the /v1/audio/diarization endpoint and mapping the
|
||||
response into the same Transcript format used by the UI.
|
||||
|
||||
For long audio files, it can chunk the input to avoid GPU OOM errors.
|
||||
|
||||
Environment Variables:
|
||||
LOCALAI_API_URL: (required) Base URL of the LocalAI server
|
||||
(e.g., http://localhost:8080)
|
||||
LOCALAI_API_KEY: (optional) API key, if configured
|
||||
LOCALAI_MODEL: (optional) Model name to use (default: vibevoice-diarize)
|
||||
|
||||
Chunking / long audio (all optional):
|
||||
LOCALAI_CHUNK_DURATION: Max duration of each chunk in seconds
|
||||
(default: 180.0)
|
||||
LOCALAI_CHUNK_OVERLAP: Overlap between consecutive chunks in seconds
|
||||
(default: 2.0)
|
||||
LOCALAI_MAX_SINGLE_REQUEST_DURATION: If audio duration exceeds this, chunking
|
||||
is enabled automatically (default: 300.0)
|
||||
"""
|
||||
|
||||
import os
|
||||
@@ -24,6 +34,8 @@ from typing import Dict, List, Any, Optional
|
||||
|
||||
import httpx
|
||||
|
||||
from .audio import get_audio_duration, split_audio_into_chunks
|
||||
|
||||
logger = logging.getLogger("scraibe.localai_client")
|
||||
|
||||
|
||||
@@ -41,8 +53,14 @@ class LocalAIClient:
|
||||
- Upload audio file as multipart/form-data.
|
||||
- Parse diarization + transcription response (verbose_json).
|
||||
- Map response into the same structure expected by Scraibe's Transcript.
|
||||
- For long audio: chunk, transcribe each chunk, merge results.
|
||||
"""
|
||||
|
||||
# Default thresholds for chunking long audio to avoid GPU OOM.
|
||||
# These can be overridden via environment or at call time.
|
||||
DEFAULT_CHUNK_DURATION = 180.0 # seconds
|
||||
DEFAULT_CHUNK_OVERLAP = 2.0 # seconds
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
api_url: Optional[str] = None,
|
||||
@@ -82,6 +100,55 @@ class LocalAIClient:
|
||||
follow_redirects=True,
|
||||
)
|
||||
|
||||
@staticmethod
|
||||
def _env_float(var: str, default: float) -> float:
|
||||
"""
|
||||
Read a float from environment with a fallback default.
|
||||
"""
|
||||
val = (os.getenv(var) or "").strip()
|
||||
if val == "":
|
||||
return default
|
||||
try:
|
||||
return float(val)
|
||||
except ValueError:
|
||||
logger.warning(
|
||||
"Invalid value for %s: %s; using default %s", var, val, default
|
||||
)
|
||||
return default
|
||||
|
||||
def _effective_chunk_duration(self, provided: Optional[float]) -> float:
|
||||
"""
|
||||
Resolve chunk_duration using this precedence:
|
||||
1) provided argument
|
||||
2) LOCALAI_CHUNK_DURATION env
|
||||
3) class default
|
||||
"""
|
||||
if provided is not None:
|
||||
return provided
|
||||
return self._env_float("LOCALAI_CHUNK_DURATION", self.DEFAULT_CHUNK_DURATION)
|
||||
|
||||
def _effective_chunk_overlap(self, provided: Optional[float]) -> float:
|
||||
"""
|
||||
Resolve chunk_overlap:
|
||||
1) provided argument
|
||||
2) LOCALAI_CHUNK_OVERLAP env
|
||||
3) class default
|
||||
"""
|
||||
if provided is not None:
|
||||
return provided
|
||||
return self._env_float("LOCALAI_CHUNK_OVERLAP", self.DEFAULT_CHUNK_OVERLAP)
|
||||
|
||||
def _effective_max_single_request_duration(self, provided: Optional[float]) -> float:
|
||||
"""
|
||||
Resolve max_single_request_duration:
|
||||
1) provided argument
|
||||
2) LOCALAI_MAX_SINGLE_REQUEST_DURATION env
|
||||
3) default 300.0
|
||||
"""
|
||||
if provided is not None:
|
||||
return provided
|
||||
return self._env_float("LOCALAI_MAX_SINGLE_REQUEST_DURATION", 300.0)
|
||||
|
||||
def close(self):
|
||||
"""Close the underlying HTTP client."""
|
||||
self._client.close()
|
||||
@@ -107,6 +174,10 @@ class LocalAIClient:
|
||||
include_text: Optional[bool] = None,
|
||||
verbose: bool = False,
|
||||
return_raw: bool = False,
|
||||
use_chunking: Optional[bool] = None,
|
||||
chunk_duration: Optional[float] = None,
|
||||
chunk_overlap: Optional[float] = None,
|
||||
max_single_request_duration: Optional[float] = None,
|
||||
**_ignored,
|
||||
) -> Dict[str, Any]:
|
||||
"""
|
||||
@@ -114,6 +185,8 @@ class LocalAIClient:
|
||||
- A normalized dict with segments, speakers, transcripts.
|
||||
- Optionally, the raw verbose_json response (for JSON export).
|
||||
|
||||
For long audio, it can automatically chunk the file to avoid GPU OOM.
|
||||
|
||||
Args:
|
||||
audio_path: Path to the audio file.
|
||||
language: Language hint, forwarded if set.
|
||||
@@ -129,6 +202,93 @@ class LocalAIClient:
|
||||
Defaults to True.
|
||||
verbose: If True, prints progress messages.
|
||||
return_raw: If True, also return the raw API response in 'raw_result'.
|
||||
use_chunking: Whether to enable chunking for long audio.
|
||||
If None, enabled automatically based on duration.
|
||||
chunk_duration: Max duration per chunk in seconds.
|
||||
Falls back to LOCALAI_CHUNK_DURATION env, then 180.0.
|
||||
chunk_overlap: Overlap between chunks in seconds.
|
||||
Falls back to LOCALAI_CHUNK_OVERLAP env, then 2.0.
|
||||
max_single_request_duration: If audio duration exceeds this, chunking
|
||||
is enabled (unless explicitly disabled).
|
||||
Falls back to LOCALAI_MAX_SINGLE_REQUEST_DURATION
|
||||
env, then 300.0.
|
||||
"""
|
||||
if verbose:
|
||||
print("Starting diarization and transcription via LocalAI.")
|
||||
|
||||
logger.info("diarize_and_transcribe requested for: %s", audio_path)
|
||||
|
||||
# Resolve chunking parameters with environment support
|
||||
chunk_duration = self._effective_chunk_duration(chunk_duration)
|
||||
chunk_overlap = self._effective_chunk_overlap(chunk_overlap)
|
||||
max_single = self._effective_max_single_request_duration(max_single_request_duration)
|
||||
|
||||
if use_chunking is None:
|
||||
try:
|
||||
duration = get_audio_duration(audio_path)
|
||||
except RuntimeError:
|
||||
duration = None
|
||||
|
||||
use_chunking = (duration is not None and duration > max_single)
|
||||
logger.info(
|
||||
"Auto-chunking decision: duration=%s, threshold=%s, use_chunking=%s",
|
||||
duration,
|
||||
max_single,
|
||||
use_chunking,
|
||||
)
|
||||
|
||||
if use_chunking:
|
||||
return self._diarize_and_transcribe_chunked(
|
||||
audio_path=audio_path,
|
||||
language=language,
|
||||
num_speakers=num_speakers,
|
||||
min_speakers=min_speakers,
|
||||
max_speakers=max_speakers,
|
||||
clustering_threshold=clustering_threshold,
|
||||
min_duration_on=min_duration_on,
|
||||
min_duration_off=min_duration_off,
|
||||
response_format=response_format,
|
||||
include_text=include_text,
|
||||
verbose=verbose,
|
||||
return_raw=return_raw,
|
||||
chunk_duration=chunk_duration,
|
||||
chunk_overlap=chunk_overlap,
|
||||
)
|
||||
|
||||
# Single-request path (existing behavior)
|
||||
return self._diarize_and_transcribe_single(
|
||||
audio_path=audio_path,
|
||||
language=language,
|
||||
num_speakers=num_speakers,
|
||||
min_speakers=min_speakers,
|
||||
max_speakers=max_speakers,
|
||||
clustering_threshold=clustering_threshold,
|
||||
min_duration_on=min_duration_on,
|
||||
min_duration_off=min_duration_off,
|
||||
response_format=response_format,
|
||||
include_text=include_text,
|
||||
verbose=verbose,
|
||||
return_raw=return_raw,
|
||||
)
|
||||
|
||||
def _diarize_and_transcribe_single(
|
||||
self,
|
||||
audio_path: str,
|
||||
*,
|
||||
language: Optional[str] = None,
|
||||
num_speakers: Optional[int] = None,
|
||||
min_speakers: Optional[int] = None,
|
||||
max_speakers: Optional[int] = None,
|
||||
clustering_threshold: Optional[float] = None,
|
||||
min_duration_on: Optional[float] = None,
|
||||
min_duration_off: Optional[float] = None,
|
||||
response_format: Optional[str] = None,
|
||||
include_text: Optional[bool] = None,
|
||||
verbose: bool = False,
|
||||
return_raw: bool = False,
|
||||
) -> Dict[str, Any]:
|
||||
"""
|
||||
Internal: single-request diarization and transcription.
|
||||
"""
|
||||
if verbose:
|
||||
print("Starting diarization and transcription via LocalAI.")
|
||||
@@ -214,6 +374,153 @@ class LocalAIClient:
|
||||
|
||||
return parsed
|
||||
|
||||
def _diarize_and_transcribe_chunked(
|
||||
self,
|
||||
audio_path: str,
|
||||
*,
|
||||
language: Optional[str] = None,
|
||||
num_speakers: Optional[int] = None,
|
||||
min_speakers: Optional[int] = None,
|
||||
max_speakers: Optional[int] = None,
|
||||
clustering_threshold: Optional[float] = None,
|
||||
min_duration_on: Optional[float] = None,
|
||||
min_duration_off: Optional[float] = None,
|
||||
response_format: Optional[str] = None,
|
||||
include_text: Optional[bool] = None,
|
||||
verbose: bool = False,
|
||||
return_raw: bool = False,
|
||||
chunk_duration: float = DEFAULT_CHUNK_DURATION,
|
||||
chunk_overlap: float = DEFAULT_CHUNK_OVERLAP,
|
||||
) -> Dict[str, Any]:
|
||||
"""
|
||||
Internal: chunked diarization and transcription for long audio.
|
||||
|
||||
- Splits audio into overlapping chunks.
|
||||
- Transcribes each chunk via /v1/audio/diarization.
|
||||
- Merges segments with adjusted timestamps.
|
||||
"""
|
||||
if verbose:
|
||||
print("Audio is long; splitting into chunks to avoid GPU memory issues.")
|
||||
|
||||
logger.info(
|
||||
"Chunked transcription: chunk_duration=%s, overlap=%s",
|
||||
chunk_duration,
|
||||
chunk_overlap,
|
||||
)
|
||||
|
||||
chunks = split_audio_into_chunks(
|
||||
input_path=audio_path,
|
||||
max_duration=chunk_duration,
|
||||
overlap=chunk_overlap,
|
||||
)
|
||||
|
||||
if len(chunks) == 1:
|
||||
# No actual split needed; fall back to single-request path
|
||||
return self._diarize_and_transcribe_single(
|
||||
audio_path=chunks[0]["path"],
|
||||
language=language,
|
||||
num_speakers=num_speakers,
|
||||
min_speakers=min_speakers,
|
||||
max_speakers=max_speakers,
|
||||
clustering_threshold=clustering_threshold,
|
||||
min_duration_on=min_duration_on,
|
||||
min_duration_off=min_duration_off,
|
||||
response_format=response_format,
|
||||
include_text=include_text,
|
||||
verbose=verbose,
|
||||
return_raw=return_raw,
|
||||
)
|
||||
|
||||
all_segments: List[List[float]] = []
|
||||
all_speakers: List[str] = []
|
||||
all_transcripts: List[str] = []
|
||||
raw_results: List[Dict[str, Any]] = []
|
||||
temp_files = [c["path"] for c in chunks]
|
||||
|
||||
try:
|
||||
for i, chunk_info in enumerate(chunks):
|
||||
chunk_path = chunk_info["path"]
|
||||
chunk_start = chunk_info["start"]
|
||||
|
||||
if verbose:
|
||||
print(
|
||||
f"Transcribing chunk {i+1}/{len(chunks)} "
|
||||
f"(start={chunk_start:.1f}s)"
|
||||
)
|
||||
|
||||
logger.info(
|
||||
"Transcribing chunk %d/%d, start=%.1f", i + 1, len(chunks), chunk_start
|
||||
)
|
||||
|
||||
# Use single-request logic for each chunk
|
||||
chunk_result = self._diarize_and_transcribe_single(
|
||||
audio_path=chunk_path,
|
||||
language=language,
|
||||
num_speakers=num_speakers,
|
||||
min_speakers=min_speakers,
|
||||
max_speakers=max_speakers,
|
||||
clustering_threshold=clustering_threshold,
|
||||
min_duration_on=min_duration_on,
|
||||
min_duration_off=min_duration_off,
|
||||
response_format=response_format,
|
||||
include_text=include_text,
|
||||
verbose=False,
|
||||
return_raw=return_raw,
|
||||
)
|
||||
|
||||
segs = chunk_result.get("segments", [])
|
||||
spks = chunk_result.get("speakers", [])
|
||||
txts = chunk_result.get("transcripts", [])
|
||||
raw = chunk_result.get("raw_result")
|
||||
|
||||
# Adjust timestamps to global timeline
|
||||
adjusted_segs = []
|
||||
for seg, sp, txt in zip(segs, spks, txts):
|
||||
start = float(seg[0]) + chunk_start
|
||||
end = float(seg[1]) + chunk_start
|
||||
adjusted_segs.append([start, end])
|
||||
all_speakers.append(sp)
|
||||
all_transcripts.append(txt)
|
||||
all_segments.extend(adjusted_segs)
|
||||
|
||||
if return_raw and raw is not None:
|
||||
raw_results.append(raw)
|
||||
|
||||
finally:
|
||||
# Clean up temporary chunk files
|
||||
for path in temp_files:
|
||||
if path and os.path.exists(path) and path != audio_path:
|
||||
try:
|
||||
os.remove(path)
|
||||
except Exception as e:
|
||||
logger.warning("Failed to remove chunk file %s: %s", path, e)
|
||||
|
||||
# Sort segments by start time
|
||||
combined = list(zip(all_segments, all_speakers, all_transcripts))
|
||||
combined.sort(key=lambda x: x[0][0])
|
||||
all_segments = [x[0] for x in combined]
|
||||
all_speakers = [x[1] for x in combined]
|
||||
all_transcripts = [x[2] for x in combined]
|
||||
|
||||
if verbose:
|
||||
print(
|
||||
f"Chunked transcription complete. Total segments: {len(all_segments)}"
|
||||
)
|
||||
|
||||
result = {
|
||||
"segments": all_segments,
|
||||
"speakers": all_speakers,
|
||||
"transcripts": all_transcripts,
|
||||
}
|
||||
|
||||
if return_raw and raw_results:
|
||||
result["raw_result"] = {
|
||||
"chunked": True,
|
||||
"chunks": raw_results,
|
||||
}
|
||||
|
||||
return result
|
||||
|
||||
def _parse_diarization_response(self, result: Dict[str, Any]) -> Dict[str, Any]:
|
||||
"""
|
||||
Convert LocalAI verbose_json response into the internal format used by Scraibe:
|
||||
|
||||
@@ -0,0 +1,86 @@
|
||||
import os
|
||||
import subprocess
|
||||
import tempfile
|
||||
import pytest
|
||||
|
||||
from scraibe.audio import (
|
||||
get_audio_duration,
|
||||
split_audio_into_chunks,
|
||||
)
|
||||
|
||||
TEST_AUDIO_1 = "tests/audio_test_1.mp4"
|
||||
TEST_AUDIO_2 = "tests/audio_test_2.mp4"
|
||||
|
||||
|
||||
@pytest.fixture(params=[TEST_AUDIO_1, TEST_AUDIO_2])
|
||||
def test_audio_path(request):
|
||||
return request.param
|
||||
|
||||
|
||||
def test_get_audio_duration(test_audio_path):
|
||||
dur = get_audio_duration(test_audio_path)
|
||||
assert isinstance(dur, float)
|
||||
assert dur > 0
|
||||
|
||||
|
||||
def test_split_audio_into_chunks_no_split_short(test_audio_path):
|
||||
# For short files, should return the same file with no extra chunks
|
||||
chunks = split_audio_into_chunks(
|
||||
input_path=test_audio_path,
|
||||
max_duration=600.0, # larger than both test files
|
||||
overlap=2.0,
|
||||
)
|
||||
assert len(chunks) == 1
|
||||
assert chunks[0]["path"] == test_audio_path
|
||||
assert chunks[0]["start"] == 0.0
|
||||
dur = get_audio_duration(test_audio_path)
|
||||
assert abs(chunks[0]["end"] - dur) < 0.05
|
||||
|
||||
|
||||
def test_split_audio_into_chunks_creates_chunks(tmp_path):
|
||||
# Use a small chunk duration to force splitting
|
||||
chunks = split_audio_into_chunks(
|
||||
input_path=TEST_AUDIO_1,
|
||||
max_duration=2.0,
|
||||
overlap=0.5,
|
||||
)
|
||||
assert len(chunks) > 1
|
||||
|
||||
# Check that each chunk file exists and is non-empty
|
||||
for c in chunks:
|
||||
assert os.path.exists(c["path"])
|
||||
assert os.path.getsize(c["path"]) > 0
|
||||
|
||||
# Check time ordering and overlap
|
||||
for i in range(1, len(chunks)):
|
||||
prev = chunks[i - 1]
|
||||
curr = chunks[i]
|
||||
assert curr["start"] >= prev["start"]
|
||||
assert curr["start"] < prev["end"] # overlap
|
||||
|
||||
# Cleanup
|
||||
for c in chunks:
|
||||
if os.path.exists(c["path"]):
|
||||
os.remove(c["path"])
|
||||
|
||||
|
||||
def test_split_audio_into_chunks_total_coverage(test_audio_path):
|
||||
dur = get_audio_duration(test_audio_path)
|
||||
|
||||
# Use small chunks to ensure coverage
|
||||
chunks = split_audio_into_chunks(
|
||||
input_path=test_audio_path,
|
||||
max_duration=2.0,
|
||||
overlap=0.5,
|
||||
)
|
||||
|
||||
# First chunk starts at 0
|
||||
assert chunks[0]["start"] == 0.0
|
||||
|
||||
# Last chunk end should cover the duration
|
||||
assert chunks[-1]["end"] >= dur - 0.05
|
||||
|
||||
# Cleanup
|
||||
for c in chunks:
|
||||
if os.path.exists(c["path"]):
|
||||
os.remove(c["path"])
|
||||
@@ -0,0 +1,230 @@
|
||||
import os
|
||||
import json
|
||||
import tempfile
|
||||
from unittest.mock import patch, MagicMock
|
||||
|
||||
import pytest
|
||||
|
||||
from scraibe.localai_client import LocalAIClient, LocalAIError
|
||||
from scraibe.audio import get_audio_duration, split_audio_into_chunks
|
||||
|
||||
|
||||
TEST_AUDIO_1 = "tests/audio_test_1.mp4"
|
||||
|
||||
|
||||
def make_fake_segments(start=0.0, count=3):
|
||||
segments = []
|
||||
for i in range(count):
|
||||
s = start + i * 2.0
|
||||
e = s + 2.0
|
||||
segments.append({
|
||||
"start": s,
|
||||
"end": e,
|
||||
"speaker": "SPEAKER_00",
|
||||
"text": f"Segment text {i}",
|
||||
})
|
||||
return segments
|
||||
|
||||
|
||||
def fake_localai_response(segments):
|
||||
return {
|
||||
"segments": segments,
|
||||
"text": " ".join(seg["text"] for seg in segments),
|
||||
}
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def client():
|
||||
with patch.object(LocalAIClient, "__init__", lambda self, **kw: None):
|
||||
c = LocalAIClient()
|
||||
c.api_url = "http://localhost:8080"
|
||||
c.model = "vibevoice-diarize"
|
||||
c.api_key = None
|
||||
c._client = MagicMock()
|
||||
return c
|
||||
|
||||
|
||||
def test_parse_diarization_response(client):
|
||||
segs = make_fake_segments()
|
||||
raw = fake_localai_response(segs)
|
||||
|
||||
out = client._parse_diarization_response(raw)
|
||||
|
||||
assert "segments" in out
|
||||
assert "speakers" in out
|
||||
assert "transcripts" in out
|
||||
assert len(out["segments"]) == len(segs)
|
||||
for i, s in enumerate(segs):
|
||||
assert out["segments"][i][0] == s["start"]
|
||||
assert out["segments"][i][1] == s["end"]
|
||||
assert out["speakers"][i] == s["speaker"]
|
||||
assert out["transcripts"][i] == s["text"]
|
||||
|
||||
|
||||
def test_parse_diarization_empty(client):
|
||||
out = client._parse_diarization_response({"segments": []})
|
||||
assert out["segments"] == []
|
||||
assert out["speakers"] == []
|
||||
assert out["transcripts"] == []
|
||||
|
||||
|
||||
def test_diarize_and_transcribe_single_happy(client):
|
||||
with patch.object(client, "_client") as mock_client:
|
||||
mock_resp = MagicMock()
|
||||
mock_resp.status_code = 200
|
||||
mock_resp.json.return_value = fake_localai_response(make_fake_segments())
|
||||
mock_client.post.return_value = mock_resp
|
||||
|
||||
result = client.diarize_and_transcribe(
|
||||
audio_path=TEST_AUDIO_1,
|
||||
verbose=False,
|
||||
return_raw=True,
|
||||
)
|
||||
|
||||
assert "segments" in result
|
||||
assert "raw_result" in result
|
||||
assert len(result["segments"]) > 0
|
||||
|
||||
|
||||
def test_chunking_triggered_for_long_audio(client):
|
||||
# Simulate long audio by patching get_audio_duration
|
||||
with patch("scraibe.localai_client.get_audio_duration") as mock_dur, \
|
||||
patch.object(client, "_diarize_and_transcribe_chunked") as mock_chunked:
|
||||
|
||||
mock_dur.return_value = 600.0 # 10 minutes
|
||||
mock_chunked.return_value = {
|
||||
"segments": [],
|
||||
"speakers": [],
|
||||
"transcripts": [],
|
||||
}
|
||||
|
||||
client.diarize_and_transcribe(
|
||||
audio_path=TEST_AUDIO_1,
|
||||
verbose=False,
|
||||
use_chunking=None,
|
||||
max_single_request_duration=300.0,
|
||||
)
|
||||
|
||||
mock_chunked.assert_called_once()
|
||||
|
||||
|
||||
def test_chunking_not_triggered_for_short_audio(client):
|
||||
with patch("scraibe.localai_client.get_audio_duration") as mock_dur, \
|
||||
patch.object(client, "_diarize_and_transcribe_chunked") as mock_chunked, \
|
||||
patch.object(client, "_diarize_and_transcribe_single") as mock_single:
|
||||
|
||||
mock_dur.return_value = 120.0
|
||||
mock_single.return_value = {
|
||||
"segments": [],
|
||||
"speakers": [],
|
||||
"transcripts": [],
|
||||
}
|
||||
|
||||
client.diarize_and_transcribe(
|
||||
audio_path=TEST_AUDIO_1,
|
||||
verbose=False,
|
||||
use_chunking=None,
|
||||
max_single_request_duration=300.0,
|
||||
)
|
||||
|
||||
mock_chunked.assert_not_called()
|
||||
mock_single.assert_called_once()
|
||||
|
||||
|
||||
def test_chunked_transcription_adjusts_timestamps(client):
|
||||
# Mock split_audio_into_chunks to return two chunks
|
||||
chunk1_path = TEST_AUDIO_1
|
||||
chunk2_path = TEST_AUDIO_1 # reusing same file; in real usage different
|
||||
|
||||
chunks = [
|
||||
{"path": chunk1_path, "start": 0.0, "end": 10.0},
|
||||
{"path": chunk2_path, "start": 10.0, "end": 20.0},
|
||||
]
|
||||
|
||||
with patch("scraibe.localai_client.split_audio_into_chunks") as mock_split, \
|
||||
patch.object(client, "_diarize_and_transcribe_single") as mock_single, \
|
||||
patch("os.remove"):
|
||||
|
||||
mock_split.return_value = chunks
|
||||
|
||||
# First chunk: segments 0–4
|
||||
# Second chunk: segments 0–4 (local times)
|
||||
def side_effect(audio_path, **kw):
|
||||
if audio_path == chunk1_path:
|
||||
segs = make_fake_segments(start=0.0, count=2)
|
||||
else:
|
||||
segs = make_fake_segments(start=0.0, count=2)
|
||||
return client._parse_diarization_response(fake_localai_response(segs))
|
||||
|
||||
mock_single.side_effect = side_effect
|
||||
|
||||
result = client._diarize_and_transcribe_chunked(
|
||||
audio_path=TEST_AUDIO_1,
|
||||
verbose=False,
|
||||
return_raw=False,
|
||||
chunk_duration=10.0,
|
||||
chunk_overlap=2.0,
|
||||
)
|
||||
|
||||
# Check we got 4 segments total
|
||||
assert len(result["segments"]) == 4
|
||||
|
||||
# First two segments should be in [0, 4]
|
||||
assert result["segments"][0][0] == 0.0
|
||||
assert result["segments"][1][0] == 2.0
|
||||
|
||||
# Next two segments should be shifted by 10
|
||||
assert result["segments"][2][0] == 10.0
|
||||
assert result["segments"][3][0] == 12.0
|
||||
|
||||
|
||||
@pytest.mark.integration
|
||||
def test_integration_chunked_transcription_with_localai():
|
||||
"""
|
||||
Integration test: run chunked transcription against a live LocalAI instance.
|
||||
Only runs if LOCALAI_API_URL is set and an audio file is provided.
|
||||
This test is skipped by default unless run with:
|
||||
pytest -m integration
|
||||
"""
|
||||
api_url = os.getenv("LOCALAI_API_URL")
|
||||
if not api_url:
|
||||
pytest.skip("LOCALAI_API_URL not set; skipping integration test")
|
||||
|
||||
# Use one of the bundled test audio files
|
||||
audio_path = TEST_AUDIO_1
|
||||
if not os.path.exists(audio_path):
|
||||
pytest.skip(f"Test audio not found: {audio_path}")
|
||||
|
||||
# Force chunking with a very small max_single_request_duration
|
||||
# Use environment-configured model or a sensible default
|
||||
model = os.getenv("LOCALAI_MODEL") or "vibevoice-cpp-asr"
|
||||
|
||||
client = LocalAIClient(api_url=api_url, model=model)
|
||||
try:
|
||||
result = client.diarize_and_transcribe(
|
||||
audio_path=audio_path,
|
||||
verbose=True,
|
||||
return_raw=True,
|
||||
use_chunking=True,
|
||||
chunk_duration=3.0,
|
||||
chunk_overlap=0.5,
|
||||
max_single_request_duration=1.0,
|
||||
)
|
||||
|
||||
assert "segments" in result
|
||||
assert len(result["segments"]) > 0
|
||||
|
||||
# Basic sanity: segments are time-ordered
|
||||
for i in range(1, len(result["segments"])):
|
||||
prev_end = result["segments"][i - 1][1]
|
||||
curr_start = result["segments"][i][0]
|
||||
assert curr_start >= result["segments"][i - 1][0]
|
||||
|
||||
# If raw_result indicates chunked, ensure structure is sensible
|
||||
raw = result.get("raw_result")
|
||||
if raw and raw.get("chunked"):
|
||||
assert "chunks" in raw
|
||||
assert len(raw["chunks"]) > 1
|
||||
|
||||
finally:
|
||||
client.close()
|
||||
Reference in New Issue
Block a user