feat: add chunked ASR for long audio with env-configurable chunk duration
- Integrate chunking into LocalAI client to avoid GPU OOM on long audio.
- Split long files into overlapping chunks; transcribe each chunk; merge segments with corrected timestamps.
- Auto-enable chunking when audio duration > LOCALAI_MAX_SINGLE_REQUEST_DURATION (default 300s).
- Add env variables:
LOCALAI_CHUNK_DURATION (default 180)
LOCALAI_CHUNK_OVERLAP (default 2)
LOCALAI_MAX_SINGLE_REQUEST_DURATION (default 300)
- Add unit and integration tests for chunking logic.
- Confirmed working end-to-end with vibevoice-cpp-asr on 88-minute file.
This commit is contained in:
@@ -7,13 +7,21 @@ Simplified audio processor for ScrAIbe.
|
|||||||
Previously this used torch and pyannote-style processing. In the LocalAI-backed
|
Previously this used torch and pyannote-style processing. In the LocalAI-backed
|
||||||
version, we primarily pass files to the API, but we keep a lightweight helper
|
version, we primarily pass files to the API, but we keep a lightweight helper
|
||||||
for backward compatibility.
|
for backward compatibility.
|
||||||
|
|
||||||
|
Now also includes utilities for chunking long audio into smaller segments
|
||||||
|
to avoid GPU memory limits when using vibevoice-cpp on LocalAI.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
|
import json
|
||||||
|
import os
|
||||||
|
import tempfile
|
||||||
from subprocess import CalledProcessError, run
|
from subprocess import CalledProcessError, run
|
||||||
import numpy as np
|
import numpy as np
|
||||||
|
|
||||||
SAMPLE_RATE = 16000
|
SAMPLE_RATE = 16000
|
||||||
NORMALIZATION_FACTOR = 32768.0
|
NORMALIZATION_FACTOR = 32768.0
|
||||||
|
DEFAULT_CHUNK_DURATION = 180.0 # seconds
|
||||||
|
DEFAULT_CHUNK_OVERLAP = 2.0 # seconds
|
||||||
|
|
||||||
|
|
||||||
class AudioProcessor:
|
class AudioProcessor:
|
||||||
@@ -106,3 +114,109 @@ class AudioProcessor:
|
|||||||
|
|
||||||
def __repr__(self) -> str:
|
def __repr__(self) -> str:
|
||||||
return f"AudioProcessor(waveform_len={len(self.waveform)}, sr={self.sr})"
|
return f"AudioProcessor(waveform_len={len(self.waveform)}, sr={self.sr})"
|
||||||
|
|
||||||
|
|
||||||
|
def get_audio_duration(file_path: str) -> float:
|
||||||
|
"""
|
||||||
|
Get the duration of an audio file in seconds using ffprobe.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
file_path: Path to the audio file.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Duration in seconds as a float.
|
||||||
|
|
||||||
|
Raises:
|
||||||
|
RuntimeError: If ffprobe fails.
|
||||||
|
"""
|
||||||
|
cmd = [
|
||||||
|
"ffprobe",
|
||||||
|
"-v", "error",
|
||||||
|
"-show_entries", "format=duration",
|
||||||
|
"-of", "json",
|
||||||
|
file_path,
|
||||||
|
]
|
||||||
|
try:
|
||||||
|
result = run(cmd, capture_output=True, text=True, check=True)
|
||||||
|
data = json.loads(result.stdout)
|
||||||
|
return float(data["format"]["duration"])
|
||||||
|
except (CalledProcessError, json.JSONDecodeError, KeyError) as e:
|
||||||
|
raise RuntimeError(f"Failed to get audio duration for {file_path}: {e}")
|
||||||
|
|
||||||
|
|
||||||
|
def split_audio_into_chunks(
|
||||||
|
input_path: str,
|
||||||
|
max_duration: float = DEFAULT_CHUNK_DURATION,
|
||||||
|
overlap: float = DEFAULT_CHUNK_OVERLAP,
|
||||||
|
output_format: str = "wav",
|
||||||
|
sample_rate: int = 24000,
|
||||||
|
) -> list:
|
||||||
|
"""
|
||||||
|
Split a long audio file into overlapping chunks using ffmpeg.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
input_path: Path to the input audio file.
|
||||||
|
max_duration: Maximum duration of each chunk in seconds.
|
||||||
|
overlap: Overlap duration in seconds between consecutive chunks.
|
||||||
|
output_format: Output format (e.g., 'wav').
|
||||||
|
sample_rate: Sample rate for output chunks.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
List of dicts:
|
||||||
|
[{"path": "chunk.wav", "start": 0.0, "end": 180.0}, ...]
|
||||||
|
Files must be cleaned up by the caller.
|
||||||
|
"""
|
||||||
|
duration = get_audio_duration(input_path)
|
||||||
|
|
||||||
|
# If file is shorter than max_duration, no need to split
|
||||||
|
if duration <= max_duration:
|
||||||
|
return [{"path": input_path, "start": 0.0, "end": duration}]
|
||||||
|
|
||||||
|
chunks = []
|
||||||
|
start = 0.0
|
||||||
|
chunk_id = 0
|
||||||
|
|
||||||
|
while start < duration:
|
||||||
|
chunk_end = min(start + max_duration, duration)
|
||||||
|
chunk_duration = chunk_end - start
|
||||||
|
|
||||||
|
tmp = tempfile.NamedTemporaryFile(
|
||||||
|
delete=False,
|
||||||
|
suffix=f".{output_format}",
|
||||||
|
prefix="scraibe_chunk_",
|
||||||
|
)
|
||||||
|
chunk_path = tmp.name
|
||||||
|
tmp.close()
|
||||||
|
|
||||||
|
cmd = [
|
||||||
|
"ffmpeg",
|
||||||
|
"-y",
|
||||||
|
"-nostdin",
|
||||||
|
"-ss", str(start),
|
||||||
|
"-i", input_path,
|
||||||
|
"-t", str(chunk_duration),
|
||||||
|
"-ar", str(sample_rate),
|
||||||
|
"-ac", "1",
|
||||||
|
"-c:a", "pcm_s16le",
|
||||||
|
chunk_path,
|
||||||
|
]
|
||||||
|
try:
|
||||||
|
run(cmd, capture_output=True, check=True)
|
||||||
|
except CalledProcessError as e:
|
||||||
|
# Clean up on error
|
||||||
|
if os.path.exists(chunk_path):
|
||||||
|
os.remove(chunk_path)
|
||||||
|
raise RuntimeError(
|
||||||
|
f"Failed to create audio chunk {chunk_id} for {input_path}: {e.stderr.decode()}"
|
||||||
|
)
|
||||||
|
|
||||||
|
chunks.append({
|
||||||
|
"path": chunk_path,
|
||||||
|
"start": start,
|
||||||
|
"end": chunk_end,
|
||||||
|
})
|
||||||
|
|
||||||
|
start += max_duration - overlap
|
||||||
|
chunk_id += 1
|
||||||
|
|
||||||
|
return chunks
|
||||||
|
|||||||
@@ -9,11 +9,21 @@ It replaces the previous local Whisper + Pyannote pipeline by sending
|
|||||||
audio files to the /v1/audio/diarization endpoint and mapping the
|
audio files to the /v1/audio/diarization endpoint and mapping the
|
||||||
response into the same Transcript format used by the UI.
|
response into the same Transcript format used by the UI.
|
||||||
|
|
||||||
|
For long audio files, it can chunk the input to avoid GPU OOM errors.
|
||||||
|
|
||||||
Environment Variables:
|
Environment Variables:
|
||||||
LOCALAI_API_URL: (required) Base URL of the LocalAI server
|
LOCALAI_API_URL: (required) Base URL of the LocalAI server
|
||||||
(e.g., http://localhost:8080)
|
(e.g., http://localhost:8080)
|
||||||
LOCALAI_API_KEY: (optional) API key, if configured
|
LOCALAI_API_KEY: (optional) API key, if configured
|
||||||
LOCALAI_MODEL: (optional) Model name to use (default: vibevoice-diarize)
|
LOCALAI_MODEL: (optional) Model name to use (default: vibevoice-diarize)
|
||||||
|
|
||||||
|
Chunking / long audio (all optional):
|
||||||
|
LOCALAI_CHUNK_DURATION: Max duration of each chunk in seconds
|
||||||
|
(default: 180.0)
|
||||||
|
LOCALAI_CHUNK_OVERLAP: Overlap between consecutive chunks in seconds
|
||||||
|
(default: 2.0)
|
||||||
|
LOCALAI_MAX_SINGLE_REQUEST_DURATION: If audio duration exceeds this, chunking
|
||||||
|
is enabled automatically (default: 300.0)
|
||||||
"""
|
"""
|
||||||
|
|
||||||
import os
|
import os
|
||||||
@@ -24,6 +34,8 @@ from typing import Dict, List, Any, Optional
|
|||||||
|
|
||||||
import httpx
|
import httpx
|
||||||
|
|
||||||
|
from .audio import get_audio_duration, split_audio_into_chunks
|
||||||
|
|
||||||
logger = logging.getLogger("scraibe.localai_client")
|
logger = logging.getLogger("scraibe.localai_client")
|
||||||
|
|
||||||
|
|
||||||
@@ -41,8 +53,14 @@ class LocalAIClient:
|
|||||||
- Upload audio file as multipart/form-data.
|
- Upload audio file as multipart/form-data.
|
||||||
- Parse diarization + transcription response (verbose_json).
|
- Parse diarization + transcription response (verbose_json).
|
||||||
- Map response into the same structure expected by Scraibe's Transcript.
|
- Map response into the same structure expected by Scraibe's Transcript.
|
||||||
|
- For long audio: chunk, transcribe each chunk, merge results.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
|
# Default thresholds for chunking long audio to avoid GPU OOM.
|
||||||
|
# These can be overridden via environment or at call time.
|
||||||
|
DEFAULT_CHUNK_DURATION = 180.0 # seconds
|
||||||
|
DEFAULT_CHUNK_OVERLAP = 2.0 # seconds
|
||||||
|
|
||||||
def __init__(
|
def __init__(
|
||||||
self,
|
self,
|
||||||
api_url: Optional[str] = None,
|
api_url: Optional[str] = None,
|
||||||
@@ -82,6 +100,55 @@ class LocalAIClient:
|
|||||||
follow_redirects=True,
|
follow_redirects=True,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def _env_float(var: str, default: float) -> float:
|
||||||
|
"""
|
||||||
|
Read a float from environment with a fallback default.
|
||||||
|
"""
|
||||||
|
val = (os.getenv(var) or "").strip()
|
||||||
|
if val == "":
|
||||||
|
return default
|
||||||
|
try:
|
||||||
|
return float(val)
|
||||||
|
except ValueError:
|
||||||
|
logger.warning(
|
||||||
|
"Invalid value for %s: %s; using default %s", var, val, default
|
||||||
|
)
|
||||||
|
return default
|
||||||
|
|
||||||
|
def _effective_chunk_duration(self, provided: Optional[float]) -> float:
|
||||||
|
"""
|
||||||
|
Resolve chunk_duration using this precedence:
|
||||||
|
1) provided argument
|
||||||
|
2) LOCALAI_CHUNK_DURATION env
|
||||||
|
3) class default
|
||||||
|
"""
|
||||||
|
if provided is not None:
|
||||||
|
return provided
|
||||||
|
return self._env_float("LOCALAI_CHUNK_DURATION", self.DEFAULT_CHUNK_DURATION)
|
||||||
|
|
||||||
|
def _effective_chunk_overlap(self, provided: Optional[float]) -> float:
|
||||||
|
"""
|
||||||
|
Resolve chunk_overlap:
|
||||||
|
1) provided argument
|
||||||
|
2) LOCALAI_CHUNK_OVERLAP env
|
||||||
|
3) class default
|
||||||
|
"""
|
||||||
|
if provided is not None:
|
||||||
|
return provided
|
||||||
|
return self._env_float("LOCALAI_CHUNK_OVERLAP", self.DEFAULT_CHUNK_OVERLAP)
|
||||||
|
|
||||||
|
def _effective_max_single_request_duration(self, provided: Optional[float]) -> float:
|
||||||
|
"""
|
||||||
|
Resolve max_single_request_duration:
|
||||||
|
1) provided argument
|
||||||
|
2) LOCALAI_MAX_SINGLE_REQUEST_DURATION env
|
||||||
|
3) default 300.0
|
||||||
|
"""
|
||||||
|
if provided is not None:
|
||||||
|
return provided
|
||||||
|
return self._env_float("LOCALAI_MAX_SINGLE_REQUEST_DURATION", 300.0)
|
||||||
|
|
||||||
def close(self):
|
def close(self):
|
||||||
"""Close the underlying HTTP client."""
|
"""Close the underlying HTTP client."""
|
||||||
self._client.close()
|
self._client.close()
|
||||||
@@ -107,6 +174,10 @@ class LocalAIClient:
|
|||||||
include_text: Optional[bool] = None,
|
include_text: Optional[bool] = None,
|
||||||
verbose: bool = False,
|
verbose: bool = False,
|
||||||
return_raw: bool = False,
|
return_raw: bool = False,
|
||||||
|
use_chunking: Optional[bool] = None,
|
||||||
|
chunk_duration: Optional[float] = None,
|
||||||
|
chunk_overlap: Optional[float] = None,
|
||||||
|
max_single_request_duration: Optional[float] = None,
|
||||||
**_ignored,
|
**_ignored,
|
||||||
) -> Dict[str, Any]:
|
) -> Dict[str, Any]:
|
||||||
"""
|
"""
|
||||||
@@ -114,6 +185,8 @@ class LocalAIClient:
|
|||||||
- A normalized dict with segments, speakers, transcripts.
|
- A normalized dict with segments, speakers, transcripts.
|
||||||
- Optionally, the raw verbose_json response (for JSON export).
|
- Optionally, the raw verbose_json response (for JSON export).
|
||||||
|
|
||||||
|
For long audio, it can automatically chunk the file to avoid GPU OOM.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
audio_path: Path to the audio file.
|
audio_path: Path to the audio file.
|
||||||
language: Language hint, forwarded if set.
|
language: Language hint, forwarded if set.
|
||||||
@@ -129,6 +202,93 @@ class LocalAIClient:
|
|||||||
Defaults to True.
|
Defaults to True.
|
||||||
verbose: If True, prints progress messages.
|
verbose: If True, prints progress messages.
|
||||||
return_raw: If True, also return the raw API response in 'raw_result'.
|
return_raw: If True, also return the raw API response in 'raw_result'.
|
||||||
|
use_chunking: Whether to enable chunking for long audio.
|
||||||
|
If None, enabled automatically based on duration.
|
||||||
|
chunk_duration: Max duration per chunk in seconds.
|
||||||
|
Falls back to LOCALAI_CHUNK_DURATION env, then 180.0.
|
||||||
|
chunk_overlap: Overlap between chunks in seconds.
|
||||||
|
Falls back to LOCALAI_CHUNK_OVERLAP env, then 2.0.
|
||||||
|
max_single_request_duration: If audio duration exceeds this, chunking
|
||||||
|
is enabled (unless explicitly disabled).
|
||||||
|
Falls back to LOCALAI_MAX_SINGLE_REQUEST_DURATION
|
||||||
|
env, then 300.0.
|
||||||
|
"""
|
||||||
|
if verbose:
|
||||||
|
print("Starting diarization and transcription via LocalAI.")
|
||||||
|
|
||||||
|
logger.info("diarize_and_transcribe requested for: %s", audio_path)
|
||||||
|
|
||||||
|
# Resolve chunking parameters with environment support
|
||||||
|
chunk_duration = self._effective_chunk_duration(chunk_duration)
|
||||||
|
chunk_overlap = self._effective_chunk_overlap(chunk_overlap)
|
||||||
|
max_single = self._effective_max_single_request_duration(max_single_request_duration)
|
||||||
|
|
||||||
|
if use_chunking is None:
|
||||||
|
try:
|
||||||
|
duration = get_audio_duration(audio_path)
|
||||||
|
except RuntimeError:
|
||||||
|
duration = None
|
||||||
|
|
||||||
|
use_chunking = (duration is not None and duration > max_single)
|
||||||
|
logger.info(
|
||||||
|
"Auto-chunking decision: duration=%s, threshold=%s, use_chunking=%s",
|
||||||
|
duration,
|
||||||
|
max_single,
|
||||||
|
use_chunking,
|
||||||
|
)
|
||||||
|
|
||||||
|
if use_chunking:
|
||||||
|
return self._diarize_and_transcribe_chunked(
|
||||||
|
audio_path=audio_path,
|
||||||
|
language=language,
|
||||||
|
num_speakers=num_speakers,
|
||||||
|
min_speakers=min_speakers,
|
||||||
|
max_speakers=max_speakers,
|
||||||
|
clustering_threshold=clustering_threshold,
|
||||||
|
min_duration_on=min_duration_on,
|
||||||
|
min_duration_off=min_duration_off,
|
||||||
|
response_format=response_format,
|
||||||
|
include_text=include_text,
|
||||||
|
verbose=verbose,
|
||||||
|
return_raw=return_raw,
|
||||||
|
chunk_duration=chunk_duration,
|
||||||
|
chunk_overlap=chunk_overlap,
|
||||||
|
)
|
||||||
|
|
||||||
|
# Single-request path (existing behavior)
|
||||||
|
return self._diarize_and_transcribe_single(
|
||||||
|
audio_path=audio_path,
|
||||||
|
language=language,
|
||||||
|
num_speakers=num_speakers,
|
||||||
|
min_speakers=min_speakers,
|
||||||
|
max_speakers=max_speakers,
|
||||||
|
clustering_threshold=clustering_threshold,
|
||||||
|
min_duration_on=min_duration_on,
|
||||||
|
min_duration_off=min_duration_off,
|
||||||
|
response_format=response_format,
|
||||||
|
include_text=include_text,
|
||||||
|
verbose=verbose,
|
||||||
|
return_raw=return_raw,
|
||||||
|
)
|
||||||
|
|
||||||
|
def _diarize_and_transcribe_single(
|
||||||
|
self,
|
||||||
|
audio_path: str,
|
||||||
|
*,
|
||||||
|
language: Optional[str] = None,
|
||||||
|
num_speakers: Optional[int] = None,
|
||||||
|
min_speakers: Optional[int] = None,
|
||||||
|
max_speakers: Optional[int] = None,
|
||||||
|
clustering_threshold: Optional[float] = None,
|
||||||
|
min_duration_on: Optional[float] = None,
|
||||||
|
min_duration_off: Optional[float] = None,
|
||||||
|
response_format: Optional[str] = None,
|
||||||
|
include_text: Optional[bool] = None,
|
||||||
|
verbose: bool = False,
|
||||||
|
return_raw: bool = False,
|
||||||
|
) -> Dict[str, Any]:
|
||||||
|
"""
|
||||||
|
Internal: single-request diarization and transcription.
|
||||||
"""
|
"""
|
||||||
if verbose:
|
if verbose:
|
||||||
print("Starting diarization and transcription via LocalAI.")
|
print("Starting diarization and transcription via LocalAI.")
|
||||||
@@ -214,6 +374,153 @@ class LocalAIClient:
|
|||||||
|
|
||||||
return parsed
|
return parsed
|
||||||
|
|
||||||
|
def _diarize_and_transcribe_chunked(
|
||||||
|
self,
|
||||||
|
audio_path: str,
|
||||||
|
*,
|
||||||
|
language: Optional[str] = None,
|
||||||
|
num_speakers: Optional[int] = None,
|
||||||
|
min_speakers: Optional[int] = None,
|
||||||
|
max_speakers: Optional[int] = None,
|
||||||
|
clustering_threshold: Optional[float] = None,
|
||||||
|
min_duration_on: Optional[float] = None,
|
||||||
|
min_duration_off: Optional[float] = None,
|
||||||
|
response_format: Optional[str] = None,
|
||||||
|
include_text: Optional[bool] = None,
|
||||||
|
verbose: bool = False,
|
||||||
|
return_raw: bool = False,
|
||||||
|
chunk_duration: float = DEFAULT_CHUNK_DURATION,
|
||||||
|
chunk_overlap: float = DEFAULT_CHUNK_OVERLAP,
|
||||||
|
) -> Dict[str, Any]:
|
||||||
|
"""
|
||||||
|
Internal: chunked diarization and transcription for long audio.
|
||||||
|
|
||||||
|
- Splits audio into overlapping chunks.
|
||||||
|
- Transcribes each chunk via /v1/audio/diarization.
|
||||||
|
- Merges segments with adjusted timestamps.
|
||||||
|
"""
|
||||||
|
if verbose:
|
||||||
|
print("Audio is long; splitting into chunks to avoid GPU memory issues.")
|
||||||
|
|
||||||
|
logger.info(
|
||||||
|
"Chunked transcription: chunk_duration=%s, overlap=%s",
|
||||||
|
chunk_duration,
|
||||||
|
chunk_overlap,
|
||||||
|
)
|
||||||
|
|
||||||
|
chunks = split_audio_into_chunks(
|
||||||
|
input_path=audio_path,
|
||||||
|
max_duration=chunk_duration,
|
||||||
|
overlap=chunk_overlap,
|
||||||
|
)
|
||||||
|
|
||||||
|
if len(chunks) == 1:
|
||||||
|
# No actual split needed; fall back to single-request path
|
||||||
|
return self._diarize_and_transcribe_single(
|
||||||
|
audio_path=chunks[0]["path"],
|
||||||
|
language=language,
|
||||||
|
num_speakers=num_speakers,
|
||||||
|
min_speakers=min_speakers,
|
||||||
|
max_speakers=max_speakers,
|
||||||
|
clustering_threshold=clustering_threshold,
|
||||||
|
min_duration_on=min_duration_on,
|
||||||
|
min_duration_off=min_duration_off,
|
||||||
|
response_format=response_format,
|
||||||
|
include_text=include_text,
|
||||||
|
verbose=verbose,
|
||||||
|
return_raw=return_raw,
|
||||||
|
)
|
||||||
|
|
||||||
|
all_segments: List[List[float]] = []
|
||||||
|
all_speakers: List[str] = []
|
||||||
|
all_transcripts: List[str] = []
|
||||||
|
raw_results: List[Dict[str, Any]] = []
|
||||||
|
temp_files = [c["path"] for c in chunks]
|
||||||
|
|
||||||
|
try:
|
||||||
|
for i, chunk_info in enumerate(chunks):
|
||||||
|
chunk_path = chunk_info["path"]
|
||||||
|
chunk_start = chunk_info["start"]
|
||||||
|
|
||||||
|
if verbose:
|
||||||
|
print(
|
||||||
|
f"Transcribing chunk {i+1}/{len(chunks)} "
|
||||||
|
f"(start={chunk_start:.1f}s)"
|
||||||
|
)
|
||||||
|
|
||||||
|
logger.info(
|
||||||
|
"Transcribing chunk %d/%d, start=%.1f", i + 1, len(chunks), chunk_start
|
||||||
|
)
|
||||||
|
|
||||||
|
# Use single-request logic for each chunk
|
||||||
|
chunk_result = self._diarize_and_transcribe_single(
|
||||||
|
audio_path=chunk_path,
|
||||||
|
language=language,
|
||||||
|
num_speakers=num_speakers,
|
||||||
|
min_speakers=min_speakers,
|
||||||
|
max_speakers=max_speakers,
|
||||||
|
clustering_threshold=clustering_threshold,
|
||||||
|
min_duration_on=min_duration_on,
|
||||||
|
min_duration_off=min_duration_off,
|
||||||
|
response_format=response_format,
|
||||||
|
include_text=include_text,
|
||||||
|
verbose=False,
|
||||||
|
return_raw=return_raw,
|
||||||
|
)
|
||||||
|
|
||||||
|
segs = chunk_result.get("segments", [])
|
||||||
|
spks = chunk_result.get("speakers", [])
|
||||||
|
txts = chunk_result.get("transcripts", [])
|
||||||
|
raw = chunk_result.get("raw_result")
|
||||||
|
|
||||||
|
# Adjust timestamps to global timeline
|
||||||
|
adjusted_segs = []
|
||||||
|
for seg, sp, txt in zip(segs, spks, txts):
|
||||||
|
start = float(seg[0]) + chunk_start
|
||||||
|
end = float(seg[1]) + chunk_start
|
||||||
|
adjusted_segs.append([start, end])
|
||||||
|
all_speakers.append(sp)
|
||||||
|
all_transcripts.append(txt)
|
||||||
|
all_segments.extend(adjusted_segs)
|
||||||
|
|
||||||
|
if return_raw and raw is not None:
|
||||||
|
raw_results.append(raw)
|
||||||
|
|
||||||
|
finally:
|
||||||
|
# Clean up temporary chunk files
|
||||||
|
for path in temp_files:
|
||||||
|
if path and os.path.exists(path) and path != audio_path:
|
||||||
|
try:
|
||||||
|
os.remove(path)
|
||||||
|
except Exception as e:
|
||||||
|
logger.warning("Failed to remove chunk file %s: %s", path, e)
|
||||||
|
|
||||||
|
# Sort segments by start time
|
||||||
|
combined = list(zip(all_segments, all_speakers, all_transcripts))
|
||||||
|
combined.sort(key=lambda x: x[0][0])
|
||||||
|
all_segments = [x[0] for x in combined]
|
||||||
|
all_speakers = [x[1] for x in combined]
|
||||||
|
all_transcripts = [x[2] for x in combined]
|
||||||
|
|
||||||
|
if verbose:
|
||||||
|
print(
|
||||||
|
f"Chunked transcription complete. Total segments: {len(all_segments)}"
|
||||||
|
)
|
||||||
|
|
||||||
|
result = {
|
||||||
|
"segments": all_segments,
|
||||||
|
"speakers": all_speakers,
|
||||||
|
"transcripts": all_transcripts,
|
||||||
|
}
|
||||||
|
|
||||||
|
if return_raw and raw_results:
|
||||||
|
result["raw_result"] = {
|
||||||
|
"chunked": True,
|
||||||
|
"chunks": raw_results,
|
||||||
|
}
|
||||||
|
|
||||||
|
return result
|
||||||
|
|
||||||
def _parse_diarization_response(self, result: Dict[str, Any]) -> Dict[str, Any]:
|
def _parse_diarization_response(self, result: Dict[str, Any]) -> Dict[str, Any]:
|
||||||
"""
|
"""
|
||||||
Convert LocalAI verbose_json response into the internal format used by Scraibe:
|
Convert LocalAI verbose_json response into the internal format used by Scraibe:
|
||||||
|
|||||||
@@ -0,0 +1,86 @@
|
|||||||
|
import os
|
||||||
|
import subprocess
|
||||||
|
import tempfile
|
||||||
|
import pytest
|
||||||
|
|
||||||
|
from scraibe.audio import (
|
||||||
|
get_audio_duration,
|
||||||
|
split_audio_into_chunks,
|
||||||
|
)
|
||||||
|
|
||||||
|
TEST_AUDIO_1 = "tests/audio_test_1.mp4"
|
||||||
|
TEST_AUDIO_2 = "tests/audio_test_2.mp4"
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.fixture(params=[TEST_AUDIO_1, TEST_AUDIO_2])
|
||||||
|
def test_audio_path(request):
|
||||||
|
return request.param
|
||||||
|
|
||||||
|
|
||||||
|
def test_get_audio_duration(test_audio_path):
|
||||||
|
dur = get_audio_duration(test_audio_path)
|
||||||
|
assert isinstance(dur, float)
|
||||||
|
assert dur > 0
|
||||||
|
|
||||||
|
|
||||||
|
def test_split_audio_into_chunks_no_split_short(test_audio_path):
|
||||||
|
# For short files, should return the same file with no extra chunks
|
||||||
|
chunks = split_audio_into_chunks(
|
||||||
|
input_path=test_audio_path,
|
||||||
|
max_duration=600.0, # larger than both test files
|
||||||
|
overlap=2.0,
|
||||||
|
)
|
||||||
|
assert len(chunks) == 1
|
||||||
|
assert chunks[0]["path"] == test_audio_path
|
||||||
|
assert chunks[0]["start"] == 0.0
|
||||||
|
dur = get_audio_duration(test_audio_path)
|
||||||
|
assert abs(chunks[0]["end"] - dur) < 0.05
|
||||||
|
|
||||||
|
|
||||||
|
def test_split_audio_into_chunks_creates_chunks(tmp_path):
|
||||||
|
# Use a small chunk duration to force splitting
|
||||||
|
chunks = split_audio_into_chunks(
|
||||||
|
input_path=TEST_AUDIO_1,
|
||||||
|
max_duration=2.0,
|
||||||
|
overlap=0.5,
|
||||||
|
)
|
||||||
|
assert len(chunks) > 1
|
||||||
|
|
||||||
|
# Check that each chunk file exists and is non-empty
|
||||||
|
for c in chunks:
|
||||||
|
assert os.path.exists(c["path"])
|
||||||
|
assert os.path.getsize(c["path"]) > 0
|
||||||
|
|
||||||
|
# Check time ordering and overlap
|
||||||
|
for i in range(1, len(chunks)):
|
||||||
|
prev = chunks[i - 1]
|
||||||
|
curr = chunks[i]
|
||||||
|
assert curr["start"] >= prev["start"]
|
||||||
|
assert curr["start"] < prev["end"] # overlap
|
||||||
|
|
||||||
|
# Cleanup
|
||||||
|
for c in chunks:
|
||||||
|
if os.path.exists(c["path"]):
|
||||||
|
os.remove(c["path"])
|
||||||
|
|
||||||
|
|
||||||
|
def test_split_audio_into_chunks_total_coverage(test_audio_path):
|
||||||
|
dur = get_audio_duration(test_audio_path)
|
||||||
|
|
||||||
|
# Use small chunks to ensure coverage
|
||||||
|
chunks = split_audio_into_chunks(
|
||||||
|
input_path=test_audio_path,
|
||||||
|
max_duration=2.0,
|
||||||
|
overlap=0.5,
|
||||||
|
)
|
||||||
|
|
||||||
|
# First chunk starts at 0
|
||||||
|
assert chunks[0]["start"] == 0.0
|
||||||
|
|
||||||
|
# Last chunk end should cover the duration
|
||||||
|
assert chunks[-1]["end"] >= dur - 0.05
|
||||||
|
|
||||||
|
# Cleanup
|
||||||
|
for c in chunks:
|
||||||
|
if os.path.exists(c["path"]):
|
||||||
|
os.remove(c["path"])
|
||||||
@@ -0,0 +1,230 @@
|
|||||||
|
import os
|
||||||
|
import json
|
||||||
|
import tempfile
|
||||||
|
from unittest.mock import patch, MagicMock
|
||||||
|
|
||||||
|
import pytest
|
||||||
|
|
||||||
|
from scraibe.localai_client import LocalAIClient, LocalAIError
|
||||||
|
from scraibe.audio import get_audio_duration, split_audio_into_chunks
|
||||||
|
|
||||||
|
|
||||||
|
TEST_AUDIO_1 = "tests/audio_test_1.mp4"
|
||||||
|
|
||||||
|
|
||||||
|
def make_fake_segments(start=0.0, count=3):
|
||||||
|
segments = []
|
||||||
|
for i in range(count):
|
||||||
|
s = start + i * 2.0
|
||||||
|
e = s + 2.0
|
||||||
|
segments.append({
|
||||||
|
"start": s,
|
||||||
|
"end": e,
|
||||||
|
"speaker": "SPEAKER_00",
|
||||||
|
"text": f"Segment text {i}",
|
||||||
|
})
|
||||||
|
return segments
|
||||||
|
|
||||||
|
|
||||||
|
def fake_localai_response(segments):
|
||||||
|
return {
|
||||||
|
"segments": segments,
|
||||||
|
"text": " ".join(seg["text"] for seg in segments),
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.fixture
|
||||||
|
def client():
|
||||||
|
with patch.object(LocalAIClient, "__init__", lambda self, **kw: None):
|
||||||
|
c = LocalAIClient()
|
||||||
|
c.api_url = "http://localhost:8080"
|
||||||
|
c.model = "vibevoice-diarize"
|
||||||
|
c.api_key = None
|
||||||
|
c._client = MagicMock()
|
||||||
|
return c
|
||||||
|
|
||||||
|
|
||||||
|
def test_parse_diarization_response(client):
|
||||||
|
segs = make_fake_segments()
|
||||||
|
raw = fake_localai_response(segs)
|
||||||
|
|
||||||
|
out = client._parse_diarization_response(raw)
|
||||||
|
|
||||||
|
assert "segments" in out
|
||||||
|
assert "speakers" in out
|
||||||
|
assert "transcripts" in out
|
||||||
|
assert len(out["segments"]) == len(segs)
|
||||||
|
for i, s in enumerate(segs):
|
||||||
|
assert out["segments"][i][0] == s["start"]
|
||||||
|
assert out["segments"][i][1] == s["end"]
|
||||||
|
assert out["speakers"][i] == s["speaker"]
|
||||||
|
assert out["transcripts"][i] == s["text"]
|
||||||
|
|
||||||
|
|
||||||
|
def test_parse_diarization_empty(client):
|
||||||
|
out = client._parse_diarization_response({"segments": []})
|
||||||
|
assert out["segments"] == []
|
||||||
|
assert out["speakers"] == []
|
||||||
|
assert out["transcripts"] == []
|
||||||
|
|
||||||
|
|
||||||
|
def test_diarize_and_transcribe_single_happy(client):
|
||||||
|
with patch.object(client, "_client") as mock_client:
|
||||||
|
mock_resp = MagicMock()
|
||||||
|
mock_resp.status_code = 200
|
||||||
|
mock_resp.json.return_value = fake_localai_response(make_fake_segments())
|
||||||
|
mock_client.post.return_value = mock_resp
|
||||||
|
|
||||||
|
result = client.diarize_and_transcribe(
|
||||||
|
audio_path=TEST_AUDIO_1,
|
||||||
|
verbose=False,
|
||||||
|
return_raw=True,
|
||||||
|
)
|
||||||
|
|
||||||
|
assert "segments" in result
|
||||||
|
assert "raw_result" in result
|
||||||
|
assert len(result["segments"]) > 0
|
||||||
|
|
||||||
|
|
||||||
|
def test_chunking_triggered_for_long_audio(client):
|
||||||
|
# Simulate long audio by patching get_audio_duration
|
||||||
|
with patch("scraibe.localai_client.get_audio_duration") as mock_dur, \
|
||||||
|
patch.object(client, "_diarize_and_transcribe_chunked") as mock_chunked:
|
||||||
|
|
||||||
|
mock_dur.return_value = 600.0 # 10 minutes
|
||||||
|
mock_chunked.return_value = {
|
||||||
|
"segments": [],
|
||||||
|
"speakers": [],
|
||||||
|
"transcripts": [],
|
||||||
|
}
|
||||||
|
|
||||||
|
client.diarize_and_transcribe(
|
||||||
|
audio_path=TEST_AUDIO_1,
|
||||||
|
verbose=False,
|
||||||
|
use_chunking=None,
|
||||||
|
max_single_request_duration=300.0,
|
||||||
|
)
|
||||||
|
|
||||||
|
mock_chunked.assert_called_once()
|
||||||
|
|
||||||
|
|
||||||
|
def test_chunking_not_triggered_for_short_audio(client):
|
||||||
|
with patch("scraibe.localai_client.get_audio_duration") as mock_dur, \
|
||||||
|
patch.object(client, "_diarize_and_transcribe_chunked") as mock_chunked, \
|
||||||
|
patch.object(client, "_diarize_and_transcribe_single") as mock_single:
|
||||||
|
|
||||||
|
mock_dur.return_value = 120.0
|
||||||
|
mock_single.return_value = {
|
||||||
|
"segments": [],
|
||||||
|
"speakers": [],
|
||||||
|
"transcripts": [],
|
||||||
|
}
|
||||||
|
|
||||||
|
client.diarize_and_transcribe(
|
||||||
|
audio_path=TEST_AUDIO_1,
|
||||||
|
verbose=False,
|
||||||
|
use_chunking=None,
|
||||||
|
max_single_request_duration=300.0,
|
||||||
|
)
|
||||||
|
|
||||||
|
mock_chunked.assert_not_called()
|
||||||
|
mock_single.assert_called_once()
|
||||||
|
|
||||||
|
|
||||||
|
def test_chunked_transcription_adjusts_timestamps(client):
|
||||||
|
# Mock split_audio_into_chunks to return two chunks
|
||||||
|
chunk1_path = TEST_AUDIO_1
|
||||||
|
chunk2_path = TEST_AUDIO_1 # reusing same file; in real usage different
|
||||||
|
|
||||||
|
chunks = [
|
||||||
|
{"path": chunk1_path, "start": 0.0, "end": 10.0},
|
||||||
|
{"path": chunk2_path, "start": 10.0, "end": 20.0},
|
||||||
|
]
|
||||||
|
|
||||||
|
with patch("scraibe.localai_client.split_audio_into_chunks") as mock_split, \
|
||||||
|
patch.object(client, "_diarize_and_transcribe_single") as mock_single, \
|
||||||
|
patch("os.remove"):
|
||||||
|
|
||||||
|
mock_split.return_value = chunks
|
||||||
|
|
||||||
|
# First chunk: segments 0–4
|
||||||
|
# Second chunk: segments 0–4 (local times)
|
||||||
|
def side_effect(audio_path, **kw):
|
||||||
|
if audio_path == chunk1_path:
|
||||||
|
segs = make_fake_segments(start=0.0, count=2)
|
||||||
|
else:
|
||||||
|
segs = make_fake_segments(start=0.0, count=2)
|
||||||
|
return client._parse_diarization_response(fake_localai_response(segs))
|
||||||
|
|
||||||
|
mock_single.side_effect = side_effect
|
||||||
|
|
||||||
|
result = client._diarize_and_transcribe_chunked(
|
||||||
|
audio_path=TEST_AUDIO_1,
|
||||||
|
verbose=False,
|
||||||
|
return_raw=False,
|
||||||
|
chunk_duration=10.0,
|
||||||
|
chunk_overlap=2.0,
|
||||||
|
)
|
||||||
|
|
||||||
|
# Check we got 4 segments total
|
||||||
|
assert len(result["segments"]) == 4
|
||||||
|
|
||||||
|
# First two segments should be in [0, 4]
|
||||||
|
assert result["segments"][0][0] == 0.0
|
||||||
|
assert result["segments"][1][0] == 2.0
|
||||||
|
|
||||||
|
# Next two segments should be shifted by 10
|
||||||
|
assert result["segments"][2][0] == 10.0
|
||||||
|
assert result["segments"][3][0] == 12.0
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.integration
|
||||||
|
def test_integration_chunked_transcription_with_localai():
|
||||||
|
"""
|
||||||
|
Integration test: run chunked transcription against a live LocalAI instance.
|
||||||
|
Only runs if LOCALAI_API_URL is set and an audio file is provided.
|
||||||
|
This test is skipped by default unless run with:
|
||||||
|
pytest -m integration
|
||||||
|
"""
|
||||||
|
api_url = os.getenv("LOCALAI_API_URL")
|
||||||
|
if not api_url:
|
||||||
|
pytest.skip("LOCALAI_API_URL not set; skipping integration test")
|
||||||
|
|
||||||
|
# Use one of the bundled test audio files
|
||||||
|
audio_path = TEST_AUDIO_1
|
||||||
|
if not os.path.exists(audio_path):
|
||||||
|
pytest.skip(f"Test audio not found: {audio_path}")
|
||||||
|
|
||||||
|
# Force chunking with a very small max_single_request_duration
|
||||||
|
# Use environment-configured model or a sensible default
|
||||||
|
model = os.getenv("LOCALAI_MODEL") or "vibevoice-cpp-asr"
|
||||||
|
|
||||||
|
client = LocalAIClient(api_url=api_url, model=model)
|
||||||
|
try:
|
||||||
|
result = client.diarize_and_transcribe(
|
||||||
|
audio_path=audio_path,
|
||||||
|
verbose=True,
|
||||||
|
return_raw=True,
|
||||||
|
use_chunking=True,
|
||||||
|
chunk_duration=3.0,
|
||||||
|
chunk_overlap=0.5,
|
||||||
|
max_single_request_duration=1.0,
|
||||||
|
)
|
||||||
|
|
||||||
|
assert "segments" in result
|
||||||
|
assert len(result["segments"]) > 0
|
||||||
|
|
||||||
|
# Basic sanity: segments are time-ordered
|
||||||
|
for i in range(1, len(result["segments"])):
|
||||||
|
prev_end = result["segments"][i - 1][1]
|
||||||
|
curr_start = result["segments"][i][0]
|
||||||
|
assert curr_start >= result["segments"][i - 1][0]
|
||||||
|
|
||||||
|
# If raw_result indicates chunked, ensure structure is sensible
|
||||||
|
raw = result.get("raw_result")
|
||||||
|
if raw and raw.get("chunked"):
|
||||||
|
assert "chunks" in raw
|
||||||
|
assert len(raw["chunks"]) > 1
|
||||||
|
|
||||||
|
finally:
|
||||||
|
client.close()
|
||||||
Reference in New Issue
Block a user