Merge pull request #127 from JSchmie/develop

Update main to release version 0.3.0
This commit is contained in:
Marko Henning
2024-10-24 10:32:02 +02:00
committed by GitHub
18 changed files with 245 additions and 401 deletions
+95
View File
@@ -0,0 +1,95 @@
# This workflow uses actions that are not certified by GitHub.
# They are provided by a third-party and are governed by
# separate terms of service, privacy policy, and support
# documentation.
# GitHub recommends pinning actions to a commit SHA.
# To get a newer version, you will need to update the SHA.
# You can also reference a tag or branch, but the action may change without warning.
name: Publish Docker image
on:
push:
tags:
- v*
workflow_dispatch:
env:
image: hadr0n/scraibe
jobs:
push_to_registry:
name: Push Docker image to Docker Hub
runs-on: ubuntu-latest
permissions:
packages: write
contents: read
security-events: write
steps:
- name: Check out the repo
uses: actions/checkout@v4
with:
fetch-tags: true
fetch-depth: 0
- name: Get Version Tag
id: version
run: |
echo "tag=$(git describe --tags --abbrev=0)" >> $GITHUB_OUTPUT
- name: Overwrite label tag
run: sed -i 's/LABEL version=".*"/LABEL version="'${{ steps.version.outputs.tag }}'"/' Dockerfile
- name: Test name and tag
run: |
echo "${{ env.image }}:latest,${{ env.image }}:${{ steps.version.outputs.tag }}"
- name: Log in to Docker Hub
uses: docker/login-action@v3
with:
username: ${{ secrets.DOCKERHUB_USERNAME }}
password: ${{ secrets.DOCKERHUB_TOKEN }}
- name: Build and push Docker image
id: push
uses: docker/build-push-action@v5
with:
context: .
file: ./Dockerfile
push: true
tags: "${{ env.image }}:latest,${{ env.image }}:${{ steps.version.outputs.tag }}"
- name: SBOM Generation
uses: anchore/sbom-action@v0
with:
image: ${{ env.image }}:latest
- name: Scan image
id: scan
uses: anchore/scan-action@v3
with:
image: ${{ env.image }}:latest
fail-build: false
- name: upload Anchore scan SARIF report
uses: github/codeql-action/upload-sarif@v3
with:
sarif_file: ${{ steps.scan.outputs.sarif }}
# - name: Inspect action SARIF report
# run: cat ${{ steps.scan.outputs.sarif }}
- uses: actions/upload-artifact@v4
with:
name: SARIF report
path: ${{ steps.scan.outputs.sarif }}
# - name: Generate artifact attestation
# uses: actions/attest-build-provenance@v1
# with:
# subject-name: ${{ env.REGISTRY }}/${{ env.IMAGE_NAME}}
# subject-digest: ${{ steps.push.outputs.digest }}
# push-to-registry: false
+11 -33
View File
@@ -1,18 +1,14 @@
name: Publish Python 🐍 distribution 📦 to PyPI and TestPyPI name: Publish Python 🐍 distribution 📦 to PyPI and TestPyPI
on: on:
pull_request_target:
branches:
- develop
types:
- closed
paths:
- scraibe/**
- pyproject.toml
push: push:
tags: tags:
- 'v*.*.*' - 'v*.*.*'
branches:
- "develop"
paths:
- "scraibe/**"
- "pyproject.toml"
workflow_dispatch: workflow_dispatch:
inputs: inputs:
@@ -27,13 +23,7 @@ on:
jobs: jobs:
Build-and-publish-to-Test-PyPI: Build-and-publish-to-Test-PyPI:
if: | if: github.event_name != 'workflow_dispatch' || github.event.inputs.test == 'true'
(github.event_name == 'workflow_dispatch' &&
github.event.inputs.test == 'true') ||
(github.event_name == 'pull_request_target' &&
github.event.pull_request.merged &&
contains(github.event.pull_request.labels.*.name, 'release')) ||
(github.event_name == 'push' && startsWith(github.ref, 'refs/tags/'))
runs-on: ubuntu-latest runs-on: ubuntu-latest
steps: steps:
- uses: actions/checkout@v4 - uses: actions/checkout@v4
@@ -72,28 +62,16 @@ jobs:
needs: Test-PyPi-install needs: Test-PyPi-install
runs-on: ubuntu-latest runs-on: ubuntu-latest
if: | if: |
always() && always() &&
(( needs.Build-and-publish-to-Test-PyPI.result != 'failure' && (( needs.Build-and-publish-to-Test-PyPI.result != 'failure' &&
needs.Test-PyPi-install.result != 'failure' ) && needs.Test-PyPi-install.result != 'failure' ) ||
((github.event_name == 'workflow_dispatch' && ((github.event_name == 'workflow_dispatch' &&
github.event.inputs.publish_to_pypi == 'true') || github.event.inputs.publish_to_pypi == 'true')))
(github.event_name == 'pull_request_target' &&
github.event.pull_request.merged &&
contains(github.event.pull_request.labels.*.name, 'release')) ||
(github.event_name == 'push' && startsWith(github.ref, 'refs/tags/'))))
steps: steps:
- name: Checkout Repository Tags
uses: actions/checkout@v4
if: github.ref == 'refs/heads/main'
with:
fetch-depth: '0'
branch: 'main'
- name: Checkout Repository (Develop) - name: Checkout Repository (Develop)
uses: actions/checkout@v4 uses: actions/checkout@v4
if: github.ref == 'refs/heads/develop'
with: with:
fetch-depth: '0' fetch-depth: '0'
branch: 'develop'
- name: Set up Poetry 📦 - name: Set up Poetry 📦
uses: JRubics/poetry-publish@v1.16 uses: JRubics/poetry-publish@v1.16
with: with:
+18 -20
View File
@@ -1,5 +1,5 @@
#pytorch Image #pytorch Image
FROM pytorch/pytorch:1.11.0-cuda11.3-cudnn8-runtime FROM pytorch/pytorch:2.3.1-cuda12.1-cudnn8-runtime
# Labels # Labels
@@ -14,33 +14,31 @@ LABEL url="https://github.com/JSchmie/ScrAIbe"
# Install dependencies # Install dependencies
WORKDIR /app WORKDIR /app
ARG model_name=medium #Enviorment dependencies
#Enviorment Dependncies ENV TRANSFORMERS_CACHE=/app/models
ENV TRANSFORMERS_CACHE /app/models ENV HF_HOME=/app/models
ENV HF_HOME /app/models ENV AUTOT_CACHE=/app/models
ENV AUTOT_CACHE /app/models ENV PYANNOTE_CACHE=/app/models/pyannote
ENV PYANNOTE_CACHE /app/models/pyannote
#Copy all necessary files #Copy all necessary files
COPY requirements.txt /app/requirements.txt COPY requirements.txt /app/requirements.txt
COPY README.md /app/README.md COPY README.md /app/README.md
COPY models /app/models
COPY scraibe /app/scraibe COPY scraibe /app/scraibe
COPY setup.py /app/setup.py
#Installing all necessary Dependencies and Running the Application with a personalised Hugging-Face-Token #Installing all necessary dependencies and running the application with a personalised Hugging-Face-Token
RUN apt update && apt-get install -y libsm6 libxrender1 libfontconfig1 RUN apt update -y && apt upgrade -y && \
RUN conda update --all apt install -y libsm6 libxrender1 libfontconfig1 && \
apt clean && \
rm -rf /var/lib/apt/lists/* /tmp/* /var/tmp/*
RUN conda install pip RUN conda update --all && \
RUN conda install -y ffmpeg # conda install -y pip ffmpeg && \
RUN conda install -c conda-forge libsndfile conda install -c conda-forge libsndfile && \
RUN pip install torchaudio==0.11.0+cu113 -f https://download.pytorch.org/whl/torch_stable.html conda clean --all -y
RUN pip install -r requirements.txt # RUN pip install torchaudio==0.11.0+cu113 -f https://download.pytorch.org/whl/torch_stable.html
RUN pip install markupsafe==2.0.1 --force-reinstall RUN pip install --no-cache-dir -r requirements.txt
RUN python3 -m 'scraibe.cli' --whisper-model-name $model_name
# Expose port # Expose port
EXPOSE 7860 EXPOSE 7860
# Run the application # Run the application
ENTRYPOINT ["python3", "-m", "scraibe.cli" ,"--whisper-model-name", "$model_name"] ENTRYPOINT ["python3", "-m", "scraibe.cli"]
-256
View File
@@ -1,256 +0,0 @@
channels:
- pytorch
- defaults
dependencies:
- _libgcc_mutex=0.1=main
- _openmp_mutex=5.1=1_gnu
- blas=1.0=mkl
- brotlipy=0.7.0=py39h27cfd23_1003
- bzip2=1.0.8=h7b6447c_0
- ca-certificates=2023.05.30=h06a4308_0
- certifi=2023.5.7=py39h06a4308_0
- cffi=1.15.1=py39h5eee18b_3
- cryptography=39.0.1=py39h9ce1e76_2
- cudatoolkit=11.3.1=h2bc3f7f_2
- ffmpeg=4.2.2=h20bf706_0
- flit-core=3.8.0=py39h06a4308_0
- freetype=2.12.1=h4a9f257_0
- giflib=5.2.1=h5eee18b_3
- gmp=6.2.1=h295c915_3
- gnutls=3.6.15=he1e5248_0
- idna=3.4=py39h06a4308_0
- intel-openmp=2021.4.0=h06a4308_3561
- jpeg=9e=h5eee18b_1
- lame=3.100=h7b6447c_0
- lcms2=2.12=h3be6417_0
- ld_impl_linux-64=2.38=h1181459_1
- lerc=3.0=h295c915_0
- libdeflate=1.17=h5eee18b_0
- libffi=3.4.2=h6a678d5_6
- libgcc-ng=11.2.0=h1234567_1
- libgomp=11.2.0=h1234567_1
- libidn2=2.3.2=h7f8727e_0
- libopus=1.3.1=h7b6447c_0
- libpng=1.6.39=h5eee18b_0
- libstdcxx-ng=11.2.0=h1234567_1
- libtasn1=4.16.0=h27cfd23_0
- libtiff=4.5.0=h6a678d5_2
- libunistring=0.9.10=h27cfd23_0
- libuv=1.44.2=h5eee18b_0
- libvpx=1.7.0=h439df22_0
- libwebp=1.2.4=h11a3e52_1
- libwebp-base=1.2.4=h5eee18b_1
- lz4-c=1.9.4=h6a678d5_0
- mkl=2021.4.0=h06a4308_640
- mkl-service=2.4.0=py39h7f8727e_0
- mkl_fft=1.3.1=py39hd3c417c_0
- mkl_random=1.2.2=py39h51133e4_0
- ncurses=6.4=h6a678d5_0
- nettle=3.7.3=hbbd107a_1
- numpy=1.23.5=py39h14f4228_0
- numpy-base=1.23.5=py39h31eccc5_0
- openh264=2.1.1=h4ff587b_0
- openssl=3.0.9=h7f8727e_0
- pillow=9.4.0=py39h6a678d5_0
- pip=23.0.1=py39h06a4308_0
- pycparser=2.21=pyhd3eb1b0_0
- pyopenssl=23.0.0=py39h06a4308_0
- pysocks=1.7.1=py39h06a4308_0
- python=3.9.16=h955ad1f_3
- pytorch=1.11.0=py3.9_cuda11.3_cudnn8.2.0_0
- pytorch-mutex=1.0=cuda
- readline=8.2=h5eee18b_0
- requests=2.28.1=py39h06a4308_1
- setuptools=65.6.3=py39h06a4308_0
- six=1.16.0=pyhd3eb1b0_1
- sqlite=3.41.2=h5eee18b_0
- tk=8.6.12=h1ccaba5_0
- torchaudio=0.11.0=py39_cu113
- torchvision=0.12.0=py39_cu113
- tzdata=2023c=h04d1e81_0
- wheel=0.38.4=py39h06a4308_0
- x264=1!157.20191217=h7b6447c_0
- xz=5.4.2=h5eee18b_0
- zlib=1.2.13=h5eee18b_0
- zstd=1.5.4=hc292b87_0
- pip:
- absl-py==1.3.0
- aiofiles==23.1.0
- aiohttp==3.8.3
- aiosignal==1.3.1
- alembic==1.9.1
- altair==5.0.1
- annotated-types==0.5.0
- ansi2html==1.8.0
- antlr4-python3-runtime==4.9.3
- anyio==3.7.1
- appdirs==1.4.4
- asteroid-filterbanks==0.4.0
- async-timeout==4.0.2
- attrs==22.2.0
- audioread==3.0.0
- autopage==0.5.1
- backports-cached-property==1.0.2
- cachetools==5.2.0
- charset-normalizer==2.1.1
- click==8.1.3
- cliff==4.1.0
- cmaes==0.9.0
- cmake==3.26.4
- cmd2==2.4.2
- colorama==0.4.6
- colorlog==6.7.0
- commonmark==0.9.1
- contourpy==1.0.6
- cycler==0.11.0
- dash==2.12.1
- dash-core-components==2.0.0
- dash-html-components==2.0.0
- dash-table==5.0.0
- decorator==4.4.2
- docopt==0.6.2
- einops==0.3.2
- exceptiongroup==1.1.1
- fastapi==0.100.0
- ffmpeg-python==0.2.0
- ffmpy==0.3.0
- filelock==3.8.0
- flask==2.2.5
- fonttools==4.38.0
- frozenlist==1.3.3
- fsspec==2022.11.0
- future==0.18.2
- google-auth==2.15.0
- google-auth-oauthlib==0.4.6
- gradio==3.36.1
- gradio-client==0.2.7
- greenlet==2.0.1
- grpcio==1.51.1
- h11==0.14.0
- hmmlearn==0.2.8
- httpcore==0.17.3
- httpx==0.24.1
- huggingface-hub==0.16.4
- humanize==4.7.0
- hyperpyyaml==1.1.0
- imageio==2.23.0
- imageio-ffmpeg==0.4.7
- importlib-metadata==4.13.0
- importlib-resources==5.12.0
- iniconfig==2.0.0
- itsdangerous==2.1.2
- jinja2==3.1.2
- joblib==1.2.0
- jsonschema==4.18.0
- jsonschema-specifications==2023.6.1
- julius==0.2.7
- kiwisolver==1.4.4
- librosa==0.9.2
- linkify-it-py==2.0.2
- lit==16.0.5.post0
- llvmlite==0.39.1
- mako==1.2.4
- markdown==3.4.1
- markdown-it-py==2.2.0
- markupsafe==2.1.1
- matplotlib==3.7.1
- mdit-py-plugins==0.3.3
- mdurl==0.1.2
- more-itertools==9.0.0
- moviepy==1.0.3
- mpmath==1.2.1
- multidict==6.0.4
- nest-asyncio==1.5.7
- networkx==2.8.8
- numba==0.56.4
- oauthlib==3.2.2
- omegaconf==2.3.0
- openai-whisper==20230314
- optuna==3.0.5
- orjson==3.9.2
- packaging==21.3
- pandas==1.5.2
- pbr==5.11.0
- plotly==5.15.0
- pluggy==1.0.0
- pooch==1.6.0
- prettytable==3.5.0
- primepy==1.3
- proglog==0.1.10
- protobuf==3.20.1
- pyannote-audio==2.1.1
- pyannote-core==4.5
- pyannote-database==4.1.3
- pyannote-metrics==3.2.1
- pyannote-pipeline==2.3
- pyasn1==0.4.8
- pyasn1-modules==0.2.8
- pydantic==2.0.2
- pydantic-core==2.1.2
- pydeprecate==0.3.2
- pydub==0.25.1
- pygments==2.13.0
- pyparsing==3.0.9
- pyperclip==1.8.2
- pytest==7.3.1
- python-dateutil==2.8.2
- python-multipart==0.0.6
- pytorch-lightning==1.6.5
- pytorch-metric-learning==1.6.3
- pytz==2022.7
- pyyaml==6.0
- qtfaststart==1.8
- referencing==0.29.1
- regex==2022.10.31
- requests-oauthlib==1.3.1
- resampy==0.4.2
- retrying==1.3.4
- rich==12.6.0
- rpds-py==0.8.10
- rsa==4.9
- ruamel-yaml==0.17.21
- ruamel-yaml-clib==0.2.7
- ruff==0.0.272
- scikit-learn==1.2.0
- scipy==1.8.1
- semantic-version==2.10.0
- semver==2.13.0
- sentencepiece==0.1.97
- setuptools-rust==1.5.2
- shellingham==1.5.0
- simplejson==3.18.0
- singledispatchmethod==1.0
- sniffio==1.3.0
- sortedcontainers==2.4.0
- soundfile==0.10.3.post1
- speechbrain==0.5.14
- sqlalchemy==1.4.45
- starlette==0.27.0
- stevedore==4.1.1
- sympy==1.11.1
- tabulate==0.9.0
- tenacity==8.2.2
- tensorboard==2.11.0
- tensorboard-data-server==0.6.1
- tensorboard-plugin-wit==1.8.1
- threadpoolctl==3.1.0
- tiktoken==0.3.1
- tokenizers==0.13.2
- tomli==2.0.1
- toolz==0.12.0
- torch-audiomentations==0.11.0
- torch-pitch-shift==1.2.2
- torchmetrics==0.11.0
- tqdm==4.64.1
- transformers==4.24.0
- triton==2.0.0
- typer==0.7.0
- typing-extensions==4.7.1
- uc-micro-py==1.0.2
- urllib3==1.26.12
- uvicorn==0.22.0
- wcwidth==0.2.5
- websockets==11.0.3
- werkzeug==2.2.2
- yarl==1.8.2
- zipp==3.11.0
+6 -6
View File
@@ -31,12 +31,12 @@ exclude =[
] ]
[tool.poetry.dependencies] [tool.poetry.dependencies]
python = "^3.9" python = "^3.9"
tqdm = "^4.66.4" tqdm = "^4.66.5"
numpy = "^1.26.4" numpy = "^1.26.4"
openai-whisper = "^20231117" openai-whisper = ">=20231117,<20240931"
whisperx = "^3.1.3" faster-whisper = "^1.0.3"
"pyannote.audio" = "^3.1.1" "pyannote.audio" = "^3.3.1"
torch = "^2.3.0" torch = "^2.1.2"
[tool.poetry.group.dev.dependencies] [tool.poetry.group.dev.dependencies]
pytest = "^8.1.1" pytest = "^8.1.1"
@@ -57,7 +57,7 @@ format-jinja = """
[tool.poetry.group.docs.dependencies] [tool.poetry.group.docs.dependencies]
sphinx = "^7.3.7" sphinx = "^7.3.7"
sphinx-rtd-theme = "^2.0.0" sphinx-rtd-theme = ">=2,<4"
markdown-it-py = {version = "~3.0.0", extras = ["plugins"]} markdown-it-py = {version = "~3.0.0", extras = ["plugins"]}
myst-parser = "^3.0.1" myst-parser = "^3.0.1"
mdit-py-plugins = "^0.4.1" mdit-py-plugins = "^0.4.1"
+4 -4
View File
@@ -1,14 +1,14 @@
tqdm>=4.65.0 tqdm>=4.66.5
numpy>=1.26.4 numpy>=1.26.4
openai-whisper==20231117 openai-whisper==20231117
whisperx~=3.1.3 faster-whisper~=1.0.3
pyannote.audio~=3.1.1 pyannote.audio~=3.3.1
pyannote.core~=5.0.0 pyannote.core~=5.0.0
pyannote.database~=5.0.1 pyannote.database~=5.0.1
pyannote.metrics~=3.2.1 pyannote.metrics~=3.2.1
pyannote.pipeline~=3.0.1 pyannote.pipeline~=3.0.1
torch>=2.0.0 torchaudio>=2.1.2
+3 -9
View File
@@ -41,26 +41,20 @@ class AudioProcessor:
The sample rate of the audio. The sample rate of the audio.
""" """
def __init__(self, waveform: torch.Tensor, sr: int = SAMPLE_RATE, def __init__(self, waveform: torch.Tensor,
*args, **kwargs) -> None: sr: int = SAMPLE_RATE) -> None:
""" """
Initialize the AudioProcessor object. Initialize the AudioProcessor object.
Args: Args:
waveform (torch.Tensor): The audio waveform tensor. waveform (torch.Tensor): The audio waveform tensor.
sr (int, optional): The sample rate of the audio. Defaults to SAMPLE_RATE. sr (int, optional): The sample rate of the audio. Defaults to SAMPLE_RATE.
args: Additional arguments.
kwargs: Additional keyword arguments, e.g., device to use for processing.
If CUDA is available, it defaults to CUDA.
Raises: Raises:
ValueError: If the provided sample rate is not of type int. ValueError: If the provided sample rate is not of type int.
""" """
device = kwargs.get( self.waveform = waveform
"device", "cuda" if torch.cuda.is_available() else "cpu")
self.waveform = waveform.to(device)
self.sr = sr self.sr = sr
if not isinstance(self.sr, int): if not isinstance(self.sr, int):
+8 -7
View File
@@ -40,6 +40,7 @@ from .audio import AudioProcessor
from .diarisation import Diariser from .diarisation import Diariser
from .transcriber import Transcriber, load_transcriber, whisper from .transcriber import Transcriber, load_transcriber, whisper
from .transcript_exporter import Transcript from .transcript_exporter import Transcript
from .misc import SCRAIBE_TORCH_DEVICE
DiarisationType = TypeVar('DiarisationType') DiarisationType = TypeVar('DiarisationType')
@@ -74,7 +75,7 @@ class Scraibe:
whisper_model (Union[bool, str, whisper], optional): whisper_model (Union[bool, str, whisper], optional):
Path to whisper model or whisper model itself. Path to whisper model or whisper model itself.
whisper_type (str): whisper_type (str):
Type of whisper model to load. "whisper" or "whisperx". Type of whisper model to load. "whisper" or "faster-whisper".
diarisation_model (Union[bool, str, DiarisationType], optional): diarisation_model (Union[bool, str, DiarisationType], optional):
Path to pyannote diarization model or model itself. Path to pyannote diarization model or model itself.
**kwargs: Additional keyword arguments for whisper **kwargs: Additional keyword arguments for whisper
@@ -116,6 +117,9 @@ class Scraibe:
else: else:
self.params = {} self.params = {}
self.device = kwargs.get(
"device", SCRAIBE_TORCH_DEVICE)
def autotranscribe(self, audio_file: Union[str, torch.Tensor, ndarray], def autotranscribe(self, audio_file: Union[str, torch.Tensor, ndarray],
remove_original: bool = False, remove_original: bool = False,
**kwargs) -> Transcript: **kwargs) -> Transcript:
@@ -141,7 +145,7 @@ class Scraibe:
# Prepare waveform and sample rate for diarization # Prepare waveform and sample rate for diarization
dia_audio = { dia_audio = {
"waveform": audio_file.waveform.reshape(1, len(audio_file.waveform)), "waveform": audio_file.waveform.reshape(1, len(audio_file.waveform)).to(self.device),
"sample_rate": audio_file.sr "sample_rate": audio_file.sr
} }
@@ -165,8 +169,6 @@ class Scraibe:
if self.verbose: if self.verbose:
print("Diarisation finished. Starting transcription.") print("Diarisation finished. Starting transcription.")
audio_file.sr = torch.Tensor([audio_file.sr]).to(
audio_file.waveform.device)
# Transcribe each segment and store the results # Transcribe each segment and store the results
final_transcript = dict() final_transcript = dict()
@@ -213,7 +215,7 @@ class Scraibe:
# Prepare waveform and sample rate for diarization # Prepare waveform and sample rate for diarization
dia_audio = { dia_audio = {
"waveform": audio_file.waveform.reshape(1, len(audio_file.waveform)), "waveform": audio_file.waveform.reshape(1, len(audio_file.waveform)).to(self.device),
"sample_rate": audio_file.sr "sample_rate": audio_file.sr
} }
@@ -323,8 +325,7 @@ class Scraibe:
print(f"Audiofile {audio_file} removed.") print(f"Audiofile {audio_file} removed.")
@staticmethod @staticmethod
def get_audio_file(audio_file: Union[str, torch.Tensor, ndarray], def get_audio_file(audio_file: Union[str, torch.Tensor, ndarray]) -> AudioProcessor:
*args, **kwargs) -> AudioProcessor:
"""Gets an audio file as TorchAudioProcessor. """Gets an audio file as TorchAudioProcessor.
Args: Args:
+11 -4
View File
@@ -36,8 +36,8 @@ def cli():
help="List of audio files to transcribe.") help="List of audio files to transcribe.")
parser.add_argument("--whisper-type", type=str, default="whisper", parser.add_argument("--whisper-type", type=str, default="whisper",
choices=["whisper", "whisperx"], choices=["whisper", "faster-whisper"],
help="Type of Whisper model to use ('whisper' or 'whisperx').") help="Type of Whisper model to use ('whisper' or 'faster-whisper').")
parser.add_argument("--whisper-model-name", default="medium", parser.add_argument("--whisper-model-name", default="medium",
help="Name of the Whisper model to use.") help="Name of the Whisper model to use.")
@@ -79,6 +79,8 @@ def cli():
choices=sorted( choices=sorted(
LANGUAGES.keys()) + sorted([k.title() for k in TO_LANGUAGE_CODE.keys()]), LANGUAGES.keys()) + sorted([k.title() for k in TO_LANGUAGE_CODE.keys()]),
help="Language spoken in the audio. Specify None to perform language detection.") help="Language spoken in the audio. Specify None to perform language detection.")
parser.add_argument("--num-speakers", type=int, default=2,
help="Number of speakers in the audio.")
args = parser.parse_args() args = parser.parse_args()
@@ -117,8 +119,13 @@ def cli():
else: else:
task = "transcribe" task = "transcribe"
out = model.autotranscribe(audio, task=task, language=arg_dict.pop( out = model.autotranscribe(
"language"), verbose=arg_dict.pop("verbose_output")) audio,
task=task,
language=arg_dict.pop("language"),
verbose=arg_dict.pop("verbose_output"),
num_speakers=arg_dict.pop("num_speakers")
)
basename = audio.split("/")[-1].split(".")[0] basename = audio.split("/")[-1].split(".")[0]
print(f'Saving {basename}.{out_format} to {out_folder}') print(f'Saving {basename}.{out_format} to {out_folder}')
out.save(os.path.join( out.save(os.path.join(
+3 -8
View File
@@ -37,11 +37,11 @@ from pyannote.audio import Pipeline
from pyannote.audio.pipelines.speaker_diarization import SpeakerDiarization from pyannote.audio.pipelines.speaker_diarization import SpeakerDiarization
from torch import Tensor from torch import Tensor
from torch import device as torch_device from torch import device as torch_device
from torch.cuda import is_available
from huggingface_hub import HfApi from huggingface_hub import HfApi
from huggingface_hub.utils import RepositoryNotFoundError from huggingface_hub.utils import RepositoryNotFoundError
from .misc import PYANNOTE_DEFAULT_PATH, PYANNOTE_DEFAULT_CONFIG from .misc import PYANNOTE_DEFAULT_PATH, PYANNOTE_DEFAULT_CONFIG, SCRAIBE_TORCH_DEVICE
Annotation = TypeVar('Annotation') Annotation = TypeVar('Annotation')
TOKEN_PATH = os.path.join(os.path.dirname( TOKEN_PATH = os.path.join(os.path.dirname(
@@ -190,8 +190,7 @@ class Diariser:
cache_token: bool = False, cache_token: bool = False,
cache_dir: Union[Path, str] = PYANNOTE_DEFAULT_PATH, cache_dir: Union[Path, str] = PYANNOTE_DEFAULT_PATH,
hparams_file: Union[str, Path] = None, hparams_file: Union[str, Path] = None,
device: str = None, device: str = SCRAIBE_TORCH_DEVICE,
*args, **kwargs
) -> Pipeline: ) -> Pipeline:
""" """
Loads a pretrained model from pyannote.audio, Loads a pretrained model from pyannote.audio,
@@ -283,10 +282,6 @@ class Diariser:
'or from huggingface.co models. Please check your token' 'or from huggingface.co models. Please check your token'
'or your local model path') 'or your local model path')
# try to move the model to the device
if device is None:
device = "cuda" if is_available() else "cpu"
# torch_device is renamed from torch.device to avoid name conflict # torch_device is renamed from torch.device to avoid name conflict
_model = _model.to(torch_device(device)) _model = _model.to(torch_device(device))
+7 -5
View File
@@ -1,23 +1,25 @@
import os import os
import yaml import yaml
from pyannote.audio.core.model import CACHE_DIR as PYANNOTE_CACHE_DIR
from argparse import Action from argparse import Action
from ast import literal_eval from ast import literal_eval
from torch.cuda import is_available
CACHE_DIR = os.getenv( CACHE_DIR = os.getenv(
"AUTOT_CACHE", "AUTOT_CACHE",
os.path.expanduser("~/.cache/torch/models"), os.path.expanduser("~/.cache/torch/models"),
) )
os.environ["PYANNOTE_CACHE"] = os.getenv(
if CACHE_DIR != PYANNOTE_CACHE_DIR: "PYANNOTE_CACHE",
os.environ["PYANNOTE_CACHE"] = os.path.join(CACHE_DIR, "pyannote") os.path.join(CACHE_DIR, "pyannote"),
)
WHISPER_DEFAULT_PATH = os.path.join(CACHE_DIR, "whisper") WHISPER_DEFAULT_PATH = os.path.join(CACHE_DIR, "whisper")
PYANNOTE_DEFAULT_PATH = os.path.join(CACHE_DIR, "pyannote") PYANNOTE_DEFAULT_PATH = os.path.join(CACHE_DIR, "pyannote")
PYANNOTE_DEFAULT_CONFIG = os.path.join(PYANNOTE_DEFAULT_PATH, "config.yaml") \ PYANNOTE_DEFAULT_CONFIG = os.path.join(PYANNOTE_DEFAULT_PATH, "config.yaml") \
if os.path.exists(os.path.join(PYANNOTE_DEFAULT_PATH, "config.yaml")) \ if os.path.exists(os.path.join(PYANNOTE_DEFAULT_PATH, "config.yaml")) \
else ('jaikinator/scraibe', 'pyannote/speaker-diarization-3.1') else ('Jaikinator/ScrAIbe', 'pyannote/speaker-diarization-3.1')
SCRAIBE_TORCH_DEVICE = os.getenv("SCRAIBE_TORCH_DEVICE", "cuda" if is_available() else "cpu")
def config_diarization_yaml(file_path: str, path_to_segmentation: str = None) -> None: def config_diarization_yaml(file_path: str, path_to_segmentation: str = None) -> None:
"""Configure diarization pipeline from a YAML file. """Configure diarization pipeline from a YAML file.
+58 -28
View File
@@ -26,17 +26,17 @@ Usage:
from whisper import Whisper from whisper import Whisper
from whisper import load_model as whisper_load_model from whisper import load_model as whisper_load_model
from whisperx.asr import WhisperModel from whisper.tokenizer import TO_LANGUAGE_CODE
from whisperx import load_model as whisperx_load_model from faster_whisper import WhisperModel as FasterWhisperModel
from faster_whisper.tokenizer import _LANGUAGE_CODES as FASTER_WHISPER_LANGUAGE_CODES
from typing import TypeVar, Union, Optional from typing import TypeVar, Union, Optional
from torch import Tensor, device from torch import Tensor, device
from torch.cuda import is_available as cuda_is_available
from numpy import ndarray from numpy import ndarray
from inspect import signature from inspect import signature
from abc import abstractmethod from abc import abstractmethod
import warnings import warnings
from .misc import WHISPER_DEFAULT_PATH from .misc import WHISPER_DEFAULT_PATH, SCRAIBE_TORCH_DEVICE
whisper = TypeVar('whisper') whisper = TypeVar('whisper')
@@ -123,7 +123,7 @@ class Transcriber:
model: str = "medium", model: str = "medium",
whisper_type: str = 'whisper', whisper_type: str = 'whisper',
download_root: str = WHISPER_DEFAULT_PATH, download_root: str = WHISPER_DEFAULT_PATH,
device: Optional[Union[str, device]] = None, device: Optional[Union[str, device]] = SCRAIBE_TORCH_DEVICE,
in_memory: bool = False, in_memory: bool = False,
*args, **kwargs *args, **kwargs
) -> None: ) -> None:
@@ -145,7 +145,7 @@ class Transcriber:
- 'large-v3' - 'large-v3'
- 'large' - 'large'
whisper_type (str): whisper_type (str):
Type of whisper model to load. "whisper" or "whisperx". Type of whisper model to load. "whisper" or "faster-whisper".
download_root (str, optional): Path to download the model. download_root (str, optional): Path to download the model.
Defaults to WHISPER_DEFAULT_PATH. Defaults to WHISPER_DEFAULT_PATH.
device (Optional[Union[str, torch.device]], optional): device (Optional[Union[str, torch.device]], optional):
@@ -205,7 +205,7 @@ class WhisperTranscriber(Transcriber):
def load_model(cls, def load_model(cls,
model: str = "medium", model: str = "medium",
download_root: str = WHISPER_DEFAULT_PATH, download_root: str = WHISPER_DEFAULT_PATH,
device: Optional[Union[str, device]] = None, device: Optional[Union[str, device]] = SCRAIBE_TORCH_DEVICE,
in_memory: bool = False, in_memory: bool = False,
*args, **kwargs *args, **kwargs
) -> 'WhisperTranscriber': ) -> 'WhisperTranscriber':
@@ -272,7 +272,7 @@ class WhisperTranscriber(Transcriber):
return f"WhisperTranscriber(model_name={self.model_name}, model={self.model})" return f"WhisperTranscriber(model_name={self.model_name}, model={self.model})"
class WhisperXTranscriber(Transcriber): class FasterWhisperTranscriber(Transcriber):
def __init__(self, model: whisper, model_name: str) -> None: def __init__(self, model: whisper, model_name: str) -> None:
super().__init__(model, model_name) super().__init__(model, model_name)
@@ -294,19 +294,19 @@ class WhisperXTranscriber(Transcriber):
if isinstance(audio, Tensor): if isinstance(audio, Tensor):
audio = audio.cpu().numpy() audio = audio.cpu().numpy()
result = self.model.transcribe(audio, *args, **kwargs) result, _ = self.model.transcribe(audio, *args, **kwargs)
text = "" text = ""
for seg in result['segments']: for seg in result:
text += seg['text'] text += seg.text
return text return text
@classmethod @classmethod
def load_model(cls, def load_model(cls,
model: str = "medium", model: str = "medium",
download_root: str = WHISPER_DEFAULT_PATH, download_root: str = WHISPER_DEFAULT_PATH,
device: Optional[Union[str, device]] = None, device: Optional[Union[str, device]] = SCRAIBE_TORCH_DEVICE,
*args, **kwargs *args, **kwargs
) -> 'WhisperXTranscriber': ) -> 'FasterWhisperModel':
""" """
Load whisper model. Load whisper model.
@@ -329,7 +329,7 @@ class WhisperXTranscriber(Transcriber):
Defaults to WHISPER_DEFAULT_PATH. Defaults to WHISPER_DEFAULT_PATH.
device (Optional[Union[str, torch.device]], optional): device (Optional[Union[str, torch.device]], optional):
Device to load model on. Defaults to None. Device to load model on. Defaults to SCRAIBE_TORCH_DEVICE.
in_memory (bool, optional): Whether to load model in memory. in_memory (bool, optional): Whether to load model in memory.
Defaults to False. Defaults to False.
args: Additional arguments only to avoid errors. args: Additional arguments only to avoid errors.
@@ -338,17 +338,17 @@ class WhisperXTranscriber(Transcriber):
Returns: Returns:
Transcriber: A Transcriber object initialized with the specified model. Transcriber: A Transcriber object initialized with the specified model.
""" """
if device is None:
device = "cuda" if cuda_is_available() else "cpu"
if not isinstance(device, str): if not isinstance(device, str):
device = str(device) device = str(device)
compute_type = kwargs.get('compute_type', 'float16') compute_type = kwargs.get('compute_type', 'float16')
if device == 'cpu' and compute_type == 'float16': if device == 'cpu' and compute_type == 'float16':
warnings.warn(f'Compute type {compute_type} not compatible with ' warnings.warn(f'Compute type {compute_type} not compatible with '
f'device {device}! Changing compute type to int8.') f'device {device}! Changing compute type to int8.')
compute_type = 'int8' compute_type = 'int8'
_model = whisperx_load_model(model, download_root=download_root, _model = FasterWhisperModel(model, download_root=download_root,
device=device, compute_type=compute_type) device=device, compute_type=compute_type)
return cls(_model, model_name=model) return cls(_model, model_name=model)
@@ -361,7 +361,7 @@ class WhisperXTranscriber(Transcriber):
dict: Keyword arguments for whisper model. dict: Keyword arguments for whisper model.
""" """
# _possible_kwargs = WhisperModel.transcribe.__code__.co_varnames # _possible_kwargs = WhisperModel.transcribe.__code__.co_varnames
_possible_kwargs = signature(WhisperModel.transcribe).parameters.keys() _possible_kwargs = signature(FasterWhisperModel.transcribe).parameters.keys()
whisper_kwargs = {k: v for k, whisper_kwargs = {k: v for k,
v in kwargs.items() if k in _possible_kwargs} v in kwargs.items() if k in _possible_kwargs}
@@ -370,21 +370,51 @@ class WhisperXTranscriber(Transcriber):
whisper_kwargs["task"] = task whisper_kwargs["task"] = task
if (language := kwargs.get("language")): if (language := kwargs.get("language")):
language = FasterWhisperTranscriber.convert_to_language_code(language)
whisper_kwargs["language"] = language whisper_kwargs["language"] = language
return whisper_kwargs return whisper_kwargs
@staticmethod
def convert_to_language_code(lang : str) -> str:
"""
Load whisper model.
Args:
lang (str): language as code or language name
Returns:
language (str) code of language
"""
# If the input is already in FASTER_WHISPER_LANGUAGE_CODES, return it directly
if lang in FASTER_WHISPER_LANGUAGE_CODES:
return lang
# Normalize the input to lowercase
lang = lang.lower()
# Check if the language name is in the TO_LANGUAGE_CODE mapping
if lang in TO_LANGUAGE_CODE:
return TO_LANGUAGE_CODE[lang]
# If the language is not recognized, raise a ValueError with the available options
available_codes = ', '.join(FASTER_WHISPER_LANGUAGE_CODES)
raise ValueError(f"Language '{lang}' is not a valid language code or name. "
f"Available language codes are: {available_codes}.")
def __repr__(self) -> str: def __repr__(self) -> str:
return f"WhisperXTranscriber(model_name={self.model_name}, model={self.model})" return f"FasterWhisperTranscriber(model_name={self.model_name}, model={self.model})"
def load_transcriber(model: str = "medium", def load_transcriber(model: str = "medium",
whisper_type: str = 'whisper', whisper_type: str = 'whisper',
download_root: str = WHISPER_DEFAULT_PATH, download_root: str = WHISPER_DEFAULT_PATH,
device: Optional[Union[str, device]] = None, device: Optional[Union[str, device]] = SCRAIBE_TORCH_DEVICE,
in_memory: bool = False, in_memory: bool = False,
*args, **kwargs *args, **kwargs
) -> Union[WhisperTranscriber, WhisperXTranscriber]: ) -> Union[WhisperTranscriber, FasterWhisperTranscriber]:
""" """
Load whisper model. Load whisper model.
@@ -403,28 +433,28 @@ def load_transcriber(model: str = "medium",
- 'large-v3' - 'large-v3'
- 'large' - 'large'
whisper_type (str): whisper_type (str):
Type of whisper model to load. "whisper" or "whisperx". Type of whisper model to load. "whisper" or "faster-whisper".
download_root (str, optional): Path to download the model. download_root (str, optional): Path to download the model.
Defaults to WHISPER_DEFAULT_PATH. Defaults to WHISPER_DEFAULT_PATH.
device (Optional[Union[str, torch.device]], optional): device (Optional[Union[str, torch.device]], optional):
Device to load model on. Defaults to None. Device to load model on. Defaults to SCRAIBE_TORCH_DEVICE.
in_memory (bool, optional): Whether to load model in memory. in_memory (bool, optional): Whether to load model in memory.
Defaults to False. Defaults to False.
args: Additional arguments only to avoid errors. args: Additional arguments only to avoid errors.
kwargs: Additional keyword arguments only to avoid errors. kwargs: Additional keyword arguments only to avoid errors.
Returns: Returns:
Union[WhisperTranscriber, WhisperXTranscriber]: Union[WhisperTranscriber, FasterWhisperTranscriber]:
One of the Whisper variants as Transcrbier object initialized with the specified model. One of the Whisper variants as Transcrbier object initialized with the specified model.
""" """
if whisper_type.lower() == 'whisper': if whisper_type.lower() == 'whisper':
_model = WhisperTranscriber.load_model( _model = WhisperTranscriber.load_model(
model, download_root, device, in_memory, *args, **kwargs) model, download_root, device, in_memory, *args, **kwargs)
return _model return _model
elif whisper_type.lower() == 'whisperx': elif whisper_type.lower() == 'faster-whisper':
_model = WhisperXTranscriber.load_model( _model = FasterWhisperTranscriber.load_model(
model, download_root, device, *args, **kwargs) model, download_root, device, *args, **kwargs)
return _model return _model
else: else:
raise ValueError(f'Model type not recognized, exptected "whisper" ' raise ValueError(f'Model type not recognized, exptected "whisper" '
f'or "whisperx", got {whisper_type}.') f'or "faster-whisper", got {whisper_type}.')
@@ -6,7 +6,7 @@ import os
@pytest.fixture @pytest.fixture
def create_scraibe_instance(): def create_scraibe_instance():
if "HF_TOKEN" in os.environ: if "HF_TOKEN" in os.environ:
return Scraibe(use_auth_token=os.environ["HF_TOKEN"]) return Scraibe(use_auth_token=os.environ["HF_TOKEN"], whisper_model= "tiny")
else: else:
return Scraibe() return Scraibe()
@@ -19,19 +19,19 @@ def test_scraibe_init(create_scraibe_instance):
def test_scraibe_autotranscribe(create_scraibe_instance): def test_scraibe_autotranscribe(create_scraibe_instance):
model = create_scraibe_instance model = create_scraibe_instance
transcript = model.autotranscribe('test/audio_test_2.mp4') transcript = model.autotranscribe('tests/audio_test_2.mp4')
assert isinstance(transcript, Transcript) assert isinstance(transcript, Transcript)
def test_scraibe_diarization(create_scraibe_instance): def test_scraibe_diarization(create_scraibe_instance):
model = create_scraibe_instance model = create_scraibe_instance
diarisation_result = model.diarization('test/audio_test_2.mp4') diarisation_result = model.diarization('tests/audio_test_2.mp4')
assert isinstance(diarisation_result, dict) assert isinstance(diarisation_result, dict)
def test_scraibe_transcribe(create_scraibe_instance): def test_scraibe_transcribe(create_scraibe_instance):
model = create_scraibe_instance model = create_scraibe_instance
transcription_result = model.transcribe('test/audio_test_2.mp4') transcription_result = model.transcribe('tests/audio_test_2.mp4')
assert isinstance(transcription_result, str) assert isinstance(transcription_result, str)
@@ -1,6 +1,6 @@
import pytest import pytest
from scraibe import (Transcriber, WhisperTranscriber, from scraibe import (Transcriber, WhisperTranscriber,
WhisperXTranscriber, load_transcriber) FasterWhisperTranscriber, load_transcriber)
import torch import torch
@@ -31,33 +31,33 @@ def test_transcriber(mock_load_model, audio_file, expected_transcription):
@pytest.fixture @pytest.fixture
def whisper_instance(): def whisper_instance():
return load_transcriber('medium', whisper_type='whisper') return load_transcriber('tiny', whisper_type='whisper')
@pytest.fixture @pytest.fixture
def whisperx_instance(): def faster_whisper_instance():
return load_transcriber('medium', whisper_type='whisperx') return load_transcriber('tiny', whisper_type='faster-whisper')
def test_whisper_base_initialization(whisper_instance): def test_whisper_base_initialization(whisper_instance):
assert isinstance(whisper_instance, Transcriber) assert isinstance(whisper_instance, Transcriber)
def test_whisperx_base_initialization(whisperx_instance): def test_faster_whisper_base_initialization(faster_whisper_instance):
assert isinstance(whisperx_instance, Transcriber) assert isinstance(faster_whisper_instance, Transcriber)
def test_whisper_transcriber_initialization(whisper_instance): def test_whisper_transcriber_initialization(whisper_instance):
assert isinstance(whisper_instance, WhisperTranscriber) assert isinstance(whisper_instance, WhisperTranscriber)
def test_whisperx_transcriber_initialization(whisperx_instance): def test_faster_whisper_transcriber_initialization(faster_whisper_instance):
assert isinstance(whisperx_instance, WhisperXTranscriber) assert isinstance(faster_whisper_instance, FasterWhisperTranscriber)
def test_wrong_transcriber_initialization(): def test_wrong_transcriber_initialization():
with pytest.raises(ValueError): with pytest.raises(ValueError):
load_transcriber('medium', whisper_type='wrong_whisper') load_transcriber('tiny', whisper_type='wrong_whisper')
def test_get_whisper_kwargs(): def test_get_whisper_kwargs():
@@ -69,12 +69,12 @@ def test_get_whisper_kwargs():
def test_whisper_transcribe(whisper_instance): def test_whisper_transcribe(whisper_instance):
model = whisper_instance model = whisper_instance
# mocker.patch.object(transcriber_instance.model, 'transcribe', return_value={'Hello, World !'} ) # mocker.patch.object(transcriber_instance.model, 'transcribe', return_value={'Hello, World !'} )
transcript = model.transcribe('test/audio_test_2.mp4') transcript = model.transcribe('tests/audio_test_2.mp4')
assert isinstance(transcript, str) assert isinstance(transcript, str)
def test_whisperx_transcribe(whisperx_instance): def test_faster_whisper_transcribe(faster_whisper_instance):
model = whisperx_instance model = faster_whisper_instance
# mocker.patch.object(transcriber_instance.model, 'transcribe', return_value={'Hello, World !'} ) # mocker.patch.object(transcriber_instance.model, 'transcribe', return_value={'Hello, World !'} )
transcript = model.transcribe('test/audio_test_2.mp4') transcript = model.transcribe('tests/audio_test_2.mp4')
assert isinstance(transcript, str) assert isinstance(transcript, str)