Merge branch 'roll-back-torch-verision' into dependabot/pip/openai-whisper-20240930
This commit is contained in:
+11
-33
@@ -1,18 +1,14 @@
|
|||||||
name: Publish Python 🐍 distribution 📦 to PyPI and TestPyPI
|
name: Publish Python 🐍 distribution 📦 to PyPI and TestPyPI
|
||||||
|
|
||||||
on:
|
on:
|
||||||
pull_request_target:
|
|
||||||
branches:
|
|
||||||
- develop
|
|
||||||
types:
|
|
||||||
- closed
|
|
||||||
paths:
|
|
||||||
- scraibe/**
|
|
||||||
- pyproject.toml
|
|
||||||
|
|
||||||
push:
|
push:
|
||||||
tags:
|
tags:
|
||||||
- 'v*.*.*'
|
- 'v*.*.*'
|
||||||
|
branches:
|
||||||
|
- "develop"
|
||||||
|
paths:
|
||||||
|
- "scraibe/**"
|
||||||
|
- "pyproject.toml"
|
||||||
|
|
||||||
workflow_dispatch:
|
workflow_dispatch:
|
||||||
inputs:
|
inputs:
|
||||||
@@ -27,13 +23,7 @@ on:
|
|||||||
|
|
||||||
jobs:
|
jobs:
|
||||||
Build-and-publish-to-Test-PyPI:
|
Build-and-publish-to-Test-PyPI:
|
||||||
if: |
|
if: github.event_name != 'workflow_dispatch' || github.event.inputs.test == 'true'
|
||||||
(github.event_name == 'workflow_dispatch' &&
|
|
||||||
github.event.inputs.test == 'true') ||
|
|
||||||
(github.event_name == 'pull_request_target' &&
|
|
||||||
github.event.pull_request.merged &&
|
|
||||||
contains(github.event.pull_request.labels.*.name, 'release')) ||
|
|
||||||
(github.event_name == 'push' && startsWith(github.ref, 'refs/tags/'))
|
|
||||||
runs-on: ubuntu-latest
|
runs-on: ubuntu-latest
|
||||||
steps:
|
steps:
|
||||||
- uses: actions/checkout@v4
|
- uses: actions/checkout@v4
|
||||||
@@ -72,28 +62,16 @@ jobs:
|
|||||||
needs: Test-PyPi-install
|
needs: Test-PyPi-install
|
||||||
runs-on: ubuntu-latest
|
runs-on: ubuntu-latest
|
||||||
if: |
|
if: |
|
||||||
always() &&
|
always() &&
|
||||||
(( needs.Build-and-publish-to-Test-PyPI.result != 'failure' &&
|
(( needs.Build-and-publish-to-Test-PyPI.result != 'failure' &&
|
||||||
needs.Test-PyPi-install.result != 'failure' ) &&
|
needs.Test-PyPi-install.result != 'failure' ) ||
|
||||||
((github.event_name == 'workflow_dispatch' &&
|
((github.event_name == 'workflow_dispatch' &&
|
||||||
github.event.inputs.publish_to_pypi == 'true') ||
|
github.event.inputs.publish_to_pypi == 'true')))
|
||||||
(github.event_name == 'pull_request_target' &&
|
|
||||||
github.event.pull_request.merged &&
|
|
||||||
contains(github.event.pull_request.labels.*.name, 'release')) ||
|
|
||||||
(github.event_name == 'push' && startsWith(github.ref, 'refs/tags/'))))
|
|
||||||
steps:
|
steps:
|
||||||
- name: Checkout Repository Tags
|
|
||||||
uses: actions/checkout@v4
|
|
||||||
if: github.ref == 'refs/heads/main'
|
|
||||||
with:
|
|
||||||
fetch-depth: '0'
|
|
||||||
branch: 'main'
|
|
||||||
- name: Checkout Repository (Develop)
|
- name: Checkout Repository (Develop)
|
||||||
uses: actions/checkout@v4
|
uses: actions/checkout@v4
|
||||||
if: github.ref == 'refs/heads/develop'
|
|
||||||
with:
|
with:
|
||||||
fetch-depth: '0'
|
fetch-depth: '0'
|
||||||
branch: 'develop'
|
|
||||||
- name: Set up Poetry 📦
|
- name: Set up Poetry 📦
|
||||||
uses: JRubics/poetry-publish@v1.16
|
uses: JRubics/poetry-publish@v1.16
|
||||||
with:
|
with:
|
||||||
|
|||||||
-256
@@ -1,256 +0,0 @@
|
|||||||
channels:
|
|
||||||
- pytorch
|
|
||||||
- defaults
|
|
||||||
dependencies:
|
|
||||||
- _libgcc_mutex=0.1=main
|
|
||||||
- _openmp_mutex=5.1=1_gnu
|
|
||||||
- blas=1.0=mkl
|
|
||||||
- brotlipy=0.7.0=py39h27cfd23_1003
|
|
||||||
- bzip2=1.0.8=h7b6447c_0
|
|
||||||
- ca-certificates=2023.05.30=h06a4308_0
|
|
||||||
- certifi=2023.5.7=py39h06a4308_0
|
|
||||||
- cffi=1.15.1=py39h5eee18b_3
|
|
||||||
- cryptography=39.0.1=py39h9ce1e76_2
|
|
||||||
- cudatoolkit=11.3.1=h2bc3f7f_2
|
|
||||||
- ffmpeg=4.2.2=h20bf706_0
|
|
||||||
- flit-core=3.8.0=py39h06a4308_0
|
|
||||||
- freetype=2.12.1=h4a9f257_0
|
|
||||||
- giflib=5.2.1=h5eee18b_3
|
|
||||||
- gmp=6.2.1=h295c915_3
|
|
||||||
- gnutls=3.6.15=he1e5248_0
|
|
||||||
- idna=3.4=py39h06a4308_0
|
|
||||||
- intel-openmp=2021.4.0=h06a4308_3561
|
|
||||||
- jpeg=9e=h5eee18b_1
|
|
||||||
- lame=3.100=h7b6447c_0
|
|
||||||
- lcms2=2.12=h3be6417_0
|
|
||||||
- ld_impl_linux-64=2.38=h1181459_1
|
|
||||||
- lerc=3.0=h295c915_0
|
|
||||||
- libdeflate=1.17=h5eee18b_0
|
|
||||||
- libffi=3.4.2=h6a678d5_6
|
|
||||||
- libgcc-ng=11.2.0=h1234567_1
|
|
||||||
- libgomp=11.2.0=h1234567_1
|
|
||||||
- libidn2=2.3.2=h7f8727e_0
|
|
||||||
- libopus=1.3.1=h7b6447c_0
|
|
||||||
- libpng=1.6.39=h5eee18b_0
|
|
||||||
- libstdcxx-ng=11.2.0=h1234567_1
|
|
||||||
- libtasn1=4.16.0=h27cfd23_0
|
|
||||||
- libtiff=4.5.0=h6a678d5_2
|
|
||||||
- libunistring=0.9.10=h27cfd23_0
|
|
||||||
- libuv=1.44.2=h5eee18b_0
|
|
||||||
- libvpx=1.7.0=h439df22_0
|
|
||||||
- libwebp=1.2.4=h11a3e52_1
|
|
||||||
- libwebp-base=1.2.4=h5eee18b_1
|
|
||||||
- lz4-c=1.9.4=h6a678d5_0
|
|
||||||
- mkl=2021.4.0=h06a4308_640
|
|
||||||
- mkl-service=2.4.0=py39h7f8727e_0
|
|
||||||
- mkl_fft=1.3.1=py39hd3c417c_0
|
|
||||||
- mkl_random=1.2.2=py39h51133e4_0
|
|
||||||
- ncurses=6.4=h6a678d5_0
|
|
||||||
- nettle=3.7.3=hbbd107a_1
|
|
||||||
- numpy=1.23.5=py39h14f4228_0
|
|
||||||
- numpy-base=1.23.5=py39h31eccc5_0
|
|
||||||
- openh264=2.1.1=h4ff587b_0
|
|
||||||
- openssl=3.0.9=h7f8727e_0
|
|
||||||
- pillow=9.4.0=py39h6a678d5_0
|
|
||||||
- pip=23.0.1=py39h06a4308_0
|
|
||||||
- pycparser=2.21=pyhd3eb1b0_0
|
|
||||||
- pyopenssl=23.0.0=py39h06a4308_0
|
|
||||||
- pysocks=1.7.1=py39h06a4308_0
|
|
||||||
- python=3.9.16=h955ad1f_3
|
|
||||||
- pytorch=1.11.0=py3.9_cuda11.3_cudnn8.2.0_0
|
|
||||||
- pytorch-mutex=1.0=cuda
|
|
||||||
- readline=8.2=h5eee18b_0
|
|
||||||
- requests=2.28.1=py39h06a4308_1
|
|
||||||
- setuptools=65.6.3=py39h06a4308_0
|
|
||||||
- six=1.16.0=pyhd3eb1b0_1
|
|
||||||
- sqlite=3.41.2=h5eee18b_0
|
|
||||||
- tk=8.6.12=h1ccaba5_0
|
|
||||||
- torchaudio=0.11.0=py39_cu113
|
|
||||||
- torchvision=0.12.0=py39_cu113
|
|
||||||
- tzdata=2023c=h04d1e81_0
|
|
||||||
- wheel=0.38.4=py39h06a4308_0
|
|
||||||
- x264=1!157.20191217=h7b6447c_0
|
|
||||||
- xz=5.4.2=h5eee18b_0
|
|
||||||
- zlib=1.2.13=h5eee18b_0
|
|
||||||
- zstd=1.5.4=hc292b87_0
|
|
||||||
- pip:
|
|
||||||
- absl-py==1.3.0
|
|
||||||
- aiofiles==23.1.0
|
|
||||||
- aiohttp==3.8.3
|
|
||||||
- aiosignal==1.3.1
|
|
||||||
- alembic==1.9.1
|
|
||||||
- altair==5.0.1
|
|
||||||
- annotated-types==0.5.0
|
|
||||||
- ansi2html==1.8.0
|
|
||||||
- antlr4-python3-runtime==4.9.3
|
|
||||||
- anyio==3.7.1
|
|
||||||
- appdirs==1.4.4
|
|
||||||
- asteroid-filterbanks==0.4.0
|
|
||||||
- async-timeout==4.0.2
|
|
||||||
- attrs==22.2.0
|
|
||||||
- audioread==3.0.0
|
|
||||||
- autopage==0.5.1
|
|
||||||
- backports-cached-property==1.0.2
|
|
||||||
- cachetools==5.2.0
|
|
||||||
- charset-normalizer==2.1.1
|
|
||||||
- click==8.1.3
|
|
||||||
- cliff==4.1.0
|
|
||||||
- cmaes==0.9.0
|
|
||||||
- cmake==3.26.4
|
|
||||||
- cmd2==2.4.2
|
|
||||||
- colorama==0.4.6
|
|
||||||
- colorlog==6.7.0
|
|
||||||
- commonmark==0.9.1
|
|
||||||
- contourpy==1.0.6
|
|
||||||
- cycler==0.11.0
|
|
||||||
- dash==2.12.1
|
|
||||||
- dash-core-components==2.0.0
|
|
||||||
- dash-html-components==2.0.0
|
|
||||||
- dash-table==5.0.0
|
|
||||||
- decorator==4.4.2
|
|
||||||
- docopt==0.6.2
|
|
||||||
- einops==0.3.2
|
|
||||||
- exceptiongroup==1.1.1
|
|
||||||
- fastapi==0.100.0
|
|
||||||
- ffmpeg-python==0.2.0
|
|
||||||
- ffmpy==0.3.0
|
|
||||||
- filelock==3.8.0
|
|
||||||
- flask==2.2.5
|
|
||||||
- fonttools==4.38.0
|
|
||||||
- frozenlist==1.3.3
|
|
||||||
- fsspec==2022.11.0
|
|
||||||
- future==0.18.2
|
|
||||||
- google-auth==2.15.0
|
|
||||||
- google-auth-oauthlib==0.4.6
|
|
||||||
- gradio==3.36.1
|
|
||||||
- gradio-client==0.2.7
|
|
||||||
- greenlet==2.0.1
|
|
||||||
- grpcio==1.51.1
|
|
||||||
- h11==0.14.0
|
|
||||||
- hmmlearn==0.2.8
|
|
||||||
- httpcore==0.17.3
|
|
||||||
- httpx==0.24.1
|
|
||||||
- huggingface-hub==0.16.4
|
|
||||||
- humanize==4.7.0
|
|
||||||
- hyperpyyaml==1.1.0
|
|
||||||
- imageio==2.23.0
|
|
||||||
- imageio-ffmpeg==0.4.7
|
|
||||||
- importlib-metadata==4.13.0
|
|
||||||
- importlib-resources==5.12.0
|
|
||||||
- iniconfig==2.0.0
|
|
||||||
- itsdangerous==2.1.2
|
|
||||||
- jinja2==3.1.2
|
|
||||||
- joblib==1.2.0
|
|
||||||
- jsonschema==4.18.0
|
|
||||||
- jsonschema-specifications==2023.6.1
|
|
||||||
- julius==0.2.7
|
|
||||||
- kiwisolver==1.4.4
|
|
||||||
- librosa==0.9.2
|
|
||||||
- linkify-it-py==2.0.2
|
|
||||||
- lit==16.0.5.post0
|
|
||||||
- llvmlite==0.39.1
|
|
||||||
- mako==1.2.4
|
|
||||||
- markdown==3.4.1
|
|
||||||
- markdown-it-py==2.2.0
|
|
||||||
- markupsafe==2.1.1
|
|
||||||
- matplotlib==3.7.1
|
|
||||||
- mdit-py-plugins==0.3.3
|
|
||||||
- mdurl==0.1.2
|
|
||||||
- more-itertools==9.0.0
|
|
||||||
- moviepy==1.0.3
|
|
||||||
- mpmath==1.2.1
|
|
||||||
- multidict==6.0.4
|
|
||||||
- nest-asyncio==1.5.7
|
|
||||||
- networkx==2.8.8
|
|
||||||
- numba==0.56.4
|
|
||||||
- oauthlib==3.2.2
|
|
||||||
- omegaconf==2.3.0
|
|
||||||
- openai-whisper==20230314
|
|
||||||
- optuna==3.0.5
|
|
||||||
- orjson==3.9.2
|
|
||||||
- packaging==21.3
|
|
||||||
- pandas==1.5.2
|
|
||||||
- pbr==5.11.0
|
|
||||||
- plotly==5.15.0
|
|
||||||
- pluggy==1.0.0
|
|
||||||
- pooch==1.6.0
|
|
||||||
- prettytable==3.5.0
|
|
||||||
- primepy==1.3
|
|
||||||
- proglog==0.1.10
|
|
||||||
- protobuf==3.20.1
|
|
||||||
- pyannote-audio==2.1.1
|
|
||||||
- pyannote-core==4.5
|
|
||||||
- pyannote-database==4.1.3
|
|
||||||
- pyannote-metrics==3.2.1
|
|
||||||
- pyannote-pipeline==2.3
|
|
||||||
- pyasn1==0.4.8
|
|
||||||
- pyasn1-modules==0.2.8
|
|
||||||
- pydantic==2.0.2
|
|
||||||
- pydantic-core==2.1.2
|
|
||||||
- pydeprecate==0.3.2
|
|
||||||
- pydub==0.25.1
|
|
||||||
- pygments==2.13.0
|
|
||||||
- pyparsing==3.0.9
|
|
||||||
- pyperclip==1.8.2
|
|
||||||
- pytest==7.3.1
|
|
||||||
- python-dateutil==2.8.2
|
|
||||||
- python-multipart==0.0.6
|
|
||||||
- pytorch-lightning==1.6.5
|
|
||||||
- pytorch-metric-learning==1.6.3
|
|
||||||
- pytz==2022.7
|
|
||||||
- pyyaml==6.0
|
|
||||||
- qtfaststart==1.8
|
|
||||||
- referencing==0.29.1
|
|
||||||
- regex==2022.10.31
|
|
||||||
- requests-oauthlib==1.3.1
|
|
||||||
- resampy==0.4.2
|
|
||||||
- retrying==1.3.4
|
|
||||||
- rich==12.6.0
|
|
||||||
- rpds-py==0.8.10
|
|
||||||
- rsa==4.9
|
|
||||||
- ruamel-yaml==0.17.21
|
|
||||||
- ruamel-yaml-clib==0.2.7
|
|
||||||
- ruff==0.0.272
|
|
||||||
- scikit-learn==1.2.0
|
|
||||||
- scipy==1.8.1
|
|
||||||
- semantic-version==2.10.0
|
|
||||||
- semver==2.13.0
|
|
||||||
- sentencepiece==0.1.97
|
|
||||||
- setuptools-rust==1.5.2
|
|
||||||
- shellingham==1.5.0
|
|
||||||
- simplejson==3.18.0
|
|
||||||
- singledispatchmethod==1.0
|
|
||||||
- sniffio==1.3.0
|
|
||||||
- sortedcontainers==2.4.0
|
|
||||||
- soundfile==0.10.3.post1
|
|
||||||
- speechbrain==0.5.14
|
|
||||||
- sqlalchemy==1.4.45
|
|
||||||
- starlette==0.27.0
|
|
||||||
- stevedore==4.1.1
|
|
||||||
- sympy==1.11.1
|
|
||||||
- tabulate==0.9.0
|
|
||||||
- tenacity==8.2.2
|
|
||||||
- tensorboard==2.11.0
|
|
||||||
- tensorboard-data-server==0.6.1
|
|
||||||
- tensorboard-plugin-wit==1.8.1
|
|
||||||
- threadpoolctl==3.1.0
|
|
||||||
- tiktoken==0.3.1
|
|
||||||
- tokenizers==0.13.2
|
|
||||||
- tomli==2.0.1
|
|
||||||
- toolz==0.12.0
|
|
||||||
- torch-audiomentations==0.11.0
|
|
||||||
- torch-pitch-shift==1.2.2
|
|
||||||
- torchmetrics==0.11.0
|
|
||||||
- tqdm==4.64.1
|
|
||||||
- transformers==4.24.0
|
|
||||||
- triton==2.0.0
|
|
||||||
- typer==0.7.0
|
|
||||||
- typing-extensions==4.7.1
|
|
||||||
- uc-micro-py==1.0.2
|
|
||||||
- urllib3==1.26.12
|
|
||||||
- uvicorn==0.22.0
|
|
||||||
- wcwidth==0.2.5
|
|
||||||
- websockets==11.0.3
|
|
||||||
- werkzeug==2.2.2
|
|
||||||
- yarl==1.8.2
|
|
||||||
- zipp==3.11.0
|
|
||||||
+5
-5
@@ -31,12 +31,12 @@ exclude =[
|
|||||||
]
|
]
|
||||||
[tool.poetry.dependencies]
|
[tool.poetry.dependencies]
|
||||||
python = "^3.9"
|
python = "^3.9"
|
||||||
tqdm = "^4.66.4"
|
tqdm = "^4.66.5"
|
||||||
numpy = "^1.26.4"
|
numpy = "^1.26.4"
|
||||||
openai-whisper = ">=20231117,<20240931"
|
openai-whisper = ">=20231117,<20240931"
|
||||||
whisperx = "^3.1.3"
|
faster-whisper = "^1.0.3"
|
||||||
"pyannote.audio" = "^3.1.1"
|
"pyannote.audio" = "^3.3.1"
|
||||||
torch = "^2.3.0"
|
torch = "^2.1.2"
|
||||||
|
|
||||||
[tool.poetry.group.dev.dependencies]
|
[tool.poetry.group.dev.dependencies]
|
||||||
pytest = "^8.1.1"
|
pytest = "^8.1.1"
|
||||||
@@ -57,7 +57,7 @@ format-jinja = """
|
|||||||
|
|
||||||
[tool.poetry.group.docs.dependencies]
|
[tool.poetry.group.docs.dependencies]
|
||||||
sphinx = "^7.3.7"
|
sphinx = "^7.3.7"
|
||||||
sphinx-rtd-theme = "^2.0.0"
|
sphinx-rtd-theme = ">=2,<4"
|
||||||
markdown-it-py = {version = "~3.0.0", extras = ["plugins"]}
|
markdown-it-py = {version = "~3.0.0", extras = ["plugins"]}
|
||||||
myst-parser = "^3.0.1"
|
myst-parser = "^3.0.1"
|
||||||
mdit-py-plugins = "^0.4.1"
|
mdit-py-plugins = "^0.4.1"
|
||||||
|
|||||||
+4
-4
@@ -1,14 +1,14 @@
|
|||||||
tqdm>=4.65.0
|
tqdm>=4.66.5
|
||||||
numpy>=1.26.4
|
numpy>=1.26.4
|
||||||
|
|
||||||
openai-whisper==20231117
|
openai-whisper==20231117
|
||||||
whisperx~=3.1.3
|
faster-whisper~=1.0.3
|
||||||
|
|
||||||
pyannote.audio~=3.1.1
|
pyannote.audio~=3.3.1
|
||||||
pyannote.core~=5.0.0
|
pyannote.core~=5.0.0
|
||||||
pyannote.database~=5.0.1
|
pyannote.database~=5.0.1
|
||||||
pyannote.metrics~=3.2.1
|
pyannote.metrics~=3.2.1
|
||||||
pyannote.pipeline~=3.0.1
|
pyannote.pipeline~=3.0.1
|
||||||
|
|
||||||
torch>=2.0.0
|
torchaudio>=2.1.2
|
||||||
|
|
||||||
|
|||||||
@@ -74,7 +74,7 @@ class Scraibe:
|
|||||||
whisper_model (Union[bool, str, whisper], optional):
|
whisper_model (Union[bool, str, whisper], optional):
|
||||||
Path to whisper model or whisper model itself.
|
Path to whisper model or whisper model itself.
|
||||||
whisper_type (str):
|
whisper_type (str):
|
||||||
Type of whisper model to load. "whisper" or "whisperx".
|
Type of whisper model to load. "whisper" or "faster-whisper".
|
||||||
diarisation_model (Union[bool, str, DiarisationType], optional):
|
diarisation_model (Union[bool, str, DiarisationType], optional):
|
||||||
Path to pyannote diarization model or model itself.
|
Path to pyannote diarization model or model itself.
|
||||||
**kwargs: Additional keyword arguments for whisper
|
**kwargs: Additional keyword arguments for whisper
|
||||||
|
|||||||
+11
-4
@@ -36,8 +36,8 @@ def cli():
|
|||||||
help="List of audio files to transcribe.")
|
help="List of audio files to transcribe.")
|
||||||
|
|
||||||
parser.add_argument("--whisper-type", type=str, default="whisper",
|
parser.add_argument("--whisper-type", type=str, default="whisper",
|
||||||
choices=["whisper", "whisperx"],
|
choices=["whisper", "faster-whisper"],
|
||||||
help="Type of Whisper model to use ('whisper' or 'whisperx').")
|
help="Type of Whisper model to use ('whisper' or 'faster-whisper').")
|
||||||
|
|
||||||
parser.add_argument("--whisper-model-name", default="medium",
|
parser.add_argument("--whisper-model-name", default="medium",
|
||||||
help="Name of the Whisper model to use.")
|
help="Name of the Whisper model to use.")
|
||||||
@@ -79,6 +79,8 @@ def cli():
|
|||||||
choices=sorted(
|
choices=sorted(
|
||||||
LANGUAGES.keys()) + sorted([k.title() for k in TO_LANGUAGE_CODE.keys()]),
|
LANGUAGES.keys()) + sorted([k.title() for k in TO_LANGUAGE_CODE.keys()]),
|
||||||
help="Language spoken in the audio. Specify None to perform language detection.")
|
help="Language spoken in the audio. Specify None to perform language detection.")
|
||||||
|
parser.add_argument("--num-speakers", type=int, default=2,
|
||||||
|
help="Number of speakers in the audio.")
|
||||||
|
|
||||||
args = parser.parse_args()
|
args = parser.parse_args()
|
||||||
|
|
||||||
@@ -117,8 +119,13 @@ def cli():
|
|||||||
else:
|
else:
|
||||||
task = "transcribe"
|
task = "transcribe"
|
||||||
|
|
||||||
out = model.autotranscribe(audio, task=task, language=arg_dict.pop(
|
out = model.autotranscribe(
|
||||||
"language"), verbose=arg_dict.pop("verbose_output"))
|
audio,
|
||||||
|
task=task,
|
||||||
|
language=arg_dict.pop("language"),
|
||||||
|
verbose=arg_dict.pop("verbose_output"),
|
||||||
|
num_speakers=arg_dict.pop("num_speakers")
|
||||||
|
)
|
||||||
basename = audio.split("/")[-1].split(".")[0]
|
basename = audio.split("/")[-1].split(".")[0]
|
||||||
print(f'Saving {basename}.{out_format} to {out_folder}')
|
print(f'Saving {basename}.{out_format} to {out_folder}')
|
||||||
out.save(os.path.join(
|
out.save(os.path.join(
|
||||||
|
|||||||
+5
-5
@@ -1,6 +1,5 @@
|
|||||||
import os
|
import os
|
||||||
import yaml
|
import yaml
|
||||||
from pyannote.audio.core.model import CACHE_DIR as PYANNOTE_CACHE_DIR
|
|
||||||
from argparse import Action
|
from argparse import Action
|
||||||
from ast import literal_eval
|
from ast import literal_eval
|
||||||
|
|
||||||
@@ -8,15 +7,16 @@ CACHE_DIR = os.getenv(
|
|||||||
"AUTOT_CACHE",
|
"AUTOT_CACHE",
|
||||||
os.path.expanduser("~/.cache/torch/models"),
|
os.path.expanduser("~/.cache/torch/models"),
|
||||||
)
|
)
|
||||||
|
os.getenv(
|
||||||
if CACHE_DIR != PYANNOTE_CACHE_DIR:
|
"PYANNOTE_CACHE",
|
||||||
os.environ["PYANNOTE_CACHE"] = os.path.join(CACHE_DIR, "pyannote")
|
os.path.join(CACHE_DIR, "pyannote"),
|
||||||
|
)
|
||||||
|
|
||||||
WHISPER_DEFAULT_PATH = os.path.join(CACHE_DIR, "whisper")
|
WHISPER_DEFAULT_PATH = os.path.join(CACHE_DIR, "whisper")
|
||||||
PYANNOTE_DEFAULT_PATH = os.path.join(CACHE_DIR, "pyannote")
|
PYANNOTE_DEFAULT_PATH = os.path.join(CACHE_DIR, "pyannote")
|
||||||
PYANNOTE_DEFAULT_CONFIG = os.path.join(PYANNOTE_DEFAULT_PATH, "config.yaml") \
|
PYANNOTE_DEFAULT_CONFIG = os.path.join(PYANNOTE_DEFAULT_PATH, "config.yaml") \
|
||||||
if os.path.exists(os.path.join(PYANNOTE_DEFAULT_PATH, "config.yaml")) \
|
if os.path.exists(os.path.join(PYANNOTE_DEFAULT_PATH, "config.yaml")) \
|
||||||
else ('jaikinator/scraibe', 'pyannote/speaker-diarization-3.1')
|
else ('Jaikinator/ScrAIbe', 'pyannote/speaker-diarization-3.1')
|
||||||
|
|
||||||
|
|
||||||
def config_diarization_yaml(file_path: str, path_to_segmentation: str = None) -> None:
|
def config_diarization_yaml(file_path: str, path_to_segmentation: str = None) -> None:
|
||||||
|
|||||||
+49
-18
@@ -26,8 +26,9 @@ Usage:
|
|||||||
|
|
||||||
from whisper import Whisper
|
from whisper import Whisper
|
||||||
from whisper import load_model as whisper_load_model
|
from whisper import load_model as whisper_load_model
|
||||||
from whisperx.asr import WhisperModel
|
from whisper.tokenizer import TO_LANGUAGE_CODE
|
||||||
from whisperx import load_model as whisperx_load_model
|
from faster_whisper import WhisperModel as FasterWhisperModel
|
||||||
|
from faster_whisper.tokenizer import _LANGUAGE_CODES as FASTER_WHISPER_LANGUAGE_CODES
|
||||||
from typing import TypeVar, Union, Optional
|
from typing import TypeVar, Union, Optional
|
||||||
from torch import Tensor, device
|
from torch import Tensor, device
|
||||||
from torch.cuda import is_available as cuda_is_available
|
from torch.cuda import is_available as cuda_is_available
|
||||||
@@ -145,7 +146,7 @@ class Transcriber:
|
|||||||
- 'large-v3'
|
- 'large-v3'
|
||||||
- 'large'
|
- 'large'
|
||||||
whisper_type (str):
|
whisper_type (str):
|
||||||
Type of whisper model to load. "whisper" or "whisperx".
|
Type of whisper model to load. "whisper" or "faster-whisper".
|
||||||
download_root (str, optional): Path to download the model.
|
download_root (str, optional): Path to download the model.
|
||||||
Defaults to WHISPER_DEFAULT_PATH.
|
Defaults to WHISPER_DEFAULT_PATH.
|
||||||
device (Optional[Union[str, torch.device]], optional):
|
device (Optional[Union[str, torch.device]], optional):
|
||||||
@@ -272,7 +273,7 @@ class WhisperTranscriber(Transcriber):
|
|||||||
return f"WhisperTranscriber(model_name={self.model_name}, model={self.model})"
|
return f"WhisperTranscriber(model_name={self.model_name}, model={self.model})"
|
||||||
|
|
||||||
|
|
||||||
class WhisperXTranscriber(Transcriber):
|
class FasterWhisperTranscriber(Transcriber):
|
||||||
def __init__(self, model: whisper, model_name: str) -> None:
|
def __init__(self, model: whisper, model_name: str) -> None:
|
||||||
super().__init__(model, model_name)
|
super().__init__(model, model_name)
|
||||||
|
|
||||||
@@ -294,10 +295,10 @@ class WhisperXTranscriber(Transcriber):
|
|||||||
|
|
||||||
if isinstance(audio, Tensor):
|
if isinstance(audio, Tensor):
|
||||||
audio = audio.cpu().numpy()
|
audio = audio.cpu().numpy()
|
||||||
result = self.model.transcribe(audio, *args, **kwargs)
|
result, _ = self.model.transcribe(audio, *args, **kwargs)
|
||||||
text = ""
|
text = ""
|
||||||
for seg in result['segments']:
|
for seg in result:
|
||||||
text += seg['text']
|
text += seg.text
|
||||||
return text
|
return text
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
@@ -306,7 +307,7 @@ class WhisperXTranscriber(Transcriber):
|
|||||||
download_root: str = WHISPER_DEFAULT_PATH,
|
download_root: str = WHISPER_DEFAULT_PATH,
|
||||||
device: Optional[Union[str, device]] = None,
|
device: Optional[Union[str, device]] = None,
|
||||||
*args, **kwargs
|
*args, **kwargs
|
||||||
) -> 'WhisperXTranscriber':
|
) -> 'FasterWhisperModel':
|
||||||
"""
|
"""
|
||||||
Load whisper model.
|
Load whisper model.
|
||||||
|
|
||||||
@@ -347,8 +348,8 @@ class WhisperXTranscriber(Transcriber):
|
|||||||
warnings.warn(f'Compute type {compute_type} not compatible with '
|
warnings.warn(f'Compute type {compute_type} not compatible with '
|
||||||
f'device {device}! Changing compute type to int8.')
|
f'device {device}! Changing compute type to int8.')
|
||||||
compute_type = 'int8'
|
compute_type = 'int8'
|
||||||
_model = whisperx_load_model(model, download_root=download_root,
|
_model = FasterWhisperModel(model, download_root=download_root,
|
||||||
device=device, compute_type=compute_type)
|
device=device, compute_type=compute_type)
|
||||||
|
|
||||||
return cls(_model, model_name=model)
|
return cls(_model, model_name=model)
|
||||||
|
|
||||||
@@ -361,7 +362,7 @@ class WhisperXTranscriber(Transcriber):
|
|||||||
dict: Keyword arguments for whisper model.
|
dict: Keyword arguments for whisper model.
|
||||||
"""
|
"""
|
||||||
# _possible_kwargs = WhisperModel.transcribe.__code__.co_varnames
|
# _possible_kwargs = WhisperModel.transcribe.__code__.co_varnames
|
||||||
_possible_kwargs = signature(WhisperModel.transcribe).parameters.keys()
|
_possible_kwargs = signature(FasterWhisperModel.transcribe).parameters.keys()
|
||||||
|
|
||||||
whisper_kwargs = {k: v for k,
|
whisper_kwargs = {k: v for k,
|
||||||
v in kwargs.items() if k in _possible_kwargs}
|
v in kwargs.items() if k in _possible_kwargs}
|
||||||
@@ -370,12 +371,42 @@ class WhisperXTranscriber(Transcriber):
|
|||||||
whisper_kwargs["task"] = task
|
whisper_kwargs["task"] = task
|
||||||
|
|
||||||
if (language := kwargs.get("language")):
|
if (language := kwargs.get("language")):
|
||||||
|
language = FasterWhisperTranscriber.convert_to_language_code(language)
|
||||||
whisper_kwargs["language"] = language
|
whisper_kwargs["language"] = language
|
||||||
|
|
||||||
return whisper_kwargs
|
return whisper_kwargs
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def convert_to_language_code(lang : str) -> str:
|
||||||
|
"""
|
||||||
|
Load whisper model.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
lang (str): language as code or language name
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
language (str) code of language
|
||||||
|
"""
|
||||||
|
|
||||||
|
# If the input is already in FASTER_WHISPER_LANGUAGE_CODES, return it directly
|
||||||
|
if lang in FASTER_WHISPER_LANGUAGE_CODES:
|
||||||
|
return lang
|
||||||
|
|
||||||
|
# Normalize the input to lowercase
|
||||||
|
lang = lang.lower()
|
||||||
|
|
||||||
|
# Check if the language name is in the TO_LANGUAGE_CODE mapping
|
||||||
|
if lang in TO_LANGUAGE_CODE:
|
||||||
|
return TO_LANGUAGE_CODE[lang]
|
||||||
|
|
||||||
|
# If the language is not recognized, raise a ValueError with the available options
|
||||||
|
available_codes = ', '.join(FASTER_WHISPER_LANGUAGE_CODES)
|
||||||
|
raise ValueError(f"Language '{lang}' is not a valid language code or name. "
|
||||||
|
f"Available language codes are: {available_codes}.")
|
||||||
|
|
||||||
def __repr__(self) -> str:
|
def __repr__(self) -> str:
|
||||||
return f"WhisperXTranscriber(model_name={self.model_name}, model={self.model})"
|
return f"FasterWhisperTranscriber(model_name={self.model_name}, model={self.model})"
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
def load_transcriber(model: str = "medium",
|
def load_transcriber(model: str = "medium",
|
||||||
@@ -384,7 +415,7 @@ def load_transcriber(model: str = "medium",
|
|||||||
device: Optional[Union[str, device]] = None,
|
device: Optional[Union[str, device]] = None,
|
||||||
in_memory: bool = False,
|
in_memory: bool = False,
|
||||||
*args, **kwargs
|
*args, **kwargs
|
||||||
) -> Union[WhisperTranscriber, WhisperXTranscriber]:
|
) -> Union[WhisperTranscriber, FasterWhisperTranscriber]:
|
||||||
"""
|
"""
|
||||||
Load whisper model.
|
Load whisper model.
|
||||||
|
|
||||||
@@ -403,7 +434,7 @@ def load_transcriber(model: str = "medium",
|
|||||||
- 'large-v3'
|
- 'large-v3'
|
||||||
- 'large'
|
- 'large'
|
||||||
whisper_type (str):
|
whisper_type (str):
|
||||||
Type of whisper model to load. "whisper" or "whisperx".
|
Type of whisper model to load. "whisper" or "faster-whisper".
|
||||||
download_root (str, optional): Path to download the model.
|
download_root (str, optional): Path to download the model.
|
||||||
Defaults to WHISPER_DEFAULT_PATH.
|
Defaults to WHISPER_DEFAULT_PATH.
|
||||||
device (Optional[Union[str, torch.device]], optional):
|
device (Optional[Union[str, torch.device]], optional):
|
||||||
@@ -414,17 +445,17 @@ def load_transcriber(model: str = "medium",
|
|||||||
kwargs: Additional keyword arguments only to avoid errors.
|
kwargs: Additional keyword arguments only to avoid errors.
|
||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
Union[WhisperTranscriber, WhisperXTranscriber]:
|
Union[WhisperTranscriber, FasterWhisperTranscriber]:
|
||||||
One of the Whisper variants as Transcrbier object initialized with the specified model.
|
One of the Whisper variants as Transcrbier object initialized with the specified model.
|
||||||
"""
|
"""
|
||||||
if whisper_type.lower() == 'whisper':
|
if whisper_type.lower() == 'whisper':
|
||||||
_model = WhisperTranscriber.load_model(
|
_model = WhisperTranscriber.load_model(
|
||||||
model, download_root, device, in_memory, *args, **kwargs)
|
model, download_root, device, in_memory, *args, **kwargs)
|
||||||
return _model
|
return _model
|
||||||
elif whisper_type.lower() == 'whisperx':
|
elif whisper_type.lower() == 'faster-whisper':
|
||||||
_model = WhisperXTranscriber.load_model(
|
_model = FasterWhisperTranscriber.load_model(
|
||||||
model, download_root, device, *args, **kwargs)
|
model, download_root, device, *args, **kwargs)
|
||||||
return _model
|
return _model
|
||||||
else:
|
else:
|
||||||
raise ValueError(f'Model type not recognized, exptected "whisper" '
|
raise ValueError(f'Model type not recognized, exptected "whisper" '
|
||||||
f'or "whisperx", got {whisper_type}.')
|
f'or "faster-whisper", got {whisper_type}.')
|
||||||
|
|||||||
+11
-11
@@ -1,6 +1,6 @@
|
|||||||
import pytest
|
import pytest
|
||||||
from scraibe import (Transcriber, WhisperTranscriber,
|
from scraibe import (Transcriber, WhisperTranscriber,
|
||||||
WhisperXTranscriber, load_transcriber)
|
FasterWhisperTranscriber, load_transcriber)
|
||||||
import torch
|
import torch
|
||||||
|
|
||||||
|
|
||||||
@@ -31,33 +31,33 @@ def test_transcriber(mock_load_model, audio_file, expected_transcription):
|
|||||||
|
|
||||||
@pytest.fixture
|
@pytest.fixture
|
||||||
def whisper_instance():
|
def whisper_instance():
|
||||||
return load_transcriber('medium', whisper_type='whisper')
|
return load_transcriber('tiny', whisper_type='whisper')
|
||||||
|
|
||||||
|
|
||||||
@pytest.fixture
|
@pytest.fixture
|
||||||
def whisperx_instance():
|
def faster_whisper_instance():
|
||||||
return load_transcriber('medium', whisper_type='whisperx')
|
return load_transcriber('tiny', whisper_type='faster-whisper')
|
||||||
|
|
||||||
|
|
||||||
def test_whisper_base_initialization(whisper_instance):
|
def test_whisper_base_initialization(whisper_instance):
|
||||||
assert isinstance(whisper_instance, Transcriber)
|
assert isinstance(whisper_instance, Transcriber)
|
||||||
|
|
||||||
|
|
||||||
def test_whisperx_base_initialization(whisperx_instance):
|
def test_faster_whisper_base_initialization(faster_whisper_instance):
|
||||||
assert isinstance(whisperx_instance, Transcriber)
|
assert isinstance(faster_whisper_instance, Transcriber)
|
||||||
|
|
||||||
|
|
||||||
def test_whisper_transcriber_initialization(whisper_instance):
|
def test_whisper_transcriber_initialization(whisper_instance):
|
||||||
assert isinstance(whisper_instance, WhisperTranscriber)
|
assert isinstance(whisper_instance, WhisperTranscriber)
|
||||||
|
|
||||||
|
|
||||||
def test_whisperx_transcriber_initialization(whisperx_instance):
|
def test_faster_whisper_transcriber_initialization(faster_whisper_instance):
|
||||||
assert isinstance(whisperx_instance, WhisperXTranscriber)
|
assert isinstance(faster_whisper_instance, FasterWhisperTranscriber)
|
||||||
|
|
||||||
|
|
||||||
def test_wrong_transcriber_initialization():
|
def test_wrong_transcriber_initialization():
|
||||||
with pytest.raises(ValueError):
|
with pytest.raises(ValueError):
|
||||||
load_transcriber('medium', whisper_type='wrong_whisper')
|
load_transcriber('tiny', whisper_type='wrong_whisper')
|
||||||
|
|
||||||
|
|
||||||
def test_get_whisper_kwargs():
|
def test_get_whisper_kwargs():
|
||||||
@@ -73,8 +73,8 @@ def test_whisper_transcribe(whisper_instance):
|
|||||||
assert isinstance(transcript, str)
|
assert isinstance(transcript, str)
|
||||||
|
|
||||||
|
|
||||||
def test_whisperx_transcribe(whisperx_instance):
|
def test_faster_whisper_transcribe(faster_whisper_instance):
|
||||||
model = whisperx_instance
|
model = faster_whisper_instance
|
||||||
# mocker.patch.object(transcriber_instance.model, 'transcribe', return_value={'Hello, World !'} )
|
# mocker.patch.object(transcriber_instance.model, 'transcribe', return_value={'Hello, World !'} )
|
||||||
transcript = model.transcribe('test/audio_test_2.mp4')
|
transcript = model.transcribe('test/audio_test_2.mp4')
|
||||||
assert isinstance(transcript, str)
|
assert isinstance(transcript, str)
|
||||||
|
|||||||
@@ -0,0 +1,10 @@
|
|||||||
|
from os import environ
|
||||||
|
|
||||||
|
environ["AUTOT_CACHE"] = "/mnt/disk1/Projekte/ScrAIbe/tests"
|
||||||
|
# environ["PYANNOTE_CACHE"] = "/mnt/disk1/Projekte/ScrAIbe/tests/pyannote"
|
||||||
|
# environ["TORCH_HOME"] = "/mnt/disk1/Projekte/ScrAIbe/tests/torch"
|
||||||
|
|
||||||
|
from scraibe import Scraibe
|
||||||
|
|
||||||
|
scraibe = Scraibe(whisper_type = "faster-whisper", whisper_model = "tiny")
|
||||||
|
print(scraibe.autotranscribe('/mnt/disk1/Projekte/ScrAIbe/test/audio_test_1.mp4'))
|
||||||
Reference in New Issue
Block a user