Compare commits
50 Commits
| Author | SHA1 | Date | |
|---|---|---|---|
| 574124558b | |||
| 46d119b63b | |||
| d00ec2d44f | |||
| de883bc062 | |||
| 663675c7b2 | |||
| e5d189fdd0 | |||
| 9528468ebb | |||
| 3fe13803b9 | |||
| 81fefd5568 | |||
| fec46aa563 | |||
| 3851311ffc | |||
| f0989a574b | |||
| de9071762e | |||
| 08f14883e2 | |||
| 101e913f84 | |||
| e7c1a5a2b0 | |||
| af99a655e5 | |||
| 44ff678e06 | |||
| 8813662d4d | |||
| 6fadf3d851 | |||
| ce2f3ebde2 | |||
| fa1dad69d1 | |||
| 575a8de48d | |||
| a4b8546033 | |||
| 81fb9af461 | |||
| 6326d0f156 | |||
| 5f6f681edf | |||
| 2adbfaef51 | |||
| 9df05033da | |||
| df9c5109f3 | |||
| ab7b43ac48 | |||
| 929f916077 | |||
| 51bf211d27 | |||
| 5c0386edac | |||
| 18666adda4 | |||
| 9ce47ac4c2 | |||
| 95c145c74a | |||
| 4e7b7e748b | |||
| ae1bae750f | |||
| 885d0c864e | |||
| de9c81b313 | |||
| 5b56b54da2 | |||
| 53e57a06d7 | |||
| 129f0ce390 | |||
| d25fda5802 | |||
| 533b199f4c | |||
| cf63ac8e2e | |||
| 0ddb52cc95 | |||
| f5ef26432b | |||
| ba058c3e02 |
@@ -0,0 +1,95 @@
|
|||||||
|
# This workflow uses actions that are not certified by GitHub.
|
||||||
|
# They are provided by a third-party and are governed by
|
||||||
|
# separate terms of service, privacy policy, and support
|
||||||
|
# documentation.
|
||||||
|
|
||||||
|
# GitHub recommends pinning actions to a commit SHA.
|
||||||
|
# To get a newer version, you will need to update the SHA.
|
||||||
|
# You can also reference a tag or branch, but the action may change without warning.
|
||||||
|
|
||||||
|
name: Publish Docker image
|
||||||
|
|
||||||
|
on:
|
||||||
|
|
||||||
|
push:
|
||||||
|
tags:
|
||||||
|
- v*
|
||||||
|
|
||||||
|
workflow_dispatch:
|
||||||
|
|
||||||
|
env:
|
||||||
|
image: hadr0n/scraibe
|
||||||
|
|
||||||
|
jobs:
|
||||||
|
push_to_registry:
|
||||||
|
name: Push Docker image to Docker Hub
|
||||||
|
runs-on: ubuntu-latest
|
||||||
|
permissions:
|
||||||
|
packages: write
|
||||||
|
contents: read
|
||||||
|
security-events: write
|
||||||
|
steps:
|
||||||
|
- name: Check out the repo
|
||||||
|
uses: actions/checkout@v4
|
||||||
|
with:
|
||||||
|
fetch-tags: true
|
||||||
|
fetch-depth: 0
|
||||||
|
|
||||||
|
- name: Get Version Tag
|
||||||
|
id: version
|
||||||
|
run: |
|
||||||
|
echo "tag=$(git describe --tags --abbrev=0)" >> $GITHUB_OUTPUT
|
||||||
|
|
||||||
|
- name: Overwrite label tag
|
||||||
|
run: sed -i 's/LABEL version=".*"/LABEL version="'${{ steps.version.outputs.tag }}'"/' Dockerfile
|
||||||
|
|
||||||
|
- name: Test name and tag
|
||||||
|
run: |
|
||||||
|
echo "${{ env.image }}:latest,${{ env.image }}:${{ steps.version.outputs.tag }}"
|
||||||
|
|
||||||
|
- name: Log in to Docker Hub
|
||||||
|
uses: docker/login-action@v3
|
||||||
|
with:
|
||||||
|
username: ${{ secrets.DOCKERHUB_USERNAME }}
|
||||||
|
password: ${{ secrets.DOCKERHUB_TOKEN }}
|
||||||
|
|
||||||
|
- name: Build and push Docker image
|
||||||
|
id: push
|
||||||
|
uses: docker/build-push-action@v5
|
||||||
|
with:
|
||||||
|
context: .
|
||||||
|
file: ./Dockerfile
|
||||||
|
push: true
|
||||||
|
tags: "${{ env.image }}:latest,${{ env.image }}:${{ steps.version.outputs.tag }}"
|
||||||
|
|
||||||
|
- name: SBOM Generation
|
||||||
|
uses: anchore/sbom-action@v0
|
||||||
|
with:
|
||||||
|
image: ${{ env.image }}:latest
|
||||||
|
|
||||||
|
- name: Scan image
|
||||||
|
id: scan
|
||||||
|
uses: anchore/scan-action@v3
|
||||||
|
with:
|
||||||
|
image: ${{ env.image }}:latest
|
||||||
|
fail-build: false
|
||||||
|
|
||||||
|
- name: upload Anchore scan SARIF report
|
||||||
|
uses: github/codeql-action/upload-sarif@v3
|
||||||
|
with:
|
||||||
|
sarif_file: ${{ steps.scan.outputs.sarif }}
|
||||||
|
|
||||||
|
# - name: Inspect action SARIF report
|
||||||
|
# run: cat ${{ steps.scan.outputs.sarif }}
|
||||||
|
|
||||||
|
- uses: actions/upload-artifact@v4
|
||||||
|
with:
|
||||||
|
name: SARIF report
|
||||||
|
path: ${{ steps.scan.outputs.sarif }}
|
||||||
|
|
||||||
|
# - name: Generate artifact attestation
|
||||||
|
# uses: actions/attest-build-provenance@v1
|
||||||
|
# with:
|
||||||
|
# subject-name: ${{ env.REGISTRY }}/${{ env.IMAGE_NAME}}
|
||||||
|
# subject-digest: ${{ steps.push.outputs.digest }}
|
||||||
|
# push-to-registry: false
|
||||||
@@ -1,18 +1,14 @@
|
|||||||
name: Publish Python 🐍 distribution 📦 to PyPI and TestPyPI
|
name: Publish Python 🐍 distribution 📦 to PyPI and TestPyPI
|
||||||
|
|
||||||
on:
|
on:
|
||||||
pull_request_target:
|
|
||||||
branches:
|
|
||||||
- develop
|
|
||||||
types:
|
|
||||||
- closed
|
|
||||||
paths:
|
|
||||||
- scraibe/**
|
|
||||||
- pyproject.toml
|
|
||||||
|
|
||||||
push:
|
push:
|
||||||
tags:
|
tags:
|
||||||
- 'v*.*.*'
|
- 'v*.*.*'
|
||||||
|
branches:
|
||||||
|
- "develop"
|
||||||
|
paths:
|
||||||
|
- "scraibe/**"
|
||||||
|
- "pyproject.toml"
|
||||||
|
|
||||||
workflow_dispatch:
|
workflow_dispatch:
|
||||||
inputs:
|
inputs:
|
||||||
@@ -27,13 +23,7 @@ on:
|
|||||||
|
|
||||||
jobs:
|
jobs:
|
||||||
Build-and-publish-to-Test-PyPI:
|
Build-and-publish-to-Test-PyPI:
|
||||||
if: |
|
if: github.event_name != 'workflow_dispatch' || github.event.inputs.test == 'true'
|
||||||
(github.event_name == 'workflow_dispatch' &&
|
|
||||||
github.event.inputs.test == 'true') ||
|
|
||||||
(github.event_name == 'pull_request_target' &&
|
|
||||||
github.event.pull_request.merged &&
|
|
||||||
contains(github.event.pull_request.labels.*.name, 'release')) ||
|
|
||||||
(github.event_name == 'push' && startsWith(github.ref, 'refs/tags/'))
|
|
||||||
runs-on: ubuntu-latest
|
runs-on: ubuntu-latest
|
||||||
steps:
|
steps:
|
||||||
- uses: actions/checkout@v4
|
- uses: actions/checkout@v4
|
||||||
@@ -74,26 +64,14 @@ jobs:
|
|||||||
if: |
|
if: |
|
||||||
always() &&
|
always() &&
|
||||||
(( needs.Build-and-publish-to-Test-PyPI.result != 'failure' &&
|
(( needs.Build-and-publish-to-Test-PyPI.result != 'failure' &&
|
||||||
needs.Test-PyPi-install.result != 'failure' ) &&
|
needs.Test-PyPi-install.result != 'failure' ) ||
|
||||||
((github.event_name == 'workflow_dispatch' &&
|
((github.event_name == 'workflow_dispatch' &&
|
||||||
github.event.inputs.publish_to_pypi == 'true') ||
|
github.event.inputs.publish_to_pypi == 'true')))
|
||||||
(github.event_name == 'pull_request_target' &&
|
|
||||||
github.event.pull_request.merged &&
|
|
||||||
contains(github.event.pull_request.labels.*.name, 'release')) ||
|
|
||||||
(github.event_name == 'push' && startsWith(github.ref, 'refs/tags/'))))
|
|
||||||
steps:
|
steps:
|
||||||
- name: Checkout Repository Tags
|
|
||||||
uses: actions/checkout@v4
|
|
||||||
if: github.ref == 'refs/heads/main'
|
|
||||||
with:
|
|
||||||
fetch-depth: '0'
|
|
||||||
branch: 'main'
|
|
||||||
- name: Checkout Repository (Develop)
|
- name: Checkout Repository (Develop)
|
||||||
uses: actions/checkout@v4
|
uses: actions/checkout@v4
|
||||||
if: github.ref == 'refs/heads/develop'
|
|
||||||
with:
|
with:
|
||||||
fetch-depth: '0'
|
fetch-depth: '0'
|
||||||
branch: 'develop'
|
|
||||||
- name: Set up Poetry 📦
|
- name: Set up Poetry 📦
|
||||||
uses: JRubics/poetry-publish@v1.16
|
uses: JRubics/poetry-publish@v1.16
|
||||||
with:
|
with:
|
||||||
|
|||||||
+30
-33
@@ -1,46 +1,43 @@
|
|||||||
#pytorch Image
|
# Lightweight Python base image (no GPU/PyTorch needed)
|
||||||
FROM pytorch/pytorch:1.11.0-cuda11.3-cudnn8-runtime
|
FROM python:3.11-slim
|
||||||
|
|
||||||
# Labels
|
# Labels
|
||||||
|
|
||||||
LABEL maintainer="Jacob Schmieder"
|
LABEL maintainer="Jacob Schmieder"
|
||||||
LABEL email="Jacob.Schmieder@dbfz.de"
|
LABEL email="Jacob.Schmieder@dbfz.de"
|
||||||
LABEL version="0.1.1.dev"
|
LABEL version="0.1.1.dev"
|
||||||
LABEL description="Scraibe is a tool for automatic speech recognition and speaker diarization. \
|
LABEL description="Scraibe: LocalAI-backed transcription and diarization client with summarization. \
|
||||||
It is based on the Hugging Face Transformers library and the Pyannote library. \
|
Sends audio to a LocalAI server running vibevoice.cpp and uses a second LLM for summarization."
|
||||||
It is designed to be used with the Whisper model, a lightweight model for automatic \
|
|
||||||
speech recognition and speaker diarization."
|
|
||||||
LABEL url="https://github.com/JSchmie/ScrAIbe"
|
LABEL url="https://github.com/JSchmie/ScrAIbe"
|
||||||
|
|
||||||
# Install dependencies
|
# Install system dependencies (ffmpeg required)
|
||||||
|
RUN apt update -y && \
|
||||||
|
apt install -y --no-install-recommends ffmpeg && \
|
||||||
|
apt clean && \
|
||||||
|
rm -rf /var/lib/apt/lists/* /tmp/* /var/tmp/*
|
||||||
|
|
||||||
|
# Working directory
|
||||||
WORKDIR /app
|
WORKDIR /app
|
||||||
ARG model_name=medium
|
|
||||||
#Enviorment Dependncies
|
# Environment variables for LocalAI (transcription/diarization)
|
||||||
ENV TRANSFORMERS_CACHE /app/models
|
# Set these via docker run -e or docker-compose
|
||||||
ENV HF_HOME /app/models
|
ENV LOCALAI_API_URL=http://localhost:8080
|
||||||
ENV AUTOT_CACHE /app/models
|
ENV LOCALAI_API_KEY=
|
||||||
ENV PYANNOTE_CACHE /app/models/pyannote
|
ENV LOCALAI_MODEL=vibevoice-diarize
|
||||||
#Copy all necessary files
|
|
||||||
|
# Environment variables for Summarizer LLM
|
||||||
|
ENV SUMMARIZER_API_URL=http://localhost:8080
|
||||||
|
ENV SUMMARIZER_API_KEY=
|
||||||
|
ENV SUMMARIZER_MODEL=llama-3.1-8b-instruct
|
||||||
|
|
||||||
|
# Copy and install Python dependencies
|
||||||
COPY requirements.txt /app/requirements.txt
|
COPY requirements.txt /app/requirements.txt
|
||||||
COPY README.md /app/README.md
|
RUN pip install --no-cache-dir -r requirements.txt
|
||||||
COPY models /app/models
|
|
||||||
|
# Copy application code
|
||||||
COPY scraibe /app/scraibe
|
COPY scraibe /app/scraibe
|
||||||
COPY setup.py /app/setup.py
|
|
||||||
|
|
||||||
#Installing all necessary Dependencies and Running the Application with a personalised Hugging-Face-Token
|
# Expose port (if UI is served)
|
||||||
RUN apt update && apt-get install -y libsm6 libxrender1 libfontconfig1
|
|
||||||
RUN conda update --all
|
|
||||||
|
|
||||||
RUN conda install pip
|
|
||||||
RUN conda install -y ffmpeg
|
|
||||||
RUN conda install -c conda-forge libsndfile
|
|
||||||
RUN pip install torchaudio==0.11.0+cu113 -f https://download.pytorch.org/whl/torch_stable.html
|
|
||||||
RUN pip install -r requirements.txt
|
|
||||||
RUN pip install markupsafe==2.0.1 --force-reinstall
|
|
||||||
|
|
||||||
RUN python3 -m 'scraibe.cli' --whisper-model-name $model_name
|
|
||||||
# Expose port
|
|
||||||
EXPOSE 7860
|
EXPOSE 7860
|
||||||
# Run the application
|
|
||||||
|
|
||||||
ENTRYPOINT ["python3", "-m", "scraibe.cli" ,"--whisper-model-name", "$model_name"]
|
# Run the application
|
||||||
|
ENTRYPOINT ["python3", "-m", "scraibe.cli"]
|
||||||
|
|||||||
-256
@@ -1,256 +0,0 @@
|
|||||||
channels:
|
|
||||||
- pytorch
|
|
||||||
- defaults
|
|
||||||
dependencies:
|
|
||||||
- _libgcc_mutex=0.1=main
|
|
||||||
- _openmp_mutex=5.1=1_gnu
|
|
||||||
- blas=1.0=mkl
|
|
||||||
- brotlipy=0.7.0=py39h27cfd23_1003
|
|
||||||
- bzip2=1.0.8=h7b6447c_0
|
|
||||||
- ca-certificates=2023.05.30=h06a4308_0
|
|
||||||
- certifi=2023.5.7=py39h06a4308_0
|
|
||||||
- cffi=1.15.1=py39h5eee18b_3
|
|
||||||
- cryptography=39.0.1=py39h9ce1e76_2
|
|
||||||
- cudatoolkit=11.3.1=h2bc3f7f_2
|
|
||||||
- ffmpeg=4.2.2=h20bf706_0
|
|
||||||
- flit-core=3.8.0=py39h06a4308_0
|
|
||||||
- freetype=2.12.1=h4a9f257_0
|
|
||||||
- giflib=5.2.1=h5eee18b_3
|
|
||||||
- gmp=6.2.1=h295c915_3
|
|
||||||
- gnutls=3.6.15=he1e5248_0
|
|
||||||
- idna=3.4=py39h06a4308_0
|
|
||||||
- intel-openmp=2021.4.0=h06a4308_3561
|
|
||||||
- jpeg=9e=h5eee18b_1
|
|
||||||
- lame=3.100=h7b6447c_0
|
|
||||||
- lcms2=2.12=h3be6417_0
|
|
||||||
- ld_impl_linux-64=2.38=h1181459_1
|
|
||||||
- lerc=3.0=h295c915_0
|
|
||||||
- libdeflate=1.17=h5eee18b_0
|
|
||||||
- libffi=3.4.2=h6a678d5_6
|
|
||||||
- libgcc-ng=11.2.0=h1234567_1
|
|
||||||
- libgomp=11.2.0=h1234567_1
|
|
||||||
- libidn2=2.3.2=h7f8727e_0
|
|
||||||
- libopus=1.3.1=h7b6447c_0
|
|
||||||
- libpng=1.6.39=h5eee18b_0
|
|
||||||
- libstdcxx-ng=11.2.0=h1234567_1
|
|
||||||
- libtasn1=4.16.0=h27cfd23_0
|
|
||||||
- libtiff=4.5.0=h6a678d5_2
|
|
||||||
- libunistring=0.9.10=h27cfd23_0
|
|
||||||
- libuv=1.44.2=h5eee18b_0
|
|
||||||
- libvpx=1.7.0=h439df22_0
|
|
||||||
- libwebp=1.2.4=h11a3e52_1
|
|
||||||
- libwebp-base=1.2.4=h5eee18b_1
|
|
||||||
- lz4-c=1.9.4=h6a678d5_0
|
|
||||||
- mkl=2021.4.0=h06a4308_640
|
|
||||||
- mkl-service=2.4.0=py39h7f8727e_0
|
|
||||||
- mkl_fft=1.3.1=py39hd3c417c_0
|
|
||||||
- mkl_random=1.2.2=py39h51133e4_0
|
|
||||||
- ncurses=6.4=h6a678d5_0
|
|
||||||
- nettle=3.7.3=hbbd107a_1
|
|
||||||
- numpy=1.23.5=py39h14f4228_0
|
|
||||||
- numpy-base=1.23.5=py39h31eccc5_0
|
|
||||||
- openh264=2.1.1=h4ff587b_0
|
|
||||||
- openssl=3.0.9=h7f8727e_0
|
|
||||||
- pillow=9.4.0=py39h6a678d5_0
|
|
||||||
- pip=23.0.1=py39h06a4308_0
|
|
||||||
- pycparser=2.21=pyhd3eb1b0_0
|
|
||||||
- pyopenssl=23.0.0=py39h06a4308_0
|
|
||||||
- pysocks=1.7.1=py39h06a4308_0
|
|
||||||
- python=3.9.16=h955ad1f_3
|
|
||||||
- pytorch=1.11.0=py3.9_cuda11.3_cudnn8.2.0_0
|
|
||||||
- pytorch-mutex=1.0=cuda
|
|
||||||
- readline=8.2=h5eee18b_0
|
|
||||||
- requests=2.28.1=py39h06a4308_1
|
|
||||||
- setuptools=65.6.3=py39h06a4308_0
|
|
||||||
- six=1.16.0=pyhd3eb1b0_1
|
|
||||||
- sqlite=3.41.2=h5eee18b_0
|
|
||||||
- tk=8.6.12=h1ccaba5_0
|
|
||||||
- torchaudio=0.11.0=py39_cu113
|
|
||||||
- torchvision=0.12.0=py39_cu113
|
|
||||||
- tzdata=2023c=h04d1e81_0
|
|
||||||
- wheel=0.38.4=py39h06a4308_0
|
|
||||||
- x264=1!157.20191217=h7b6447c_0
|
|
||||||
- xz=5.4.2=h5eee18b_0
|
|
||||||
- zlib=1.2.13=h5eee18b_0
|
|
||||||
- zstd=1.5.4=hc292b87_0
|
|
||||||
- pip:
|
|
||||||
- absl-py==1.3.0
|
|
||||||
- aiofiles==23.1.0
|
|
||||||
- aiohttp==3.8.3
|
|
||||||
- aiosignal==1.3.1
|
|
||||||
- alembic==1.9.1
|
|
||||||
- altair==5.0.1
|
|
||||||
- annotated-types==0.5.0
|
|
||||||
- ansi2html==1.8.0
|
|
||||||
- antlr4-python3-runtime==4.9.3
|
|
||||||
- anyio==3.7.1
|
|
||||||
- appdirs==1.4.4
|
|
||||||
- asteroid-filterbanks==0.4.0
|
|
||||||
- async-timeout==4.0.2
|
|
||||||
- attrs==22.2.0
|
|
||||||
- audioread==3.0.0
|
|
||||||
- autopage==0.5.1
|
|
||||||
- backports-cached-property==1.0.2
|
|
||||||
- cachetools==5.2.0
|
|
||||||
- charset-normalizer==2.1.1
|
|
||||||
- click==8.1.3
|
|
||||||
- cliff==4.1.0
|
|
||||||
- cmaes==0.9.0
|
|
||||||
- cmake==3.26.4
|
|
||||||
- cmd2==2.4.2
|
|
||||||
- colorama==0.4.6
|
|
||||||
- colorlog==6.7.0
|
|
||||||
- commonmark==0.9.1
|
|
||||||
- contourpy==1.0.6
|
|
||||||
- cycler==0.11.0
|
|
||||||
- dash==2.12.1
|
|
||||||
- dash-core-components==2.0.0
|
|
||||||
- dash-html-components==2.0.0
|
|
||||||
- dash-table==5.0.0
|
|
||||||
- decorator==4.4.2
|
|
||||||
- docopt==0.6.2
|
|
||||||
- einops==0.3.2
|
|
||||||
- exceptiongroup==1.1.1
|
|
||||||
- fastapi==0.100.0
|
|
||||||
- ffmpeg-python==0.2.0
|
|
||||||
- ffmpy==0.3.0
|
|
||||||
- filelock==3.8.0
|
|
||||||
- flask==2.2.5
|
|
||||||
- fonttools==4.38.0
|
|
||||||
- frozenlist==1.3.3
|
|
||||||
- fsspec==2022.11.0
|
|
||||||
- future==0.18.2
|
|
||||||
- google-auth==2.15.0
|
|
||||||
- google-auth-oauthlib==0.4.6
|
|
||||||
- gradio==3.36.1
|
|
||||||
- gradio-client==0.2.7
|
|
||||||
- greenlet==2.0.1
|
|
||||||
- grpcio==1.51.1
|
|
||||||
- h11==0.14.0
|
|
||||||
- hmmlearn==0.2.8
|
|
||||||
- httpcore==0.17.3
|
|
||||||
- httpx==0.24.1
|
|
||||||
- huggingface-hub==0.16.4
|
|
||||||
- humanize==4.7.0
|
|
||||||
- hyperpyyaml==1.1.0
|
|
||||||
- imageio==2.23.0
|
|
||||||
- imageio-ffmpeg==0.4.7
|
|
||||||
- importlib-metadata==4.13.0
|
|
||||||
- importlib-resources==5.12.0
|
|
||||||
- iniconfig==2.0.0
|
|
||||||
- itsdangerous==2.1.2
|
|
||||||
- jinja2==3.1.2
|
|
||||||
- joblib==1.2.0
|
|
||||||
- jsonschema==4.18.0
|
|
||||||
- jsonschema-specifications==2023.6.1
|
|
||||||
- julius==0.2.7
|
|
||||||
- kiwisolver==1.4.4
|
|
||||||
- librosa==0.9.2
|
|
||||||
- linkify-it-py==2.0.2
|
|
||||||
- lit==16.0.5.post0
|
|
||||||
- llvmlite==0.39.1
|
|
||||||
- mako==1.2.4
|
|
||||||
- markdown==3.4.1
|
|
||||||
- markdown-it-py==2.2.0
|
|
||||||
- markupsafe==2.1.1
|
|
||||||
- matplotlib==3.7.1
|
|
||||||
- mdit-py-plugins==0.3.3
|
|
||||||
- mdurl==0.1.2
|
|
||||||
- more-itertools==9.0.0
|
|
||||||
- moviepy==1.0.3
|
|
||||||
- mpmath==1.2.1
|
|
||||||
- multidict==6.0.4
|
|
||||||
- nest-asyncio==1.5.7
|
|
||||||
- networkx==2.8.8
|
|
||||||
- numba==0.56.4
|
|
||||||
- oauthlib==3.2.2
|
|
||||||
- omegaconf==2.3.0
|
|
||||||
- openai-whisper==20230314
|
|
||||||
- optuna==3.0.5
|
|
||||||
- orjson==3.9.2
|
|
||||||
- packaging==21.3
|
|
||||||
- pandas==1.5.2
|
|
||||||
- pbr==5.11.0
|
|
||||||
- plotly==5.15.0
|
|
||||||
- pluggy==1.0.0
|
|
||||||
- pooch==1.6.0
|
|
||||||
- prettytable==3.5.0
|
|
||||||
- primepy==1.3
|
|
||||||
- proglog==0.1.10
|
|
||||||
- protobuf==3.20.1
|
|
||||||
- pyannote-audio==2.1.1
|
|
||||||
- pyannote-core==4.5
|
|
||||||
- pyannote-database==4.1.3
|
|
||||||
- pyannote-metrics==3.2.1
|
|
||||||
- pyannote-pipeline==2.3
|
|
||||||
- pyasn1==0.4.8
|
|
||||||
- pyasn1-modules==0.2.8
|
|
||||||
- pydantic==2.0.2
|
|
||||||
- pydantic-core==2.1.2
|
|
||||||
- pydeprecate==0.3.2
|
|
||||||
- pydub==0.25.1
|
|
||||||
- pygments==2.13.0
|
|
||||||
- pyparsing==3.0.9
|
|
||||||
- pyperclip==1.8.2
|
|
||||||
- pytest==7.3.1
|
|
||||||
- python-dateutil==2.8.2
|
|
||||||
- python-multipart==0.0.6
|
|
||||||
- pytorch-lightning==1.6.5
|
|
||||||
- pytorch-metric-learning==1.6.3
|
|
||||||
- pytz==2022.7
|
|
||||||
- pyyaml==6.0
|
|
||||||
- qtfaststart==1.8
|
|
||||||
- referencing==0.29.1
|
|
||||||
- regex==2022.10.31
|
|
||||||
- requests-oauthlib==1.3.1
|
|
||||||
- resampy==0.4.2
|
|
||||||
- retrying==1.3.4
|
|
||||||
- rich==12.6.0
|
|
||||||
- rpds-py==0.8.10
|
|
||||||
- rsa==4.9
|
|
||||||
- ruamel-yaml==0.17.21
|
|
||||||
- ruamel-yaml-clib==0.2.7
|
|
||||||
- ruff==0.0.272
|
|
||||||
- scikit-learn==1.2.0
|
|
||||||
- scipy==1.8.1
|
|
||||||
- semantic-version==2.10.0
|
|
||||||
- semver==2.13.0
|
|
||||||
- sentencepiece==0.1.97
|
|
||||||
- setuptools-rust==1.5.2
|
|
||||||
- shellingham==1.5.0
|
|
||||||
- simplejson==3.18.0
|
|
||||||
- singledispatchmethod==1.0
|
|
||||||
- sniffio==1.3.0
|
|
||||||
- sortedcontainers==2.4.0
|
|
||||||
- soundfile==0.10.3.post1
|
|
||||||
- speechbrain==0.5.14
|
|
||||||
- sqlalchemy==1.4.45
|
|
||||||
- starlette==0.27.0
|
|
||||||
- stevedore==4.1.1
|
|
||||||
- sympy==1.11.1
|
|
||||||
- tabulate==0.9.0
|
|
||||||
- tenacity==8.2.2
|
|
||||||
- tensorboard==2.11.0
|
|
||||||
- tensorboard-data-server==0.6.1
|
|
||||||
- tensorboard-plugin-wit==1.8.1
|
|
||||||
- threadpoolctl==3.1.0
|
|
||||||
- tiktoken==0.3.1
|
|
||||||
- tokenizers==0.13.2
|
|
||||||
- tomli==2.0.1
|
|
||||||
- toolz==0.12.0
|
|
||||||
- torch-audiomentations==0.11.0
|
|
||||||
- torch-pitch-shift==1.2.2
|
|
||||||
- torchmetrics==0.11.0
|
|
||||||
- tqdm==4.64.1
|
|
||||||
- transformers==4.24.0
|
|
||||||
- triton==2.0.0
|
|
||||||
- typer==0.7.0
|
|
||||||
- typing-extensions==4.7.1
|
|
||||||
- uc-micro-py==1.0.2
|
|
||||||
- urllib3==1.26.12
|
|
||||||
- uvicorn==0.22.0
|
|
||||||
- wcwidth==0.2.5
|
|
||||||
- websockets==11.0.3
|
|
||||||
- werkzeug==2.2.2
|
|
||||||
- yarl==1.8.2
|
|
||||||
- zipp==3.11.0
|
|
||||||
+23
-19
@@ -5,38 +5,42 @@ build-backend = "poetry_dynamic_versioning.backend"
|
|||||||
[tool.poetry]
|
[tool.poetry]
|
||||||
name = "scraibe"
|
name = "scraibe"
|
||||||
version = "0.0.0"
|
version = "0.0.0"
|
||||||
description = "Transcription tool for audio files based on Whisper and Pyannote"
|
description = "LocalAI-backed transcription and diarization client using vibevoice.cpp"
|
||||||
authors = ["Schmieder, Jacob <jacob.schmieder@dbfz.de>"]
|
authors = ["Schmieder, Jacob <jacob.schmieder@dbfz.de>"]
|
||||||
license = "GPL-3.0-or-later"
|
license = "GPL-3.0-or-later"
|
||||||
readme = ["README.md", "LICENSE"]
|
readme = ["README.md", "LICENSE"]
|
||||||
repository = "https://github.com/JSchmie/ScAIbe"
|
repository = "https://github.com/JSchmie/ScAIbe"
|
||||||
documentation = "https://jschmie.github.io/ScrAIbe/"
|
documentation = "https://jschmie.github.io/ScrAIbe/"
|
||||||
keywords = ["transcription", "audio", "whisper", "pyannote", "speech-to-text", "speech-recognition"]
|
keywords = [
|
||||||
|
"transcription",
|
||||||
|
"audio",
|
||||||
|
"diarization",
|
||||||
|
"localai",
|
||||||
|
"vibevoice",
|
||||||
|
"speech-to-text",
|
||||||
|
]
|
||||||
classifiers = [
|
classifiers = [
|
||||||
'Development Status :: 4 - Beta',
|
"Development Status :: 4 - Beta",
|
||||||
'Intended Audience :: Developers',
|
"Intended Audience :: Developers",
|
||||||
'License :: OSI Approved :: GNU General Public License v3 (GPLv3)',
|
"License :: OSI Approved :: GNU General Public License v3 (GPLv3)",
|
||||||
'Programming Language :: Python :: 3.8',
|
"Programming Language :: Python :: 3.9",
|
||||||
'Programming Language :: Python :: 3.9',
|
"Programming Language :: Python :: 3.10",
|
||||||
'Programming Language :: Python :: 3.10',
|
"Programming Language :: Python :: 3.11",
|
||||||
'Programming Language :: Python :: 3.11',
|
"Programming Language :: Python :: 3.12",
|
||||||
'Environment :: GPU :: NVIDIA CUDA :: 12 :: 12.1',
|
"Topic :: Scientific/Engineering :: Artificial Intelligence",
|
||||||
'Topic :: Scientific/Engineering :: Artificial Intelligence'
|
|
||||||
]
|
]
|
||||||
packages = [{include = "scraibe"}]
|
packages = [{include = "scraibe"}]
|
||||||
exclude = [
|
exclude = [
|
||||||
"__pycache__",
|
"__pycache__",
|
||||||
"*.pyc",
|
"*.pyc",
|
||||||
"test"
|
"test",
|
||||||
]
|
]
|
||||||
|
|
||||||
[tool.poetry.dependencies]
|
[tool.poetry.dependencies]
|
||||||
python = "^3.9"
|
python = "^3.9"
|
||||||
tqdm = "^4.66.4"
|
tqdm = "^4.66.5"
|
||||||
numpy = "^1.26.4"
|
numpy = "^1.26.4"
|
||||||
openai-whisper = ">=20231117,<20240931"
|
httpx = ">=0.28.0"
|
||||||
whisperx = "^3.1.3"
|
|
||||||
"pyannote.audio" = "^3.1.1"
|
|
||||||
torch = "^2.3.0"
|
|
||||||
|
|
||||||
[tool.poetry.group.dev.dependencies]
|
[tool.poetry.group.dev.dependencies]
|
||||||
pytest = "^8.1.1"
|
pytest = "^8.1.1"
|
||||||
@@ -57,7 +61,7 @@ format-jinja = """
|
|||||||
|
|
||||||
[tool.poetry.group.docs.dependencies]
|
[tool.poetry.group.docs.dependencies]
|
||||||
sphinx = "^7.3.7"
|
sphinx = "^7.3.7"
|
||||||
sphinx-rtd-theme = "^2.0.0"
|
sphinx-rtd-theme = ">=2,<4"
|
||||||
markdown-it-py = {version = "~3.0.0", extras = ["plugins"]}
|
markdown-it-py = {version = "~3.0.0", extras = ["plugins"]}
|
||||||
myst-parser = "^3.0.1"
|
myst-parser = "^3.0.1"
|
||||||
mdit-py-plugins = "^0.4.1"
|
mdit-py-plugins = "^0.4.1"
|
||||||
@@ -69,5 +73,5 @@ scraibe = "scraibe.cli:cli"
|
|||||||
app = ["scraibe-webui"]
|
app = ["scraibe-webui"]
|
||||||
|
|
||||||
[tool.ruff.lint.extend-per-file-ignores]
|
[tool.ruff.lint.extend-per-file-ignores]
|
||||||
"__init__.py" = ["E402","F403",'F401']
|
"__init__.py" = ["E402", "F403", "F401"]
|
||||||
"scraibe/misc.py" = ["E722"]
|
"scraibe/misc.py" = ["E722"]
|
||||||
|
|||||||
+2
-13
@@ -1,14 +1,3 @@
|
|||||||
tqdm>=4.65.0
|
tqdm>=4.66.5
|
||||||
numpy>=1.26.4
|
numpy>=1.26.4
|
||||||
|
httpx>=0.28.0
|
||||||
openai-whisper==20231117
|
|
||||||
whisperx~=3.1.3
|
|
||||||
|
|
||||||
pyannote.audio~=3.1.1
|
|
||||||
pyannote.core~=5.0.0
|
|
||||||
pyannote.database~=5.0.1
|
|
||||||
pyannote.metrics~=3.2.1
|
|
||||||
pyannote.pipeline~=3.0.1
|
|
||||||
|
|
||||||
torch>=2.0.0
|
|
||||||
|
|
||||||
|
|||||||
+7
-8
@@ -1,11 +1,10 @@
|
|||||||
from .autotranscript import *
|
from .autotranscript import Scraibe
|
||||||
from .transcriber import *
|
from .localai_client import LocalAIClient, LocalAIError
|
||||||
from .audio import *
|
from .summarizer import SummarizerClient, SummarizerError
|
||||||
from .transcript_exporter import *
|
from .audio import AudioProcessor
|
||||||
from .diarisation import *
|
from .transcript_exporter import Transcript
|
||||||
|
from .misc import set_threads, ParseKwargs
|
||||||
|
|
||||||
from .misc import *
|
from .cli import cli
|
||||||
|
|
||||||
from .cli import *
|
|
||||||
|
|
||||||
from ._version import __version__
|
from ._version import __version__
|
||||||
|
|||||||
+33
-77
@@ -2,28 +2,15 @@
|
|||||||
Audio Processor Module
|
Audio Processor Module
|
||||||
=======================
|
=======================
|
||||||
|
|
||||||
This module provides the AudioProcessor class, utilizing PyTorchaudio for handling audio files.
|
Simplified audio processor for ScrAIbe.
|
||||||
It includes functionalities to load, cut, and manage audio waveforms, offering efficient and
|
|
||||||
flexible audio processing.
|
|
||||||
|
|
||||||
Available Classes:
|
Previously this used torch and pyannote-style processing. In the LocalAI-backed
|
||||||
- AudioProcessor: Processes audio waveforms and provides methods for loading,
|
version, we primarily pass files to the API, but we keep a lightweight helper
|
||||||
cutting, and handling audio.
|
for backward compatibility.
|
||||||
|
|
||||||
Usage:
|
|
||||||
from .audio_import AudioProcessor
|
|
||||||
|
|
||||||
processor = AudioProcessor.from_file("path/to/audiofile.wav")
|
|
||||||
cut_waveform = processor.cut(start=1.0, end=5.0)
|
|
||||||
|
|
||||||
Constants:
|
|
||||||
- SAMPLE_RATE (int): Default sample rate for processing.
|
|
||||||
- NORMALIZATION_FACTOR (float): Normalization factor for audio waveform.
|
|
||||||
"""
|
"""
|
||||||
|
|
||||||
from subprocess import CalledProcessError, run
|
from subprocess import CalledProcessError, run
|
||||||
import numpy as np
|
import numpy as np
|
||||||
import torch
|
|
||||||
|
|
||||||
SAMPLE_RATE = 16000
|
SAMPLE_RATE = 16000
|
||||||
NORMALIZATION_FACTOR = 32768.0
|
NORMALIZATION_FACTOR = 32768.0
|
||||||
@@ -31,44 +18,25 @@ NORMALIZATION_FACTOR = 32768.0
|
|||||||
|
|
||||||
class AudioProcessor:
|
class AudioProcessor:
|
||||||
"""
|
"""
|
||||||
Audio Processor class that leverages PyTorchaudio to provide functionalities
|
Lightweight audio processor for loading and cutting audio.
|
||||||
for loading, cutting, and handling audio waveforms.
|
|
||||||
|
|
||||||
Attributes:
|
Attributes:
|
||||||
waveform: torch.Tensor
|
waveform (np.ndarray): The audio waveform as float32.
|
||||||
The audio waveform tensor.
|
sr (int): The sample rate of the audio.
|
||||||
sr: int
|
|
||||||
The sample rate of the audio.
|
|
||||||
"""
|
"""
|
||||||
|
|
||||||
def __init__(self, waveform: torch.Tensor, sr: int = SAMPLE_RATE,
|
def __init__(self, waveform: np.ndarray, sr: int = SAMPLE_RATE):
|
||||||
*args, **kwargs) -> None:
|
self.waveform = waveform
|
||||||
"""
|
|
||||||
Initialize the AudioProcessor object.
|
|
||||||
|
|
||||||
Args:
|
|
||||||
waveform (torch.Tensor): The audio waveform tensor.
|
|
||||||
sr (int, optional): The sample rate of the audio. Defaults to SAMPLE_RATE.
|
|
||||||
args: Additional arguments.
|
|
||||||
kwargs: Additional keyword arguments, e.g., device to use for processing.
|
|
||||||
If CUDA is available, it defaults to CUDA.
|
|
||||||
|
|
||||||
Raises:
|
|
||||||
ValueError: If the provided sample rate is not of type int.
|
|
||||||
"""
|
|
||||||
|
|
||||||
device = kwargs.get(
|
|
||||||
"device", "cuda" if torch.cuda.is_available() else "cpu")
|
|
||||||
|
|
||||||
self.waveform = waveform.to(device)
|
|
||||||
self.sr = sr
|
self.sr = sr
|
||||||
|
|
||||||
if not isinstance(self.sr, int):
|
if not isinstance(self.sr, int):
|
||||||
raise ValueError("Sample rate should be a single value of type int,"
|
raise ValueError(
|
||||||
f"not {len(self.sr)} and type {type(self.sr)}")
|
"Sample rate should be a single value of type int, "
|
||||||
|
f"not {len(self.sr)} and type {type(self.sr)}"
|
||||||
|
)
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
def from_file(cls, file: str, *args, **kwargs) -> 'AudioProcessor':
|
def from_file(cls, file: str, *args, **kwargs):
|
||||||
"""
|
"""
|
||||||
Create an AudioProcessor instance from an audio file.
|
Create an AudioProcessor instance from an audio file.
|
||||||
|
|
||||||
@@ -76,55 +44,42 @@ class AudioProcessor:
|
|||||||
file (str): The audio file path.
|
file (str): The audio file path.
|
||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
AudioProcessor: An instance of the AudioProcessor class containing the loaded audio.
|
AudioProcessor: Instance with loaded audio.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
audio, sr = cls.load_audio(file, *args, **kwargs)
|
audio, sr = cls.load_audio(file, *args, **kwargs)
|
||||||
|
|
||||||
audio = torch.from_numpy(audio)
|
|
||||||
|
|
||||||
return cls(audio, sr)
|
return cls(audio, sr)
|
||||||
|
|
||||||
def cut(self, start: float, end: float) -> torch.Tensor:
|
def cut(self, start: float, end: float) -> np.ndarray:
|
||||||
"""
|
"""
|
||||||
Cut a segment from the audio waveform between the specified start and end times.
|
Cut a segment from the audio waveform.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
start (float): Start time in seconds.
|
start (float): Start time in seconds.
|
||||||
end (float): End time in seconds.
|
end (float): End time in seconds.
|
||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
torch.Tensor: The cut waveform segment.
|
np.ndarray: The cut waveform segment.
|
||||||
"""
|
"""
|
||||||
|
start_idx = int(start * self.sr)
|
||||||
start = int(start * self.sr)
|
end_idx = int(np.ceil(end * self.sr))
|
||||||
if (isinstance(end, float) or isinstance(end, int)) and isinstance(self.sr, int):
|
return self.waveform[start_idx:end_idx]
|
||||||
end = int(np.ceil(end * self.sr))
|
|
||||||
else:
|
|
||||||
end = int(torch.ceil(end * self.sr))
|
|
||||||
return self.waveform[start:end]
|
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
def load_audio(file: str, sr: int = SAMPLE_RATE):
|
def load_audio(file: str, sr: int = SAMPLE_RATE):
|
||||||
"""
|
"""
|
||||||
Open an audio file and read it as a mono waveform, resampling if necessary.
|
Load an audio file as a mono waveform, resampling if necessary.
|
||||||
This method ensures compatibility with pyannote.audio
|
Requires ffmpeg in PATH.
|
||||||
and requires the ffmpeg CLI in PATH.
|
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
file (str): The audio file to open.
|
file (str): The audio file to open.
|
||||||
sr (int, optional): The desired sample rate. Defaults to SAMPLE_RATE.
|
sr (int, optional): The desired sample rate.
|
||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
tuple: A NumPy array containing the audio waveform in float32 dtype
|
tuple: (waveform as np.ndarray[float32], sample rate)
|
||||||
and the sample rate.
|
|
||||||
|
|
||||||
Raises:
|
Raises:
|
||||||
RuntimeError: If failed to load audio.
|
RuntimeError: If failed to load audio.
|
||||||
"""
|
"""
|
||||||
# This launches a subprocess to decode audio while down-mixing
|
|
||||||
# and resampling as necessary. Requires the ffmpeg CLI in PATH.
|
|
||||||
# fmt: off
|
|
||||||
cmd = [
|
cmd = [
|
||||||
"ffmpeg",
|
"ffmpeg",
|
||||||
"-nostdin",
|
"-nostdin",
|
||||||
@@ -134,19 +89,20 @@ class AudioProcessor:
|
|||||||
"-ac", "1",
|
"-ac", "1",
|
||||||
"-acodec", "pcm_s16le",
|
"-acodec", "pcm_s16le",
|
||||||
"-ar", str(sr),
|
"-ar", str(sr),
|
||||||
"-"
|
"-",
|
||||||
]
|
]
|
||||||
# fmt: on
|
|
||||||
try:
|
try:
|
||||||
out = run(cmd, capture_output=True, check=True).stdout
|
out = run(cmd, capture_output=True, check=True).stdout
|
||||||
except CalledProcessError as e:
|
except CalledProcessError as e:
|
||||||
raise RuntimeError(
|
raise RuntimeError(
|
||||||
f"Failed to load audio: {e.stderr.decode()}") from e
|
f"Failed to load audio: {e.stderr.decode()}"
|
||||||
|
) from e
|
||||||
|
|
||||||
out = np.frombuffer(out, np.int16).flatten().astype(
|
waveform = np.frombuffer(out, np.int16).flatten().astype(
|
||||||
np.float32) / NORMALIZATION_FACTOR
|
np.float32
|
||||||
|
) / NORMALIZATION_FACTOR
|
||||||
|
|
||||||
return out, sr
|
return waveform, sr
|
||||||
|
|
||||||
def __repr__(self) -> str:
|
def __repr__(self) -> str:
|
||||||
return f'TorchAudioProcessor(waveform={len(self.waveform)}, sr={int(self.sr)})'
|
return f"AudioProcessor(waveform_len={len(self.waveform)}, sr={self.sr})"
|
||||||
|
|||||||
+218
-294
@@ -1,357 +1,281 @@
|
|||||||
"""
|
"""
|
||||||
Scraibe Class
|
Scraibe Class (LocalAI-backed)
|
||||||
--------------------
|
------------------------------
|
||||||
|
|
||||||
This class serves as the core of the transcription system, responsible for handling
|
Core class for transcription and (optionally) summarization.
|
||||||
transcription and diarization of audio files. It leverages pretrained models for
|
|
||||||
speech-to-text (such as Whisper) and speaker diarization (such as pyannote.audio),
|
|
||||||
providing an accessible interface for audio processing tasks such as transcription,
|
|
||||||
speaker separation, and timestamping.
|
|
||||||
|
|
||||||
By encapsulating the complexities of underlying models, it allows for straightforward
|
- Transcription and diarization are delegated to LocalAI (vibevoice.cpp).
|
||||||
integration into various applications, ranging from transcription services to voice assistants.
|
- Summarization is delegated to a separate LLM via /v1/chat/completions.
|
||||||
|
|
||||||
Available Classes:
|
Public tasks:
|
||||||
- Scraibe: Main class for performing transcription and diarization.
|
- transcribe
|
||||||
Includes methods for loading models, processing audio files,
|
- transcript_and_summarize (transcribe + generate a detailed summary)
|
||||||
and formatting the transcription output.
|
|
||||||
|
|
||||||
Usage:
|
Previous task/whisper/pyannote-specific settings are kept for compatibility
|
||||||
from scraibe import Scraibe
|
but ignored when not relevant.
|
||||||
|
|
||||||
model = Scraibe()
|
|
||||||
transcript = model.autotranscribe("path/to/audiofile.wav")
|
|
||||||
"""
|
"""
|
||||||
|
|
||||||
# Standard Library Imports
|
|
||||||
import os
|
import os
|
||||||
from glob import iglob
|
from typing import Union, Optional
|
||||||
from subprocess import run
|
|
||||||
from typing import TypeVar, Union
|
|
||||||
from warnings import warn
|
|
||||||
|
|
||||||
# Third-Party Imports
|
from .localai_client import LocalAIClient, LocalAIError
|
||||||
import torch
|
from .summarizer import SummarizerClient, SummarizerError
|
||||||
from numpy import ndarray
|
|
||||||
from tqdm import trange
|
|
||||||
|
|
||||||
# Application-Specific Imports
|
|
||||||
from .audio import AudioProcessor
|
|
||||||
from .diarisation import Diariser
|
|
||||||
from .transcriber import Transcriber, load_transcriber, whisper
|
|
||||||
from .transcript_exporter import Transcript
|
from .transcript_exporter import Transcript
|
||||||
|
|
||||||
|
|
||||||
DiarisationType = TypeVar('DiarisationType')
|
|
||||||
|
|
||||||
|
|
||||||
class Scraibe:
|
class Scraibe:
|
||||||
"""
|
"""
|
||||||
Scraibe is a class responsible for managing the transcription and diarization of audio files.
|
Scraibe now:
|
||||||
It serves as the core of the transcription system, incorporating pretrained models
|
- Uses LocalAI for transcription + diarization.
|
||||||
for speech-to-text (such as Whisper) and speaker diarization (such as pyannote.audio),
|
- Uses a separate LLM for summarization (when requested).
|
||||||
allowing for comprehensive audio processing.
|
|
||||||
|
|
||||||
Attributes:
|
Public methods:
|
||||||
transcriber (Transcriber): The transcriber object to handle transcription.
|
- transcribe(audio_file, ...)
|
||||||
diariser (Diariser): The diariser object to handle diarization.
|
- transcript_and_summarize(audio_file, ...)
|
||||||
|
|
||||||
Methods:
|
|
||||||
__init__: Initializes the Scraibe class with appropriate models.
|
|
||||||
transcribe: Transcribes an audio file using the whisper model and pyannote diarization model.
|
|
||||||
remove_audio_file: Removes the original audio file to avoid disk space issues or ensure data privacy.
|
|
||||||
get_audio_file: Gets an audio file as an AudioProcessor object.
|
|
||||||
"""
|
"""
|
||||||
|
|
||||||
def __init__(self,
|
def __init__(
|
||||||
whisper_model: Union[bool, str, whisper] = None,
|
self,
|
||||||
|
api_url: Optional[str] = None,
|
||||||
|
api_key: Optional[str] = None,
|
||||||
|
model: Optional[str] = None,
|
||||||
|
whisper_model: Union[bool, str] = None,
|
||||||
whisper_type: str = "whisper",
|
whisper_type: str = "whisper",
|
||||||
dia_model: Union[bool, str, DiarisationType] = None,
|
dia_model: Union[bool, str] = None,
|
||||||
**kwargs) -> None:
|
use_auth_token: str = None,
|
||||||
"""Initializes the Scraibe class.
|
verbose: bool = False,
|
||||||
|
**kwargs,
|
||||||
|
) -> None:
|
||||||
|
"""
|
||||||
|
Initialize Scraibe with LocalAI client and summarizer client.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
whisper_model (Union[bool, str, whisper], optional):
|
api_url: LocalAI server URL for transcription/diarization.
|
||||||
Path to whisper model or whisper model itself.
|
Falls back to LOCALAI_API_URL env var.
|
||||||
whisper_type (str):
|
api_key: API key for LocalAI. Falls back to LOCALAI_API_KEY.
|
||||||
Type of whisper model to load. "whisper" or "whisperx".
|
model: Model name for LocalAI (e.g., vibevoice-diarize).
|
||||||
diarisation_model (Union[bool, str, DiarisationType], optional):
|
Falls back to LOCALAI_MODEL env var.
|
||||||
Path to pyannote diarization model or model itself.
|
|
||||||
**kwargs: Additional keyword arguments for whisper
|
|
||||||
and pyannote diarization models.
|
|
||||||
e.g.:
|
|
||||||
|
|
||||||
- verbose: If True, the class will print additional information.
|
Summarizer uses:
|
||||||
- save_kwargs: If True, the keyword arguments will be saved
|
- SUMMARIZER_API_URL
|
||||||
for autotranscribe. So you can unload the class and reload it again.
|
- SUMMARIZER_API_KEY
|
||||||
|
- SUMMARIZER_MODEL
|
||||||
|
These can be overridden via environment or via the transcript_and_summarize
|
||||||
|
method if needed.
|
||||||
|
|
||||||
|
Backward-compat (ignored):
|
||||||
|
- whisper_model, whisper_type, dia_model, use_auth_token, etc.
|
||||||
"""
|
"""
|
||||||
|
self.verbose = verbose or kwargs.get("verbose", False)
|
||||||
|
|
||||||
if whisper_model is None:
|
try:
|
||||||
self.transcriber = load_transcriber(
|
self.client = LocalAIClient(
|
||||||
"medium", whisper_type, **kwargs)
|
api_url=api_url,
|
||||||
elif isinstance(whisper_model, str):
|
api_key=api_key,
|
||||||
self.transcriber = load_transcriber(
|
model=model,
|
||||||
whisper_model, whisper_type, **kwargs)
|
)
|
||||||
else:
|
except LocalAIError as e:
|
||||||
self.transcriber = whisper_model
|
raise LocalAIError(f"Failed to initialize LocalAI client: {e}")
|
||||||
|
|
||||||
if dia_model is None:
|
# Summarizer is lazy-initialized if needed
|
||||||
self.diariser = Diariser.load_model(**kwargs)
|
self._summarizer: Optional[SummarizerClient] = None
|
||||||
elif isinstance(dia_model, str):
|
|
||||||
self.diariser = Diariser.load_model(dia_model, **kwargs)
|
|
||||||
else:
|
|
||||||
self.diariser: Diariser = dia_model
|
|
||||||
|
|
||||||
if kwargs.get("verbose"):
|
|
||||||
print("Scraibe initialized all models successfully loaded.")
|
|
||||||
self.verbose = True
|
|
||||||
else:
|
|
||||||
self.verbose = False
|
|
||||||
|
|
||||||
# Save kwargs for autotranscribe if you want to unload the class and load it again.
|
|
||||||
if kwargs.get('save_setup'):
|
|
||||||
self.params = dict(whisper_model=whisper_model,
|
|
||||||
dia_model=dia_model,
|
|
||||||
**kwargs)
|
|
||||||
else:
|
|
||||||
self.params = {}
|
|
||||||
|
|
||||||
def autotranscribe(self, audio_file: Union[str, torch.Tensor, ndarray],
|
|
||||||
remove_original: bool = False,
|
|
||||||
**kwargs) -> Transcript:
|
|
||||||
"""
|
|
||||||
Transcribes an audio file using the whisper model and pyannote diarization model.
|
|
||||||
|
|
||||||
Args:
|
|
||||||
audio_file (Union[str, torch.Tensor, ndarray]):
|
|
||||||
Path to audio file or a tensor representing the audio.
|
|
||||||
remove_original (bool, optional): If True, the original audio file will
|
|
||||||
be removed after transcription.
|
|
||||||
*args: Additional positional arguments for diarization and transcription.
|
|
||||||
**kwargs: Additional keyword arguments for diarization and transcription.
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
Transcript: A Transcript object containing the transcription,
|
|
||||||
which can be exported to different formats.
|
|
||||||
"""
|
|
||||||
if kwargs.get("verbose"):
|
|
||||||
self.verbose = kwargs.get("verbose")
|
|
||||||
# Get audio file as an AudioProcessor object
|
|
||||||
audio_file: AudioProcessor = self.get_audio_file(audio_file)
|
|
||||||
|
|
||||||
# Prepare waveform and sample rate for diarization
|
|
||||||
dia_audio = {
|
|
||||||
"waveform": audio_file.waveform.reshape(1, len(audio_file.waveform)),
|
|
||||||
"sample_rate": audio_file.sr
|
|
||||||
}
|
|
||||||
|
|
||||||
if self.verbose:
|
if self.verbose:
|
||||||
print("Starting diarisation.")
|
print("Scraibe initialized. Using LocalAI for transcription and diarization.")
|
||||||
|
|
||||||
diarisation = self.diariser.diarization(dia_audio, **kwargs)
|
def _ensure_summarizer(
|
||||||
|
self,
|
||||||
if not diarisation["segments"]:
|
api_url: Optional[str] = None,
|
||||||
print("No segments found. Try to run transcription without diarisation.")
|
api_key: Optional[str] = None,
|
||||||
|
model: Optional[str] = None,
|
||||||
transcript = self.transcriber.transcribe(
|
) -> SummarizerClient:
|
||||||
audio_file.waveform, **kwargs)
|
|
||||||
|
|
||||||
final_transcript = {0: {"speakers": 'SPEAKER_01',
|
|
||||||
"segments": [0, len(audio_file.waveform)],
|
|
||||||
"text": transcript}}
|
|
||||||
|
|
||||||
return Transcript(final_transcript)
|
|
||||||
|
|
||||||
if self.verbose:
|
|
||||||
print("Diarisation finished. Starting transcription.")
|
|
||||||
|
|
||||||
audio_file.sr = torch.Tensor([audio_file.sr]).to(
|
|
||||||
audio_file.waveform.device)
|
|
||||||
|
|
||||||
# Transcribe each segment and store the results
|
|
||||||
final_transcript = dict()
|
|
||||||
|
|
||||||
for i in trange(len(diarisation["segments"]), desc="Transcribing", disable=not self.verbose):
|
|
||||||
|
|
||||||
seg = diarisation["segments"][i]
|
|
||||||
|
|
||||||
audio = audio_file.cut(seg[0], seg[1])
|
|
||||||
|
|
||||||
transcript = self.transcriber.transcribe(audio, **kwargs)
|
|
||||||
|
|
||||||
final_transcript[i] = {"speakers": diarisation["speakers"][i],
|
|
||||||
"segments": seg,
|
|
||||||
"text": transcript}
|
|
||||||
|
|
||||||
# Remove original file if needed
|
|
||||||
if remove_original:
|
|
||||||
if kwargs.get("shred") is True:
|
|
||||||
self.remove_audio_file(audio_file, shred=True)
|
|
||||||
else:
|
|
||||||
self.remove_audio_file(audio_file, shred=False)
|
|
||||||
|
|
||||||
return Transcript(final_transcript)
|
|
||||||
|
|
||||||
def diarization(self, audio_file: Union[str, torch.Tensor, ndarray],
|
|
||||||
**kwargs) -> dict:
|
|
||||||
"""
|
"""
|
||||||
Perform diarization on an audio file using the pyannote diarization model.
|
Lazy-init summarizer client.
|
||||||
|
"""
|
||||||
|
if self._summarizer is not None:
|
||||||
|
return self._summarizer
|
||||||
|
|
||||||
|
try:
|
||||||
|
self._summarizer = SummarizerClient(
|
||||||
|
api_url=api_url,
|
||||||
|
api_key=api_key,
|
||||||
|
model=model,
|
||||||
|
)
|
||||||
|
except SummarizerError as e:
|
||||||
|
raise SummarizerError(f"Failed to initialize Summarizer client: {e}")
|
||||||
|
|
||||||
|
return self._summarizer
|
||||||
|
|
||||||
|
# -----------------
|
||||||
|
# Primary public API
|
||||||
|
# -----------------
|
||||||
|
|
||||||
|
def transcribe(
|
||||||
|
self,
|
||||||
|
audio_file: Union[str],
|
||||||
|
**kwargs,
|
||||||
|
) -> str:
|
||||||
|
"""
|
||||||
|
Transcribe the provided audio file using LocalAI.
|
||||||
|
|
||||||
|
Uses /v1/audio/diarization with vibevoice.cpp, then concatenates
|
||||||
|
all segment texts.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
audio_file (Union[str, torch.Tensor, ndarray]):
|
audio_file (str): Path to the audio file.
|
||||||
The audio source which can either be a path to the audio file or a tensor representation.
|
**kwargs: Additional keyword arguments (some forwarded, others ignored).
|
||||||
**kwargs:
|
|
||||||
Additional keyword arguments for diarization.
|
|
||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
dict:
|
str: The concatenated transcribed text.
|
||||||
A dictionary containing the results of the diarization process.
|
|
||||||
"""
|
"""
|
||||||
|
if isinstance(audio_file, str):
|
||||||
|
if not os.path.exists(audio_file):
|
||||||
|
raise FileNotFoundError(f"Audio file not found: {audio_file}")
|
||||||
|
else:
|
||||||
|
raise TypeError(
|
||||||
|
"In LocalAI mode, audio_file must be a file path (str)."
|
||||||
|
)
|
||||||
|
|
||||||
# Get audio file as an AudioProcessor object
|
verbose = kwargs.get("verbose", self.verbose)
|
||||||
audio_file: AudioProcessor = self.get_audio_file(audio_file)
|
|
||||||
|
|
||||||
# Prepare waveform and sample rate for diarization
|
try:
|
||||||
dia_audio = {
|
result = self.client.diarize_and_transcribe(
|
||||||
"waveform": audio_file.waveform.reshape(1, len(audio_file.waveform)),
|
audio_path=audio_file,
|
||||||
"sample_rate": audio_file.sr
|
include_text=True,
|
||||||
|
verbose=verbose,
|
||||||
|
**kwargs,
|
||||||
|
)
|
||||||
|
except LocalAIError as e:
|
||||||
|
raise LocalAIError(f"Error during LocalAI transcription: {e}")
|
||||||
|
|
||||||
|
transcripts = result.get("transcripts", [])
|
||||||
|
return " ".join(t.strip() for t in transcripts if t.strip())
|
||||||
|
|
||||||
|
def transcript_and_summarize(
|
||||||
|
self,
|
||||||
|
audio_file: Union[str],
|
||||||
|
*,
|
||||||
|
summarizer_api_url: Optional[str] = None,
|
||||||
|
summarizer_api_key: Optional[str] = None,
|
||||||
|
summarizer_model: Optional[str] = None,
|
||||||
|
**kwargs,
|
||||||
|
) -> dict:
|
||||||
|
"""
|
||||||
|
Transcribe the audio file and generate a detailed summary.
|
||||||
|
|
||||||
|
Steps:
|
||||||
|
- Transcribe via LocalAI.
|
||||||
|
- Build a plain-text transcript (with speaker labels).
|
||||||
|
- Summarize the transcript using the configured LLM.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
dict with:
|
||||||
|
- transcript: full transcript text (with speaker labels)
|
||||||
|
- summary: final detailed summary (markdown-ready)
|
||||||
|
"""
|
||||||
|
if isinstance(audio_file, str):
|
||||||
|
if not os.path.exists(audio_file):
|
||||||
|
raise FileNotFoundError(f"Audio file not found: {audio_file}")
|
||||||
|
else:
|
||||||
|
raise TypeError(
|
||||||
|
"In LocalAI mode, audio_file must be a file path (str)."
|
||||||
|
)
|
||||||
|
|
||||||
|
verbose = kwargs.get("verbose", self.verbose)
|
||||||
|
|
||||||
|
# 1) Get diarized + transcribed result
|
||||||
|
try:
|
||||||
|
result = self.client.diarize_and_transcribe(
|
||||||
|
audio_path=audio_file,
|
||||||
|
include_text=True,
|
||||||
|
verbose=verbose,
|
||||||
|
**kwargs,
|
||||||
|
)
|
||||||
|
except LocalAIError as e:
|
||||||
|
raise LocalAIError(f"Error during LocalAI transcription: {e}")
|
||||||
|
|
||||||
|
segments = result.get("segments", [])
|
||||||
|
speakers = result.get("speakers", [])
|
||||||
|
transcripts = result.get("transcripts", [])
|
||||||
|
|
||||||
|
if not segments:
|
||||||
|
return {
|
||||||
|
"transcript": "",
|
||||||
|
"summary": "No transcript content to summarize.",
|
||||||
}
|
}
|
||||||
|
|
||||||
print("Starting diarisation.")
|
# 2) Build full transcript text with speaker labels
|
||||||
|
lines = []
|
||||||
|
for seg, speaker, text in zip(segments, speakers, transcripts):
|
||||||
|
start, end = seg
|
||||||
|
ts = self._format_timestamp(start)
|
||||||
|
line = f"[{ts}] {speaker}: {text.strip()}"
|
||||||
|
lines.append(line)
|
||||||
|
|
||||||
diarisation = self.diariser.diarization(dia_audio, **kwargs)
|
full_transcript = "\n\n".join(lines)
|
||||||
|
|
||||||
return diarisation
|
# 3) Summarize
|
||||||
|
try:
|
||||||
|
summarizer = self._ensure_summarizer(
|
||||||
|
api_url=summarizer_api_url,
|
||||||
|
api_key=summarizer_api_key,
|
||||||
|
model=summarizer_model,
|
||||||
|
)
|
||||||
|
except SummarizerError as e:
|
||||||
|
raise SummarizerError(f"Failed to initialize summarizer: {e}")
|
||||||
|
|
||||||
def transcribe(self, audio_file: Union[str, torch.Tensor, ndarray],
|
try:
|
||||||
**kwargs):
|
summary = summarizer.summarize_transcript(full_transcript)
|
||||||
"""
|
except SummarizerError as e:
|
||||||
Transcribe the provided audio file.
|
raise SummarizerError(f"Error during summarization: {e}")
|
||||||
|
|
||||||
Args:
|
return {
|
||||||
audio_file (Union[str, torch.Tensor, ndarray]):
|
"transcript": full_transcript,
|
||||||
The audio source, which can either be a path or a tensor representation.
|
"summary": summary,
|
||||||
**kwargs:
|
}
|
||||||
Additional keyword arguments for transcription.
|
|
||||||
|
|
||||||
Returns:
|
# -----------------
|
||||||
str:
|
# Helpers
|
||||||
The transcribed text from the audio source.
|
# -----------------
|
||||||
"""
|
|
||||||
audio_file: AudioProcessor = self.get_audio_file(audio_file)
|
|
||||||
|
|
||||||
return self.transcriber.transcribe(audio_file.waveform, **kwargs)
|
|
||||||
|
|
||||||
def update_transcriber(self, whisper_model: Union[str, whisper], **kwargs) -> None:
|
|
||||||
"""
|
|
||||||
Update the transcriber model.
|
|
||||||
|
|
||||||
Args:
|
|
||||||
whisper_model (Union[str, whisper]):
|
|
||||||
The new whisper model to use for transcription.
|
|
||||||
**kwargs:
|
|
||||||
Additional keyword arguments for the transcriber model.
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
None
|
|
||||||
"""
|
|
||||||
_old_model = self.transcriber.model_name
|
|
||||||
|
|
||||||
if isinstance(whisper_model, str):
|
|
||||||
self.transcriber = load_transcriber(whisper_model, **kwargs)
|
|
||||||
elif isinstance(whisper_model, Transcriber):
|
|
||||||
self.transcriber = whisper_model
|
|
||||||
else:
|
|
||||||
warn(
|
|
||||||
f"Invalid model type. Please provide a valid model. Fallback to old {_old_model} Model.", RuntimeWarning)
|
|
||||||
|
|
||||||
return None
|
|
||||||
|
|
||||||
def update_diariser(self, dia_model: Union[str, DiarisationType], **kwargs) -> None:
|
|
||||||
"""
|
|
||||||
Update the diariser model.
|
|
||||||
|
|
||||||
Args:
|
|
||||||
dia_model (Union[str, DiarisationType]):
|
|
||||||
The new diariser model to use for diarization.
|
|
||||||
**kwargs:
|
|
||||||
Additional keyword arguments for the diariser model.
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
None
|
|
||||||
"""
|
|
||||||
if isinstance(dia_model, str):
|
|
||||||
self.diariser = Diariser.load_model(dia_model, **kwargs)
|
|
||||||
elif isinstance(dia_model, Diariser):
|
|
||||||
self.diariser = dia_model
|
|
||||||
else:
|
|
||||||
warn("Invalid model type. Please provide a valid model. Fallback to old Model.", RuntimeWarning)
|
|
||||||
|
|
||||||
return None
|
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
def remove_audio_file(audio_file: str,
|
def _format_timestamp(seconds: float) -> str:
|
||||||
shred: bool = False) -> None:
|
|
||||||
"""
|
"""
|
||||||
Removes the original audio file to avoid disk space issues or ensure data privacy.
|
Format seconds into MM:SS or HH:MM:SS.
|
||||||
|
"""
|
||||||
|
m, s = divmod(int(seconds), 60)
|
||||||
|
h, m = divmod(m, 60)
|
||||||
|
if h > 0:
|
||||||
|
return f"{h:02d}:{m:02d}:{s:02d}"
|
||||||
|
return f"{m:02d}:{s:02d}"
|
||||||
|
|
||||||
Args:
|
@staticmethod
|
||||||
audio_file_path (str): Path to the audio file.
|
def remove_audio_file(audio_file: str, shred: bool = False) -> None:
|
||||||
shred (bool, optional): If True, the audio file will be shredded,
|
"""
|
||||||
not just removed.
|
Remove the original audio file.
|
||||||
"""
|
"""
|
||||||
if not os.path.exists(audio_file):
|
if not os.path.exists(audio_file):
|
||||||
raise ValueError(f"Audiofile {audio_file} does not exist.")
|
raise ValueError(f"Audiofile {audio_file} does not exist.")
|
||||||
|
|
||||||
if shred:
|
if shred:
|
||||||
|
import subprocess
|
||||||
|
import warnings
|
||||||
|
from glob import iglob
|
||||||
|
|
||||||
warn("Shredding audiofile can take a long time.", RuntimeWarning)
|
warnings.warn("Shredding audiofile can take a long time.", RuntimeWarning)
|
||||||
|
|
||||||
gen = iglob(f'{audio_file}', recursive=True)
|
gen = iglob(f"{audio_file}", recursive=True)
|
||||||
cmd = ['shred', '-zvu', '-n', '10', f'{audio_file}']
|
cmd = ["shred", "-zvu", "-n", "10", f"{audio_file}"]
|
||||||
|
|
||||||
if os.path.isdir(audio_file):
|
if os.path.isdir(audio_file):
|
||||||
raise ValueError(f"Audiofile {audio_file} is a directory.")
|
raise ValueError(f"Audiofile {audio_file} is a directory.")
|
||||||
|
|
||||||
for file in gen:
|
for file in gen:
|
||||||
print(f'shredding {file} now\n')
|
print(f"shredding {file} now\n")
|
||||||
|
subprocess.run(cmd, check=True)
|
||||||
run(cmd, check=True)
|
|
||||||
|
|
||||||
else:
|
else:
|
||||||
os.remove(audio_file)
|
os.remove(audio_file)
|
||||||
print(f"Audiofile {audio_file} removed.")
|
print(f"Audiofile {audio_file} removed.")
|
||||||
|
|
||||||
@staticmethod
|
|
||||||
def get_audio_file(audio_file: Union[str, torch.Tensor, ndarray],
|
|
||||||
*args, **kwargs) -> AudioProcessor:
|
|
||||||
"""Gets an audio file as TorchAudioProcessor.
|
|
||||||
|
|
||||||
Args:
|
|
||||||
audio_file (Union[str, torch.Tensor, ndarray]): Path to the audio file or
|
|
||||||
a tensor representing the audio.
|
|
||||||
*args: Additional positional arguments.
|
|
||||||
**kwargs: Additional keyword arguments.
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
AudioProcessor: An object containing the waveform and sample rate in
|
|
||||||
torch.Tensor format.
|
|
||||||
"""
|
|
||||||
|
|
||||||
if isinstance(audio_file, str):
|
|
||||||
audio_file = AudioProcessor.from_file(audio_file)
|
|
||||||
|
|
||||||
elif isinstance(audio_file, torch.Tensor):
|
|
||||||
audio_file = AudioProcessor(audio_file[0], audio_file[1])
|
|
||||||
elif isinstance(audio_file, ndarray):
|
|
||||||
audio_file = AudioProcessor(torch.Tensor(audio_file[0]),
|
|
||||||
audio_file[1])
|
|
||||||
|
|
||||||
if not isinstance(audio_file, AudioProcessor):
|
|
||||||
raise ValueError(f'Audiofile must be of type AudioProcessor,'
|
|
||||||
f'not {type(audio_file)}')
|
|
||||||
|
|
||||||
return audio_file
|
|
||||||
|
|
||||||
def __repr__(self):
|
def __repr__(self):
|
||||||
return f"Scraibe(transcriber={self.transcriber}, diariser={self.diariser})"
|
return "Scraibe(LocalAI-backed)"
|
||||||
|
|||||||
+191
-87
@@ -3,23 +3,21 @@ Command-Line Interface (CLI) for the Scraibe class,
|
|||||||
allowing for user interaction to transcribe and diarize audio files.
|
allowing for user interaction to transcribe and diarize audio files.
|
||||||
The function includes arguments for specifying the audio files, model paths,
|
The function includes arguments for specifying the audio files, model paths,
|
||||||
output formats, and other options necessary for transcription.
|
output formats, and other options necessary for transcription.
|
||||||
|
|
||||||
|
This version is adapted for LocalAI-based transcription and diarization.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
import os
|
import os
|
||||||
import json
|
import json
|
||||||
from argparse import ArgumentParser, ArgumentDefaultsHelpFormatter
|
from argparse import ArgumentParser, ArgumentDefaultsHelpFormatter
|
||||||
from whisper.tokenizer import LANGUAGES, TO_LANGUAGE_CODE
|
|
||||||
from torch.cuda import is_available
|
|
||||||
from torch import set_num_threads
|
|
||||||
from .autotranscript import Scraibe
|
from .autotranscript import Scraibe
|
||||||
|
from .misc import set_threads
|
||||||
|
|
||||||
|
|
||||||
def cli():
|
def cli():
|
||||||
"""
|
"""
|
||||||
Command-Line Interface (CLI) for the Scraibe class, allowing for user interaction to transcribe
|
Command-Line Interface (CLI) for the Scraibe class, allowing for user interaction to transcribe
|
||||||
and diarize audio files. The function includes arguments for specifying the audio files, model paths,
|
and diarize audio files via a LocalAI server.
|
||||||
output formats, and other options necessary for transcription.
|
|
||||||
|
|
||||||
This function can be executed from the command line to perform transcription tasks, providing a
|
|
||||||
user-friendly way to access the Scraibe class functionalities.
|
|
||||||
"""
|
"""
|
||||||
|
|
||||||
def str2bool(string):
|
def str2bool(string):
|
||||||
@@ -28,57 +26,160 @@ def cli():
|
|||||||
return str2val[string]
|
return str2val[string]
|
||||||
else:
|
else:
|
||||||
raise ValueError(
|
raise ValueError(
|
||||||
f"Expected one of {set(str2val.keys())}, got {string}")
|
f"Expected one of {set(str2val.keys())}, got {string}"
|
||||||
|
)
|
||||||
|
|
||||||
parser = ArgumentParser(formatter_class=ArgumentDefaultsHelpFormatter)
|
parser = ArgumentParser(formatter_class=ArgumentDefaultsHelpFormatter)
|
||||||
|
|
||||||
parser.add_argument("-f", "--audio-files", nargs="+", type=str, default=None,
|
parser.add_argument(
|
||||||
help="List of audio files to transcribe.")
|
"-f",
|
||||||
|
"--audio-files",
|
||||||
|
nargs="+",
|
||||||
|
type=str,
|
||||||
|
default=None,
|
||||||
|
help="List of audio files to transcribe.",
|
||||||
|
)
|
||||||
|
|
||||||
parser.add_argument("--whisper-type", type=str, default="whisper",
|
# LocalAI connection (env vars preferred, but CLI overrides allowed)
|
||||||
choices=["whisper", "whisperx"],
|
parser.add_argument(
|
||||||
help="Type of Whisper model to use ('whisper' or 'whisperx').")
|
"--localai-api-url",
|
||||||
|
type=str,
|
||||||
|
default=None,
|
||||||
|
help="LocalAI server URL (e.g., http://localhost:8080). "
|
||||||
|
"Overrides LOCALAI_API_URL env var if provided.",
|
||||||
|
)
|
||||||
|
parser.add_argument(
|
||||||
|
"--localai-api-key",
|
||||||
|
type=str,
|
||||||
|
default=None,
|
||||||
|
help="LocalAI API key. Overrides LOCALAI_API_KEY env var if provided.",
|
||||||
|
)
|
||||||
|
parser.add_argument(
|
||||||
|
"--localai-model",
|
||||||
|
type=str,
|
||||||
|
default=None,
|
||||||
|
help="Model name to use on LocalAI (e.g., vibevoice-diarize). "
|
||||||
|
"Overrides LOCALAI_MODEL env var if provided.",
|
||||||
|
)
|
||||||
|
|
||||||
parser.add_argument("--whisper-model-name", default="medium",
|
# Summarizer overrides (env vars are primary)
|
||||||
help="Name of the Whisper model to use.")
|
parser.add_argument(
|
||||||
|
"--summarizer-api-url",
|
||||||
|
type=str,
|
||||||
|
default=None,
|
||||||
|
help="Summarization LLM API URL (e.g., http://localhost:8080). "
|
||||||
|
"Overrides SUMMARIZER_API_URL env var if provided.",
|
||||||
|
)
|
||||||
|
parser.add_argument(
|
||||||
|
"--summarizer-api-key",
|
||||||
|
type=str,
|
||||||
|
default=None,
|
||||||
|
help="Summarization LLM API key. Overrides SUMMARIZER_API_KEY env var if provided.",
|
||||||
|
)
|
||||||
|
parser.add_argument(
|
||||||
|
"--summarizer-model",
|
||||||
|
type=str,
|
||||||
|
default=None,
|
||||||
|
help="Model name for summarization. Overrides SUMMARIZER_MODEL env var if provided.",
|
||||||
|
)
|
||||||
|
|
||||||
parser.add_argument("--whisper-model-directory", type=str, default=None,
|
# Kept for backward compatibility with UI / existing scripts; ignored by LocalAI client.
|
||||||
help="Path to save Whisper model files; defaults to ./models/whisper.")
|
parser.add_argument(
|
||||||
|
"--whisper-type",
|
||||||
|
type=str,
|
||||||
|
default="whisper",
|
||||||
|
choices=["whisper", "faster-whisper"],
|
||||||
|
help="[Backward compatibility] Type of Whisper model. Ignored when using LocalAI.",
|
||||||
|
)
|
||||||
|
|
||||||
parser.add_argument("--diarization-directory", type=str, default=None,
|
parser.add_argument(
|
||||||
help="Path to the diarization model directory.")
|
"--whisper-model-name",
|
||||||
|
default="medium",
|
||||||
|
help="[Backward compatibility] Whisper model name. Ignored when using LocalAI.",
|
||||||
|
)
|
||||||
|
|
||||||
parser.add_argument("--hf-token", default=None, type=str,
|
parser.add_argument(
|
||||||
help="HuggingFace token for private model download.")
|
"--whisper-model-directory",
|
||||||
|
type=str,
|
||||||
|
default=None,
|
||||||
|
help="[Backward compatibility] Whisper model directory. Ignored when using LocalAI.",
|
||||||
|
)
|
||||||
|
|
||||||
parser.add_argument("--inference-device",
|
parser.add_argument(
|
||||||
default="cuda" if is_available() else "cpu",
|
"--diarization-directory",
|
||||||
help="Device to use for PyTorch inference.")
|
type=str,
|
||||||
|
default=None,
|
||||||
|
help="[Backward compatibility] Diarization model directory. Ignored when using LocalAI.",
|
||||||
|
)
|
||||||
|
|
||||||
parser.add_argument("--num-threads", type=int, default=0,
|
parser.add_argument(
|
||||||
help="Number of threads used by torch for CPU inference; '\
|
"--hf-token",
|
||||||
'overrides MKL_NUM_THREADS/OMP_NUM_THREADS.")
|
default=None,
|
||||||
|
type=str,
|
||||||
|
help="[Backward compatibility] HuggingFace token. Ignored when using LocalAI.",
|
||||||
|
)
|
||||||
|
|
||||||
parser.add_argument("--output-directory", "-o", type=str, default=".",
|
parser.add_argument(
|
||||||
help="Directory to save the transcription outputs.")
|
"--inference-device",
|
||||||
|
default="cpu",
|
||||||
|
help="[Backward compatibility] Device for inference. Ignored when using LocalAI.",
|
||||||
|
)
|
||||||
|
|
||||||
parser.add_argument("--output-format", "-of", type=str, default="txt",
|
parser.add_argument(
|
||||||
|
"--num-threads",
|
||||||
|
type=int,
|
||||||
|
default=None,
|
||||||
|
help="Number of threads used for CPU operations; overrides MKL_NUM_THREADS/OMP_NUM_THREADS.",
|
||||||
|
)
|
||||||
|
|
||||||
|
parser.add_argument(
|
||||||
|
"--output-directory",
|
||||||
|
"-o",
|
||||||
|
type=str,
|
||||||
|
default=".",
|
||||||
|
help="Directory to save the transcription outputs.",
|
||||||
|
)
|
||||||
|
|
||||||
|
parser.add_argument(
|
||||||
|
"--output-format",
|
||||||
|
"-of",
|
||||||
|
type=str,
|
||||||
|
default="txt",
|
||||||
choices=["txt", "json", "md", "html"],
|
choices=["txt", "json", "md", "html"],
|
||||||
help="Format of the output file; defaults to txt.")
|
help="Format of the output file; defaults to txt.",
|
||||||
|
)
|
||||||
|
|
||||||
parser.add_argument("--verbose-output", type=str2bool, default=True,
|
parser.add_argument(
|
||||||
help="Enable or disable progress and debug messages.")
|
"--verbose-output",
|
||||||
|
type=str2bool,
|
||||||
|
default=True,
|
||||||
|
help="Enable or disable progress and debug messages.",
|
||||||
|
)
|
||||||
|
|
||||||
parser.add_argument("--task", type=str, default='autotranscribe',
|
parser.add_argument(
|
||||||
choices=["autotranscribe", "diarization",
|
"--task",
|
||||||
"autotranscribe+translate", "translate", 'transcribe'],
|
type=str,
|
||||||
help="Choose to perform transcription, diarization, or translation. \
|
default="transcribe",
|
||||||
If set to translate, the output will be translated to English.")
|
choices=[
|
||||||
|
"transcribe",
|
||||||
|
"transcript_and_summarize",
|
||||||
|
],
|
||||||
|
help="Task to perform: 'transcribe' or 'transcript_and_summarize'.",
|
||||||
|
)
|
||||||
|
|
||||||
parser.add_argument("--language", type=str, default=None,
|
parser.add_argument(
|
||||||
choices=sorted(
|
"--language",
|
||||||
LANGUAGES.keys()) + sorted([k.title() for k in TO_LANGUAGE_CODE.keys()]),
|
type=str,
|
||||||
help="Language spoken in the audio. Specify None to perform language detection.")
|
default=None,
|
||||||
|
help="Language spoken in the audio. Specify None to perform language detection.",
|
||||||
|
)
|
||||||
|
|
||||||
|
parser.add_argument(
|
||||||
|
"--num-speakers",
|
||||||
|
type=int,
|
||||||
|
default=None,
|
||||||
|
help="Number of speakers in the audio.",
|
||||||
|
)
|
||||||
|
|
||||||
args = parser.parse_args()
|
args = parser.parse_args()
|
||||||
|
|
||||||
@@ -92,63 +193,66 @@ def cli():
|
|||||||
|
|
||||||
task = arg_dict.pop("task")
|
task = arg_dict.pop("task")
|
||||||
|
|
||||||
if args.num_threads > 0:
|
set_threads(arg_dict.pop("num_threads"))
|
||||||
set_num_threads(arg_dict.pop("num_threads"))
|
|
||||||
|
|
||||||
class_kwargs = {'whisper_model': arg_dict.pop("whisper_model_name"),
|
# Build kwargs for Scraibe (LocalAI-backed)
|
||||||
'whisper_type':arg_dict.pop("whisper_type"),
|
class_kwargs = {
|
||||||
'dia_model': arg_dict.pop("diarization_directory"),
|
"api_url": arg_dict.pop("localai_api_url"),
|
||||||
'use_auth_token': arg_dict.pop("hf_token"),
|
"api_key": arg_dict.pop("localai_api_key"),
|
||||||
|
"model": arg_dict.pop("localai_model"),
|
||||||
|
# kept for backward compatibility, but ignored:
|
||||||
|
"whisper_model": arg_dict.pop("whisper_model_name"),
|
||||||
|
"whisper_type": arg_dict.pop("whisper_type"),
|
||||||
|
"dia_model": arg_dict.pop("diarization_directory"),
|
||||||
|
"use_auth_token": arg_dict.pop("hf_token"),
|
||||||
|
"verbose": arg_dict.pop("verbose_output"),
|
||||||
}
|
}
|
||||||
|
|
||||||
if arg_dict["whisper_model_directory"]:
|
|
||||||
class_kwargs["download_root"] = arg_dict.pop("whisper_model_directory")
|
|
||||||
|
|
||||||
|
|
||||||
model = Scraibe(**class_kwargs)
|
model = Scraibe(**class_kwargs)
|
||||||
|
|
||||||
if arg_dict["audio_files"]:
|
if arg_dict["audio_files"]:
|
||||||
audio_files = arg_dict.pop("audio_files")
|
audio_files = arg_dict.pop("audio_files")
|
||||||
|
|
||||||
if task == "autotranscribe" or task == "autotranscribe+translate":
|
if task == "transcribe":
|
||||||
for audio in audio_files:
|
for audio in audio_files:
|
||||||
if task == "autotranscribe+translate":
|
out = model.transcribe(
|
||||||
task = "translate"
|
audio,
|
||||||
else:
|
|
||||||
task = "transcribe"
|
|
||||||
|
|
||||||
out = model.autotranscribe(audio, task=task, language=arg_dict.pop(
|
|
||||||
"language"), verbose=arg_dict.pop("verbose_output"))
|
|
||||||
basename = audio.split("/")[-1].split(".")[0]
|
|
||||||
print(f'Saving {basename}.{out_format} to {out_folder}')
|
|
||||||
out.save(os.path.join(
|
|
||||||
out_folder, f"{basename}.{out_format}"))
|
|
||||||
|
|
||||||
elif task == "diarization":
|
|
||||||
for audio in audio_files:
|
|
||||||
if arg_dict.pop("verbose_output"):
|
|
||||||
print("Verbose not implemented for diarization.")
|
|
||||||
|
|
||||||
out = model.diarization(audio)
|
|
||||||
basename = audio.split("/")[-1].split(".")[0]
|
|
||||||
path = os.path.join(out_folder, f"{basename}.{out_format}")
|
|
||||||
|
|
||||||
print(f'Saving {basename}.{out_format} to {out_folder}')
|
|
||||||
|
|
||||||
with open(path, "w") as f:
|
|
||||||
json.dump(json.dumps(out, indent=1), f)
|
|
||||||
|
|
||||||
elif task == "transcribe" or task == "translate":
|
|
||||||
|
|
||||||
for audio in audio_files:
|
|
||||||
|
|
||||||
out = model.transcribe(audio, task=task,
|
|
||||||
language=arg_dict.pop("language"),
|
language=arg_dict.pop("language"),
|
||||||
verbose=arg_dict.pop("verbose_output"))
|
verbose=arg_dict.pop("verbose_output"),
|
||||||
|
num_speakers=arg_dict.pop("num_speakers"),
|
||||||
|
)
|
||||||
basename = audio.split("/")[-1].split(".")[0]
|
basename = audio.split("/")[-1].split(".")[0]
|
||||||
path = os.path.join(out_folder, f"{basename}.{out_format}")
|
path = os.path.join(out_folder, f"{basename}.{out_format}")
|
||||||
with open(path, "w") as f:
|
print(f"Saving {basename}.{out_format} to {out_folder}")
|
||||||
|
with open(path, "w", encoding="utf-8") as f:
|
||||||
f.write(out)
|
f.write(out)
|
||||||
|
|
||||||
|
elif task == "transcript_and_summarize":
|
||||||
|
for audio in audio_files:
|
||||||
|
result = model.transcript_and_summarize(
|
||||||
|
audio,
|
||||||
|
summarizer_api_url=arg_dict.pop("summarizer_api_url"),
|
||||||
|
summarizer_api_key=arg_dict.pop("summarizer_api_key"),
|
||||||
|
summarizer_model=arg_dict.pop("summarizer_model"),
|
||||||
|
language=arg_dict.pop("language"),
|
||||||
|
verbose=arg_dict.pop("verbose_output"),
|
||||||
|
num_speakers=arg_dict.pop("num_speakers"),
|
||||||
|
)
|
||||||
|
|
||||||
|
transcript_text = result.get("transcript", "")
|
||||||
|
summary_text = result.get("summary", "")
|
||||||
|
|
||||||
|
basename = audio.split("/")[-1].split(".")[0]
|
||||||
|
|
||||||
|
# Always use .md for transcript_and_summarize
|
||||||
|
md_path = os.path.join(out_folder, f"{basename}.md")
|
||||||
|
print(f"Saving {basename}.md (transcript + summary) to {out_folder}")
|
||||||
|
|
||||||
|
with open(md_path, "w", encoding="utf-8") as f:
|
||||||
|
f.write("# Transcript\n\n")
|
||||||
|
f.write(transcript_text)
|
||||||
|
f.write("\n\n# Summary\n\n")
|
||||||
|
f.write(summary_text)
|
||||||
|
|
||||||
if __name__ == "__main__":
|
if __name__ == "__main__":
|
||||||
cli()
|
cli()
|
||||||
|
|||||||
@@ -37,11 +37,11 @@ from pyannote.audio import Pipeline
|
|||||||
from pyannote.audio.pipelines.speaker_diarization import SpeakerDiarization
|
from pyannote.audio.pipelines.speaker_diarization import SpeakerDiarization
|
||||||
from torch import Tensor
|
from torch import Tensor
|
||||||
from torch import device as torch_device
|
from torch import device as torch_device
|
||||||
from torch.cuda import is_available
|
|
||||||
from huggingface_hub import HfApi
|
from huggingface_hub import HfApi
|
||||||
from huggingface_hub.utils import RepositoryNotFoundError
|
from huggingface_hub.utils import RepositoryNotFoundError
|
||||||
|
|
||||||
from .misc import PYANNOTE_DEFAULT_PATH, PYANNOTE_DEFAULT_CONFIG
|
from .misc import PYANNOTE_DEFAULT_PATH, PYANNOTE_DEFAULT_CONFIG, SCRAIBE_TORCH_DEVICE
|
||||||
Annotation = TypeVar('Annotation')
|
Annotation = TypeVar('Annotation')
|
||||||
|
|
||||||
TOKEN_PATH = os.path.join(os.path.dirname(
|
TOKEN_PATH = os.path.join(os.path.dirname(
|
||||||
@@ -190,8 +190,7 @@ class Diariser:
|
|||||||
cache_token: bool = False,
|
cache_token: bool = False,
|
||||||
cache_dir: Union[Path, str] = PYANNOTE_DEFAULT_PATH,
|
cache_dir: Union[Path, str] = PYANNOTE_DEFAULT_PATH,
|
||||||
hparams_file: Union[str, Path] = None,
|
hparams_file: Union[str, Path] = None,
|
||||||
device: str = None,
|
device: str = SCRAIBE_TORCH_DEVICE,
|
||||||
*args, **kwargs
|
|
||||||
) -> Pipeline:
|
) -> Pipeline:
|
||||||
"""
|
"""
|
||||||
Loads a pretrained model from pyannote.audio,
|
Loads a pretrained model from pyannote.audio,
|
||||||
@@ -283,10 +282,6 @@ class Diariser:
|
|||||||
'or from huggingface.co models. Please check your token'
|
'or from huggingface.co models. Please check your token'
|
||||||
'or your local model path')
|
'or your local model path')
|
||||||
|
|
||||||
# try to move the model to the device
|
|
||||||
if device is None:
|
|
||||||
device = "cuda" if is_available() else "cpu"
|
|
||||||
|
|
||||||
# torch_device is renamed from torch.device to avoid name conflict
|
# torch_device is renamed from torch.device to avoid name conflict
|
||||||
_model = _model.to(torch_device(device))
|
_model = _model.to(torch_device(device))
|
||||||
|
|
||||||
|
|||||||
@@ -0,0 +1,237 @@
|
|||||||
|
"""
|
||||||
|
LocalAI Client Module
|
||||||
|
---------------------
|
||||||
|
|
||||||
|
This module provides a client for communicating with a LocalAI server
|
||||||
|
running vibevoice.cpp for transcription and speaker diarization.
|
||||||
|
|
||||||
|
It replaces the previous local Whisper + Pyannote pipeline by sending
|
||||||
|
audio files to the /v1/audio/diarization endpoint and mapping the
|
||||||
|
response into the same Transcript format used by the UI.
|
||||||
|
|
||||||
|
Environment Variables:
|
||||||
|
LOCALAI_API_URL: (required) Base URL of the LocalAI server
|
||||||
|
(e.g., http://localhost:8080)
|
||||||
|
LOCALAI_API_KEY: (optional) API key, if configured
|
||||||
|
LOCALAI_MODEL: (optional) Model name to use (default: vibevoice-diarize)
|
||||||
|
"""
|
||||||
|
|
||||||
|
import os
|
||||||
|
import io
|
||||||
|
import json
|
||||||
|
from typing import Dict, List, Any, Optional
|
||||||
|
|
||||||
|
import httpx
|
||||||
|
|
||||||
|
|
||||||
|
class LocalAIError(Exception):
|
||||||
|
"""Raised when the LocalAI API returns an error or unexpected response."""
|
||||||
|
pass
|
||||||
|
|
||||||
|
|
||||||
|
class LocalAIClient:
|
||||||
|
"""
|
||||||
|
Thin HTTP client for LocalAI /v1/audio/diarization with vibevoice.cpp.
|
||||||
|
|
||||||
|
Responsibilities:
|
||||||
|
- Read configuration from environment.
|
||||||
|
- Upload audio file as multipart/form-data.
|
||||||
|
- Parse diarization + transcription response.
|
||||||
|
- Map response into the same structure expected by Scraibe's Transcript.
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
api_url: Optional[str] = None,
|
||||||
|
api_key: Optional[str] = None,
|
||||||
|
model: Optional[str] = None,
|
||||||
|
timeout: float = 600.0,
|
||||||
|
):
|
||||||
|
"""
|
||||||
|
Args:
|
||||||
|
api_url: LocalAI server URL (e.g., http://localhost:8080).
|
||||||
|
Falls back to LOCALAI_API_URL env var.
|
||||||
|
api_key: API key, if required. Falls back to LOCALAI_API_KEY.
|
||||||
|
model: Model name (e.g., vibevoice-diarize).
|
||||||
|
Falls back to LOCALAI_MODEL or default.
|
||||||
|
timeout: Request timeout in seconds.
|
||||||
|
"""
|
||||||
|
self.api_url = (api_url or os.getenv("LOCALAI_API_URL")).strip().rstrip("/")
|
||||||
|
self.api_key = api_key or os.getenv("LOCALAI_API_KEY") or None
|
||||||
|
self.model = model or os.getenv("LOCALAI_MODEL") or "vibevoice-diarize"
|
||||||
|
self.timeout = timeout
|
||||||
|
|
||||||
|
if not self.api_url:
|
||||||
|
raise LocalAIError(
|
||||||
|
"LOCALAI_API_URL is not set. "
|
||||||
|
"Provide the LocalAI server URL via environment or constructor."
|
||||||
|
)
|
||||||
|
|
||||||
|
self._client = httpx.Client(
|
||||||
|
base_url=self.api_url,
|
||||||
|
timeout=self.timeout,
|
||||||
|
follow_redirects=True,
|
||||||
|
)
|
||||||
|
|
||||||
|
def close(self):
|
||||||
|
"""Close the underlying HTTP client."""
|
||||||
|
self._client.close()
|
||||||
|
|
||||||
|
def __del__(self):
|
||||||
|
try:
|
||||||
|
self._client.close()
|
||||||
|
except Exception:
|
||||||
|
pass
|
||||||
|
|
||||||
|
def diarize_and_transcribe(
|
||||||
|
self,
|
||||||
|
audio_path: str,
|
||||||
|
*,
|
||||||
|
language: Optional[str] = None,
|
||||||
|
num_speakers: Optional[int] = None,
|
||||||
|
min_speakers: Optional[int] = None,
|
||||||
|
max_speakers: Optional[int] = None,
|
||||||
|
clustering_threshold: Optional[float] = None,
|
||||||
|
min_duration_on: Optional[float] = None,
|
||||||
|
min_duration_off: Optional[float] = None,
|
||||||
|
response_format: Optional[str] = None,
|
||||||
|
include_text: Optional[bool] = None,
|
||||||
|
verbose: bool = False,
|
||||||
|
**_ignored,
|
||||||
|
) -> Dict[str, Any]:
|
||||||
|
"""
|
||||||
|
Send audio to LocalAI /v1/audio/diarization and return a dict
|
||||||
|
in the same style as the previous internal diarization output:
|
||||||
|
|
||||||
|
{
|
||||||
|
"segments": [ [start, end], ... ],
|
||||||
|
"speakers": [ "SPEAKER_00", ... ],
|
||||||
|
"transcripts": [ "text for segment", ... ]
|
||||||
|
}
|
||||||
|
|
||||||
|
Extra kwargs that the old UI used (e.g., whisper-specific) are
|
||||||
|
accepted but ignored.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
audio_path: Path to the audio file.
|
||||||
|
language: Language hint, forwarded if set.
|
||||||
|
num_speakers: Optional exact speaker count.
|
||||||
|
min_speakers: Optional hint.
|
||||||
|
max_speakers: Optional hint.
|
||||||
|
clustering_threshold: Optional clustering threshold.
|
||||||
|
min_duration_on: Optional min segment duration.
|
||||||
|
min_duration_off: Optional min gap duration.
|
||||||
|
response_format: "json", "verbose_json", or "rttm".
|
||||||
|
Defaults to "verbose_json" if not set.
|
||||||
|
include_text: Whether to request per-segment text.
|
||||||
|
Defaults to True.
|
||||||
|
verbose: If True, prints progress messages.
|
||||||
|
"""
|
||||||
|
if verbose:
|
||||||
|
print("Starting diarization and transcription via LocalAI.")
|
||||||
|
|
||||||
|
# Defaults: use verbose_json + include_text to get both diarization and transcription.
|
||||||
|
if response_format is None:
|
||||||
|
response_format = "verbose_json"
|
||||||
|
if include_text is None:
|
||||||
|
include_text = True
|
||||||
|
|
||||||
|
# Prepare form data
|
||||||
|
data = {
|
||||||
|
"model": self.model,
|
||||||
|
"response_format": response_format,
|
||||||
|
"include_text": str(include_text).lower(),
|
||||||
|
}
|
||||||
|
|
||||||
|
if language is not None:
|
||||||
|
data["language"] = language
|
||||||
|
if num_speakers is not None:
|
||||||
|
data["num_speakers"] = str(num_speakers)
|
||||||
|
if min_speakers is not None:
|
||||||
|
data["min_speakers"] = str(min_speakers)
|
||||||
|
if max_speakers is not None:
|
||||||
|
data["max_speakers"] = str(max_speakers)
|
||||||
|
if clustering_threshold is not None:
|
||||||
|
data["clustering_threshold"] = str(clustering_threshold)
|
||||||
|
if min_duration_on is not None:
|
||||||
|
data["min_duration_on"] = str(min_duration_on)
|
||||||
|
if min_duration_off is not None:
|
||||||
|
data["min_duration_off"] = str(min_duration_off)
|
||||||
|
|
||||||
|
# Open file
|
||||||
|
if not os.path.exists(audio_path):
|
||||||
|
raise LocalAIError(f"Audio file not found: {audio_path}")
|
||||||
|
|
||||||
|
with open(audio_path, "rb") as f:
|
||||||
|
files = {
|
||||||
|
"file": (os.path.basename(audio_path), f, "application/octet-stream")
|
||||||
|
}
|
||||||
|
|
||||||
|
headers = {}
|
||||||
|
if self.api_key:
|
||||||
|
headers["Authorization"] = f"Bearer {self.api_key}"
|
||||||
|
|
||||||
|
# POST /v1/audio/diarization
|
||||||
|
resp = self._client.post(
|
||||||
|
"/v1/audio/diarization",
|
||||||
|
data=data,
|
||||||
|
files=files,
|
||||||
|
headers=headers,
|
||||||
|
)
|
||||||
|
|
||||||
|
if resp.status_code >= 400:
|
||||||
|
body = resp.text
|
||||||
|
raise LocalAIError(
|
||||||
|
f"LocalAI request failed with status {resp.status_code}: {body}"
|
||||||
|
)
|
||||||
|
|
||||||
|
try:
|
||||||
|
result = resp.json()
|
||||||
|
except json.JSONDecodeError:
|
||||||
|
raise LocalAIError(
|
||||||
|
"Failed to parse LocalAI response as JSON."
|
||||||
|
)
|
||||||
|
|
||||||
|
if verbose:
|
||||||
|
print("Diarization and transcription finished. Starting post-processing.")
|
||||||
|
|
||||||
|
return self._parse_diarization_response(result)
|
||||||
|
|
||||||
|
def _parse_diarization_response(self, result: Dict[str, Any]) -> Dict[str, Any]:
|
||||||
|
"""
|
||||||
|
Convert LocalAI response into the internal format used by Scraibe:
|
||||||
|
{
|
||||||
|
"segments": [ [start, end], ... ],
|
||||||
|
"speakers": [ "SPEAKER_00", ... ],
|
||||||
|
"transcripts": [ "text for segment", ... ]
|
||||||
|
}
|
||||||
|
"""
|
||||||
|
segments = result.get("segments", [])
|
||||||
|
|
||||||
|
if not segments:
|
||||||
|
# If no segments, return empty but valid structure
|
||||||
|
return {
|
||||||
|
"segments": [],
|
||||||
|
"speakers": [],
|
||||||
|
"transcripts": [],
|
||||||
|
}
|
||||||
|
|
||||||
|
out_segments = []
|
||||||
|
out_speakers = []
|
||||||
|
out_transcripts = []
|
||||||
|
|
||||||
|
for seg in segments:
|
||||||
|
start = float(seg.get("start", 0.0))
|
||||||
|
end = float(seg.get("end", 0.0))
|
||||||
|
speaker = seg.get("speaker", "SPEAKER_00")
|
||||||
|
text = seg.get("text", "").strip()
|
||||||
|
|
||||||
|
out_segments.append([start, end])
|
||||||
|
out_speakers.append(speaker)
|
||||||
|
out_transcripts.append(text)
|
||||||
|
|
||||||
|
return {
|
||||||
|
"segments": out_segments,
|
||||||
|
"speakers": out_speakers,
|
||||||
|
"transcripts": out_transcripts,
|
||||||
|
}
|
||||||
+33
-32
@@ -1,6 +1,4 @@
|
|||||||
import os
|
import os
|
||||||
import yaml
|
|
||||||
from pyannote.audio.core.model import CACHE_DIR as PYANNOTE_CACHE_DIR
|
|
||||||
from argparse import Action
|
from argparse import Action
|
||||||
from ast import literal_eval
|
from ast import literal_eval
|
||||||
|
|
||||||
@@ -9,42 +7,45 @@ CACHE_DIR = os.getenv(
|
|||||||
os.path.expanduser("~/.cache/torch/models"),
|
os.path.expanduser("~/.cache/torch/models"),
|
||||||
)
|
)
|
||||||
|
|
||||||
if CACHE_DIR != PYANNOTE_CACHE_DIR:
|
# Legacy paths kept for backward compatibility (ignored by LocalAI client)
|
||||||
os.environ["PYANNOTE_CACHE"] = os.path.join(CACHE_DIR, "pyannote")
|
|
||||||
|
|
||||||
WHISPER_DEFAULT_PATH = os.path.join(CACHE_DIR, "whisper")
|
WHISPER_DEFAULT_PATH = os.path.join(CACHE_DIR, "whisper")
|
||||||
PYANNOTE_DEFAULT_PATH = os.path.join(CACHE_DIR, "pyannote")
|
PYANNOTE_DEFAULT_PATH = os.path.join(CACHE_DIR, "pyannote")
|
||||||
PYANNOTE_DEFAULT_CONFIG = os.path.join(PYANNOTE_DEFAULT_PATH, "config.yaml") \
|
PYANNOTE_DEFAULT_CONFIG = os.path.join(PYANNOTE_DEFAULT_PATH, "config.yaml")
|
||||||
if os.path.exists(os.path.join(PYANNOTE_DEFAULT_PATH, "config.yaml")) \
|
|
||||||
else ('jaikinator/scraibe', 'pyannote/speaker-diarization-3.1')
|
|
||||||
|
|
||||||
|
|
||||||
def config_diarization_yaml(file_path: str, path_to_segmentation: str = None) -> None:
|
def set_threads(parse_threads=None, yaml_threads=None):
|
||||||
"""Configure diarization pipeline from a YAML file.
|
|
||||||
|
|
||||||
This function updates the YAML file to use the given segmentation model
|
|
||||||
offline, and avoids manual file manipulation.
|
|
||||||
|
|
||||||
Args:
|
|
||||||
file_path (str): Path to the YAML file.
|
|
||||||
path_to_segmentation (str, optional): Optional path to the segmentation model.
|
|
||||||
|
|
||||||
Raises:
|
|
||||||
FileNotFoundError: If the segmentation model file is not found.
|
|
||||||
"""
|
"""
|
||||||
with open(file_path, "r") as stream:
|
Configure number of threads.
|
||||||
yml = yaml.safe_load(stream)
|
|
||||||
|
|
||||||
segmentation_path = path_to_segmentation or os.path.join(
|
In LocalAI mode, this is mainly kept for backward compatibility.
|
||||||
PYANNOTE_DEFAULT_PATH, "pytorch_model.bin")
|
"""
|
||||||
yml["pipeline"]["params"]["segmentation"] = segmentation_path
|
chosen = None
|
||||||
|
if parse_threads is not None:
|
||||||
|
if not isinstance(parse_threads, int):
|
||||||
|
raise ValueError(
|
||||||
|
f"Type of --num-threads must be int, but the type is {type(parse_threads)}"
|
||||||
|
)
|
||||||
|
elif parse_threads < 1:
|
||||||
|
raise ValueError(
|
||||||
|
f"Number of threads must be a positive integer, {parse_threads} was given"
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
chosen = parse_threads
|
||||||
|
elif yaml_threads is not None:
|
||||||
|
if not isinstance(yaml_threads, int):
|
||||||
|
raise ValueError(
|
||||||
|
f"Type of num_threads must be int, but the type is {type(yaml_threads)}"
|
||||||
|
)
|
||||||
|
elif yaml_threads < 1:
|
||||||
|
raise ValueError(
|
||||||
|
f"Number of threads must be a positive integer, {yaml_threads} was given"
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
chosen = yaml_threads
|
||||||
|
|
||||||
if not os.path.exists(segmentation_path):
|
if chosen is not None:
|
||||||
raise FileNotFoundError(
|
os.environ["OMP_NUM_THREADS"] = str(chosen)
|
||||||
f"Segmentation model not found at {segmentation_path}")
|
os.environ["MKL_NUM_THREADS"] = str(chosen)
|
||||||
|
|
||||||
with open(file_path, "w") as stream:
|
|
||||||
yaml.dump(yml, stream)
|
|
||||||
|
|
||||||
|
|
||||||
class ParseKwargs(Action):
|
class ParseKwargs(Action):
|
||||||
@@ -55,7 +56,7 @@ class ParseKwargs(Action):
|
|||||||
def __call__(self, parser, namespace, values, option_string=None):
|
def __call__(self, parser, namespace, values, option_string=None):
|
||||||
setattr(namespace, self.dest, dict())
|
setattr(namespace, self.dest, dict())
|
||||||
for value in values:
|
for value in values:
|
||||||
key, value = value.split('=')
|
key, value = value.split("=")
|
||||||
try:
|
try:
|
||||||
value = literal_eval(value)
|
value = literal_eval(value)
|
||||||
except:
|
except:
|
||||||
|
|||||||
@@ -0,0 +1,212 @@
|
|||||||
|
"""
|
||||||
|
Summarizer Module
|
||||||
|
-----------------
|
||||||
|
|
||||||
|
Provides a client to summarize long transcripts via an LLM endpoint.
|
||||||
|
|
||||||
|
Behavior:
|
||||||
|
- Chunks transcript into 10,240-character segments.
|
||||||
|
- Generates a summary for each chunk.
|
||||||
|
- Combines all chunk summaries and produces a final, detailed summary.
|
||||||
|
|
||||||
|
Environment Variables:
|
||||||
|
- SUMMARIZER_API_URL: (required) Base URL of the LLM API (e.g., http://localhost:8080)
|
||||||
|
- SUMMARIZER_API_KEY: (optional) API key, if required
|
||||||
|
- SUMMARIZER_MODEL: (optional) Model name (e.g., llama-3.1-8b-instruct)
|
||||||
|
"""
|
||||||
|
|
||||||
|
import os
|
||||||
|
import json
|
||||||
|
from typing import Optional
|
||||||
|
|
||||||
|
import httpx
|
||||||
|
|
||||||
|
|
||||||
|
class SummarizerError(Exception):
|
||||||
|
"""Raised when the summarization API call fails."""
|
||||||
|
pass
|
||||||
|
|
||||||
|
|
||||||
|
class SummarizerClient:
|
||||||
|
"""
|
||||||
|
HTTP client for an OpenAI-compatible chat completions endpoint.
|
||||||
|
Used to summarize long transcripts in chunks.
|
||||||
|
"""
|
||||||
|
|
||||||
|
CHUNK_SIZE = 10_240 # characters per chunk
|
||||||
|
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
api_url: Optional[str] = None,
|
||||||
|
api_key: Optional[str] = None,
|
||||||
|
model: Optional[str] = None,
|
||||||
|
timeout: float = 600.0,
|
||||||
|
):
|
||||||
|
self.api_url = (api_url or os.getenv("SUMMARIZER_API_URL")).strip().rstrip("/")
|
||||||
|
self.api_key = api_key or os.getenv("SUMMARIZER_API_KEY") or None
|
||||||
|
self.model = model or os.getenv("SUMMARIZER_MODEL") or "llama-3.1-8b-instruct"
|
||||||
|
self.timeout = timeout
|
||||||
|
|
||||||
|
if not self.api_url:
|
||||||
|
raise SummarizerError(
|
||||||
|
"SUMMARIZER_API_URL is not set. "
|
||||||
|
"Provide the summarization LLM URL via environment or constructor."
|
||||||
|
)
|
||||||
|
|
||||||
|
self._client = httpx.Client(
|
||||||
|
base_url=self.api_url,
|
||||||
|
timeout=self.timeout,
|
||||||
|
follow_redirects=True,
|
||||||
|
)
|
||||||
|
|
||||||
|
def close(self):
|
||||||
|
self._client.close()
|
||||||
|
|
||||||
|
def __del__(self):
|
||||||
|
try:
|
||||||
|
self._client.close()
|
||||||
|
except Exception:
|
||||||
|
pass
|
||||||
|
|
||||||
|
def summarize_transcript(self, transcript: str) -> str:
|
||||||
|
"""
|
||||||
|
Summarize a (possibly very long) transcript.
|
||||||
|
|
||||||
|
Strategy:
|
||||||
|
- Split transcript into chunks of CHUNK_SIZE characters.
|
||||||
|
- Generate a detailed summary for each chunk.
|
||||||
|
- Combine all chunk summaries and generate a final, concise but thorough summary.
|
||||||
|
|
||||||
|
The final summary should make it clear:
|
||||||
|
- What was discussed
|
||||||
|
- Main issues
|
||||||
|
- Outcomes / decisions
|
||||||
|
- Next steps / action items
|
||||||
|
"""
|
||||||
|
if not transcript.strip():
|
||||||
|
return "No transcript provided to summarize."
|
||||||
|
|
||||||
|
# 1) Chunk the transcript
|
||||||
|
chunks = self._chunk_text(transcript)
|
||||||
|
|
||||||
|
# 2) Summarize each chunk
|
||||||
|
chunk_summaries = []
|
||||||
|
for i, chunk in enumerate(chunks):
|
||||||
|
summary = self._summarize_chunk(chunk, i, len(chunks))
|
||||||
|
chunk_summaries.append(summary)
|
||||||
|
|
||||||
|
# 3) Combine and summarize summaries
|
||||||
|
combined = "\n\n".join(chunk_summaries)
|
||||||
|
final_summary = self._summarize_combined(combined)
|
||||||
|
|
||||||
|
return final_summary
|
||||||
|
|
||||||
|
def _chunk_text(self, text: str) -> list[str]:
|
||||||
|
"""Split text into chunks of CHUNK_SIZE characters."""
|
||||||
|
chunks = []
|
||||||
|
start = 0
|
||||||
|
while start < len(text):
|
||||||
|
end = start + self.CHUNK_SIZE
|
||||||
|
if end >= len(text):
|
||||||
|
chunks.append(text[start:])
|
||||||
|
break
|
||||||
|
# Try to break at a reasonable boundary (newline or space)
|
||||||
|
break_pos = text.rfind("\n", start, end)
|
||||||
|
if break_pos == -1:
|
||||||
|
break_pos = text.rfind(" ", start, end)
|
||||||
|
if break_pos == -1 or break_pos <= start:
|
||||||
|
break_pos = end
|
||||||
|
chunks.append(text[start:break_pos].strip())
|
||||||
|
start = break_pos
|
||||||
|
return chunks
|
||||||
|
|
||||||
|
def _summarize_chunk(self, chunk: str, index: int, total: int) -> str:
|
||||||
|
system_prompt = (
|
||||||
|
"You are an expert legal and business meeting summarizer. "
|
||||||
|
"You will receive a segment of a longer transcript. "
|
||||||
|
"Provide a detailed, structured summary of this segment, focusing on: "
|
||||||
|
"- Topics discussed\n"
|
||||||
|
"- Key points and arguments\n"
|
||||||
|
"- Decisions and agreements\n"
|
||||||
|
"- Action items and responsibilities\n"
|
||||||
|
"- Any risks, conflicts, or open issues\n\n"
|
||||||
|
"Be concise but complete. Use bullet points when helpful. "
|
||||||
|
"Do not add information that is not present in the transcript."
|
||||||
|
)
|
||||||
|
|
||||||
|
user_prompt = (
|
||||||
|
f"This is segment {index + 1} of {total} from a longer conversation.\n\n"
|
||||||
|
f"{chunk}"
|
||||||
|
)
|
||||||
|
|
||||||
|
return self._chat_completion(system_prompt, user_prompt)
|
||||||
|
|
||||||
|
def _summarize_combined(self, combined_summaries: str) -> str:
|
||||||
|
system_prompt = (
|
||||||
|
"You are an expert legal and business meeting summarizer. "
|
||||||
|
"You will receive several intermediate summaries of a longer conversation. "
|
||||||
|
"Produce a single, comprehensive summary that makes it clear: "
|
||||||
|
"- The overall purpose and context of the discussion\n"
|
||||||
|
"- The main issues and topics addressed\n"
|
||||||
|
"- Key arguments and positions (briefly)\n"
|
||||||
|
"- Decisions and outcomes\n"
|
||||||
|
"- Action items, responsibilities, and next steps\n"
|
||||||
|
"- Any unresolved issues or risks\n\n"
|
||||||
|
"The summary should be detailed enough that a reader who was not present "
|
||||||
|
"can understand what happened and what is expected going forward. "
|
||||||
|
"Use clear, concise language and bullet points where appropriate."
|
||||||
|
)
|
||||||
|
|
||||||
|
user_prompt = (
|
||||||
|
"Here are the intermediate summaries from different parts of the same conversation:\n\n"
|
||||||
|
f"{combined_summaries}"
|
||||||
|
)
|
||||||
|
|
||||||
|
return self._chat_completion(system_prompt, user_prompt)
|
||||||
|
|
||||||
|
def _chat_completion(self, system_prompt: str, user_prompt: str) -> str:
|
||||||
|
"""
|
||||||
|
Call OpenAI-compatible /v1/chat/completions endpoint.
|
||||||
|
"""
|
||||||
|
payload = {
|
||||||
|
"model": self.model,
|
||||||
|
"messages": [
|
||||||
|
{"role": "system", "content": system_prompt},
|
||||||
|
{"role": "user", "content": user_prompt},
|
||||||
|
],
|
||||||
|
"temperature": 0.3,
|
||||||
|
}
|
||||||
|
|
||||||
|
headers = {
|
||||||
|
"Content-Type": "application/json",
|
||||||
|
}
|
||||||
|
if self.api_key:
|
||||||
|
headers["Authorization"] = f"Bearer {self.api_key}"
|
||||||
|
|
||||||
|
resp = self._client.post(
|
||||||
|
"/v1/chat/completions",
|
||||||
|
json=payload,
|
||||||
|
headers=headers,
|
||||||
|
)
|
||||||
|
|
||||||
|
if resp.status_code >= 400:
|
||||||
|
raise SummarizerError(
|
||||||
|
f"Summarizer API error {resp.status_code}: {resp.text}"
|
||||||
|
)
|
||||||
|
|
||||||
|
try:
|
||||||
|
data = resp.json()
|
||||||
|
except json.JSONDecodeError:
|
||||||
|
raise SummarizerError(
|
||||||
|
"Failed to parse summarizer response as JSON."
|
||||||
|
)
|
||||||
|
|
||||||
|
# Extract assistant message
|
||||||
|
try:
|
||||||
|
content = data["choices"][0]["message"]["content"]
|
||||||
|
return content.strip()
|
||||||
|
except (KeyError, IndexError, TypeError):
|
||||||
|
raise SummarizerError(
|
||||||
|
"Unexpected summarizer response format: "
|
||||||
|
f"{json.dumps(data, indent=2)}"
|
||||||
|
)
|
||||||
+59
-28
@@ -26,17 +26,17 @@ Usage:
|
|||||||
|
|
||||||
from whisper import Whisper
|
from whisper import Whisper
|
||||||
from whisper import load_model as whisper_load_model
|
from whisper import load_model as whisper_load_model
|
||||||
from whisperx.asr import WhisperModel
|
from whisper.tokenizer import TO_LANGUAGE_CODE
|
||||||
from whisperx import load_model as whisperx_load_model
|
from faster_whisper import WhisperModel as FasterWhisperModel
|
||||||
|
from faster_whisper.tokenizer import _LANGUAGE_CODES as FASTER_WHISPER_LANGUAGE_CODES
|
||||||
from typing import TypeVar, Union, Optional
|
from typing import TypeVar, Union, Optional
|
||||||
from torch import Tensor, device
|
from torch import Tensor, device
|
||||||
from torch.cuda import is_available as cuda_is_available
|
|
||||||
from numpy import ndarray
|
from numpy import ndarray
|
||||||
from inspect import signature
|
from inspect import signature
|
||||||
from abc import abstractmethod
|
from abc import abstractmethod
|
||||||
import warnings
|
import warnings
|
||||||
|
|
||||||
from .misc import WHISPER_DEFAULT_PATH
|
from .misc import WHISPER_DEFAULT_PATH, SCRAIBE_TORCH_DEVICE, SCRAIBE_NUM_THREADS
|
||||||
whisper = TypeVar('whisper')
|
whisper = TypeVar('whisper')
|
||||||
|
|
||||||
|
|
||||||
@@ -123,7 +123,7 @@ class Transcriber:
|
|||||||
model: str = "medium",
|
model: str = "medium",
|
||||||
whisper_type: str = 'whisper',
|
whisper_type: str = 'whisper',
|
||||||
download_root: str = WHISPER_DEFAULT_PATH,
|
download_root: str = WHISPER_DEFAULT_PATH,
|
||||||
device: Optional[Union[str, device]] = None,
|
device: Optional[Union[str, device]] = SCRAIBE_TORCH_DEVICE,
|
||||||
in_memory: bool = False,
|
in_memory: bool = False,
|
||||||
*args, **kwargs
|
*args, **kwargs
|
||||||
) -> None:
|
) -> None:
|
||||||
@@ -145,7 +145,7 @@ class Transcriber:
|
|||||||
- 'large-v3'
|
- 'large-v3'
|
||||||
- 'large'
|
- 'large'
|
||||||
whisper_type (str):
|
whisper_type (str):
|
||||||
Type of whisper model to load. "whisper" or "whisperx".
|
Type of whisper model to load. "whisper" or "faster-whisper".
|
||||||
download_root (str, optional): Path to download the model.
|
download_root (str, optional): Path to download the model.
|
||||||
Defaults to WHISPER_DEFAULT_PATH.
|
Defaults to WHISPER_DEFAULT_PATH.
|
||||||
device (Optional[Union[str, torch.device]], optional):
|
device (Optional[Union[str, torch.device]], optional):
|
||||||
@@ -205,7 +205,7 @@ class WhisperTranscriber(Transcriber):
|
|||||||
def load_model(cls,
|
def load_model(cls,
|
||||||
model: str = "medium",
|
model: str = "medium",
|
||||||
download_root: str = WHISPER_DEFAULT_PATH,
|
download_root: str = WHISPER_DEFAULT_PATH,
|
||||||
device: Optional[Union[str, device]] = None,
|
device: Optional[Union[str, device]] = SCRAIBE_TORCH_DEVICE,
|
||||||
in_memory: bool = False,
|
in_memory: bool = False,
|
||||||
*args, **kwargs
|
*args, **kwargs
|
||||||
) -> 'WhisperTranscriber':
|
) -> 'WhisperTranscriber':
|
||||||
@@ -272,7 +272,7 @@ class WhisperTranscriber(Transcriber):
|
|||||||
return f"WhisperTranscriber(model_name={self.model_name}, model={self.model})"
|
return f"WhisperTranscriber(model_name={self.model_name}, model={self.model})"
|
||||||
|
|
||||||
|
|
||||||
class WhisperXTranscriber(Transcriber):
|
class FasterWhisperTranscriber(Transcriber):
|
||||||
def __init__(self, model: whisper, model_name: str) -> None:
|
def __init__(self, model: whisper, model_name: str) -> None:
|
||||||
super().__init__(model, model_name)
|
super().__init__(model, model_name)
|
||||||
|
|
||||||
@@ -294,19 +294,19 @@ class WhisperXTranscriber(Transcriber):
|
|||||||
|
|
||||||
if isinstance(audio, Tensor):
|
if isinstance(audio, Tensor):
|
||||||
audio = audio.cpu().numpy()
|
audio = audio.cpu().numpy()
|
||||||
result = self.model.transcribe(audio, *args, **kwargs)
|
result, _ = self.model.transcribe(audio, *args, **kwargs)
|
||||||
text = ""
|
text = ""
|
||||||
for seg in result['segments']:
|
for seg in result:
|
||||||
text += seg['text']
|
text += seg.text
|
||||||
return text
|
return text
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
def load_model(cls,
|
def load_model(cls,
|
||||||
model: str = "medium",
|
model: str = "medium",
|
||||||
download_root: str = WHISPER_DEFAULT_PATH,
|
download_root: str = WHISPER_DEFAULT_PATH,
|
||||||
device: Optional[Union[str, device]] = None,
|
device: Optional[Union[str, device]] = SCRAIBE_TORCH_DEVICE,
|
||||||
*args, **kwargs
|
*args, **kwargs
|
||||||
) -> 'WhisperXTranscriber':
|
) -> 'FasterWhisperModel':
|
||||||
"""
|
"""
|
||||||
Load whisper model.
|
Load whisper model.
|
||||||
|
|
||||||
@@ -329,7 +329,7 @@ class WhisperXTranscriber(Transcriber):
|
|||||||
Defaults to WHISPER_DEFAULT_PATH.
|
Defaults to WHISPER_DEFAULT_PATH.
|
||||||
|
|
||||||
device (Optional[Union[str, torch.device]], optional):
|
device (Optional[Union[str, torch.device]], optional):
|
||||||
Device to load model on. Defaults to None.
|
Device to load model on. Defaults to SCRAIBE_TORCH_DEVICE.
|
||||||
in_memory (bool, optional): Whether to load model in memory.
|
in_memory (bool, optional): Whether to load model in memory.
|
||||||
Defaults to False.
|
Defaults to False.
|
||||||
args: Additional arguments only to avoid errors.
|
args: Additional arguments only to avoid errors.
|
||||||
@@ -338,17 +338,18 @@ class WhisperXTranscriber(Transcriber):
|
|||||||
Returns:
|
Returns:
|
||||||
Transcriber: A Transcriber object initialized with the specified model.
|
Transcriber: A Transcriber object initialized with the specified model.
|
||||||
"""
|
"""
|
||||||
if device is None:
|
|
||||||
device = "cuda" if cuda_is_available() else "cpu"
|
|
||||||
if not isinstance(device, str):
|
if not isinstance(device, str):
|
||||||
device = str(device)
|
device = str(device)
|
||||||
|
|
||||||
compute_type = kwargs.get('compute_type', 'float16')
|
compute_type = kwargs.get('compute_type', 'float16')
|
||||||
if device == 'cpu' and compute_type == 'float16':
|
if device == 'cpu' and compute_type == 'float16':
|
||||||
warnings.warn(f'Compute type {compute_type} not compatible with '
|
warnings.warn(f'Compute type {compute_type} not compatible with '
|
||||||
f'device {device}! Changing compute type to int8.')
|
f'device {device}! Changing compute type to int8.')
|
||||||
compute_type = 'int8'
|
compute_type = 'int8'
|
||||||
_model = whisperx_load_model(model, download_root=download_root,
|
_model = FasterWhisperModel(model, download_root=download_root,
|
||||||
device=device, compute_type=compute_type)
|
device=device, compute_type=compute_type,
|
||||||
|
cpu_threads=SCRAIBE_NUM_THREADS)
|
||||||
|
|
||||||
return cls(_model, model_name=model)
|
return cls(_model, model_name=model)
|
||||||
|
|
||||||
@@ -361,7 +362,7 @@ class WhisperXTranscriber(Transcriber):
|
|||||||
dict: Keyword arguments for whisper model.
|
dict: Keyword arguments for whisper model.
|
||||||
"""
|
"""
|
||||||
# _possible_kwargs = WhisperModel.transcribe.__code__.co_varnames
|
# _possible_kwargs = WhisperModel.transcribe.__code__.co_varnames
|
||||||
_possible_kwargs = signature(WhisperModel.transcribe).parameters.keys()
|
_possible_kwargs = signature(FasterWhisperModel.transcribe).parameters.keys()
|
||||||
|
|
||||||
whisper_kwargs = {k: v for k,
|
whisper_kwargs = {k: v for k,
|
||||||
v in kwargs.items() if k in _possible_kwargs}
|
v in kwargs.items() if k in _possible_kwargs}
|
||||||
@@ -370,21 +371,51 @@ class WhisperXTranscriber(Transcriber):
|
|||||||
whisper_kwargs["task"] = task
|
whisper_kwargs["task"] = task
|
||||||
|
|
||||||
if (language := kwargs.get("language")):
|
if (language := kwargs.get("language")):
|
||||||
|
language = FasterWhisperTranscriber.convert_to_language_code(language)
|
||||||
whisper_kwargs["language"] = language
|
whisper_kwargs["language"] = language
|
||||||
|
|
||||||
return whisper_kwargs
|
return whisper_kwargs
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def convert_to_language_code(lang : str) -> str:
|
||||||
|
"""
|
||||||
|
Load whisper model.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
lang (str): language as code or language name
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
language (str) code of language
|
||||||
|
"""
|
||||||
|
|
||||||
|
# If the input is already in FASTER_WHISPER_LANGUAGE_CODES, return it directly
|
||||||
|
if lang in FASTER_WHISPER_LANGUAGE_CODES:
|
||||||
|
return lang
|
||||||
|
|
||||||
|
# Normalize the input to lowercase
|
||||||
|
lang = lang.lower()
|
||||||
|
|
||||||
|
# Check if the language name is in the TO_LANGUAGE_CODE mapping
|
||||||
|
if lang in TO_LANGUAGE_CODE:
|
||||||
|
return TO_LANGUAGE_CODE[lang]
|
||||||
|
|
||||||
|
# If the language is not recognized, raise a ValueError with the available options
|
||||||
|
available_codes = ', '.join(FASTER_WHISPER_LANGUAGE_CODES)
|
||||||
|
raise ValueError(f"Language '{lang}' is not a valid language code or name. "
|
||||||
|
f"Available language codes are: {available_codes}.")
|
||||||
|
|
||||||
def __repr__(self) -> str:
|
def __repr__(self) -> str:
|
||||||
return f"WhisperXTranscriber(model_name={self.model_name}, model={self.model})"
|
return f"FasterWhisperTranscriber(model_name={self.model_name}, model={self.model})"
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
def load_transcriber(model: str = "medium",
|
def load_transcriber(model: str = "medium",
|
||||||
whisper_type: str = 'whisper',
|
whisper_type: str = 'whisper',
|
||||||
download_root: str = WHISPER_DEFAULT_PATH,
|
download_root: str = WHISPER_DEFAULT_PATH,
|
||||||
device: Optional[Union[str, device]] = None,
|
device: Optional[Union[str, device]] = SCRAIBE_TORCH_DEVICE,
|
||||||
in_memory: bool = False,
|
in_memory: bool = False,
|
||||||
*args, **kwargs
|
*args, **kwargs
|
||||||
) -> Union[WhisperTranscriber, WhisperXTranscriber]:
|
) -> Union[WhisperTranscriber, FasterWhisperTranscriber]:
|
||||||
"""
|
"""
|
||||||
Load whisper model.
|
Load whisper model.
|
||||||
|
|
||||||
@@ -403,28 +434,28 @@ def load_transcriber(model: str = "medium",
|
|||||||
- 'large-v3'
|
- 'large-v3'
|
||||||
- 'large'
|
- 'large'
|
||||||
whisper_type (str):
|
whisper_type (str):
|
||||||
Type of whisper model to load. "whisper" or "whisperx".
|
Type of whisper model to load. "whisper" or "faster-whisper".
|
||||||
download_root (str, optional): Path to download the model.
|
download_root (str, optional): Path to download the model.
|
||||||
Defaults to WHISPER_DEFAULT_PATH.
|
Defaults to WHISPER_DEFAULT_PATH.
|
||||||
device (Optional[Union[str, torch.device]], optional):
|
device (Optional[Union[str, torch.device]], optional):
|
||||||
Device to load model on. Defaults to None.
|
Device to load model on. Defaults to SCRAIBE_TORCH_DEVICE.
|
||||||
in_memory (bool, optional): Whether to load model in memory.
|
in_memory (bool, optional): Whether to load model in memory.
|
||||||
Defaults to False.
|
Defaults to False.
|
||||||
args: Additional arguments only to avoid errors.
|
args: Additional arguments only to avoid errors.
|
||||||
kwargs: Additional keyword arguments only to avoid errors.
|
kwargs: Additional keyword arguments only to avoid errors.
|
||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
Union[WhisperTranscriber, WhisperXTranscriber]:
|
Union[WhisperTranscriber, FasterWhisperTranscriber]:
|
||||||
One of the Whisper variants as Transcrbier object initialized with the specified model.
|
One of the Whisper variants as Transcrbier object initialized with the specified model.
|
||||||
"""
|
"""
|
||||||
if whisper_type.lower() == 'whisper':
|
if whisper_type.lower() == 'whisper':
|
||||||
_model = WhisperTranscriber.load_model(
|
_model = WhisperTranscriber.load_model(
|
||||||
model, download_root, device, in_memory, *args, **kwargs)
|
model, download_root, device, in_memory, *args, **kwargs)
|
||||||
return _model
|
return _model
|
||||||
elif whisper_type.lower() == 'whisperx':
|
elif whisper_type.lower() == 'faster-whisper':
|
||||||
_model = WhisperXTranscriber.load_model(
|
_model = FasterWhisperTranscriber.load_model(
|
||||||
model, download_root, device, *args, **kwargs)
|
model, download_root, device, *args, **kwargs)
|
||||||
return _model
|
return _model
|
||||||
else:
|
else:
|
||||||
raise ValueError(f'Model type not recognized, exptected "whisper" '
|
raise ValueError(f'Model type not recognized, exptected "whisper" '
|
||||||
f'or "whisperx", got {whisper_type}.')
|
f'or "faster-whisper", got {whisper_type}.')
|
||||||
|
|||||||
@@ -6,7 +6,7 @@ import os
|
|||||||
@pytest.fixture
|
@pytest.fixture
|
||||||
def create_scraibe_instance():
|
def create_scraibe_instance():
|
||||||
if "HF_TOKEN" in os.environ:
|
if "HF_TOKEN" in os.environ:
|
||||||
return Scraibe(use_auth_token=os.environ["HF_TOKEN"])
|
return Scraibe(use_auth_token=os.environ["HF_TOKEN"], whisper_model= "tiny")
|
||||||
else:
|
else:
|
||||||
return Scraibe()
|
return Scraibe()
|
||||||
|
|
||||||
@@ -19,19 +19,19 @@ def test_scraibe_init(create_scraibe_instance):
|
|||||||
|
|
||||||
def test_scraibe_autotranscribe(create_scraibe_instance):
|
def test_scraibe_autotranscribe(create_scraibe_instance):
|
||||||
model = create_scraibe_instance
|
model = create_scraibe_instance
|
||||||
transcript = model.autotranscribe('test/audio_test_2.mp4')
|
transcript = model.autotranscribe('tests/audio_test_2.mp4')
|
||||||
assert isinstance(transcript, Transcript)
|
assert isinstance(transcript, Transcript)
|
||||||
|
|
||||||
|
|
||||||
def test_scraibe_diarization(create_scraibe_instance):
|
def test_scraibe_diarization(create_scraibe_instance):
|
||||||
model = create_scraibe_instance
|
model = create_scraibe_instance
|
||||||
diarisation_result = model.diarization('test/audio_test_2.mp4')
|
diarisation_result = model.diarization('tests/audio_test_2.mp4')
|
||||||
assert isinstance(diarisation_result, dict)
|
assert isinstance(diarisation_result, dict)
|
||||||
|
|
||||||
|
|
||||||
def test_scraibe_transcribe(create_scraibe_instance):
|
def test_scraibe_transcribe(create_scraibe_instance):
|
||||||
model = create_scraibe_instance
|
model = create_scraibe_instance
|
||||||
transcription_result = model.transcribe('test/audio_test_2.mp4')
|
transcription_result = model.transcribe('tests/audio_test_2.mp4')
|
||||||
assert isinstance(transcription_result, str)
|
assert isinstance(transcription_result, str)
|
||||||
|
|
||||||
|
|
||||||
@@ -1,6 +1,6 @@
|
|||||||
import pytest
|
import pytest
|
||||||
from scraibe import (Transcriber, WhisperTranscriber,
|
from scraibe import (Transcriber, WhisperTranscriber,
|
||||||
WhisperXTranscriber, load_transcriber)
|
FasterWhisperTranscriber, load_transcriber)
|
||||||
import torch
|
import torch
|
||||||
|
|
||||||
|
|
||||||
@@ -31,33 +31,33 @@ def test_transcriber(mock_load_model, audio_file, expected_transcription):
|
|||||||
|
|
||||||
@pytest.fixture
|
@pytest.fixture
|
||||||
def whisper_instance():
|
def whisper_instance():
|
||||||
return load_transcriber('medium', whisper_type='whisper')
|
return load_transcriber('tiny', whisper_type='whisper')
|
||||||
|
|
||||||
|
|
||||||
@pytest.fixture
|
@pytest.fixture
|
||||||
def whisperx_instance():
|
def faster_whisper_instance():
|
||||||
return load_transcriber('medium', whisper_type='whisperx')
|
return load_transcriber('tiny', whisper_type='faster-whisper')
|
||||||
|
|
||||||
|
|
||||||
def test_whisper_base_initialization(whisper_instance):
|
def test_whisper_base_initialization(whisper_instance):
|
||||||
assert isinstance(whisper_instance, Transcriber)
|
assert isinstance(whisper_instance, Transcriber)
|
||||||
|
|
||||||
|
|
||||||
def test_whisperx_base_initialization(whisperx_instance):
|
def test_faster_whisper_base_initialization(faster_whisper_instance):
|
||||||
assert isinstance(whisperx_instance, Transcriber)
|
assert isinstance(faster_whisper_instance, Transcriber)
|
||||||
|
|
||||||
|
|
||||||
def test_whisper_transcriber_initialization(whisper_instance):
|
def test_whisper_transcriber_initialization(whisper_instance):
|
||||||
assert isinstance(whisper_instance, WhisperTranscriber)
|
assert isinstance(whisper_instance, WhisperTranscriber)
|
||||||
|
|
||||||
|
|
||||||
def test_whisperx_transcriber_initialization(whisperx_instance):
|
def test_faster_whisper_transcriber_initialization(faster_whisper_instance):
|
||||||
assert isinstance(whisperx_instance, WhisperXTranscriber)
|
assert isinstance(faster_whisper_instance, FasterWhisperTranscriber)
|
||||||
|
|
||||||
|
|
||||||
def test_wrong_transcriber_initialization():
|
def test_wrong_transcriber_initialization():
|
||||||
with pytest.raises(ValueError):
|
with pytest.raises(ValueError):
|
||||||
load_transcriber('medium', whisper_type='wrong_whisper')
|
load_transcriber('tiny', whisper_type='wrong_whisper')
|
||||||
|
|
||||||
|
|
||||||
def test_get_whisper_kwargs():
|
def test_get_whisper_kwargs():
|
||||||
@@ -69,12 +69,12 @@ def test_get_whisper_kwargs():
|
|||||||
def test_whisper_transcribe(whisper_instance):
|
def test_whisper_transcribe(whisper_instance):
|
||||||
model = whisper_instance
|
model = whisper_instance
|
||||||
# mocker.patch.object(transcriber_instance.model, 'transcribe', return_value={'Hello, World !'} )
|
# mocker.patch.object(transcriber_instance.model, 'transcribe', return_value={'Hello, World !'} )
|
||||||
transcript = model.transcribe('test/audio_test_2.mp4')
|
transcript = model.transcribe('tests/audio_test_2.mp4')
|
||||||
assert isinstance(transcript, str)
|
assert isinstance(transcript, str)
|
||||||
|
|
||||||
|
|
||||||
def test_whisperx_transcribe(whisperx_instance):
|
def test_faster_whisper_transcribe(faster_whisper_instance):
|
||||||
model = whisperx_instance
|
model = faster_whisper_instance
|
||||||
# mocker.patch.object(transcriber_instance.model, 'transcribe', return_value={'Hello, World !'} )
|
# mocker.patch.object(transcriber_instance.model, 'transcribe', return_value={'Hello, World !'} )
|
||||||
transcript = model.transcribe('test/audio_test_2.mp4')
|
transcript = model.transcribe('tests/audio_test_2.mp4')
|
||||||
assert isinstance(transcript, str)
|
assert isinstance(transcript, str)
|
||||||
Reference in New Issue
Block a user