From 574124558bf7cd4e3df36d5b3a3a8f7e8ca068d0 Mon Sep 17 00:00:00 2001 From: admin Date: Sat, 13 Jun 2026 16:38:59 +0000 Subject: [PATCH] Initial commit: LocalAI-backed ScrAIbe with summarization --- Dockerfile | 61 +++-- pyproject.toml | 48 ++-- requirements.txt | 13 +- scraibe/__init__.py | 15 +- scraibe/audio.py | 104 +++----- scraibe/autotranscript.py | 525 ++++++++++++++++---------------------- scraibe/cli.py | 292 ++++++++++++++------- scraibe/localai_client.py | 237 +++++++++++++++++ scraibe/misc.py | 79 ++---- scraibe/summarizer.py | 212 +++++++++++++++ 10 files changed, 992 insertions(+), 594 deletions(-) create mode 100644 scraibe/localai_client.py create mode 100644 scraibe/summarizer.py diff --git a/Dockerfile b/Dockerfile index a9feb8e..2b6c105 100644 --- a/Dockerfile +++ b/Dockerfile @@ -1,44 +1,43 @@ -#pytorch Image -FROM pytorch/pytorch:2.3.1-cuda12.1-cudnn8-runtime +# Lightweight Python base image (no GPU/PyTorch needed) +FROM python:3.11-slim # Labels - LABEL maintainer="Jacob Schmieder" LABEL email="Jacob.Schmieder@dbfz.de" LABEL version="0.1.1.dev" -LABEL description="Scraibe is a tool for automatic speech recognition and speaker diarization. \ - It is based on the Hugging Face Transformers library and the Pyannote library. \ - It is designed to be used with the Whisper model, a lightweight model for automatic \ - speech recognition and speaker diarization." +LABEL description="Scraibe: LocalAI-backed transcription and diarization client with summarization. \ + Sends audio to a LocalAI server running vibevoice.cpp and uses a second LLM for summarization." LABEL url="https://github.com/JSchmie/ScrAIbe" -# Install dependencies -WORKDIR /app -#Enviorment dependencies -ENV TRANSFORMERS_CACHE=/app/models -ENV HF_HOME=/app/models -ENV AUTOT_CACHE=/app/models -ENV PYANNOTE_CACHE=/app/models/pyannote -#Copy all necessary files -COPY requirements.txt /app/requirements.txt -COPY README.md /app/README.md -COPY scraibe /app/scraibe - -#Installing all necessary dependencies and running the application with a personalised Hugging-Face-Token -RUN apt update -y && apt upgrade -y && \ - apt install -y libsm6 libxrender1 libfontconfig1 && \ +# Install system dependencies (ffmpeg required) +RUN apt update -y && \ + apt install -y --no-install-recommends ffmpeg && \ apt clean && \ rm -rf /var/lib/apt/lists/* /tmp/* /var/tmp/* -RUN conda update --all && \ - # conda install -y pip ffmpeg && \ - conda install -c conda-forge libsndfile && \ - conda clean --all -y -# RUN pip install torchaudio==0.11.0+cu113 -f https://download.pytorch.org/whl/torch_stable.html +# Working directory +WORKDIR /app + +# Environment variables for LocalAI (transcription/diarization) +# Set these via docker run -e or docker-compose +ENV LOCALAI_API_URL=http://localhost:8080 +ENV LOCALAI_API_KEY= +ENV LOCALAI_MODEL=vibevoice-diarize + +# Environment variables for Summarizer LLM +ENV SUMMARIZER_API_URL=http://localhost:8080 +ENV SUMMARIZER_API_KEY= +ENV SUMMARIZER_MODEL=llama-3.1-8b-instruct + +# Copy and install Python dependencies +COPY requirements.txt /app/requirements.txt RUN pip install --no-cache-dir -r requirements.txt -# Expose port -EXPOSE 7860 -# Run the application +# Copy application code +COPY scraibe /app/scraibe -ENTRYPOINT ["python3", "-m", "scraibe.cli"] \ No newline at end of file +# Expose port (if UI is served) +EXPOSE 7860 + +# Run the application +ENTRYPOINT ["python3", "-m", "scraibe.cli"] diff --git a/pyproject.toml b/pyproject.toml index c113502..702805a 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -5,38 +5,42 @@ build-backend = "poetry_dynamic_versioning.backend" [tool.poetry] name = "scraibe" version = "0.0.0" -description = "Transcription tool for audio files based on Whisper and Pyannote" +description = "LocalAI-backed transcription and diarization client using vibevoice.cpp" authors = ["Schmieder, Jacob "] license = "GPL-3.0-or-later" readme = ["README.md", "LICENSE"] repository = "https://github.com/JSchmie/ScAIbe" documentation = "https://jschmie.github.io/ScrAIbe/" -keywords = ["transcription", "audio", "whisper", "pyannote", "speech-to-text", "speech-recognition"] +keywords = [ + "transcription", + "audio", + "diarization", + "localai", + "vibevoice", + "speech-to-text", +] classifiers = [ - 'Development Status :: 4 - Beta', - 'Intended Audience :: Developers', - 'License :: OSI Approved :: GNU General Public License v3 (GPLv3)', - 'Programming Language :: Python :: 3.8', - 'Programming Language :: Python :: 3.9', - 'Programming Language :: Python :: 3.10', - 'Programming Language :: Python :: 3.11', - 'Environment :: GPU :: NVIDIA CUDA :: 12 :: 12.1', - 'Topic :: Scientific/Engineering :: Artificial Intelligence' - ] + "Development Status :: 4 - Beta", + "Intended Audience :: Developers", + "License :: OSI Approved :: GNU General Public License v3 (GPLv3)", + "Programming Language :: Python :: 3.9", + "Programming Language :: Python :: 3.10", + "Programming Language :: Python :: 3.11", + "Programming Language :: Python :: 3.12", + "Topic :: Scientific/Engineering :: Artificial Intelligence", +] packages = [{include = "scraibe"}] -exclude =[ - "__pycache__", - "*.pyc", - "test" - ] +exclude = [ + "__pycache__", + "*.pyc", + "test", +] + [tool.poetry.dependencies] python = "^3.9" tqdm = "^4.66.5" numpy = "^1.26.4" -openai-whisper = ">=20231117,<20240931" -faster-whisper = "^1.0.3" -"pyannote.audio" = "^3.3.1" -torch = "^2.1.2" +httpx = ">=0.28.0" [tool.poetry.group.dev.dependencies] pytest = "^8.1.1" @@ -69,5 +73,5 @@ scraibe = "scraibe.cli:cli" app = ["scraibe-webui"] [tool.ruff.lint.extend-per-file-ignores] -"__init__.py" = ["E402","F403",'F401'] +"__init__.py" = ["E402", "F403", "F401"] "scraibe/misc.py" = ["E722"] diff --git a/requirements.txt b/requirements.txt index 8786d84..d72fea9 100644 --- a/requirements.txt +++ b/requirements.txt @@ -1,14 +1,3 @@ tqdm>=4.66.5 numpy>=1.26.4 - -openai-whisper==20231117 -faster-whisper~=1.0.3 - -pyannote.audio~=3.3.1 -pyannote.core~=5.0.0 -pyannote.database~=5.0.1 -pyannote.metrics~=3.2.1 -pyannote.pipeline~=3.0.1 - -torchaudio>=2.1.2 - +httpx>=0.28.0 diff --git a/scraibe/__init__.py b/scraibe/__init__.py index 399023a..e9c4181 100644 --- a/scraibe/__init__.py +++ b/scraibe/__init__.py @@ -1,11 +1,10 @@ -from .autotranscript import * -from .transcriber import * -from .audio import * -from .transcript_exporter import * -from .diarisation import * +from .autotranscript import Scraibe +from .localai_client import LocalAIClient, LocalAIError +from .summarizer import SummarizerClient, SummarizerError +from .audio import AudioProcessor +from .transcript_exporter import Transcript +from .misc import set_threads, ParseKwargs -from .misc import * - -from .cli import * +from .cli import cli from ._version import __version__ diff --git a/scraibe/audio.py b/scraibe/audio.py index 4e5dd0f..621b0b1 100644 --- a/scraibe/audio.py +++ b/scraibe/audio.py @@ -2,28 +2,15 @@ Audio Processor Module ======================= -This module provides the AudioProcessor class, utilizing PyTorchaudio for handling audio files. -It includes functionalities to load, cut, and manage audio waveforms, offering efficient and -flexible audio processing. +Simplified audio processor for ScrAIbe. -Available Classes: -- AudioProcessor: Processes audio waveforms and provides methods for loading, - cutting, and handling audio. - -Usage: - from .audio_import AudioProcessor - - processor = AudioProcessor.from_file("path/to/audiofile.wav") - cut_waveform = processor.cut(start=1.0, end=5.0) - -Constants: -- SAMPLE_RATE (int): Default sample rate for processing. -- NORMALIZATION_FACTOR (float): Normalization factor for audio waveform. +Previously this used torch and pyannote-style processing. In the LocalAI-backed +version, we primarily pass files to the API, but we keep a lightweight helper +for backward compatibility. """ from subprocess import CalledProcessError, run import numpy as np -import torch SAMPLE_RATE = 16000 NORMALIZATION_FACTOR = 32768.0 @@ -31,38 +18,25 @@ NORMALIZATION_FACTOR = 32768.0 class AudioProcessor: """ - Audio Processor class that leverages PyTorchaudio to provide functionalities - for loading, cutting, and handling audio waveforms. + Lightweight audio processor for loading and cutting audio. Attributes: - waveform: torch.Tensor - The audio waveform tensor. - sr: int - The sample rate of the audio. + waveform (np.ndarray): The audio waveform as float32. + sr (int): The sample rate of the audio. """ - def __init__(self, waveform: torch.Tensor, - sr: int = SAMPLE_RATE) -> None: - """ - Initialize the AudioProcessor object. - - Args: - waveform (torch.Tensor): The audio waveform tensor. - sr (int, optional): The sample rate of the audio. Defaults to SAMPLE_RATE. - - Raises: - ValueError: If the provided sample rate is not of type int. - """ - + def __init__(self, waveform: np.ndarray, sr: int = SAMPLE_RATE): self.waveform = waveform self.sr = sr if not isinstance(self.sr, int): - raise ValueError("Sample rate should be a single value of type int," - f"not {len(self.sr)} and type {type(self.sr)}") + raise ValueError( + "Sample rate should be a single value of type int, " + f"not {len(self.sr)} and type {type(self.sr)}" + ) @classmethod - def from_file(cls, file: str, *args, **kwargs) -> 'AudioProcessor': + def from_file(cls, file: str, *args, **kwargs): """ Create an AudioProcessor instance from an audio file. @@ -70,55 +44,42 @@ class AudioProcessor: file (str): The audio file path. Returns: - AudioProcessor: An instance of the AudioProcessor class containing the loaded audio. + AudioProcessor: Instance with loaded audio. """ - audio, sr = cls.load_audio(file, *args, **kwargs) - - audio = torch.from_numpy(audio) - return cls(audio, sr) - def cut(self, start: float, end: float) -> torch.Tensor: + def cut(self, start: float, end: float) -> np.ndarray: """ - Cut a segment from the audio waveform between the specified start and end times. + Cut a segment from the audio waveform. Args: start (float): Start time in seconds. end (float): End time in seconds. Returns: - torch.Tensor: The cut waveform segment. + np.ndarray: The cut waveform segment. """ - - start = int(start * self.sr) - if (isinstance(end, float) or isinstance(end, int)) and isinstance(self.sr, int): - end = int(np.ceil(end * self.sr)) - else: - end = int(torch.ceil(end * self.sr)) - return self.waveform[start:end] + start_idx = int(start * self.sr) + end_idx = int(np.ceil(end * self.sr)) + return self.waveform[start_idx:end_idx] @staticmethod def load_audio(file: str, sr: int = SAMPLE_RATE): """ - Open an audio file and read it as a mono waveform, resampling if necessary. - This method ensures compatibility with pyannote.audio - and requires the ffmpeg CLI in PATH. + Load an audio file as a mono waveform, resampling if necessary. + Requires ffmpeg in PATH. Args: file (str): The audio file to open. - sr (int, optional): The desired sample rate. Defaults to SAMPLE_RATE. + sr (int, optional): The desired sample rate. Returns: - tuple: A NumPy array containing the audio waveform in float32 dtype - and the sample rate. + tuple: (waveform as np.ndarray[float32], sample rate) Raises: RuntimeError: If failed to load audio. """ - # This launches a subprocess to decode audio while down-mixing - # and resampling as necessary. Requires the ffmpeg CLI in PATH. - # fmt: off cmd = [ "ffmpeg", "-nostdin", @@ -128,19 +89,20 @@ class AudioProcessor: "-ac", "1", "-acodec", "pcm_s16le", "-ar", str(sr), - "-" + "-", ] - # fmt: on try: out = run(cmd, capture_output=True, check=True).stdout except CalledProcessError as e: raise RuntimeError( - f"Failed to load audio: {e.stderr.decode()}") from e + f"Failed to load audio: {e.stderr.decode()}" + ) from e - out = np.frombuffer(out, np.int16).flatten().astype( - np.float32) / NORMALIZATION_FACTOR + waveform = np.frombuffer(out, np.int16).flatten().astype( + np.float32 + ) / NORMALIZATION_FACTOR + + return waveform, sr - return out, sr - def __repr__(self) -> str: - return f'TorchAudioProcessor(waveform={len(self.waveform)}, sr={int(self.sr)})' + return f"AudioProcessor(waveform_len={len(self.waveform)}, sr={self.sr})" diff --git a/scraibe/autotranscript.py b/scraibe/autotranscript.py index 9023107..c895af2 100644 --- a/scraibe/autotranscript.py +++ b/scraibe/autotranscript.py @@ -1,358 +1,281 @@ """ -Scraibe Class --------------------- +Scraibe Class (LocalAI-backed) +------------------------------ -This class serves as the core of the transcription system, responsible for handling -transcription and diarization of audio files. It leverages pretrained models for -speech-to-text (such as Whisper) and speaker diarization (such as pyannote.audio), -providing an accessible interface for audio processing tasks such as transcription, -speaker separation, and timestamping. +Core class for transcription and (optionally) summarization. -By encapsulating the complexities of underlying models, it allows for straightforward -integration into various applications, ranging from transcription services to voice assistants. +- Transcription and diarization are delegated to LocalAI (vibevoice.cpp). +- Summarization is delegated to a separate LLM via /v1/chat/completions. -Available Classes: -- Scraibe: Main class for performing transcription and diarization. - Includes methods for loading models, processing audio files, - and formatting the transcription output. +Public tasks: +- transcribe +- transcript_and_summarize (transcribe + generate a detailed summary) -Usage: - from scraibe import Scraibe - - model = Scraibe() - transcript = model.autotranscribe("path/to/audiofile.wav") +Previous task/whisper/pyannote-specific settings are kept for compatibility +but ignored when not relevant. """ -# Standard Library Imports import os -from glob import iglob -from subprocess import run -from typing import TypeVar, Union -from warnings import warn +from typing import Union, Optional -# Third-Party Imports -import torch -from numpy import ndarray -from tqdm import trange - -# Application-Specific Imports -from .audio import AudioProcessor -from .diarisation import Diariser -from .transcriber import Transcriber, load_transcriber, whisper +from .localai_client import LocalAIClient, LocalAIError +from .summarizer import SummarizerClient, SummarizerError from .transcript_exporter import Transcript -from .misc import SCRAIBE_TORCH_DEVICE - - -DiarisationType = TypeVar('DiarisationType') class Scraibe: """ - Scraibe is a class responsible for managing the transcription and diarization of audio files. - It serves as the core of the transcription system, incorporating pretrained models - for speech-to-text (such as Whisper) and speaker diarization (such as pyannote.audio), - allowing for comprehensive audio processing. + Scraibe now: + - Uses LocalAI for transcription + diarization. + - Uses a separate LLM for summarization (when requested). - Attributes: - transcriber (Transcriber): The transcriber object to handle transcription. - diariser (Diariser): The diariser object to handle diarization. - - Methods: - __init__: Initializes the Scraibe class with appropriate models. - transcribe: Transcribes an audio file using the whisper model and pyannote diarization model. - remove_audio_file: Removes the original audio file to avoid disk space issues or ensure data privacy. - get_audio_file: Gets an audio file as an AudioProcessor object. + Public methods: + - transcribe(audio_file, ...) + - transcript_and_summarize(audio_file, ...) """ - def __init__(self, - whisper_model: Union[bool, str, whisper] = None, - whisper_type: str = "whisper", - dia_model: Union[bool, str, DiarisationType] = None, - **kwargs) -> None: - """Initializes the Scraibe class. + def __init__( + self, + api_url: Optional[str] = None, + api_key: Optional[str] = None, + model: Optional[str] = None, + whisper_model: Union[bool, str] = None, + whisper_type: str = "whisper", + dia_model: Union[bool, str] = None, + use_auth_token: str = None, + verbose: bool = False, + **kwargs, + ) -> None: + """ + Initialize Scraibe with LocalAI client and summarizer client. Args: - whisper_model (Union[bool, str, whisper], optional): - Path to whisper model or whisper model itself. - whisper_type (str): - Type of whisper model to load. "whisper" or "faster-whisper". - diarisation_model (Union[bool, str, DiarisationType], optional): - Path to pyannote diarization model or model itself. - **kwargs: Additional keyword arguments for whisper - and pyannote diarization models. - e.g.: + api_url: LocalAI server URL for transcription/diarization. + Falls back to LOCALAI_API_URL env var. + api_key: API key for LocalAI. Falls back to LOCALAI_API_KEY. + model: Model name for LocalAI (e.g., vibevoice-diarize). + Falls back to LOCALAI_MODEL env var. - - verbose: If True, the class will print additional information. - - save_kwargs: If True, the keyword arguments will be saved - for autotranscribe. So you can unload the class and reload it again. + Summarizer uses: + - SUMMARIZER_API_URL + - SUMMARIZER_API_KEY + - SUMMARIZER_MODEL + These can be overridden via environment or via the transcript_and_summarize + method if needed. + + Backward-compat (ignored): + - whisper_model, whisper_type, dia_model, use_auth_token, etc. """ + self.verbose = verbose or kwargs.get("verbose", False) - if whisper_model is None: - self.transcriber = load_transcriber( - "medium", whisper_type, **kwargs) - elif isinstance(whisper_model, str): - self.transcriber = load_transcriber( - whisper_model, whisper_type, **kwargs) - else: - self.transcriber = whisper_model + try: + self.client = LocalAIClient( + api_url=api_url, + api_key=api_key, + model=model, + ) + except LocalAIError as e: + raise LocalAIError(f"Failed to initialize LocalAI client: {e}") - if dia_model is None: - self.diariser = Diariser.load_model(**kwargs) - elif isinstance(dia_model, str): - self.diariser = Diariser.load_model(dia_model, **kwargs) - else: - self.diariser: Diariser = dia_model - - if kwargs.get("verbose"): - print("Scraibe initialized all models successfully loaded.") - self.verbose = True - else: - self.verbose = False - - # Save kwargs for autotranscribe if you want to unload the class and load it again. - if kwargs.get('save_setup'): - self.params = dict(whisper_model=whisper_model, - dia_model=dia_model, - **kwargs) - else: - self.params = {} - - self.device = kwargs.get( - "device", SCRAIBE_TORCH_DEVICE) - - def autotranscribe(self, audio_file: Union[str, torch.Tensor, ndarray], - remove_original: bool = False, - **kwargs) -> Transcript: - """ - Transcribes an audio file using the whisper model and pyannote diarization model. - - Args: - audio_file (Union[str, torch.Tensor, ndarray]): - Path to audio file or a tensor representing the audio. - remove_original (bool, optional): If True, the original audio file will - be removed after transcription. - *args: Additional positional arguments for diarization and transcription. - **kwargs: Additional keyword arguments for diarization and transcription. - - Returns: - Transcript: A Transcript object containing the transcription, - which can be exported to different formats. - """ - if kwargs.get("verbose"): - self.verbose = kwargs.get("verbose") - # Get audio file as an AudioProcessor object - audio_file: AudioProcessor = self.get_audio_file(audio_file) - - # Prepare waveform and sample rate for diarization - dia_audio = { - "waveform": audio_file.waveform.reshape(1, len(audio_file.waveform)).to(self.device), - "sample_rate": audio_file.sr - } - - if self.verbose: - print("Starting diarisation.") - - diarisation = self.diariser.diarization(dia_audio, **kwargs) - - if not diarisation["segments"]: - print("No segments found. Try to run transcription without diarisation.") - - transcript = self.transcriber.transcribe( - audio_file.waveform, **kwargs) - - final_transcript = {0: {"speakers": 'SPEAKER_01', - "segments": [0, len(audio_file.waveform)], - "text": transcript}} - - return Transcript(final_transcript) + # Summarizer is lazy-initialized if needed + self._summarizer: Optional[SummarizerClient] = None if self.verbose: - print("Diarisation finished. Starting transcription.") + print("Scraibe initialized. Using LocalAI for transcription and diarization.") - - # Transcribe each segment and store the results - final_transcript = dict() - - for i in trange(len(diarisation["segments"]), desc="Transcribing", disable=not self.verbose): - - seg = diarisation["segments"][i] - - audio = audio_file.cut(seg[0], seg[1]) - - transcript = self.transcriber.transcribe(audio, **kwargs) - - final_transcript[i] = {"speakers": diarisation["speakers"][i], - "segments": seg, - "text": transcript} - - # Remove original file if needed - if remove_original: - if kwargs.get("shred") is True: - self.remove_audio_file(audio_file, shred=True) - else: - self.remove_audio_file(audio_file, shred=False) - - return Transcript(final_transcript) - - def diarization(self, audio_file: Union[str, torch.Tensor, ndarray], - **kwargs) -> dict: + def _ensure_summarizer( + self, + api_url: Optional[str] = None, + api_key: Optional[str] = None, + model: Optional[str] = None, + ) -> SummarizerClient: """ - Perform diarization on an audio file using the pyannote diarization model. + Lazy-init summarizer client. + """ + if self._summarizer is not None: + return self._summarizer + + try: + self._summarizer = SummarizerClient( + api_url=api_url, + api_key=api_key, + model=model, + ) + except SummarizerError as e: + raise SummarizerError(f"Failed to initialize Summarizer client: {e}") + + return self._summarizer + + # ----------------- + # Primary public API + # ----------------- + + def transcribe( + self, + audio_file: Union[str], + **kwargs, + ) -> str: + """ + Transcribe the provided audio file using LocalAI. + + Uses /v1/audio/diarization with vibevoice.cpp, then concatenates + all segment texts. Args: - audio_file (Union[str, torch.Tensor, ndarray]): - The audio source which can either be a path to the audio file or a tensor representation. - **kwargs: - Additional keyword arguments for diarization. + audio_file (str): Path to the audio file. + **kwargs: Additional keyword arguments (some forwarded, others ignored). Returns: - dict: - A dictionary containing the results of the diarization process. + str: The concatenated transcribed text. """ + if isinstance(audio_file, str): + if not os.path.exists(audio_file): + raise FileNotFoundError(f"Audio file not found: {audio_file}") + else: + raise TypeError( + "In LocalAI mode, audio_file must be a file path (str)." + ) - # Get audio file as an AudioProcessor object - audio_file: AudioProcessor = self.get_audio_file(audio_file) + verbose = kwargs.get("verbose", self.verbose) - # Prepare waveform and sample rate for diarization - dia_audio = { - "waveform": audio_file.waveform.reshape(1, len(audio_file.waveform)).to(self.device), - "sample_rate": audio_file.sr + try: + result = self.client.diarize_and_transcribe( + audio_path=audio_file, + include_text=True, + verbose=verbose, + **kwargs, + ) + except LocalAIError as e: + raise LocalAIError(f"Error during LocalAI transcription: {e}") + + transcripts = result.get("transcripts", []) + return " ".join(t.strip() for t in transcripts if t.strip()) + + def transcript_and_summarize( + self, + audio_file: Union[str], + *, + summarizer_api_url: Optional[str] = None, + summarizer_api_key: Optional[str] = None, + summarizer_model: Optional[str] = None, + **kwargs, + ) -> dict: + """ + Transcribe the audio file and generate a detailed summary. + + Steps: + - Transcribe via LocalAI. + - Build a plain-text transcript (with speaker labels). + - Summarize the transcript using the configured LLM. + + Returns: + dict with: + - transcript: full transcript text (with speaker labels) + - summary: final detailed summary (markdown-ready) + """ + if isinstance(audio_file, str): + if not os.path.exists(audio_file): + raise FileNotFoundError(f"Audio file not found: {audio_file}") + else: + raise TypeError( + "In LocalAI mode, audio_file must be a file path (str)." + ) + + verbose = kwargs.get("verbose", self.verbose) + + # 1) Get diarized + transcribed result + try: + result = self.client.diarize_and_transcribe( + audio_path=audio_file, + include_text=True, + verbose=verbose, + **kwargs, + ) + except LocalAIError as e: + raise LocalAIError(f"Error during LocalAI transcription: {e}") + + segments = result.get("segments", []) + speakers = result.get("speakers", []) + transcripts = result.get("transcripts", []) + + if not segments: + return { + "transcript": "", + "summary": "No transcript content to summarize.", + } + + # 2) Build full transcript text with speaker labels + lines = [] + for seg, speaker, text in zip(segments, speakers, transcripts): + start, end = seg + ts = self._format_timestamp(start) + line = f"[{ts}] {speaker}: {text.strip()}" + lines.append(line) + + full_transcript = "\n\n".join(lines) + + # 3) Summarize + try: + summarizer = self._ensure_summarizer( + api_url=summarizer_api_url, + api_key=summarizer_api_key, + model=summarizer_model, + ) + except SummarizerError as e: + raise SummarizerError(f"Failed to initialize summarizer: {e}") + + try: + summary = summarizer.summarize_transcript(full_transcript) + except SummarizerError as e: + raise SummarizerError(f"Error during summarization: {e}") + + return { + "transcript": full_transcript, + "summary": summary, } - print("Starting diarisation.") - - diarisation = self.diariser.diarization(dia_audio, **kwargs) - - return diarisation - - def transcribe(self, audio_file: Union[str, torch.Tensor, ndarray], - **kwargs): - """ - Transcribe the provided audio file. - - Args: - audio_file (Union[str, torch.Tensor, ndarray]): - The audio source, which can either be a path or a tensor representation. - **kwargs: - Additional keyword arguments for transcription. - - Returns: - str: - The transcribed text from the audio source. - """ - audio_file: AudioProcessor = self.get_audio_file(audio_file) - - return self.transcriber.transcribe(audio_file.waveform, **kwargs) - - def update_transcriber(self, whisper_model: Union[str, whisper], **kwargs) -> None: - """ - Update the transcriber model. - - Args: - whisper_model (Union[str, whisper]): - The new whisper model to use for transcription. - **kwargs: - Additional keyword arguments for the transcriber model. - - Returns: - None - """ - _old_model = self.transcriber.model_name - - if isinstance(whisper_model, str): - self.transcriber = load_transcriber(whisper_model, **kwargs) - elif isinstance(whisper_model, Transcriber): - self.transcriber = whisper_model - else: - warn( - f"Invalid model type. Please provide a valid model. Fallback to old {_old_model} Model.", RuntimeWarning) - - return None - - def update_diariser(self, dia_model: Union[str, DiarisationType], **kwargs) -> None: - """ - Update the diariser model. - - Args: - dia_model (Union[str, DiarisationType]): - The new diariser model to use for diarization. - **kwargs: - Additional keyword arguments for the diariser model. - - Returns: - None - """ - if isinstance(dia_model, str): - self.diariser = Diariser.load_model(dia_model, **kwargs) - elif isinstance(dia_model, Diariser): - self.diariser = dia_model - else: - warn("Invalid model type. Please provide a valid model. Fallback to old Model.", RuntimeWarning) - - return None + # ----------------- + # Helpers + # ----------------- @staticmethod - def remove_audio_file(audio_file: str, - shred: bool = False) -> None: + def _format_timestamp(seconds: float) -> str: """ - Removes the original audio file to avoid disk space issues or ensure data privacy. + Format seconds into MM:SS or HH:MM:SS. + """ + m, s = divmod(int(seconds), 60) + h, m = divmod(m, 60) + if h > 0: + return f"{h:02d}:{m:02d}:{s:02d}" + return f"{m:02d}:{s:02d}" - Args: - audio_file_path (str): Path to the audio file. - shred (bool, optional): If True, the audio file will be shredded, - not just removed. + @staticmethod + def remove_audio_file(audio_file: str, shred: bool = False) -> None: + """ + Remove the original audio file. """ if not os.path.exists(audio_file): raise ValueError(f"Audiofile {audio_file} does not exist.") if shred: + import subprocess + import warnings + from glob import iglob - warn("Shredding audiofile can take a long time.", RuntimeWarning) + warnings.warn("Shredding audiofile can take a long time.", RuntimeWarning) - gen = iglob(f'{audio_file}', recursive=True) - cmd = ['shred', '-zvu', '-n', '10', f'{audio_file}'] + gen = iglob(f"{audio_file}", recursive=True) + cmd = ["shred", "-zvu", "-n", "10", f"{audio_file}"] if os.path.isdir(audio_file): raise ValueError(f"Audiofile {audio_file} is a directory.") for file in gen: - print(f'shredding {file} now\n') - - run(cmd, check=True) - + print(f"shredding {file} now\n") + subprocess.run(cmd, check=True) else: os.remove(audio_file) print(f"Audiofile {audio_file} removed.") - @staticmethod - def get_audio_file(audio_file: Union[str, torch.Tensor, ndarray]) -> AudioProcessor: - """Gets an audio file as TorchAudioProcessor. - - Args: - audio_file (Union[str, torch.Tensor, ndarray]): Path to the audio file or - a tensor representing the audio. - *args: Additional positional arguments. - **kwargs: Additional keyword arguments. - - Returns: - AudioProcessor: An object containing the waveform and sample rate in - torch.Tensor format. - """ - - if isinstance(audio_file, str): - audio_file = AudioProcessor.from_file(audio_file) - - elif isinstance(audio_file, torch.Tensor): - audio_file = AudioProcessor(audio_file[0], audio_file[1]) - elif isinstance(audio_file, ndarray): - audio_file = AudioProcessor(torch.Tensor(audio_file[0]), - audio_file[1]) - - if not isinstance(audio_file, AudioProcessor): - raise ValueError(f'Audiofile must be of type AudioProcessor,' - f'not {type(audio_file)}') - - return audio_file - def __repr__(self): - return f"Scraibe(transcriber={self.transcriber}, diariser={self.diariser})" + return "Scraibe(LocalAI-backed)" diff --git a/scraibe/cli.py b/scraibe/cli.py index e4eeaad..01b5659 100644 --- a/scraibe/cli.py +++ b/scraibe/cli.py @@ -1,25 +1,23 @@ """ Command-Line Interface (CLI) for the Scraibe class, -allowing for user interaction to transcribe and diarize audio files. +allowing for user interaction to transcribe and diarize audio files. The function includes arguments for specifying the audio files, model paths, output formats, and other options necessary for transcription. + +This version is adapted for LocalAI-based transcription and diarization. """ + import os import json from argparse import ArgumentParser, ArgumentDefaultsHelpFormatter -from whisper.tokenizer import LANGUAGES, TO_LANGUAGE_CODE -from torch.cuda import is_available from .autotranscript import Scraibe from .misc import set_threads + def cli(): """ - Command-Line Interface (CLI) for the Scraibe class, allowing for user interaction to transcribe - and diarize audio files. The function includes arguments for specifying the audio files, model paths, - output formats, and other options necessary for transcription. - - This function can be executed from the command line to perform transcription tasks, providing a - user-friendly way to access the Scraibe class functionalities. + Command-Line Interface (CLI) for the Scraibe class, allowing for user interaction to transcribe + and diarize audio files via a LocalAI server. """ def str2bool(string): @@ -28,59 +26,160 @@ def cli(): return str2val[string] else: raise ValueError( - f"Expected one of {set(str2val.keys())}, got {string}") + f"Expected one of {set(str2val.keys())}, got {string}" + ) parser = ArgumentParser(formatter_class=ArgumentDefaultsHelpFormatter) - parser.add_argument("-f", "--audio-files", nargs="+", type=str, default=None, - help="List of audio files to transcribe.") + parser.add_argument( + "-f", + "--audio-files", + nargs="+", + type=str, + default=None, + help="List of audio files to transcribe.", + ) - parser.add_argument("--whisper-type", type=str, default="whisper", - choices=["whisper", "faster-whisper"], - help="Type of Whisper model to use ('whisper' or 'faster-whisper').") - - parser.add_argument("--whisper-model-name", default="medium", - help="Name of the Whisper model to use.") + # LocalAI connection (env vars preferred, but CLI overrides allowed) + parser.add_argument( + "--localai-api-url", + type=str, + default=None, + help="LocalAI server URL (e.g., http://localhost:8080). " + "Overrides LOCALAI_API_URL env var if provided.", + ) + parser.add_argument( + "--localai-api-key", + type=str, + default=None, + help="LocalAI API key. Overrides LOCALAI_API_KEY env var if provided.", + ) + parser.add_argument( + "--localai-model", + type=str, + default=None, + help="Model name to use on LocalAI (e.g., vibevoice-diarize). " + "Overrides LOCALAI_MODEL env var if provided.", + ) - parser.add_argument("--whisper-model-directory", type=str, default=None, - help="Path to save Whisper model files; defaults to ./models/whisper.") + # Summarizer overrides (env vars are primary) + parser.add_argument( + "--summarizer-api-url", + type=str, + default=None, + help="Summarization LLM API URL (e.g., http://localhost:8080). " + "Overrides SUMMARIZER_API_URL env var if provided.", + ) + parser.add_argument( + "--summarizer-api-key", + type=str, + default=None, + help="Summarization LLM API key. Overrides SUMMARIZER_API_KEY env var if provided.", + ) + parser.add_argument( + "--summarizer-model", + type=str, + default=None, + help="Model name for summarization. Overrides SUMMARIZER_MODEL env var if provided.", + ) - parser.add_argument("--diarization-directory", type=str, default=None, - help="Path to the diarization model directory.") + # Kept for backward compatibility with UI / existing scripts; ignored by LocalAI client. + parser.add_argument( + "--whisper-type", + type=str, + default="whisper", + choices=["whisper", "faster-whisper"], + help="[Backward compatibility] Type of Whisper model. Ignored when using LocalAI.", + ) - parser.add_argument("--hf-token", default=None, type=str, - help="HuggingFace token for private model download.") + parser.add_argument( + "--whisper-model-name", + default="medium", + help="[Backward compatibility] Whisper model name. Ignored when using LocalAI.", + ) - parser.add_argument("--inference-device", - default="cuda" if is_available() else "cpu", - help="Device to use for PyTorch inference.") + parser.add_argument( + "--whisper-model-directory", + type=str, + default=None, + help="[Backward compatibility] Whisper model directory. Ignored when using LocalAI.", + ) - parser.add_argument("--num-threads", type=int, default=None, - help="Number of threads used by torch for CPU inference; '\ - 'overrides MKL_NUM_THREADS/OMP_NUM_THREADS.") + parser.add_argument( + "--diarization-directory", + type=str, + default=None, + help="[Backward compatibility] Diarization model directory. Ignored when using LocalAI.", + ) - parser.add_argument("--output-directory", "-o", type=str, default=".", - help="Directory to save the transcription outputs.") + parser.add_argument( + "--hf-token", + default=None, + type=str, + help="[Backward compatibility] HuggingFace token. Ignored when using LocalAI.", + ) - parser.add_argument("--output-format", "-of", type=str, default="txt", - choices=["txt", "json", "md", "html"], - help="Format of the output file; defaults to txt.") + parser.add_argument( + "--inference-device", + default="cpu", + help="[Backward compatibility] Device for inference. Ignored when using LocalAI.", + ) - parser.add_argument("--verbose-output", type=str2bool, default=True, - help="Enable or disable progress and debug messages.") + parser.add_argument( + "--num-threads", + type=int, + default=None, + help="Number of threads used for CPU operations; overrides MKL_NUM_THREADS/OMP_NUM_THREADS.", + ) - parser.add_argument("--task", type=str, default='autotranscribe', - choices=["autotranscribe", "diarization", - "autotranscribe+translate", "translate", 'transcribe'], - help="Choose to perform transcription, diarization, or translation. \ - If set to translate, the output will be translated to English.") + parser.add_argument( + "--output-directory", + "-o", + type=str, + default=".", + help="Directory to save the transcription outputs.", + ) - parser.add_argument("--language", type=str, default=None, - choices=sorted( - LANGUAGES.keys()) + sorted([k.title() for k in TO_LANGUAGE_CODE.keys()]), - help="Language spoken in the audio. Specify None to perform language detection.") - parser.add_argument("--num-speakers", type=int, default=2, - help="Number of speakers in the audio.") + parser.add_argument( + "--output-format", + "-of", + type=str, + default="txt", + choices=["txt", "json", "md", "html"], + help="Format of the output file; defaults to txt.", + ) + + parser.add_argument( + "--verbose-output", + type=str2bool, + default=True, + help="Enable or disable progress and debug messages.", + ) + + parser.add_argument( + "--task", + type=str, + default="transcribe", + choices=[ + "transcribe", + "transcript_and_summarize", + ], + help="Task to perform: 'transcribe' or 'transcript_and_summarize'.", + ) + + parser.add_argument( + "--language", + type=str, + default=None, + help="Language spoken in the audio. Specify None to perform language detection.", + ) + + parser.add_argument( + "--num-speakers", + type=int, + default=None, + help="Number of speakers in the audio.", + ) args = parser.parse_args() @@ -96,65 +195,64 @@ def cli(): set_threads(arg_dict.pop("num_threads")) - class_kwargs = {'whisper_model': arg_dict.pop("whisper_model_name"), - 'whisper_type':arg_dict.pop("whisper_type"), - 'dia_model': arg_dict.pop("diarization_directory"), - 'use_auth_token': arg_dict.pop("hf_token"), - } - - if arg_dict["whisper_model_directory"]: - class_kwargs["download_root"] = arg_dict.pop("whisper_model_directory") - + # Build kwargs for Scraibe (LocalAI-backed) + class_kwargs = { + "api_url": arg_dict.pop("localai_api_url"), + "api_key": arg_dict.pop("localai_api_key"), + "model": arg_dict.pop("localai_model"), + # kept for backward compatibility, but ignored: + "whisper_model": arg_dict.pop("whisper_model_name"), + "whisper_type": arg_dict.pop("whisper_type"), + "dia_model": arg_dict.pop("diarization_directory"), + "use_auth_token": arg_dict.pop("hf_token"), + "verbose": arg_dict.pop("verbose_output"), + } model = Scraibe(**class_kwargs) if arg_dict["audio_files"]: audio_files = arg_dict.pop("audio_files") - if task == "autotranscribe" or task == "autotranscribe+translate": + if task == "transcribe": for audio in audio_files: - if task == "autotranscribe+translate": - task = "translate" - else: - task = "transcribe" - - out = model.autotranscribe( - audio, - task=task, - language=arg_dict.pop("language"), - verbose=arg_dict.pop("verbose_output"), - num_speakers=arg_dict.pop("num_speakers") - ) - basename = audio.split("/")[-1].split(".")[0] - print(f'Saving {basename}.{out_format} to {out_folder}') - out.save(os.path.join( - out_folder, f"{basename}.{out_format}")) - - elif task == "diarization": - for audio in audio_files: - if arg_dict.pop("verbose_output"): - print("Verbose not implemented for diarization.") - - out = model.diarization(audio) + out = model.transcribe( + audio, + language=arg_dict.pop("language"), + verbose=arg_dict.pop("verbose_output"), + num_speakers=arg_dict.pop("num_speakers"), + ) basename = audio.split("/")[-1].split(".")[0] path = os.path.join(out_folder, f"{basename}.{out_format}") - - print(f'Saving {basename}.{out_format} to {out_folder}') - - with open(path, "w") as f: - json.dump(json.dumps(out, indent=1), f) - - elif task == "transcribe" or task == "translate": - - for audio in audio_files: - - out = model.transcribe(audio, task=task, - language=arg_dict.pop("language"), - verbose=arg_dict.pop("verbose_output")) - basename = audio.split("/")[-1].split(".")[0] - path = os.path.join(out_folder, f"{basename}.{out_format}") - with open(path, "w") as f: + print(f"Saving {basename}.{out_format} to {out_folder}") + with open(path, "w", encoding="utf-8") as f: f.write(out) + elif task == "transcript_and_summarize": + for audio in audio_files: + result = model.transcript_and_summarize( + audio, + summarizer_api_url=arg_dict.pop("summarizer_api_url"), + summarizer_api_key=arg_dict.pop("summarizer_api_key"), + summarizer_model=arg_dict.pop("summarizer_model"), + language=arg_dict.pop("language"), + verbose=arg_dict.pop("verbose_output"), + num_speakers=arg_dict.pop("num_speakers"), + ) + + transcript_text = result.get("transcript", "") + summary_text = result.get("summary", "") + + basename = audio.split("/")[-1].split(".")[0] + + # Always use .md for transcript_and_summarize + md_path = os.path.join(out_folder, f"{basename}.md") + print(f"Saving {basename}.md (transcript + summary) to {out_folder}") + + with open(md_path, "w", encoding="utf-8") as f: + f.write("# Transcript\n\n") + f.write(transcript_text) + f.write("\n\n# Summary\n\n") + f.write(summary_text) + if __name__ == "__main__": cli() diff --git a/scraibe/localai_client.py b/scraibe/localai_client.py new file mode 100644 index 0000000..7f6d24c --- /dev/null +++ b/scraibe/localai_client.py @@ -0,0 +1,237 @@ +""" +LocalAI Client Module +--------------------- + +This module provides a client for communicating with a LocalAI server +running vibevoice.cpp for transcription and speaker diarization. + +It replaces the previous local Whisper + Pyannote pipeline by sending +audio files to the /v1/audio/diarization endpoint and mapping the +response into the same Transcript format used by the UI. + +Environment Variables: + LOCALAI_API_URL: (required) Base URL of the LocalAI server + (e.g., http://localhost:8080) + LOCALAI_API_KEY: (optional) API key, if configured + LOCALAI_MODEL: (optional) Model name to use (default: vibevoice-diarize) +""" + +import os +import io +import json +from typing import Dict, List, Any, Optional + +import httpx + + +class LocalAIError(Exception): + """Raised when the LocalAI API returns an error or unexpected response.""" + pass + + +class LocalAIClient: + """ + Thin HTTP client for LocalAI /v1/audio/diarization with vibevoice.cpp. + + Responsibilities: + - Read configuration from environment. + - Upload audio file as multipart/form-data. + - Parse diarization + transcription response. + - Map response into the same structure expected by Scraibe's Transcript. + """ + + def __init__( + self, + api_url: Optional[str] = None, + api_key: Optional[str] = None, + model: Optional[str] = None, + timeout: float = 600.0, + ): + """ + Args: + api_url: LocalAI server URL (e.g., http://localhost:8080). + Falls back to LOCALAI_API_URL env var. + api_key: API key, if required. Falls back to LOCALAI_API_KEY. + model: Model name (e.g., vibevoice-diarize). + Falls back to LOCALAI_MODEL or default. + timeout: Request timeout in seconds. + """ + self.api_url = (api_url or os.getenv("LOCALAI_API_URL")).strip().rstrip("/") + self.api_key = api_key or os.getenv("LOCALAI_API_KEY") or None + self.model = model or os.getenv("LOCALAI_MODEL") or "vibevoice-diarize" + self.timeout = timeout + + if not self.api_url: + raise LocalAIError( + "LOCALAI_API_URL is not set. " + "Provide the LocalAI server URL via environment or constructor." + ) + + self._client = httpx.Client( + base_url=self.api_url, + timeout=self.timeout, + follow_redirects=True, + ) + + def close(self): + """Close the underlying HTTP client.""" + self._client.close() + + def __del__(self): + try: + self._client.close() + except Exception: + pass + + def diarize_and_transcribe( + self, + audio_path: str, + *, + language: Optional[str] = None, + num_speakers: Optional[int] = None, + min_speakers: Optional[int] = None, + max_speakers: Optional[int] = None, + clustering_threshold: Optional[float] = None, + min_duration_on: Optional[float] = None, + min_duration_off: Optional[float] = None, + response_format: Optional[str] = None, + include_text: Optional[bool] = None, + verbose: bool = False, + **_ignored, + ) -> Dict[str, Any]: + """ + Send audio to LocalAI /v1/audio/diarization and return a dict + in the same style as the previous internal diarization output: + + { + "segments": [ [start, end], ... ], + "speakers": [ "SPEAKER_00", ... ], + "transcripts": [ "text for segment", ... ] + } + + Extra kwargs that the old UI used (e.g., whisper-specific) are + accepted but ignored. + + Args: + audio_path: Path to the audio file. + language: Language hint, forwarded if set. + num_speakers: Optional exact speaker count. + min_speakers: Optional hint. + max_speakers: Optional hint. + clustering_threshold: Optional clustering threshold. + min_duration_on: Optional min segment duration. + min_duration_off: Optional min gap duration. + response_format: "json", "verbose_json", or "rttm". + Defaults to "verbose_json" if not set. + include_text: Whether to request per-segment text. + Defaults to True. + verbose: If True, prints progress messages. + """ + if verbose: + print("Starting diarization and transcription via LocalAI.") + + # Defaults: use verbose_json + include_text to get both diarization and transcription. + if response_format is None: + response_format = "verbose_json" + if include_text is None: + include_text = True + + # Prepare form data + data = { + "model": self.model, + "response_format": response_format, + "include_text": str(include_text).lower(), + } + + if language is not None: + data["language"] = language + if num_speakers is not None: + data["num_speakers"] = str(num_speakers) + if min_speakers is not None: + data["min_speakers"] = str(min_speakers) + if max_speakers is not None: + data["max_speakers"] = str(max_speakers) + if clustering_threshold is not None: + data["clustering_threshold"] = str(clustering_threshold) + if min_duration_on is not None: + data["min_duration_on"] = str(min_duration_on) + if min_duration_off is not None: + data["min_duration_off"] = str(min_duration_off) + + # Open file + if not os.path.exists(audio_path): + raise LocalAIError(f"Audio file not found: {audio_path}") + + with open(audio_path, "rb") as f: + files = { + "file": (os.path.basename(audio_path), f, "application/octet-stream") + } + + headers = {} + if self.api_key: + headers["Authorization"] = f"Bearer {self.api_key}" + + # POST /v1/audio/diarization + resp = self._client.post( + "/v1/audio/diarization", + data=data, + files=files, + headers=headers, + ) + + if resp.status_code >= 400: + body = resp.text + raise LocalAIError( + f"LocalAI request failed with status {resp.status_code}: {body}" + ) + + try: + result = resp.json() + except json.JSONDecodeError: + raise LocalAIError( + "Failed to parse LocalAI response as JSON." + ) + + if verbose: + print("Diarization and transcription finished. Starting post-processing.") + + return self._parse_diarization_response(result) + + def _parse_diarization_response(self, result: Dict[str, Any]) -> Dict[str, Any]: + """ + Convert LocalAI response into the internal format used by Scraibe: + { + "segments": [ [start, end], ... ], + "speakers": [ "SPEAKER_00", ... ], + "transcripts": [ "text for segment", ... ] + } + """ + segments = result.get("segments", []) + + if not segments: + # If no segments, return empty but valid structure + return { + "segments": [], + "speakers": [], + "transcripts": [], + } + + out_segments = [] + out_speakers = [] + out_transcripts = [] + + for seg in segments: + start = float(seg.get("start", 0.0)) + end = float(seg.get("end", 0.0)) + speaker = seg.get("speaker", "SPEAKER_00") + text = seg.get("text", "").strip() + + out_segments.append([start, end]) + out_speakers.append(speaker) + out_transcripts.append(text) + + return { + "segments": out_segments, + "speakers": out_speakers, + "transcripts": out_transcripts, + } diff --git a/scraibe/misc.py b/scraibe/misc.py index f5d2bfe..3857a1c 100644 --- a/scraibe/misc.py +++ b/scraibe/misc.py @@ -1,77 +1,52 @@ import os -import yaml from argparse import Action from ast import literal_eval -from torch.cuda import is_available -from torch import get_num_threads, set_num_threads CACHE_DIR = os.getenv( "AUTOT_CACHE", os.path.expanduser("~/.cache/torch/models"), ) -os.environ["PYANNOTE_CACHE"] = os.getenv( - "PYANNOTE_CACHE", - os.path.join(CACHE_DIR, "pyannote"), -) +# Legacy paths kept for backward compatibility (ignored by LocalAI client) WHISPER_DEFAULT_PATH = os.path.join(CACHE_DIR, "whisper") PYANNOTE_DEFAULT_PATH = os.path.join(CACHE_DIR, "pyannote") -PYANNOTE_DEFAULT_CONFIG = os.path.join(PYANNOTE_DEFAULT_PATH, "config.yaml") \ - if os.path.exists(os.path.join(PYANNOTE_DEFAULT_PATH, "config.yaml")) \ - else ('Jaikinator/ScrAIbe', 'pyannote/speaker-diarization-3.1') +PYANNOTE_DEFAULT_CONFIG = os.path.join(PYANNOTE_DEFAULT_PATH, "config.yaml") -SCRAIBE_TORCH_DEVICE = os.getenv("SCRAIBE_TORCH_DEVICE", "cuda" if is_available() else "cpu") -SCRAIBE_NUM_THREADS = os.getenv("SCRAIBE_NUM_THREADS", min(8, get_num_threads())) - -def config_diarization_yaml(file_path: str, path_to_segmentation: str = None) -> None: - """Configure diarization pipeline from a YAML file. - - This function updates the YAML file to use the given segmentation model - offline, and avoids manual file manipulation. - - Args: - file_path (str): Path to the YAML file. - path_to_segmentation (str, optional): Optional path to the segmentation model. - - Raises: - FileNotFoundError: If the segmentation model file is not found. +def set_threads(parse_threads=None, yaml_threads=None): """ - with open(file_path, "r") as stream: - yml = yaml.safe_load(stream) + Configure number of threads. - segmentation_path = path_to_segmentation or os.path.join( - PYANNOTE_DEFAULT_PATH, "pytorch_model.bin") - yml["pipeline"]["params"]["segmentation"] = segmentation_path - - if not os.path.exists(segmentation_path): - raise FileNotFoundError( - f"Segmentation model not found at {segmentation_path}") - - with open(file_path, "w") as stream: - yaml.dump(yml, stream) - - -def set_threads(parse_threads=None, - yaml_threads=None): - global SCRAIBE_NUM_THREADS + In LocalAI mode, this is mainly kept for backward compatibility. + """ + chosen = None if parse_threads is not None: if not isinstance(parse_threads, int): - # probably covered with int type of parser arg - raise ValueError(f"Type of --num-threads must be int, but the type is {type(parse_threads)}") + raise ValueError( + f"Type of --num-threads must be int, but the type is {type(parse_threads)}" + ) elif parse_threads < 1: - raise ValueError(f"Number of threads must be a positive integer, {parse_threads} was given") + raise ValueError( + f"Number of threads must be a positive integer, {parse_threads} was given" + ) else: - set_num_threads(parse_threads) - SCRAIBE_NUM_THREADS = parse_threads + chosen = parse_threads elif yaml_threads is not None: if not isinstance(yaml_threads, int): - raise ValueError(f"Type of num_threads must be int, but the type is {type(yaml_threads)}") + raise ValueError( + f"Type of num_threads must be int, but the type is {type(yaml_threads)}" + ) elif yaml_threads < 1: - raise ValueError(f"Number of threads must be a positive integer, {yaml_threads} was given") + raise ValueError( + f"Number of threads must be a positive integer, {yaml_threads} was given" + ) else: - set_num_threads(yaml_threads) - SCRAIBE_NUM_THREADS = yaml_threads + chosen = yaml_threads + + if chosen is not None: + os.environ["OMP_NUM_THREADS"] = str(chosen) + os.environ["MKL_NUM_THREADS"] = str(chosen) + class ParseKwargs(Action): """ @@ -81,7 +56,7 @@ class ParseKwargs(Action): def __call__(self, parser, namespace, values, option_string=None): setattr(namespace, self.dest, dict()) for value in values: - key, value = value.split('=') + key, value = value.split("=") try: value = literal_eval(value) except: diff --git a/scraibe/summarizer.py b/scraibe/summarizer.py new file mode 100644 index 0000000..cd52cb3 --- /dev/null +++ b/scraibe/summarizer.py @@ -0,0 +1,212 @@ +""" +Summarizer Module +----------------- + +Provides a client to summarize long transcripts via an LLM endpoint. + +Behavior: +- Chunks transcript into 10,240-character segments. +- Generates a summary for each chunk. +- Combines all chunk summaries and produces a final, detailed summary. + +Environment Variables: +- SUMMARIZER_API_URL: (required) Base URL of the LLM API (e.g., http://localhost:8080) +- SUMMARIZER_API_KEY: (optional) API key, if required +- SUMMARIZER_MODEL: (optional) Model name (e.g., llama-3.1-8b-instruct) +""" + +import os +import json +from typing import Optional + +import httpx + + +class SummarizerError(Exception): + """Raised when the summarization API call fails.""" + pass + + +class SummarizerClient: + """ + HTTP client for an OpenAI-compatible chat completions endpoint. + Used to summarize long transcripts in chunks. + """ + + CHUNK_SIZE = 10_240 # characters per chunk + + def __init__( + self, + api_url: Optional[str] = None, + api_key: Optional[str] = None, + model: Optional[str] = None, + timeout: float = 600.0, + ): + self.api_url = (api_url or os.getenv("SUMMARIZER_API_URL")).strip().rstrip("/") + self.api_key = api_key or os.getenv("SUMMARIZER_API_KEY") or None + self.model = model or os.getenv("SUMMARIZER_MODEL") or "llama-3.1-8b-instruct" + self.timeout = timeout + + if not self.api_url: + raise SummarizerError( + "SUMMARIZER_API_URL is not set. " + "Provide the summarization LLM URL via environment or constructor." + ) + + self._client = httpx.Client( + base_url=self.api_url, + timeout=self.timeout, + follow_redirects=True, + ) + + def close(self): + self._client.close() + + def __del__(self): + try: + self._client.close() + except Exception: + pass + + def summarize_transcript(self, transcript: str) -> str: + """ + Summarize a (possibly very long) transcript. + + Strategy: + - Split transcript into chunks of CHUNK_SIZE characters. + - Generate a detailed summary for each chunk. + - Combine all chunk summaries and generate a final, concise but thorough summary. + + The final summary should make it clear: + - What was discussed + - Main issues + - Outcomes / decisions + - Next steps / action items + """ + if not transcript.strip(): + return "No transcript provided to summarize." + + # 1) Chunk the transcript + chunks = self._chunk_text(transcript) + + # 2) Summarize each chunk + chunk_summaries = [] + for i, chunk in enumerate(chunks): + summary = self._summarize_chunk(chunk, i, len(chunks)) + chunk_summaries.append(summary) + + # 3) Combine and summarize summaries + combined = "\n\n".join(chunk_summaries) + final_summary = self._summarize_combined(combined) + + return final_summary + + def _chunk_text(self, text: str) -> list[str]: + """Split text into chunks of CHUNK_SIZE characters.""" + chunks = [] + start = 0 + while start < len(text): + end = start + self.CHUNK_SIZE + if end >= len(text): + chunks.append(text[start:]) + break + # Try to break at a reasonable boundary (newline or space) + break_pos = text.rfind("\n", start, end) + if break_pos == -1: + break_pos = text.rfind(" ", start, end) + if break_pos == -1 or break_pos <= start: + break_pos = end + chunks.append(text[start:break_pos].strip()) + start = break_pos + return chunks + + def _summarize_chunk(self, chunk: str, index: int, total: int) -> str: + system_prompt = ( + "You are an expert legal and business meeting summarizer. " + "You will receive a segment of a longer transcript. " + "Provide a detailed, structured summary of this segment, focusing on: " + "- Topics discussed\n" + "- Key points and arguments\n" + "- Decisions and agreements\n" + "- Action items and responsibilities\n" + "- Any risks, conflicts, or open issues\n\n" + "Be concise but complete. Use bullet points when helpful. " + "Do not add information that is not present in the transcript." + ) + + user_prompt = ( + f"This is segment {index + 1} of {total} from a longer conversation.\n\n" + f"{chunk}" + ) + + return self._chat_completion(system_prompt, user_prompt) + + def _summarize_combined(self, combined_summaries: str) -> str: + system_prompt = ( + "You are an expert legal and business meeting summarizer. " + "You will receive several intermediate summaries of a longer conversation. " + "Produce a single, comprehensive summary that makes it clear: " + "- The overall purpose and context of the discussion\n" + "- The main issues and topics addressed\n" + "- Key arguments and positions (briefly)\n" + "- Decisions and outcomes\n" + "- Action items, responsibilities, and next steps\n" + "- Any unresolved issues or risks\n\n" + "The summary should be detailed enough that a reader who was not present " + "can understand what happened and what is expected going forward. " + "Use clear, concise language and bullet points where appropriate." + ) + + user_prompt = ( + "Here are the intermediate summaries from different parts of the same conversation:\n\n" + f"{combined_summaries}" + ) + + return self._chat_completion(system_prompt, user_prompt) + + def _chat_completion(self, system_prompt: str, user_prompt: str) -> str: + """ + Call OpenAI-compatible /v1/chat/completions endpoint. + """ + payload = { + "model": self.model, + "messages": [ + {"role": "system", "content": system_prompt}, + {"role": "user", "content": user_prompt}, + ], + "temperature": 0.3, + } + + headers = { + "Content-Type": "application/json", + } + if self.api_key: + headers["Authorization"] = f"Bearer {self.api_key}" + + resp = self._client.post( + "/v1/chat/completions", + json=payload, + headers=headers, + ) + + if resp.status_code >= 400: + raise SummarizerError( + f"Summarizer API error {resp.status_code}: {resp.text}" + ) + + try: + data = resp.json() + except json.JSONDecodeError: + raise SummarizerError( + "Failed to parse summarizer response as JSON." + ) + + # Extract assistant message + try: + content = data["choices"][0]["message"]["content"] + return content.strip() + except (KeyError, IndexError, TypeError): + raise SummarizerError( + "Unexpected summarizer response format: " + f"{json.dumps(data, indent=2)}" + )