Initial commit: LocalAI-backed ScrAIbe with summarization
This commit is contained in:
+224
-301
@@ -1,358 +1,281 @@
|
||||
"""
|
||||
Scraibe Class
|
||||
--------------------
|
||||
Scraibe Class (LocalAI-backed)
|
||||
------------------------------
|
||||
|
||||
This class serves as the core of the transcription system, responsible for handling
|
||||
transcription and diarization of audio files. It leverages pretrained models for
|
||||
speech-to-text (such as Whisper) and speaker diarization (such as pyannote.audio),
|
||||
providing an accessible interface for audio processing tasks such as transcription,
|
||||
speaker separation, and timestamping.
|
||||
Core class for transcription and (optionally) summarization.
|
||||
|
||||
By encapsulating the complexities of underlying models, it allows for straightforward
|
||||
integration into various applications, ranging from transcription services to voice assistants.
|
||||
- Transcription and diarization are delegated to LocalAI (vibevoice.cpp).
|
||||
- Summarization is delegated to a separate LLM via /v1/chat/completions.
|
||||
|
||||
Available Classes:
|
||||
- Scraibe: Main class for performing transcription and diarization.
|
||||
Includes methods for loading models, processing audio files,
|
||||
and formatting the transcription output.
|
||||
Public tasks:
|
||||
- transcribe
|
||||
- transcript_and_summarize (transcribe + generate a detailed summary)
|
||||
|
||||
Usage:
|
||||
from scraibe import Scraibe
|
||||
|
||||
model = Scraibe()
|
||||
transcript = model.autotranscribe("path/to/audiofile.wav")
|
||||
Previous task/whisper/pyannote-specific settings are kept for compatibility
|
||||
but ignored when not relevant.
|
||||
"""
|
||||
|
||||
# Standard Library Imports
|
||||
import os
|
||||
from glob import iglob
|
||||
from subprocess import run
|
||||
from typing import TypeVar, Union
|
||||
from warnings import warn
|
||||
from typing import Union, Optional
|
||||
|
||||
# Third-Party Imports
|
||||
import torch
|
||||
from numpy import ndarray
|
||||
from tqdm import trange
|
||||
|
||||
# Application-Specific Imports
|
||||
from .audio import AudioProcessor
|
||||
from .diarisation import Diariser
|
||||
from .transcriber import Transcriber, load_transcriber, whisper
|
||||
from .localai_client import LocalAIClient, LocalAIError
|
||||
from .summarizer import SummarizerClient, SummarizerError
|
||||
from .transcript_exporter import Transcript
|
||||
from .misc import SCRAIBE_TORCH_DEVICE
|
||||
|
||||
|
||||
DiarisationType = TypeVar('DiarisationType')
|
||||
|
||||
|
||||
class Scraibe:
|
||||
"""
|
||||
Scraibe is a class responsible for managing the transcription and diarization of audio files.
|
||||
It serves as the core of the transcription system, incorporating pretrained models
|
||||
for speech-to-text (such as Whisper) and speaker diarization (such as pyannote.audio),
|
||||
allowing for comprehensive audio processing.
|
||||
Scraibe now:
|
||||
- Uses LocalAI for transcription + diarization.
|
||||
- Uses a separate LLM for summarization (when requested).
|
||||
|
||||
Attributes:
|
||||
transcriber (Transcriber): The transcriber object to handle transcription.
|
||||
diariser (Diariser): The diariser object to handle diarization.
|
||||
|
||||
Methods:
|
||||
__init__: Initializes the Scraibe class with appropriate models.
|
||||
transcribe: Transcribes an audio file using the whisper model and pyannote diarization model.
|
||||
remove_audio_file: Removes the original audio file to avoid disk space issues or ensure data privacy.
|
||||
get_audio_file: Gets an audio file as an AudioProcessor object.
|
||||
Public methods:
|
||||
- transcribe(audio_file, ...)
|
||||
- transcript_and_summarize(audio_file, ...)
|
||||
"""
|
||||
|
||||
def __init__(self,
|
||||
whisper_model: Union[bool, str, whisper] = None,
|
||||
whisper_type: str = "whisper",
|
||||
dia_model: Union[bool, str, DiarisationType] = None,
|
||||
**kwargs) -> None:
|
||||
"""Initializes the Scraibe class.
|
||||
def __init__(
|
||||
self,
|
||||
api_url: Optional[str] = None,
|
||||
api_key: Optional[str] = None,
|
||||
model: Optional[str] = None,
|
||||
whisper_model: Union[bool, str] = None,
|
||||
whisper_type: str = "whisper",
|
||||
dia_model: Union[bool, str] = None,
|
||||
use_auth_token: str = None,
|
||||
verbose: bool = False,
|
||||
**kwargs,
|
||||
) -> None:
|
||||
"""
|
||||
Initialize Scraibe with LocalAI client and summarizer client.
|
||||
|
||||
Args:
|
||||
whisper_model (Union[bool, str, whisper], optional):
|
||||
Path to whisper model or whisper model itself.
|
||||
whisper_type (str):
|
||||
Type of whisper model to load. "whisper" or "faster-whisper".
|
||||
diarisation_model (Union[bool, str, DiarisationType], optional):
|
||||
Path to pyannote diarization model or model itself.
|
||||
**kwargs: Additional keyword arguments for whisper
|
||||
and pyannote diarization models.
|
||||
e.g.:
|
||||
api_url: LocalAI server URL for transcription/diarization.
|
||||
Falls back to LOCALAI_API_URL env var.
|
||||
api_key: API key for LocalAI. Falls back to LOCALAI_API_KEY.
|
||||
model: Model name for LocalAI (e.g., vibevoice-diarize).
|
||||
Falls back to LOCALAI_MODEL env var.
|
||||
|
||||
- verbose: If True, the class will print additional information.
|
||||
- save_kwargs: If True, the keyword arguments will be saved
|
||||
for autotranscribe. So you can unload the class and reload it again.
|
||||
Summarizer uses:
|
||||
- SUMMARIZER_API_URL
|
||||
- SUMMARIZER_API_KEY
|
||||
- SUMMARIZER_MODEL
|
||||
These can be overridden via environment or via the transcript_and_summarize
|
||||
method if needed.
|
||||
|
||||
Backward-compat (ignored):
|
||||
- whisper_model, whisper_type, dia_model, use_auth_token, etc.
|
||||
"""
|
||||
self.verbose = verbose or kwargs.get("verbose", False)
|
||||
|
||||
if whisper_model is None:
|
||||
self.transcriber = load_transcriber(
|
||||
"medium", whisper_type, **kwargs)
|
||||
elif isinstance(whisper_model, str):
|
||||
self.transcriber = load_transcriber(
|
||||
whisper_model, whisper_type, **kwargs)
|
||||
else:
|
||||
self.transcriber = whisper_model
|
||||
try:
|
||||
self.client = LocalAIClient(
|
||||
api_url=api_url,
|
||||
api_key=api_key,
|
||||
model=model,
|
||||
)
|
||||
except LocalAIError as e:
|
||||
raise LocalAIError(f"Failed to initialize LocalAI client: {e}")
|
||||
|
||||
if dia_model is None:
|
||||
self.diariser = Diariser.load_model(**kwargs)
|
||||
elif isinstance(dia_model, str):
|
||||
self.diariser = Diariser.load_model(dia_model, **kwargs)
|
||||
else:
|
||||
self.diariser: Diariser = dia_model
|
||||
|
||||
if kwargs.get("verbose"):
|
||||
print("Scraibe initialized all models successfully loaded.")
|
||||
self.verbose = True
|
||||
else:
|
||||
self.verbose = False
|
||||
|
||||
# Save kwargs for autotranscribe if you want to unload the class and load it again.
|
||||
if kwargs.get('save_setup'):
|
||||
self.params = dict(whisper_model=whisper_model,
|
||||
dia_model=dia_model,
|
||||
**kwargs)
|
||||
else:
|
||||
self.params = {}
|
||||
|
||||
self.device = kwargs.get(
|
||||
"device", SCRAIBE_TORCH_DEVICE)
|
||||
|
||||
def autotranscribe(self, audio_file: Union[str, torch.Tensor, ndarray],
|
||||
remove_original: bool = False,
|
||||
**kwargs) -> Transcript:
|
||||
"""
|
||||
Transcribes an audio file using the whisper model and pyannote diarization model.
|
||||
|
||||
Args:
|
||||
audio_file (Union[str, torch.Tensor, ndarray]):
|
||||
Path to audio file or a tensor representing the audio.
|
||||
remove_original (bool, optional): If True, the original audio file will
|
||||
be removed after transcription.
|
||||
*args: Additional positional arguments for diarization and transcription.
|
||||
**kwargs: Additional keyword arguments for diarization and transcription.
|
||||
|
||||
Returns:
|
||||
Transcript: A Transcript object containing the transcription,
|
||||
which can be exported to different formats.
|
||||
"""
|
||||
if kwargs.get("verbose"):
|
||||
self.verbose = kwargs.get("verbose")
|
||||
# Get audio file as an AudioProcessor object
|
||||
audio_file: AudioProcessor = self.get_audio_file(audio_file)
|
||||
|
||||
# Prepare waveform and sample rate for diarization
|
||||
dia_audio = {
|
||||
"waveform": audio_file.waveform.reshape(1, len(audio_file.waveform)).to(self.device),
|
||||
"sample_rate": audio_file.sr
|
||||
}
|
||||
|
||||
if self.verbose:
|
||||
print("Starting diarisation.")
|
||||
|
||||
diarisation = self.diariser.diarization(dia_audio, **kwargs)
|
||||
|
||||
if not diarisation["segments"]:
|
||||
print("No segments found. Try to run transcription without diarisation.")
|
||||
|
||||
transcript = self.transcriber.transcribe(
|
||||
audio_file.waveform, **kwargs)
|
||||
|
||||
final_transcript = {0: {"speakers": 'SPEAKER_01',
|
||||
"segments": [0, len(audio_file.waveform)],
|
||||
"text": transcript}}
|
||||
|
||||
return Transcript(final_transcript)
|
||||
# Summarizer is lazy-initialized if needed
|
||||
self._summarizer: Optional[SummarizerClient] = None
|
||||
|
||||
if self.verbose:
|
||||
print("Diarisation finished. Starting transcription.")
|
||||
print("Scraibe initialized. Using LocalAI for transcription and diarization.")
|
||||
|
||||
|
||||
# Transcribe each segment and store the results
|
||||
final_transcript = dict()
|
||||
|
||||
for i in trange(len(diarisation["segments"]), desc="Transcribing", disable=not self.verbose):
|
||||
|
||||
seg = diarisation["segments"][i]
|
||||
|
||||
audio = audio_file.cut(seg[0], seg[1])
|
||||
|
||||
transcript = self.transcriber.transcribe(audio, **kwargs)
|
||||
|
||||
final_transcript[i] = {"speakers": diarisation["speakers"][i],
|
||||
"segments": seg,
|
||||
"text": transcript}
|
||||
|
||||
# Remove original file if needed
|
||||
if remove_original:
|
||||
if kwargs.get("shred") is True:
|
||||
self.remove_audio_file(audio_file, shred=True)
|
||||
else:
|
||||
self.remove_audio_file(audio_file, shred=False)
|
||||
|
||||
return Transcript(final_transcript)
|
||||
|
||||
def diarization(self, audio_file: Union[str, torch.Tensor, ndarray],
|
||||
**kwargs) -> dict:
|
||||
def _ensure_summarizer(
|
||||
self,
|
||||
api_url: Optional[str] = None,
|
||||
api_key: Optional[str] = None,
|
||||
model: Optional[str] = None,
|
||||
) -> SummarizerClient:
|
||||
"""
|
||||
Perform diarization on an audio file using the pyannote diarization model.
|
||||
Lazy-init summarizer client.
|
||||
"""
|
||||
if self._summarizer is not None:
|
||||
return self._summarizer
|
||||
|
||||
try:
|
||||
self._summarizer = SummarizerClient(
|
||||
api_url=api_url,
|
||||
api_key=api_key,
|
||||
model=model,
|
||||
)
|
||||
except SummarizerError as e:
|
||||
raise SummarizerError(f"Failed to initialize Summarizer client: {e}")
|
||||
|
||||
return self._summarizer
|
||||
|
||||
# -----------------
|
||||
# Primary public API
|
||||
# -----------------
|
||||
|
||||
def transcribe(
|
||||
self,
|
||||
audio_file: Union[str],
|
||||
**kwargs,
|
||||
) -> str:
|
||||
"""
|
||||
Transcribe the provided audio file using LocalAI.
|
||||
|
||||
Uses /v1/audio/diarization with vibevoice.cpp, then concatenates
|
||||
all segment texts.
|
||||
|
||||
Args:
|
||||
audio_file (Union[str, torch.Tensor, ndarray]):
|
||||
The audio source which can either be a path to the audio file or a tensor representation.
|
||||
**kwargs:
|
||||
Additional keyword arguments for diarization.
|
||||
audio_file (str): Path to the audio file.
|
||||
**kwargs: Additional keyword arguments (some forwarded, others ignored).
|
||||
|
||||
Returns:
|
||||
dict:
|
||||
A dictionary containing the results of the diarization process.
|
||||
str: The concatenated transcribed text.
|
||||
"""
|
||||
if isinstance(audio_file, str):
|
||||
if not os.path.exists(audio_file):
|
||||
raise FileNotFoundError(f"Audio file not found: {audio_file}")
|
||||
else:
|
||||
raise TypeError(
|
||||
"In LocalAI mode, audio_file must be a file path (str)."
|
||||
)
|
||||
|
||||
# Get audio file as an AudioProcessor object
|
||||
audio_file: AudioProcessor = self.get_audio_file(audio_file)
|
||||
verbose = kwargs.get("verbose", self.verbose)
|
||||
|
||||
# Prepare waveform and sample rate for diarization
|
||||
dia_audio = {
|
||||
"waveform": audio_file.waveform.reshape(1, len(audio_file.waveform)).to(self.device),
|
||||
"sample_rate": audio_file.sr
|
||||
try:
|
||||
result = self.client.diarize_and_transcribe(
|
||||
audio_path=audio_file,
|
||||
include_text=True,
|
||||
verbose=verbose,
|
||||
**kwargs,
|
||||
)
|
||||
except LocalAIError as e:
|
||||
raise LocalAIError(f"Error during LocalAI transcription: {e}")
|
||||
|
||||
transcripts = result.get("transcripts", [])
|
||||
return " ".join(t.strip() for t in transcripts if t.strip())
|
||||
|
||||
def transcript_and_summarize(
|
||||
self,
|
||||
audio_file: Union[str],
|
||||
*,
|
||||
summarizer_api_url: Optional[str] = None,
|
||||
summarizer_api_key: Optional[str] = None,
|
||||
summarizer_model: Optional[str] = None,
|
||||
**kwargs,
|
||||
) -> dict:
|
||||
"""
|
||||
Transcribe the audio file and generate a detailed summary.
|
||||
|
||||
Steps:
|
||||
- Transcribe via LocalAI.
|
||||
- Build a plain-text transcript (with speaker labels).
|
||||
- Summarize the transcript using the configured LLM.
|
||||
|
||||
Returns:
|
||||
dict with:
|
||||
- transcript: full transcript text (with speaker labels)
|
||||
- summary: final detailed summary (markdown-ready)
|
||||
"""
|
||||
if isinstance(audio_file, str):
|
||||
if not os.path.exists(audio_file):
|
||||
raise FileNotFoundError(f"Audio file not found: {audio_file}")
|
||||
else:
|
||||
raise TypeError(
|
||||
"In LocalAI mode, audio_file must be a file path (str)."
|
||||
)
|
||||
|
||||
verbose = kwargs.get("verbose", self.verbose)
|
||||
|
||||
# 1) Get diarized + transcribed result
|
||||
try:
|
||||
result = self.client.diarize_and_transcribe(
|
||||
audio_path=audio_file,
|
||||
include_text=True,
|
||||
verbose=verbose,
|
||||
**kwargs,
|
||||
)
|
||||
except LocalAIError as e:
|
||||
raise LocalAIError(f"Error during LocalAI transcription: {e}")
|
||||
|
||||
segments = result.get("segments", [])
|
||||
speakers = result.get("speakers", [])
|
||||
transcripts = result.get("transcripts", [])
|
||||
|
||||
if not segments:
|
||||
return {
|
||||
"transcript": "",
|
||||
"summary": "No transcript content to summarize.",
|
||||
}
|
||||
|
||||
# 2) Build full transcript text with speaker labels
|
||||
lines = []
|
||||
for seg, speaker, text in zip(segments, speakers, transcripts):
|
||||
start, end = seg
|
||||
ts = self._format_timestamp(start)
|
||||
line = f"[{ts}] {speaker}: {text.strip()}"
|
||||
lines.append(line)
|
||||
|
||||
full_transcript = "\n\n".join(lines)
|
||||
|
||||
# 3) Summarize
|
||||
try:
|
||||
summarizer = self._ensure_summarizer(
|
||||
api_url=summarizer_api_url,
|
||||
api_key=summarizer_api_key,
|
||||
model=summarizer_model,
|
||||
)
|
||||
except SummarizerError as e:
|
||||
raise SummarizerError(f"Failed to initialize summarizer: {e}")
|
||||
|
||||
try:
|
||||
summary = summarizer.summarize_transcript(full_transcript)
|
||||
except SummarizerError as e:
|
||||
raise SummarizerError(f"Error during summarization: {e}")
|
||||
|
||||
return {
|
||||
"transcript": full_transcript,
|
||||
"summary": summary,
|
||||
}
|
||||
|
||||
print("Starting diarisation.")
|
||||
|
||||
diarisation = self.diariser.diarization(dia_audio, **kwargs)
|
||||
|
||||
return diarisation
|
||||
|
||||
def transcribe(self, audio_file: Union[str, torch.Tensor, ndarray],
|
||||
**kwargs):
|
||||
"""
|
||||
Transcribe the provided audio file.
|
||||
|
||||
Args:
|
||||
audio_file (Union[str, torch.Tensor, ndarray]):
|
||||
The audio source, which can either be a path or a tensor representation.
|
||||
**kwargs:
|
||||
Additional keyword arguments for transcription.
|
||||
|
||||
Returns:
|
||||
str:
|
||||
The transcribed text from the audio source.
|
||||
"""
|
||||
audio_file: AudioProcessor = self.get_audio_file(audio_file)
|
||||
|
||||
return self.transcriber.transcribe(audio_file.waveform, **kwargs)
|
||||
|
||||
def update_transcriber(self, whisper_model: Union[str, whisper], **kwargs) -> None:
|
||||
"""
|
||||
Update the transcriber model.
|
||||
|
||||
Args:
|
||||
whisper_model (Union[str, whisper]):
|
||||
The new whisper model to use for transcription.
|
||||
**kwargs:
|
||||
Additional keyword arguments for the transcriber model.
|
||||
|
||||
Returns:
|
||||
None
|
||||
"""
|
||||
_old_model = self.transcriber.model_name
|
||||
|
||||
if isinstance(whisper_model, str):
|
||||
self.transcriber = load_transcriber(whisper_model, **kwargs)
|
||||
elif isinstance(whisper_model, Transcriber):
|
||||
self.transcriber = whisper_model
|
||||
else:
|
||||
warn(
|
||||
f"Invalid model type. Please provide a valid model. Fallback to old {_old_model} Model.", RuntimeWarning)
|
||||
|
||||
return None
|
||||
|
||||
def update_diariser(self, dia_model: Union[str, DiarisationType], **kwargs) -> None:
|
||||
"""
|
||||
Update the diariser model.
|
||||
|
||||
Args:
|
||||
dia_model (Union[str, DiarisationType]):
|
||||
The new diariser model to use for diarization.
|
||||
**kwargs:
|
||||
Additional keyword arguments for the diariser model.
|
||||
|
||||
Returns:
|
||||
None
|
||||
"""
|
||||
if isinstance(dia_model, str):
|
||||
self.diariser = Diariser.load_model(dia_model, **kwargs)
|
||||
elif isinstance(dia_model, Diariser):
|
||||
self.diariser = dia_model
|
||||
else:
|
||||
warn("Invalid model type. Please provide a valid model. Fallback to old Model.", RuntimeWarning)
|
||||
|
||||
return None
|
||||
# -----------------
|
||||
# Helpers
|
||||
# -----------------
|
||||
|
||||
@staticmethod
|
||||
def remove_audio_file(audio_file: str,
|
||||
shred: bool = False) -> None:
|
||||
def _format_timestamp(seconds: float) -> str:
|
||||
"""
|
||||
Removes the original audio file to avoid disk space issues or ensure data privacy.
|
||||
Format seconds into MM:SS or HH:MM:SS.
|
||||
"""
|
||||
m, s = divmod(int(seconds), 60)
|
||||
h, m = divmod(m, 60)
|
||||
if h > 0:
|
||||
return f"{h:02d}:{m:02d}:{s:02d}"
|
||||
return f"{m:02d}:{s:02d}"
|
||||
|
||||
Args:
|
||||
audio_file_path (str): Path to the audio file.
|
||||
shred (bool, optional): If True, the audio file will be shredded,
|
||||
not just removed.
|
||||
@staticmethod
|
||||
def remove_audio_file(audio_file: str, shred: bool = False) -> None:
|
||||
"""
|
||||
Remove the original audio file.
|
||||
"""
|
||||
if not os.path.exists(audio_file):
|
||||
raise ValueError(f"Audiofile {audio_file} does not exist.")
|
||||
|
||||
if shred:
|
||||
import subprocess
|
||||
import warnings
|
||||
from glob import iglob
|
||||
|
||||
warn("Shredding audiofile can take a long time.", RuntimeWarning)
|
||||
warnings.warn("Shredding audiofile can take a long time.", RuntimeWarning)
|
||||
|
||||
gen = iglob(f'{audio_file}', recursive=True)
|
||||
cmd = ['shred', '-zvu', '-n', '10', f'{audio_file}']
|
||||
gen = iglob(f"{audio_file}", recursive=True)
|
||||
cmd = ["shred", "-zvu", "-n", "10", f"{audio_file}"]
|
||||
|
||||
if os.path.isdir(audio_file):
|
||||
raise ValueError(f"Audiofile {audio_file} is a directory.")
|
||||
|
||||
for file in gen:
|
||||
print(f'shredding {file} now\n')
|
||||
|
||||
run(cmd, check=True)
|
||||
|
||||
print(f"shredding {file} now\n")
|
||||
subprocess.run(cmd, check=True)
|
||||
else:
|
||||
os.remove(audio_file)
|
||||
print(f"Audiofile {audio_file} removed.")
|
||||
|
||||
@staticmethod
|
||||
def get_audio_file(audio_file: Union[str, torch.Tensor, ndarray]) -> AudioProcessor:
|
||||
"""Gets an audio file as TorchAudioProcessor.
|
||||
|
||||
Args:
|
||||
audio_file (Union[str, torch.Tensor, ndarray]): Path to the audio file or
|
||||
a tensor representing the audio.
|
||||
*args: Additional positional arguments.
|
||||
**kwargs: Additional keyword arguments.
|
||||
|
||||
Returns:
|
||||
AudioProcessor: An object containing the waveform and sample rate in
|
||||
torch.Tensor format.
|
||||
"""
|
||||
|
||||
if isinstance(audio_file, str):
|
||||
audio_file = AudioProcessor.from_file(audio_file)
|
||||
|
||||
elif isinstance(audio_file, torch.Tensor):
|
||||
audio_file = AudioProcessor(audio_file[0], audio_file[1])
|
||||
elif isinstance(audio_file, ndarray):
|
||||
audio_file = AudioProcessor(torch.Tensor(audio_file[0]),
|
||||
audio_file[1])
|
||||
|
||||
if not isinstance(audio_file, AudioProcessor):
|
||||
raise ValueError(f'Audiofile must be of type AudioProcessor,'
|
||||
f'not {type(audio_file)}')
|
||||
|
||||
return audio_file
|
||||
|
||||
def __repr__(self):
|
||||
return f"Scraibe(transcriber={self.transcriber}, diariser={self.diariser})"
|
||||
return "Scraibe(LocalAI-backed)"
|
||||
|
||||
Reference in New Issue
Block a user