Merge pull request #127 from JSchmie/develop

Update main to release version 0.3.0
This commit is contained in:
Marko Henning
2024-10-24 10:32:02 +02:00
committed by GitHub
18 changed files with 245 additions and 401 deletions
+95
View File
@@ -0,0 +1,95 @@
# This workflow uses actions that are not certified by GitHub.
# They are provided by a third-party and are governed by
# separate terms of service, privacy policy, and support
# documentation.
# GitHub recommends pinning actions to a commit SHA.
# To get a newer version, you will need to update the SHA.
# You can also reference a tag or branch, but the action may change without warning.
name: Publish Docker image
on:
push:
tags:
- v*
workflow_dispatch:
env:
image: hadr0n/scraibe
jobs:
push_to_registry:
name: Push Docker image to Docker Hub
runs-on: ubuntu-latest
permissions:
packages: write
contents: read
security-events: write
steps:
- name: Check out the repo
uses: actions/checkout@v4
with:
fetch-tags: true
fetch-depth: 0
- name: Get Version Tag
id: version
run: |
echo "tag=$(git describe --tags --abbrev=0)" >> $GITHUB_OUTPUT
- name: Overwrite label tag
run: sed -i 's/LABEL version=".*"/LABEL version="'${{ steps.version.outputs.tag }}'"/' Dockerfile
- name: Test name and tag
run: |
echo "${{ env.image }}:latest,${{ env.image }}:${{ steps.version.outputs.tag }}"
- name: Log in to Docker Hub
uses: docker/login-action@v3
with:
username: ${{ secrets.DOCKERHUB_USERNAME }}
password: ${{ secrets.DOCKERHUB_TOKEN }}
- name: Build and push Docker image
id: push
uses: docker/build-push-action@v5
with:
context: .
file: ./Dockerfile
push: true
tags: "${{ env.image }}:latest,${{ env.image }}:${{ steps.version.outputs.tag }}"
- name: SBOM Generation
uses: anchore/sbom-action@v0
with:
image: ${{ env.image }}:latest
- name: Scan image
id: scan
uses: anchore/scan-action@v3
with:
image: ${{ env.image }}:latest
fail-build: false
- name: upload Anchore scan SARIF report
uses: github/codeql-action/upload-sarif@v3
with:
sarif_file: ${{ steps.scan.outputs.sarif }}
# - name: Inspect action SARIF report
# run: cat ${{ steps.scan.outputs.sarif }}
- uses: actions/upload-artifact@v4
with:
name: SARIF report
path: ${{ steps.scan.outputs.sarif }}
# - name: Generate artifact attestation
# uses: actions/attest-build-provenance@v1
# with:
# subject-name: ${{ env.REGISTRY }}/${{ env.IMAGE_NAME}}
# subject-digest: ${{ steps.push.outputs.digest }}
# push-to-registry: false
+11 -33
View File
@@ -1,18 +1,14 @@
name: Publish Python 🐍 distribution 📦 to PyPI and TestPyPI
on:
pull_request_target:
branches:
- develop
types:
- closed
paths:
- scraibe/**
- pyproject.toml
push:
tags:
- 'v*.*.*'
branches:
- "develop"
paths:
- "scraibe/**"
- "pyproject.toml"
workflow_dispatch:
inputs:
@@ -27,13 +23,7 @@ on:
jobs:
Build-and-publish-to-Test-PyPI:
if: |
(github.event_name == 'workflow_dispatch' &&
github.event.inputs.test == 'true') ||
(github.event_name == 'pull_request_target' &&
github.event.pull_request.merged &&
contains(github.event.pull_request.labels.*.name, 'release')) ||
(github.event_name == 'push' && startsWith(github.ref, 'refs/tags/'))
if: github.event_name != 'workflow_dispatch' || github.event.inputs.test == 'true'
runs-on: ubuntu-latest
steps:
- uses: actions/checkout@v4
@@ -72,28 +62,16 @@ jobs:
needs: Test-PyPi-install
runs-on: ubuntu-latest
if: |
always() &&
(( needs.Build-and-publish-to-Test-PyPI.result != 'failure' &&
needs.Test-PyPi-install.result != 'failure' ) &&
((github.event_name == 'workflow_dispatch' &&
github.event.inputs.publish_to_pypi == 'true') ||
(github.event_name == 'pull_request_target' &&
github.event.pull_request.merged &&
contains(github.event.pull_request.labels.*.name, 'release')) ||
(github.event_name == 'push' && startsWith(github.ref, 'refs/tags/'))))
always() &&
(( needs.Build-and-publish-to-Test-PyPI.result != 'failure' &&
needs.Test-PyPi-install.result != 'failure' ) ||
((github.event_name == 'workflow_dispatch' &&
github.event.inputs.publish_to_pypi == 'true')))
steps:
- name: Checkout Repository Tags
uses: actions/checkout@v4
if: github.ref == 'refs/heads/main'
with:
fetch-depth: '0'
branch: 'main'
- name: Checkout Repository (Develop)
uses: actions/checkout@v4
if: github.ref == 'refs/heads/develop'
with:
fetch-depth: '0'
branch: 'develop'
- name: Set up Poetry 📦
uses: JRubics/poetry-publish@v1.16
with:
+18 -20
View File
@@ -1,5 +1,5 @@
#pytorch Image
FROM pytorch/pytorch:1.11.0-cuda11.3-cudnn8-runtime
FROM pytorch/pytorch:2.3.1-cuda12.1-cudnn8-runtime
# Labels
@@ -14,33 +14,31 @@ LABEL url="https://github.com/JSchmie/ScrAIbe"
# Install dependencies
WORKDIR /app
ARG model_name=medium
#Enviorment Dependncies
ENV TRANSFORMERS_CACHE /app/models
ENV HF_HOME /app/models
ENV AUTOT_CACHE /app/models
ENV PYANNOTE_CACHE /app/models/pyannote
#Enviorment dependencies
ENV TRANSFORMERS_CACHE=/app/models
ENV HF_HOME=/app/models
ENV AUTOT_CACHE=/app/models
ENV PYANNOTE_CACHE=/app/models/pyannote
#Copy all necessary files
COPY requirements.txt /app/requirements.txt
COPY README.md /app/README.md
COPY models /app/models
COPY scraibe /app/scraibe
COPY setup.py /app/setup.py
#Installing all necessary Dependencies and Running the Application with a personalised Hugging-Face-Token
RUN apt update && apt-get install -y libsm6 libxrender1 libfontconfig1
RUN conda update --all
#Installing all necessary dependencies and running the application with a personalised Hugging-Face-Token
RUN apt update -y && apt upgrade -y && \
apt install -y libsm6 libxrender1 libfontconfig1 && \
apt clean && \
rm -rf /var/lib/apt/lists/* /tmp/* /var/tmp/*
RUN conda install pip
RUN conda install -y ffmpeg
RUN conda install -c conda-forge libsndfile
RUN pip install torchaudio==0.11.0+cu113 -f https://download.pytorch.org/whl/torch_stable.html
RUN pip install -r requirements.txt
RUN pip install markupsafe==2.0.1 --force-reinstall
RUN conda update --all && \
# conda install -y pip ffmpeg && \
conda install -c conda-forge libsndfile && \
conda clean --all -y
# RUN pip install torchaudio==0.11.0+cu113 -f https://download.pytorch.org/whl/torch_stable.html
RUN pip install --no-cache-dir -r requirements.txt
RUN python3 -m 'scraibe.cli' --whisper-model-name $model_name
# Expose port
EXPOSE 7860
# Run the application
ENTRYPOINT ["python3", "-m", "scraibe.cli" ,"--whisper-model-name", "$model_name"]
ENTRYPOINT ["python3", "-m", "scraibe.cli"]
-256
View File
@@ -1,256 +0,0 @@
channels:
- pytorch
- defaults
dependencies:
- _libgcc_mutex=0.1=main
- _openmp_mutex=5.1=1_gnu
- blas=1.0=mkl
- brotlipy=0.7.0=py39h27cfd23_1003
- bzip2=1.0.8=h7b6447c_0
- ca-certificates=2023.05.30=h06a4308_0
- certifi=2023.5.7=py39h06a4308_0
- cffi=1.15.1=py39h5eee18b_3
- cryptography=39.0.1=py39h9ce1e76_2
- cudatoolkit=11.3.1=h2bc3f7f_2
- ffmpeg=4.2.2=h20bf706_0
- flit-core=3.8.0=py39h06a4308_0
- freetype=2.12.1=h4a9f257_0
- giflib=5.2.1=h5eee18b_3
- gmp=6.2.1=h295c915_3
- gnutls=3.6.15=he1e5248_0
- idna=3.4=py39h06a4308_0
- intel-openmp=2021.4.0=h06a4308_3561
- jpeg=9e=h5eee18b_1
- lame=3.100=h7b6447c_0
- lcms2=2.12=h3be6417_0
- ld_impl_linux-64=2.38=h1181459_1
- lerc=3.0=h295c915_0
- libdeflate=1.17=h5eee18b_0
- libffi=3.4.2=h6a678d5_6
- libgcc-ng=11.2.0=h1234567_1
- libgomp=11.2.0=h1234567_1
- libidn2=2.3.2=h7f8727e_0
- libopus=1.3.1=h7b6447c_0
- libpng=1.6.39=h5eee18b_0
- libstdcxx-ng=11.2.0=h1234567_1
- libtasn1=4.16.0=h27cfd23_0
- libtiff=4.5.0=h6a678d5_2
- libunistring=0.9.10=h27cfd23_0
- libuv=1.44.2=h5eee18b_0
- libvpx=1.7.0=h439df22_0
- libwebp=1.2.4=h11a3e52_1
- libwebp-base=1.2.4=h5eee18b_1
- lz4-c=1.9.4=h6a678d5_0
- mkl=2021.4.0=h06a4308_640
- mkl-service=2.4.0=py39h7f8727e_0
- mkl_fft=1.3.1=py39hd3c417c_0
- mkl_random=1.2.2=py39h51133e4_0
- ncurses=6.4=h6a678d5_0
- nettle=3.7.3=hbbd107a_1
- numpy=1.23.5=py39h14f4228_0
- numpy-base=1.23.5=py39h31eccc5_0
- openh264=2.1.1=h4ff587b_0
- openssl=3.0.9=h7f8727e_0
- pillow=9.4.0=py39h6a678d5_0
- pip=23.0.1=py39h06a4308_0
- pycparser=2.21=pyhd3eb1b0_0
- pyopenssl=23.0.0=py39h06a4308_0
- pysocks=1.7.1=py39h06a4308_0
- python=3.9.16=h955ad1f_3
- pytorch=1.11.0=py3.9_cuda11.3_cudnn8.2.0_0
- pytorch-mutex=1.0=cuda
- readline=8.2=h5eee18b_0
- requests=2.28.1=py39h06a4308_1
- setuptools=65.6.3=py39h06a4308_0
- six=1.16.0=pyhd3eb1b0_1
- sqlite=3.41.2=h5eee18b_0
- tk=8.6.12=h1ccaba5_0
- torchaudio=0.11.0=py39_cu113
- torchvision=0.12.0=py39_cu113
- tzdata=2023c=h04d1e81_0
- wheel=0.38.4=py39h06a4308_0
- x264=1!157.20191217=h7b6447c_0
- xz=5.4.2=h5eee18b_0
- zlib=1.2.13=h5eee18b_0
- zstd=1.5.4=hc292b87_0
- pip:
- absl-py==1.3.0
- aiofiles==23.1.0
- aiohttp==3.8.3
- aiosignal==1.3.1
- alembic==1.9.1
- altair==5.0.1
- annotated-types==0.5.0
- ansi2html==1.8.0
- antlr4-python3-runtime==4.9.3
- anyio==3.7.1
- appdirs==1.4.4
- asteroid-filterbanks==0.4.0
- async-timeout==4.0.2
- attrs==22.2.0
- audioread==3.0.0
- autopage==0.5.1
- backports-cached-property==1.0.2
- cachetools==5.2.0
- charset-normalizer==2.1.1
- click==8.1.3
- cliff==4.1.0
- cmaes==0.9.0
- cmake==3.26.4
- cmd2==2.4.2
- colorama==0.4.6
- colorlog==6.7.0
- commonmark==0.9.1
- contourpy==1.0.6
- cycler==0.11.0
- dash==2.12.1
- dash-core-components==2.0.0
- dash-html-components==2.0.0
- dash-table==5.0.0
- decorator==4.4.2
- docopt==0.6.2
- einops==0.3.2
- exceptiongroup==1.1.1
- fastapi==0.100.0
- ffmpeg-python==0.2.0
- ffmpy==0.3.0
- filelock==3.8.0
- flask==2.2.5
- fonttools==4.38.0
- frozenlist==1.3.3
- fsspec==2022.11.0
- future==0.18.2
- google-auth==2.15.0
- google-auth-oauthlib==0.4.6
- gradio==3.36.1
- gradio-client==0.2.7
- greenlet==2.0.1
- grpcio==1.51.1
- h11==0.14.0
- hmmlearn==0.2.8
- httpcore==0.17.3
- httpx==0.24.1
- huggingface-hub==0.16.4
- humanize==4.7.0
- hyperpyyaml==1.1.0
- imageio==2.23.0
- imageio-ffmpeg==0.4.7
- importlib-metadata==4.13.0
- importlib-resources==5.12.0
- iniconfig==2.0.0
- itsdangerous==2.1.2
- jinja2==3.1.2
- joblib==1.2.0
- jsonschema==4.18.0
- jsonschema-specifications==2023.6.1
- julius==0.2.7
- kiwisolver==1.4.4
- librosa==0.9.2
- linkify-it-py==2.0.2
- lit==16.0.5.post0
- llvmlite==0.39.1
- mako==1.2.4
- markdown==3.4.1
- markdown-it-py==2.2.0
- markupsafe==2.1.1
- matplotlib==3.7.1
- mdit-py-plugins==0.3.3
- mdurl==0.1.2
- more-itertools==9.0.0
- moviepy==1.0.3
- mpmath==1.2.1
- multidict==6.0.4
- nest-asyncio==1.5.7
- networkx==2.8.8
- numba==0.56.4
- oauthlib==3.2.2
- omegaconf==2.3.0
- openai-whisper==20230314
- optuna==3.0.5
- orjson==3.9.2
- packaging==21.3
- pandas==1.5.2
- pbr==5.11.0
- plotly==5.15.0
- pluggy==1.0.0
- pooch==1.6.0
- prettytable==3.5.0
- primepy==1.3
- proglog==0.1.10
- protobuf==3.20.1
- pyannote-audio==2.1.1
- pyannote-core==4.5
- pyannote-database==4.1.3
- pyannote-metrics==3.2.1
- pyannote-pipeline==2.3
- pyasn1==0.4.8
- pyasn1-modules==0.2.8
- pydantic==2.0.2
- pydantic-core==2.1.2
- pydeprecate==0.3.2
- pydub==0.25.1
- pygments==2.13.0
- pyparsing==3.0.9
- pyperclip==1.8.2
- pytest==7.3.1
- python-dateutil==2.8.2
- python-multipart==0.0.6
- pytorch-lightning==1.6.5
- pytorch-metric-learning==1.6.3
- pytz==2022.7
- pyyaml==6.0
- qtfaststart==1.8
- referencing==0.29.1
- regex==2022.10.31
- requests-oauthlib==1.3.1
- resampy==0.4.2
- retrying==1.3.4
- rich==12.6.0
- rpds-py==0.8.10
- rsa==4.9
- ruamel-yaml==0.17.21
- ruamel-yaml-clib==0.2.7
- ruff==0.0.272
- scikit-learn==1.2.0
- scipy==1.8.1
- semantic-version==2.10.0
- semver==2.13.0
- sentencepiece==0.1.97
- setuptools-rust==1.5.2
- shellingham==1.5.0
- simplejson==3.18.0
- singledispatchmethod==1.0
- sniffio==1.3.0
- sortedcontainers==2.4.0
- soundfile==0.10.3.post1
- speechbrain==0.5.14
- sqlalchemy==1.4.45
- starlette==0.27.0
- stevedore==4.1.1
- sympy==1.11.1
- tabulate==0.9.0
- tenacity==8.2.2
- tensorboard==2.11.0
- tensorboard-data-server==0.6.1
- tensorboard-plugin-wit==1.8.1
- threadpoolctl==3.1.0
- tiktoken==0.3.1
- tokenizers==0.13.2
- tomli==2.0.1
- toolz==0.12.0
- torch-audiomentations==0.11.0
- torch-pitch-shift==1.2.2
- torchmetrics==0.11.0
- tqdm==4.64.1
- transformers==4.24.0
- triton==2.0.0
- typer==0.7.0
- typing-extensions==4.7.1
- uc-micro-py==1.0.2
- urllib3==1.26.12
- uvicorn==0.22.0
- wcwidth==0.2.5
- websockets==11.0.3
- werkzeug==2.2.2
- yarl==1.8.2
- zipp==3.11.0
+6 -6
View File
@@ -31,12 +31,12 @@ exclude =[
]
[tool.poetry.dependencies]
python = "^3.9"
tqdm = "^4.66.4"
tqdm = "^4.66.5"
numpy = "^1.26.4"
openai-whisper = "^20231117"
whisperx = "^3.1.3"
"pyannote.audio" = "^3.1.1"
torch = "^2.3.0"
openai-whisper = ">=20231117,<20240931"
faster-whisper = "^1.0.3"
"pyannote.audio" = "^3.3.1"
torch = "^2.1.2"
[tool.poetry.group.dev.dependencies]
pytest = "^8.1.1"
@@ -57,7 +57,7 @@ format-jinja = """
[tool.poetry.group.docs.dependencies]
sphinx = "^7.3.7"
sphinx-rtd-theme = "^2.0.0"
sphinx-rtd-theme = ">=2,<4"
markdown-it-py = {version = "~3.0.0", extras = ["plugins"]}
myst-parser = "^3.0.1"
mdit-py-plugins = "^0.4.1"
+4 -4
View File
@@ -1,14 +1,14 @@
tqdm>=4.65.0
tqdm>=4.66.5
numpy>=1.26.4
openai-whisper==20231117
whisperx~=3.1.3
faster-whisper~=1.0.3
pyannote.audio~=3.1.1
pyannote.audio~=3.3.1
pyannote.core~=5.0.0
pyannote.database~=5.0.1
pyannote.metrics~=3.2.1
pyannote.pipeline~=3.0.1
torch>=2.0.0
torchaudio>=2.1.2
+3 -9
View File
@@ -41,26 +41,20 @@ class AudioProcessor:
The sample rate of the audio.
"""
def __init__(self, waveform: torch.Tensor, sr: int = SAMPLE_RATE,
*args, **kwargs) -> None:
def __init__(self, waveform: torch.Tensor,
sr: int = SAMPLE_RATE) -> None:
"""
Initialize the AudioProcessor object.
Args:
waveform (torch.Tensor): The audio waveform tensor.
sr (int, optional): The sample rate of the audio. Defaults to SAMPLE_RATE.
args: Additional arguments.
kwargs: Additional keyword arguments, e.g., device to use for processing.
If CUDA is available, it defaults to CUDA.
Raises:
ValueError: If the provided sample rate is not of type int.
"""
device = kwargs.get(
"device", "cuda" if torch.cuda.is_available() else "cpu")
self.waveform = waveform.to(device)
self.waveform = waveform
self.sr = sr
if not isinstance(self.sr, int):
+8 -7
View File
@@ -40,6 +40,7 @@ from .audio import AudioProcessor
from .diarisation import Diariser
from .transcriber import Transcriber, load_transcriber, whisper
from .transcript_exporter import Transcript
from .misc import SCRAIBE_TORCH_DEVICE
DiarisationType = TypeVar('DiarisationType')
@@ -74,7 +75,7 @@ class Scraibe:
whisper_model (Union[bool, str, whisper], optional):
Path to whisper model or whisper model itself.
whisper_type (str):
Type of whisper model to load. "whisper" or "whisperx".
Type of whisper model to load. "whisper" or "faster-whisper".
diarisation_model (Union[bool, str, DiarisationType], optional):
Path to pyannote diarization model or model itself.
**kwargs: Additional keyword arguments for whisper
@@ -116,6 +117,9 @@ class Scraibe:
else:
self.params = {}
self.device = kwargs.get(
"device", SCRAIBE_TORCH_DEVICE)
def autotranscribe(self, audio_file: Union[str, torch.Tensor, ndarray],
remove_original: bool = False,
**kwargs) -> Transcript:
@@ -141,7 +145,7 @@ class Scraibe:
# Prepare waveform and sample rate for diarization
dia_audio = {
"waveform": audio_file.waveform.reshape(1, len(audio_file.waveform)),
"waveform": audio_file.waveform.reshape(1, len(audio_file.waveform)).to(self.device),
"sample_rate": audio_file.sr
}
@@ -165,8 +169,6 @@ class Scraibe:
if self.verbose:
print("Diarisation finished. Starting transcription.")
audio_file.sr = torch.Tensor([audio_file.sr]).to(
audio_file.waveform.device)
# Transcribe each segment and store the results
final_transcript = dict()
@@ -213,7 +215,7 @@ class Scraibe:
# Prepare waveform and sample rate for diarization
dia_audio = {
"waveform": audio_file.waveform.reshape(1, len(audio_file.waveform)),
"waveform": audio_file.waveform.reshape(1, len(audio_file.waveform)).to(self.device),
"sample_rate": audio_file.sr
}
@@ -323,8 +325,7 @@ class Scraibe:
print(f"Audiofile {audio_file} removed.")
@staticmethod
def get_audio_file(audio_file: Union[str, torch.Tensor, ndarray],
*args, **kwargs) -> AudioProcessor:
def get_audio_file(audio_file: Union[str, torch.Tensor, ndarray]) -> AudioProcessor:
"""Gets an audio file as TorchAudioProcessor.
Args:
+11 -4
View File
@@ -36,8 +36,8 @@ def cli():
help="List of audio files to transcribe.")
parser.add_argument("--whisper-type", type=str, default="whisper",
choices=["whisper", "whisperx"],
help="Type of Whisper model to use ('whisper' or 'whisperx').")
choices=["whisper", "faster-whisper"],
help="Type of Whisper model to use ('whisper' or 'faster-whisper').")
parser.add_argument("--whisper-model-name", default="medium",
help="Name of the Whisper model to use.")
@@ -79,6 +79,8 @@ def cli():
choices=sorted(
LANGUAGES.keys()) + sorted([k.title() for k in TO_LANGUAGE_CODE.keys()]),
help="Language spoken in the audio. Specify None to perform language detection.")
parser.add_argument("--num-speakers", type=int, default=2,
help="Number of speakers in the audio.")
args = parser.parse_args()
@@ -117,8 +119,13 @@ def cli():
else:
task = "transcribe"
out = model.autotranscribe(audio, task=task, language=arg_dict.pop(
"language"), verbose=arg_dict.pop("verbose_output"))
out = model.autotranscribe(
audio,
task=task,
language=arg_dict.pop("language"),
verbose=arg_dict.pop("verbose_output"),
num_speakers=arg_dict.pop("num_speakers")
)
basename = audio.split("/")[-1].split(".")[0]
print(f'Saving {basename}.{out_format} to {out_folder}')
out.save(os.path.join(
+3 -8
View File
@@ -37,11 +37,11 @@ from pyannote.audio import Pipeline
from pyannote.audio.pipelines.speaker_diarization import SpeakerDiarization
from torch import Tensor
from torch import device as torch_device
from torch.cuda import is_available
from huggingface_hub import HfApi
from huggingface_hub.utils import RepositoryNotFoundError
from .misc import PYANNOTE_DEFAULT_PATH, PYANNOTE_DEFAULT_CONFIG
from .misc import PYANNOTE_DEFAULT_PATH, PYANNOTE_DEFAULT_CONFIG, SCRAIBE_TORCH_DEVICE
Annotation = TypeVar('Annotation')
TOKEN_PATH = os.path.join(os.path.dirname(
@@ -190,8 +190,7 @@ class Diariser:
cache_token: bool = False,
cache_dir: Union[Path, str] = PYANNOTE_DEFAULT_PATH,
hparams_file: Union[str, Path] = None,
device: str = None,
*args, **kwargs
device: str = SCRAIBE_TORCH_DEVICE,
) -> Pipeline:
"""
Loads a pretrained model from pyannote.audio,
@@ -283,10 +282,6 @@ class Diariser:
'or from huggingface.co models. Please check your token'
'or your local model path')
# try to move the model to the device
if device is None:
device = "cuda" if is_available() else "cpu"
# torch_device is renamed from torch.device to avoid name conflict
_model = _model.to(torch_device(device))
+7 -5
View File
@@ -1,23 +1,25 @@
import os
import yaml
from pyannote.audio.core.model import CACHE_DIR as PYANNOTE_CACHE_DIR
from argparse import Action
from ast import literal_eval
from torch.cuda import is_available
CACHE_DIR = os.getenv(
"AUTOT_CACHE",
os.path.expanduser("~/.cache/torch/models"),
)
if CACHE_DIR != PYANNOTE_CACHE_DIR:
os.environ["PYANNOTE_CACHE"] = os.path.join(CACHE_DIR, "pyannote")
os.environ["PYANNOTE_CACHE"] = os.getenv(
"PYANNOTE_CACHE",
os.path.join(CACHE_DIR, "pyannote"),
)
WHISPER_DEFAULT_PATH = os.path.join(CACHE_DIR, "whisper")
PYANNOTE_DEFAULT_PATH = os.path.join(CACHE_DIR, "pyannote")
PYANNOTE_DEFAULT_CONFIG = os.path.join(PYANNOTE_DEFAULT_PATH, "config.yaml") \
if os.path.exists(os.path.join(PYANNOTE_DEFAULT_PATH, "config.yaml")) \
else ('jaikinator/scraibe', 'pyannote/speaker-diarization-3.1')
else ('Jaikinator/ScrAIbe', 'pyannote/speaker-diarization-3.1')
SCRAIBE_TORCH_DEVICE = os.getenv("SCRAIBE_TORCH_DEVICE", "cuda" if is_available() else "cpu")
def config_diarization_yaml(file_path: str, path_to_segmentation: str = None) -> None:
"""Configure diarization pipeline from a YAML file.
+58 -28
View File
@@ -26,17 +26,17 @@ Usage:
from whisper import Whisper
from whisper import load_model as whisper_load_model
from whisperx.asr import WhisperModel
from whisperx import load_model as whisperx_load_model
from whisper.tokenizer import TO_LANGUAGE_CODE
from faster_whisper import WhisperModel as FasterWhisperModel
from faster_whisper.tokenizer import _LANGUAGE_CODES as FASTER_WHISPER_LANGUAGE_CODES
from typing import TypeVar, Union, Optional
from torch import Tensor, device
from torch.cuda import is_available as cuda_is_available
from numpy import ndarray
from inspect import signature
from abc import abstractmethod
import warnings
from .misc import WHISPER_DEFAULT_PATH
from .misc import WHISPER_DEFAULT_PATH, SCRAIBE_TORCH_DEVICE
whisper = TypeVar('whisper')
@@ -123,7 +123,7 @@ class Transcriber:
model: str = "medium",
whisper_type: str = 'whisper',
download_root: str = WHISPER_DEFAULT_PATH,
device: Optional[Union[str, device]] = None,
device: Optional[Union[str, device]] = SCRAIBE_TORCH_DEVICE,
in_memory: bool = False,
*args, **kwargs
) -> None:
@@ -145,7 +145,7 @@ class Transcriber:
- 'large-v3'
- 'large'
whisper_type (str):
Type of whisper model to load. "whisper" or "whisperx".
Type of whisper model to load. "whisper" or "faster-whisper".
download_root (str, optional): Path to download the model.
Defaults to WHISPER_DEFAULT_PATH.
device (Optional[Union[str, torch.device]], optional):
@@ -205,7 +205,7 @@ class WhisperTranscriber(Transcriber):
def load_model(cls,
model: str = "medium",
download_root: str = WHISPER_DEFAULT_PATH,
device: Optional[Union[str, device]] = None,
device: Optional[Union[str, device]] = SCRAIBE_TORCH_DEVICE,
in_memory: bool = False,
*args, **kwargs
) -> 'WhisperTranscriber':
@@ -272,7 +272,7 @@ class WhisperTranscriber(Transcriber):
return f"WhisperTranscriber(model_name={self.model_name}, model={self.model})"
class WhisperXTranscriber(Transcriber):
class FasterWhisperTranscriber(Transcriber):
def __init__(self, model: whisper, model_name: str) -> None:
super().__init__(model, model_name)
@@ -294,19 +294,19 @@ class WhisperXTranscriber(Transcriber):
if isinstance(audio, Tensor):
audio = audio.cpu().numpy()
result = self.model.transcribe(audio, *args, **kwargs)
result, _ = self.model.transcribe(audio, *args, **kwargs)
text = ""
for seg in result['segments']:
text += seg['text']
for seg in result:
text += seg.text
return text
@classmethod
def load_model(cls,
model: str = "medium",
download_root: str = WHISPER_DEFAULT_PATH,
device: Optional[Union[str, device]] = None,
device: Optional[Union[str, device]] = SCRAIBE_TORCH_DEVICE,
*args, **kwargs
) -> 'WhisperXTranscriber':
) -> 'FasterWhisperModel':
"""
Load whisper model.
@@ -329,7 +329,7 @@ class WhisperXTranscriber(Transcriber):
Defaults to WHISPER_DEFAULT_PATH.
device (Optional[Union[str, torch.device]], optional):
Device to load model on. Defaults to None.
Device to load model on. Defaults to SCRAIBE_TORCH_DEVICE.
in_memory (bool, optional): Whether to load model in memory.
Defaults to False.
args: Additional arguments only to avoid errors.
@@ -338,17 +338,17 @@ class WhisperXTranscriber(Transcriber):
Returns:
Transcriber: A Transcriber object initialized with the specified model.
"""
if device is None:
device = "cuda" if cuda_is_available() else "cpu"
if not isinstance(device, str):
device = str(device)
compute_type = kwargs.get('compute_type', 'float16')
if device == 'cpu' and compute_type == 'float16':
warnings.warn(f'Compute type {compute_type} not compatible with '
f'device {device}! Changing compute type to int8.')
compute_type = 'int8'
_model = whisperx_load_model(model, download_root=download_root,
device=device, compute_type=compute_type)
_model = FasterWhisperModel(model, download_root=download_root,
device=device, compute_type=compute_type)
return cls(_model, model_name=model)
@@ -361,7 +361,7 @@ class WhisperXTranscriber(Transcriber):
dict: Keyword arguments for whisper model.
"""
# _possible_kwargs = WhisperModel.transcribe.__code__.co_varnames
_possible_kwargs = signature(WhisperModel.transcribe).parameters.keys()
_possible_kwargs = signature(FasterWhisperModel.transcribe).parameters.keys()
whisper_kwargs = {k: v for k,
v in kwargs.items() if k in _possible_kwargs}
@@ -370,21 +370,51 @@ class WhisperXTranscriber(Transcriber):
whisper_kwargs["task"] = task
if (language := kwargs.get("language")):
language = FasterWhisperTranscriber.convert_to_language_code(language)
whisper_kwargs["language"] = language
return whisper_kwargs
@staticmethod
def convert_to_language_code(lang : str) -> str:
"""
Load whisper model.
Args:
lang (str): language as code or language name
Returns:
language (str) code of language
"""
# If the input is already in FASTER_WHISPER_LANGUAGE_CODES, return it directly
if lang in FASTER_WHISPER_LANGUAGE_CODES:
return lang
# Normalize the input to lowercase
lang = lang.lower()
# Check if the language name is in the TO_LANGUAGE_CODE mapping
if lang in TO_LANGUAGE_CODE:
return TO_LANGUAGE_CODE[lang]
# If the language is not recognized, raise a ValueError with the available options
available_codes = ', '.join(FASTER_WHISPER_LANGUAGE_CODES)
raise ValueError(f"Language '{lang}' is not a valid language code or name. "
f"Available language codes are: {available_codes}.")
def __repr__(self) -> str:
return f"WhisperXTranscriber(model_name={self.model_name}, model={self.model})"
return f"FasterWhisperTranscriber(model_name={self.model_name}, model={self.model})"
def load_transcriber(model: str = "medium",
whisper_type: str = 'whisper',
download_root: str = WHISPER_DEFAULT_PATH,
device: Optional[Union[str, device]] = None,
device: Optional[Union[str, device]] = SCRAIBE_TORCH_DEVICE,
in_memory: bool = False,
*args, **kwargs
) -> Union[WhisperTranscriber, WhisperXTranscriber]:
) -> Union[WhisperTranscriber, FasterWhisperTranscriber]:
"""
Load whisper model.
@@ -403,28 +433,28 @@ def load_transcriber(model: str = "medium",
- 'large-v3'
- 'large'
whisper_type (str):
Type of whisper model to load. "whisper" or "whisperx".
Type of whisper model to load. "whisper" or "faster-whisper".
download_root (str, optional): Path to download the model.
Defaults to WHISPER_DEFAULT_PATH.
device (Optional[Union[str, torch.device]], optional):
Device to load model on. Defaults to None.
Device to load model on. Defaults to SCRAIBE_TORCH_DEVICE.
in_memory (bool, optional): Whether to load model in memory.
Defaults to False.
args: Additional arguments only to avoid errors.
kwargs: Additional keyword arguments only to avoid errors.
Returns:
Union[WhisperTranscriber, WhisperXTranscriber]:
Union[WhisperTranscriber, FasterWhisperTranscriber]:
One of the Whisper variants as Transcrbier object initialized with the specified model.
"""
if whisper_type.lower() == 'whisper':
_model = WhisperTranscriber.load_model(
model, download_root, device, in_memory, *args, **kwargs)
return _model
elif whisper_type.lower() == 'whisperx':
_model = WhisperXTranscriber.load_model(
elif whisper_type.lower() == 'faster-whisper':
_model = FasterWhisperTranscriber.load_model(
model, download_root, device, *args, **kwargs)
return _model
else:
raise ValueError(f'Model type not recognized, exptected "whisper" '
f'or "whisperx", got {whisper_type}.')
f'or "faster-whisper", got {whisper_type}.')
@@ -6,7 +6,7 @@ import os
@pytest.fixture
def create_scraibe_instance():
if "HF_TOKEN" in os.environ:
return Scraibe(use_auth_token=os.environ["HF_TOKEN"])
return Scraibe(use_auth_token=os.environ["HF_TOKEN"], whisper_model= "tiny")
else:
return Scraibe()
@@ -19,19 +19,19 @@ def test_scraibe_init(create_scraibe_instance):
def test_scraibe_autotranscribe(create_scraibe_instance):
model = create_scraibe_instance
transcript = model.autotranscribe('test/audio_test_2.mp4')
transcript = model.autotranscribe('tests/audio_test_2.mp4')
assert isinstance(transcript, Transcript)
def test_scraibe_diarization(create_scraibe_instance):
model = create_scraibe_instance
diarisation_result = model.diarization('test/audio_test_2.mp4')
diarisation_result = model.diarization('tests/audio_test_2.mp4')
assert isinstance(diarisation_result, dict)
def test_scraibe_transcribe(create_scraibe_instance):
model = create_scraibe_instance
transcription_result = model.transcribe('test/audio_test_2.mp4')
transcription_result = model.transcribe('tests/audio_test_2.mp4')
assert isinstance(transcription_result, str)
@@ -1,6 +1,6 @@
import pytest
from scraibe import (Transcriber, WhisperTranscriber,
WhisperXTranscriber, load_transcriber)
FasterWhisperTranscriber, load_transcriber)
import torch
@@ -31,33 +31,33 @@ def test_transcriber(mock_load_model, audio_file, expected_transcription):
@pytest.fixture
def whisper_instance():
return load_transcriber('medium', whisper_type='whisper')
return load_transcriber('tiny', whisper_type='whisper')
@pytest.fixture
def whisperx_instance():
return load_transcriber('medium', whisper_type='whisperx')
def faster_whisper_instance():
return load_transcriber('tiny', whisper_type='faster-whisper')
def test_whisper_base_initialization(whisper_instance):
assert isinstance(whisper_instance, Transcriber)
def test_whisperx_base_initialization(whisperx_instance):
assert isinstance(whisperx_instance, Transcriber)
def test_faster_whisper_base_initialization(faster_whisper_instance):
assert isinstance(faster_whisper_instance, Transcriber)
def test_whisper_transcriber_initialization(whisper_instance):
assert isinstance(whisper_instance, WhisperTranscriber)
def test_whisperx_transcriber_initialization(whisperx_instance):
assert isinstance(whisperx_instance, WhisperXTranscriber)
def test_faster_whisper_transcriber_initialization(faster_whisper_instance):
assert isinstance(faster_whisper_instance, FasterWhisperTranscriber)
def test_wrong_transcriber_initialization():
with pytest.raises(ValueError):
load_transcriber('medium', whisper_type='wrong_whisper')
load_transcriber('tiny', whisper_type='wrong_whisper')
def test_get_whisper_kwargs():
@@ -69,12 +69,12 @@ def test_get_whisper_kwargs():
def test_whisper_transcribe(whisper_instance):
model = whisper_instance
# mocker.patch.object(transcriber_instance.model, 'transcribe', return_value={'Hello, World !'} )
transcript = model.transcribe('test/audio_test_2.mp4')
transcript = model.transcribe('tests/audio_test_2.mp4')
assert isinstance(transcript, str)
def test_whisperx_transcribe(whisperx_instance):
model = whisperx_instance
def test_faster_whisper_transcribe(faster_whisper_instance):
model = faster_whisper_instance
# mocker.patch.object(transcriber_instance.model, 'transcribe', return_value={'Hello, World !'} )
transcript = model.transcribe('test/audio_test_2.mp4')
transcript = model.transcribe('tests/audio_test_2.mp4')
assert isinstance(transcript, str)