Merge pull request #99 from JSchmie/develop

Update Major Release v0.2.0
This commit is contained in:
Jacob Schmieder
2024-05-31 16:01:56 +02:00
committed by GitHub
54 changed files with 2178 additions and 1810 deletions
+6
View File
@@ -0,0 +1,6 @@
scraibe/*__pycache__
scraibe/app/*__pycache__
scraibe/.pyannotetoken
.git
.gitignore
.github
+90
View File
@@ -0,0 +1,90 @@
name: Check and Add Version in Changelog
on:
pull_request:
branches:
- main
- develop
jobs:
check-and-add-version:
runs-on: ubuntu-latest
steps:
- name: Checkout Repository
uses: actions/checkout@v4
with:
fetch-depth: 0
- name: Check if Source Branch is docs
id: check_docs_branch
run: |
pr_head_ref="${{ github.event.pull_request.head.ref }}"
if [[ "$pr_head_ref" == "docs" ]]; then
echo "This is a docs branch merge. Exiting without creating a tag."
echo "is_docs_branch=true" >> $GITHUB_ENV
exit 0
else
echo "is_docs_branch=false" >> $GITHUB_ENV
fi
- name: Extract and Determine Version
if: env.is_docs_branch != 'true'
id: extract_version
run: |
# Fetch the latest tags from the remote
git fetch --tags
# Get the latest tag, or initialize to v0.0.0 if no tags are found
latest_tag=$(git describe --tags `git rev-list --tags --max-count=1` 2>/dev/null || echo "v0.0.0")
# Extract version from PR title or body
pr_body="${{ github.event.pull_request.body }}"
pr_title="${{ github.event.pull_request.title }}"
version_regex="v([0-9]+)\.([0-9]+)\.([0-9]+)"
if [[ $pr_body =~ $version_regex ]]; then
major=${BASH_REMATCH[1]}
minor=${BASH_REMATCH[2]}
patch=${BASH_REMATCH[3]}
new_tag="v$major.$minor.$patch"
elif [[ $pr_title =~ $version_regex ]]; then
major=${BASH_REMATCH[1]}
minor=${BASH_REMATCH[2]}
patch=${BASH_REMATCH[3]}
new_tag="v$major.$minor.$patch"
else
# Split the latest tag into parts
IFS='.' read -r -a parts <<< "${latest_tag#v}"
major=${parts[0]}
minor=${parts[1]}
patch=${parts[2]}
patch=$((patch + 1))
new_tag="v$major.$minor.$patch"
fi
clean_version="${new_tag#v}"
echo "version=$clean_version" >> $GITHUB_ENV
echo "Version determined: $clean_version"
- name: Check if Version Already Exists in Tags
if: env.is_docs_branch != 'true'
run: |
version="${{ env.version }}"
if git tag --list | grep -q "^$version$"; then
echo "Version $version already exists in tags."
exit 1
else
echo "Version $version does not exist in tags."
fi
- name: Check Version in CHANGELOG
if: env.is_docs_branch != 'true'
id: check_version
run: |
version="${{ env.version }}"
if ! grep -q "^## \[$version\]" CHANGELOG.md; then
echo "Version $version not found in CHANGELOG.md."
exit 1
else
echo "Version $version found in CHANGELOG.md."
fi
+44
View File
@@ -0,0 +1,44 @@
name: documentation
on:
push:
branches:
- main
workflow_dispatch:
permissions:
contents: write
jobs:
docs:
runs-on: ubuntu-latest
steps:
- uses: actions/checkout@v4
- uses: actions/setup-python@v3
with:
python-version: 3.9
- name: Install dependencies
run: |
sudo apt-get install libsndfile1-dev
pip install --upgrade pip
pip install -r requirements.txt
pip install --upgrade sphinx sphinx_rtd_theme myst-parser
pip install --upgrade markdown-it-py[plugins]
pip install --upgrade mdit-py-plugins
- name: Sphinx build
run: |
cp README.md ./source/README.md
cp LICENSE ./source/LICENSE
cp -r Pictures ./source/Pictures
sphinx-apidoc -o source scraibe/
sphinx-build -M html source docs
make html
- name: Deploy to GitHub Pages
uses: peaceiris/actions-gh-pages@v3
if: ${{ github.event_name == 'push' && github.ref == 'refs/heads/sphinx_action' }}
with:
publish_branch: gh-pages
github_token: ${{ secrets.TOKEN_GH }}
publish_dir: ./docs/html
force_orphan: true
+23
View File
@@ -0,0 +1,23 @@
name: Mirror and run GitLab CI
on: [push, delete]
jobs:
build:
runs-on: ubuntu-latest
steps:
- uses: actions/checkout@v3
with:
fetch-depth: 0
- name: Mirror + trigger CI
uses: SvanBoxel/gitlab-mirror-and-ci-action@master
with:
args: "https://git-dmz.thuenen.de/kida/i2-skills-beratungsstelle/active-service-requests/scraibe/scraibe"
env:
FOLLOW_TAGS: "true"
FORCE_PUSH: "true"
GITLAB_HOSTNAME: "git-dmz.thuenen.de"
GITLAB_USERNAME: ${{ secrets.GITLAB_USERNAME }}
GITLAB_PASSWORD: ${{ secrets.GITLAB_PASSWORD }}
GITLAB_PROJECT_ID: ${{ secrets.GITLAB_PROJECT_ID }}
GITHUB_TOKEN: ${{ secrets.GH_TOKEN }}
+75
View File
@@ -0,0 +1,75 @@
name: Publish Python 🐍 distribution 📦 to PyPI and TestPyPI
on:
push:
tags:
- v* # Push tags to trigger the workflow
pull_request:
types: [closed]
branches:
- develop
workflow_dispatch:
inputs:
test:
description: "Push to TestPyPI not PyPI"
default: true
type: boolean
jobs:
Build-and-publish-to-Test-PyPI:
runs-on: ubuntu-latest
steps:
- uses: actions/checkout@v4
with:
fetch-depth: '0'
- name: Set up Poetry 📦
uses: JRubics/poetry-publish@v1.16
with:
pypi_token: ${{ secrets.TEST_PYPI_API_TOKEN }}
plugins: "poetry-dynamic-versioning"
repository_name: "scraibe"
repository_url: "https://test.pypi.org/legacy/"
test-install:
name: Test Installation from TestPyPI
needs: Build-and-publish-to-Test-PyPI
runs-on: ubuntu-latest
strategy:
matrix:
python-version: [3.9, 3.11, 3.12]
steps:
- uses: actions/checkout@v4
- name: Set up Python ${{ matrix.python-version }}
uses: actions/setup-python@v5
with:
python-version: ${{ matrix.python-version }}
- name: Install package
run: |
pip install -U setuptools
pip install -r requirements.txt
python3 -m pip install --no-deps --pre --index-url https://test.pypi.org/simple/ scraibe>=0.1.3
python3 -c "import scraibe; print(scraibe.__version__)"
publish-to-pypi:
name: Publish to PyPI
needs: test-install
runs-on: ubuntu-latest
steps:
- name: Checkout Repository Tags
uses: actions/checkout@v4
if: github.ref == 'refs/tags/v*'
with:
fetch-depth: '0'
branch: 'main'
- name: Checkout Repository (Develop)
uses: actions/checkout@v4
if: github.ref == 'refs/heads/develop'
with:
fetch-depth: '0'
branch: 'develop'
- name: Set up Poetry 📦
uses: JRubics/poetry-publish@v1.16
with:
pypi_token: ${{ secrets.PYPI_API_TOKEN }}
plugins: "poetry-dynamic-versioning"
repository_name: "scraibe"
+39
View File
@@ -0,0 +1,39 @@
name: Run Tests
on:
pull_request:
branches: ['main', 'develop']
workflow_dispatch:
jobs:
pytest:
runs-on: ubuntu-latest
steps:
- name: Checkout
uses: actions/checkout@v3
with:
fetch-depth: 0
- name: Setup Python
uses: actions/setup-python@v3
with:
python-version: 3.9
- name: Install Dependencies
run: |
sudo apt update && sudo apt upgrade
python -m pip install --upgrade pip
pip install -r requirements.txt
pip install .
sudo apt-get install libsndfile1-dev
sudo apt-get install ffmpeg
pip install pytest
- name: Run pytest
env:
HF_TOKEN : ${{ secrets.HF_TOKEN }}
run: |
pytest
-39
View File
@@ -1,39 +0,0 @@
# This workflow will upload a Python Package using Twine when a release is created
# For more information see: https://docs.github.com/en/actions/automating-builds-and-tests/building-and-testing-python#publishing-to-package-registries
# This workflow uses actions that are not certified by GitHub.
# They are provided by a third-party and are governed by
# separate terms of service, privacy policy, and support
# documentation.
name: Upload Python Package
on:
release:
types: [published]
permissions:
contents: read
jobs:
deploy:
runs-on: ubuntu-latest
steps:
- uses: actions/checkout@v3
- name: Set up Python
uses: actions/setup-python@v3
with:
python-version: '3.x'
- name: Install dependencies
run: |
python -m pip install --upgrade pip
pip install build
- name: Build package
run: python -m build
- name: Publish package
uses: pypa/gh-action-pypi-publish@27b31702a0e7fc50959f5ad993c78deac1bdfc29
with:
user: __token__
password: ${{ secrets.PYPI_API_TOKEN }}
+9
View File
@@ -0,0 +1,9 @@
name: Ruff
on: push
jobs:
ruff:
runs-on: ubuntu-latest
if: ${{ github.event_name == 'pull_request' || (github.event_name == 'push') }}
steps:
- uses: actions/checkout@v4
- uses: chartboost/ruff-action@v1
+111
View File
@@ -0,0 +1,111 @@
name: Semantic Versioning for Tags
on:
pull_request:
types: [closed]
branches:
- main
jobs:
bump-version:
if: ${{ github.event.pull_request.merged == true && github.event.pull_request.base.ref == 'main' }}
runs-on: ubuntu-latest
steps:
- name: Checkout Repository
uses: actions/checkout@v4
with:
fetch-depth: 0
- name: Check if Source Branch is docs
id: check_docs_branch
run: |
pr_head_ref="${{ github.event.pull_request.head.ref }}"
if [[ "$pr_head_ref" == "docs" ]]; then
echo "is_docs_branch=true" >> $GITHUB_ENV
echo "This is a docs branch merge. Exiting without creating a tag."
exit 0
else
echo "is_docs_branch=false" >> $GITHUB_ENV
fi
- name: Bump Version and Tag
if: env.is_docs_branch != 'true'
id: bump_version
env:
GITHUB_TOKEN: ${{ secrets.GH_TOKEN }}
run: |
# Fetch the latest tags from the remote
git fetch --tags
# Get the latest tag, or initialize to v0.0.0 if no tags are found
latest_tag=$(git describe --tags `git rev-list --tags --max-count=1` 2>/dev/null || echo "v0.0.0")
# Extract version from PR title or body
pr_body="${{ github.event.pull_request.body }}"
pr_title="${{ github.event.pull_request.title }}"
version_regex="v([0-9]+)\.([0-9]+)\.([0-9]+)"
if [[ $pr_body =~ $version_regex ]]; then
major=${BASH_REMATCH[1]}
minor=${BASH_REMATCH[2]}
patch=${BASH_REMATCH[3]}
new_tag="v$major.$minor.$patch"
elif [[ $pr_title =~ $version_regex ]]; then
major=${BASH_REMATCH[1]}
minor=${BASH_REMATCH[2]}
patch=${BASH_REMATCH[3]}
new_tag="v$major.$minor.$patch"
else
# Split the latest tag into parts
IFS='.' read -r -a parts <<< "${latest_tag#v}"
major=${parts[0]}
minor=${parts[1]}
patch=${parts[2]}
patch=$((patch + 1))
new_tag="v$major.$minor.$patch"
fi
echo "Bumping version from $latest_tag to $new_tag"
# Set the new tag as an environment variable
echo "new_tag=$new_tag" >> $GITHUB_ENV
# Tag the new version
git tag $new_tag
# Configure GitHub token authentication
git remote set-url origin https://x-access-token:${{ secrets.GH_TOKEN }}@github.com/${{ github.repository }}.git
# Push the new tag to the remote repository
git push origin $new_tag
- name: Extract Release Notes
if: env.is_docs_branch != 'true'
id: extract_notes
run: |
version="${{ env.new_tag }}"
clean_version="${version#v}"
release_notes=$(awk -v version="$clean_version" '
BEGIN { flag=0 }
# Start flagging when the version section is found
/^## \[.*\]/ {
if (flag) exit # Exit when the next section starts
}
/^## \['"$clean_version"'\]/ { flag=1; next } # Start printing after the header
flag { print } # Print lines while flag is 1
' CHANGELOG.md)
echo "RELEASE_NOTES<<EOF" >> $GITHUB_ENV
echo "$release_notes" >> $GITHUB_ENV
echo "EOF" >> $GITHUB_ENV
- name: Create Release
if: env.is_docs_branch != 'true'
uses: actions/create-release@v1
env:
GITHUB_TOKEN: ${{ secrets.GH_TOKEN }}
with:
tag_name: ${{ env.new_tag }}
release_name: Release ${{ env.new_tag }}
body: ${{ env.RELEASE_NOTES }}
draft: false
prerelease: false
+242
View File
@@ -0,0 +1,242 @@
transcibe.py
scraibe/*__pycache__
scraibe/app/*__pycache__
scraibe/.pyannotetoken
# Created by https://www.toptal.com/developers/gitignore/api/python,visualstudiocode,linux,windows
# Edit at https://www.toptal.com/developers/gitignore?templates=python,visualstudiocode,linux,windows
### Linux ###
*~
# temporary files which can be created if a process still has a handle open of a deleted file
.fuse_hidden*
# KDE directory preferences
.directory
# Linux trash folder which might appear on any partition or disk
.Trash-*
# .nfs files are created when an open file is removed but is still being accessed
.nfs*
### Python ###
# Byte-compiled / optimized / DLL files
__pycache__/
*.py[cod]
*$py.class
# C extensions
*.so
# Distribution / packaging
.Python
build/
develop-eggs/
dist/
downloads/
eggs/
.eggs/
lib/
lib64/
parts/
sdist/
var/
wheels/
share/python-wheels/
*.egg-info/
.installed.cfg
*.egg
MANIFEST
# PyInstaller
# Usually these files are written by a python script from a template
# before PyInstaller builds the exe, so as to inject date/other infos into it.
*.manifest
*.spec
# Installer logs
pip-log.txt
pip-delete-this-directory.txt
# Unit test / coverage reports
htmlcov/
.tox/
.nox/
.coverage
.coverage.*
.cache
nosetests.xml
coverage.xml
*.cover
*.py,cover
.hypothesis/
.pytest_cache/
cover/
# Translations
*.mo
*.pot
# Django stuff:
*.log
local_settings.py
db.sqlite3
db.sqlite3-journal
# Flask stuff:
instance/
.webassets-cache
# Scrapy stuff:
.scrapy
# Sphinx documentation
docs/_build/
# PyBuilder
.pybuilder/
target/
# Jupyter Notebook
.ipynb_checkpoints
# IPython
profile_default/
ipython_config.py
# pyenv
# For a library or package, you might want to ignore these files since the code is
# intended to run in multiple environments; otherwise, check them in:
# .python-version
# pipenv
# According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control.
# However, in case of collaboration, if having platform-specific dependencies or dependencies
# having no cross-platform support, pipenv may install dependencies that don't work, or not
# install all needed dependencies.
#Pipfile.lock
# poetry
# Similar to Pipfile.lock, it is generally recommended to include poetry.lock in version control.
# This is especially recommended for binary packages to ensure reproducibility, and is more
# commonly ignored for libraries.
# https://python-poetry.org/docs/basic-usage/#commit-your-poetrylock-file-to-version-control
#poetry.lock
# pdm
# Similar to Pipfile.lock, it is generally recommended to include pdm.lock in version control.
#pdm.lock
# pdm stores project-wide configurations in .pdm.toml, but it is recommended to not include it
# in version control.
# https://pdm.fming.dev/#use-with-ide
.pdm.toml
# PEP 582; used by e.g. github.com/David-OConnor/pyflow and github.com/pdm-project/pdm
__pypackages__/
# Celery stuff
celerybeat-schedule
celerybeat.pid
# SageMath parsed files
*.sage.py
# Environments
.env
.venv
env/
venv/
ENV/
env.bak/
venv.bak/
# Spyder project settings
.spyderproject
.spyproject
# Rope project settings
.ropeproject
# mkdocs documentation
/site
# mypy
.mypy_cache/
.dmypy.json
dmypy.json
# Pyre type checker
.pyre/
# pytype static type analyzer
.pytype/
# Cython debug symbols
cython_debug/
# PyCharm
# JetBrains specific template is maintained in a separate JetBrains.gitignore that can
# be found at https://github.com/github/gitignore/blob/main/Global/JetBrains.gitignore
# and can be added to the global gitignore or merged into this file. For a more nuclear
# option (not recommended) you can uncomment the following to ignore the entire idea folder.
#.idea/
### Python Patch ###
# Poetry local configuration file - https://python-poetry.org/docs/configuration/#local-configuration
poetry.toml
# ruff
.ruff_cache/
# LSP config files
pyrightconfig.json
### VisualStudioCode ###
.vscode/*
!.vscode/settings.json
!.vscode/tasks.json
!.vscode/launch.json
!.vscode/extensions.json
!.vscode/*.code-snippets
# Local History for Visual Studio Code
.history/
# Built Visual Studio Code Extensions
*.vsix
### VisualStudioCode Patch ###
# Ignore all local history of files
.history
.ionide
### Windows ###
# Windows thumbnail cache files
Thumbs.db
Thumbs.db:encryptable
ehthumbs.db
ehthumbs_vista.db
# Dump file
*.stackdump
# Folder config file
[Dd]esktop.ini
# Recycle Bin used on file shares
$RECYCLE.BIN/
# Windows Installer files
*.cab
*.msi
*.msix
*.msm
*.msp
# Windows shortcuts
*.lnk
# End of https://www.toptal.com/developers/gitignore/api/python,visualstudiocode,linux,windows
+34
View File
@@ -0,0 +1,34 @@
# Changelog
All notable changes to this project will be documented in this file.
## [0.2.0] - 2024-05-28
### Added
- **Python Usage Section**: Detailed instructions on how to use ScrAIbe with Python, including examples for Whisper models, WhisperX, and keyword arguments.
- **Command-line Usage Section**: Enhanced instructions for using ScrAIbe via the command-line interface, including examples and key options.
- **Documentation Section**: Expanded the documentation section with highlights on installation guides, usage examples, API reference, troubleshooting tips, and advanced configuration.
- **Getting Started Section**: Added detailed prerequisites and installation instructions for both stable and development versions of ScrAIbe.
- **WhisperX Support**: Added support for the WhisperX backend.
### Changed
- **Model Customization**: Clarified the use of various keywords to customize Whisper models, Pyannote diarization models, and WhisperX.
- **Example Enhancements**: Improved examples to illustrate the usage of different features and options in ScrAIbe.
- **Formatting and Clarity**: Improved formatting and clarity across all sections to enhance readability and user experience.
- **Backend Robustness**: Enhanced the backend to be more robust, removing the need for a HuggingFace token for basic usage.
- **CLI**: to Work without Gradio
### Removed
- **Docker Build**: Removed Docker build support.
- **Gradio App**: Removed the Gradio App integration.
Both the Docker Build and the Gradio App are now Available under [ScrAIbe-WebUI](https://github.com/JSchmie/ScrAIbe-WebUI)
### Documentation
- **Documentation Page Link**: Updated the documentation section with a direct link to the [ScrAIbe documentation page](https://jschmie.github.io/ScrAIbe/).
**Note**: This changelog might be incomplete, but we promise to improve it in the future. Thank you for your understanding and support.
+59
View File
@@ -0,0 +1,59 @@
# Contributing to ScrAIbe
Thank you for your interest in contributing to ScrAIbe! We appreciate your efforts to improve the project. Before making any changes, please discuss them with the project maintainers via an issue, email, or any other method.
Please note that we have a code of conduct, and we ask you to adhere to it in all your interactions with the project.
## Pull Request Process
1. **Dependency Management**: Ensure any install or build dependencies are removed before the end of the layer when doing a build.
2. **Documentation Updates**: Update the `README.md` with details of changes to the interface, including new environment variables, exposed ports, useful file locations, and container parameters.
3. **Versioning**: Increase the version numbers in any example files and the `README.md` to the new version that this Pull Request would represent. We use the [SemVer](http://semver.org/) versioning scheme.
4. **Review and Merge**: You may merge the Pull Request once you have the sign-off of two other developers. If you do not have permission to merge, request a second reviewer to merge it for you.
## Code of Conduct
### Our Pledge
In the interest of fostering an open and welcoming environment, we as contributors and maintainers pledge to make participation in our project and our community a harassment-free experience for everyone, regardless of age, body size, disability, ethnicity, gender identity and expression, level of experience, nationality, personal appearance, race, religion, or sexual identity and orientation.
### Our Standards
Examples of behavior that contributes to creating a positive environment include:
* Using welcoming and inclusive language
* Being respectful of differing viewpoints and experiences
* Gracefully accepting constructive criticism
* Focusing on what is best for the community
* Showing empathy towards other community members
Examples of unacceptable behavior by participants include:
* The use of sexualized language or imagery and unwelcome sexual attention or advances
* Trolling, insulting/derogatory comments, and personal or political attacks
* Public or private harassment
* Publishing others' private information, such as a physical or electronic address, without explicit permission
* Other conduct which could reasonably be considered inappropriate in a professional setting
### Our Responsibilities
Project maintainers are responsible for clarifying the standards of acceptable behavior and are expected to take appropriate and fair corrective action in response to any instances of unacceptable behavior.
Project maintainers have the right and responsibility to remove, edit, or reject comments, commits, code, wiki edits, issues, and other contributions that are not aligned to this Code of Conduct, or to ban temporarily or permanently any contributor for other behaviors that they deem inappropriate, threatening, offensive, or harmful.
### Scope
This Code of Conduct applies both within project spaces and in public spaces when an individual is representing the project or its community. Examples include using an official project email address, posting via an official social media account, or acting as an appointed representative at an online or offline event. Representation of a project may be further defined and clarified by project maintainers.
### Enforcement
Instances of abusive, harassing, or otherwise unacceptable behavior may be reported by contacting the project team at [INSERT EMAIL ADDRESS]. All complaints will be reviewed and investigated and will result in a response that is deemed necessary and appropriate to the circumstances. The project team is obligated to maintain confidentiality with regard to the reporter of an incident. Further details of specific enforcement policies may be posted separately.
Project maintainers who do not follow or enforce the Code of Conduct in good faith may face temporary or permanent repercussions as determined by other members of the project's leadership.
### Attribution
This Code of Conduct is adapted from the [Contributor Covenant][homepage], version 1.4, available at [http://contributor-covenant.org/version/1/4][version].
[homepage]: http://contributor-covenant.org
[version]: http://contributor-covenant.org/version/1/4/
+25 -5
View File
@@ -1,8 +1,20 @@
#pytorch Image #pytorch Image
FROM pytorch/pytorch:1.11.0-cuda11.3-cudnn8-runtime FROM pytorch/pytorch:1.11.0-cuda11.3-cudnn8-runtime
# Labels
LABEL maintainer="Jacob Schmieder"
LABEL email="Jacob.Schmieder@dbfz.de"
LABEL version="0.1.1.dev"
LABEL description="Scraibe is a tool for automatic speech recognition and speaker diarization. \
It is based on the Hugging Face Transformers library and the Pyannote library. \
It is designed to be used with the Whisper model, a lightweight model for automatic \
speech recognition and speaker diarization."
LABEL url="https://github.com/JSchmie/ScrAIbe"
# Install dependencies # Install dependencies
WORKDIR /app WORKDIR /app
ARG hf_token ARG model_name=medium
#Enviorment Dependncies #Enviorment Dependncies
ENV TRANSFORMERS_CACHE /app/models ENV TRANSFORMERS_CACHE /app/models
ENV HF_HOME /app/models ENV HF_HOME /app/models
@@ -10,17 +22,25 @@ ENV AUTOT_CACHE /app/models
ENV PYANNOTE_CACHE /app/models/pyannote ENV PYANNOTE_CACHE /app/models/pyannote
#Copy all necessary files #Copy all necessary files
COPY requirements.txt /app/requirements.txt COPY requirements.txt /app/requirements.txt
COPY scraibe /app/Scraibe COPY README.md /app/README.md
COPY models /app/models
COPY scraibe /app/scraibe
COPY setup.py /app/setup.py COPY setup.py /app/setup.py
#Installing all necessary Dependencies and Running the Application with a personalised Hugging-Face-Token #Installing all necessary Dependencies and Running the Application with a personalised Hugging-Face-Token
RUN apt update && apt-get install -y libsm6 libxrender1 libfontconfig1
RUN conda update --all
RUN conda install pip RUN conda install pip
RUN conda install -y ffmpeg RUN conda install -y ffmpeg
RUN conda install -c conda-forge libsndfile RUN conda install -c conda-forge libsndfile
RUN pip install torchaudio==0.11.0+cu113 -f https://download.pytorch.org/whl/torch_stable.html RUN pip install torchaudio==0.11.0+cu113 -f https://download.pytorch.org/whl/torch_stable.html
RUN pip install /app/ RUN pip install -r requirements.txt
RUN pip install markupsafe==2.0.1 --force-reinstall RUN pip install markupsafe==2.0.1 --force-reinstall
RUN Scraibe --hf_token $hf_token
RUN python3 -m 'scraibe.cli' --whisper-model-name $model_name
# Expose port # Expose port
EXPOSE 7860 EXPOSE 7860
# Run the application # Run the application
ENTRYPOINT ["scraibe"]
ENTRYPOINT ["python3", "-m", "scraibe.cli" ,"--whisper-model-name", "$model_name"]
+20
View File
@@ -0,0 +1,20 @@
# Minimal makefile for Sphinx documentation
#
# You can set these variables from the command line, and also
# from the environment for the first two.
SPHINXOPTS ?=
SPHINXBUILD ?= sphinx-build
SOURCEDIR = source
BUILDDIR = build
# Put it first so that "make" without argument is like "make help".
help:
@$(SPHINXBUILD) -M help "$(SOURCEDIR)" "$(BUILDDIR)" $(SPHINXOPTS) $(O)
.PHONY: help Makefile
# Catch-all target: route all unknown targets to Sphinx using the new
# "make mode" option. $(O) is meant as a shortcut for $(SPHINXOPTS).
%: Makefile
@$(SPHINXBUILD) -M $@ "$(SOURCEDIR)" "$(BUILDDIR)" $(SPHINXOPTS) $(O)
BIN
View File
Binary file not shown.

Before

Width:  |  Height:  |  Size: 7.2 KiB

Binary file not shown.

Before

Width:  |  Height:  |  Size: 17 KiB

BIN
View File
Binary file not shown.

Before

Width:  |  Height:  |  Size: 14 KiB

Binary file not shown.

Before

Width:  |  Height:  |  Size: 15 KiB

BIN
View File
Binary file not shown.

Before

Width:  |  Height:  |  Size: 8.7 KiB

Binary file not shown.

Before

Width:  |  Height:  |  Size: 131 KiB

BIN
View File
Binary file not shown.

Before

Width:  |  Height:  |  Size: 16 KiB

Binary file not shown.

Before

Width:  |  Height:  |  Size: 17 KiB

+103 -129
View File
@@ -1,199 +1,173 @@
# `ScrAIbe: Streamlined Conversation Recording with Automated Intelligence Based Environment` 🎙️🧠
# `ScrAIbe: Streamlined Conversation Recording with Automated Intelligence Based Environment` Welcome to `ScrAIbe`, a state-of-the-art, [PyTorch](https://pytorch.org/) based multilingual speech-to-text framework designed to generate fully automated transcriptions.
`ScrAIbe` is a state-of-the-art, [PyTorch](https://pytorch.org/) based multilingual speech-to-text framework to generate fully automated transcriptions. Beyond transcription, ScrAIbe supports advanced functions such as speaker diarization and speaker recognition. 🚀
Beyond transcription, ScrAIbe supports advanced functions, such as speaker diarization and speaker recognition. Designed as a comprehensive AI toolkit, it uses multiple powerful AI models:
Designed as a comprehensive AI toolkit, it uses multiple AI models: - **[Whisper](https://github.com/openai/whisper)**: A general-purpose speech recognition model.
- **[WhisperX](https://github.com/m-bain/whisperX)**: A faster, quantized version of Whisper for enhanced performance on CPU. ⚡
- [whisper](https://github.com/openai/whisper): A general-purpose speech recognition model. - **[Pyannote-Audio](https://github.com/pyannote/pyannote-audio)**: An open-source toolkit for speaker diarization. 🗣️
- [payannote-audio](https://github.com/pyannote/pyannote-audio): An open-source toolkit for speaker diarization.
The framework utilizes a PyanNet-inspired pipeline, with the `Pyannote` library for speaker diarization and `VoxCeleb` for speaker embedding. The framework utilizes a PyanNet-inspired pipeline, with the `Pyannote` library for speaker diarization and `VoxCeleb` for speaker embedding.
During post-diarization, each audio segment is processed by the OpenAI `Whisper` model, in a transformer encoder-decoder structure. Initially, a CNN mitigates noise and enhances speech. Before transcription, `VoxLingua` identifies the language segment, facilitating Whisper's role in both transcription and text translation. During post-diarization, each audio segment is processed by the OpenAI `Whisper` model in a transformer encoder-decoder structure. Initially, a CNN mitigates noise and enhances speech. Before transcription, `VoxLingua` identifies the language segment, facilitating Whisper's role in both transcription and text translation. 🌍✨
The following graphic illustrates the whole pipeline: The following graphic illustrates the whole pipeline:
![Pipeline](Pictures/pipeline.png#gh-dark-mode-only) <div style="text-align:center;">
![Pipeline](Pictures/pipeline_light.png#gh-light-mode-only) <img src="./Pictures/pipeline.png#gh-dark-mode-only" style="width: 60%;" />
<img src="./Pictures/pipeline_light.png#gh-light-mode-only" style="width: 60%;" />
</div>
## Install `ScrAIbe` : ## Getting Started 🚀
The following command will pull and install the latest commit from this repository, along with its Python dependencies. ### Prerequisites
pip install scraibe Before installing ScrAIbe, ensure you have the following prerequisites:
- **Python version**: Python 3.8 - **Python**: Version 3.9 or later.
- **PyTorch version**: Python 1.11.0 - **PyTorch**: Version 2.0 or later.
- **CUDA version**: Cuda-toolkit 11.3.1 - **CUDA**: A compatible version with your PyTorch Version if you want to use GPU acceleration.
**Note:** PyTorch should be automatically installed with the pip installer. However, if you encounter any issues, you should consider installing it manually by following the instructions on the [PyTorch website](https://pytorch.org/get-started/locally/).
Important: For the `Pyannote` model, you need to be granted access to Hugging Face. ### Install ScrAIbe
Check the [Pyannote model page](https://huggingface.co/pyannote/speaker-diarization) to get access to the model.
Additionally, you need to generate a [Hugging Face token](https://huggingface.co/docs/hub/security-tokens). Install ScrAIbe on your local machine with ease using PyPI.
## Usage ```bash
pip install scraibe
```
If you want to install the development version, you can do so by installing it from GitHub:
```bash
pip install git+https://github.com/JSchmie/ScrAIbe.git@develop
```
or from PyPI using our latest pre-release:
```bash
pip install --pre scraibe
```
Get started with ScrAIbe today and experience seamless, automated transcription and diarization.
## Usage
We've developed ScrAIbe with several access points to cater to diverse user needs. We've developed ScrAIbe with several access points to cater to diverse user needs.
### Python usage ### Python Usage
It enables full control over the functionalities as well as process customization. Gain full control over the functionalities as well as process customization.
```python ```python
from scraibe import Scraibe from scraibe import Scraibe
model = Scraibe(use_auth_token = "hf_yourhftoken") model = Scraibe()
text = model.autotranscribe("audio.wav") text = model.autotranscribe("audio.wav")
print(f"Transcription: \n{text}") print(f"Transcription: \n{text}")
``` ```
The `Scraibe` Class is taking care of the models being properly loaded. Therefore, you can choose the other [whisper](https://github.com/openai/whisper/blob/main/model-card.md) models using the `whisper_model` keyword.
You can also change the `pyannote` diarization model using the `dia_model` keyword.
The `Scraibe` class ensures the models are properly loaded. You can customize the models with various keywords:
As input, `autoranscribe` accepts every format which is compatible with [ffmgeg](https://ffmpeg.org/ffmpeg-formats.html). Examples therefore are `.mp4 .mp3 .wav .ogg .flac` and many more. - **Whisper Models**: Use the `whisper_model` keyword to specify models like `tiny`, `base`, `small`, `medium`, or `large` (`large-v2`, `large-v3`) depending on your accuracy and speed needs.
- **Pyannote Diarization Model**: Use the `dia_model` keyword to change the diarization model.
- **WhisperX**: Set the `whisper_type` to `"whisperX"` for enhanced performance on CPU and use their enhanced models. (Model names are the same)
- **Keyword Arguments**: A variety of different `kwargs` are available:
- `use_auth_token`: Pass a Hugging Face token to the Pyannote backend if you want to use one of the models hosted on their Hugging Face.
- `verbose`: Enable this to add an additional level of verbosity.
In general, you should be able to input any `kwargs` that you can input in the original Whisper (WhisperX) and Pyannote Python APIs.
To further control the pipeline of `ScrAIbe` you can parse almost any keyword you also cloud parsed towards `whisper` or `pyannote` if you need more option, try to check out the documentations tows two Frameworks, you might have a good chance that these keywords will work here as well. As input, `autotranscribe` accepts every format compatible with [FFmpeg](https://ffmpeg.org/ffmpeg-formats.html). Examples include `.mp4`, `.mp3`, `.wav`, `.ogg`, `.flac`, and many more.
Here's are some examples regarding the `diarization` (which relies on the `pyannote` pipeline):
- `num_speakers` Number of speakers in the audio file To further control the pipeline of `ScrAIbe`, you can pass almost any keyword argument that is accepted by `Whisper` or `Pyannote`. For more options, refer to the documentation of these frameworks, as their keywords are likely to work here as well.
- `min_speakers` Minimal Number of speakers in the audio file
- `max_speakers` maximal Number of speakers in the audio file
Then there are arguments about the transcription process, which uses the "whisper" model. Here are some examples regarding `diarization` (which relies on the `pyannote` pipeline):
- `language` Specify the language ([list to supported languages](https://github.com/openai/whisper/blob/main/language-breakdown.svg)) - `num_speakers`: Number of speakers in the audio file
- `task` can be just `transcribe` or `translate`. If `translate` is selected, the transcribed audio will be translated to English. - `min_speakers`: Minimum number of speakers in the audio file
- `max_speakers`: Maximum number of speakers in the audio file
Then there are arguments for the transcription process, which uses the "Whisper" model:
- `language`: Specify the language ([list of supported languages](https://github.com/openai/whisper/blob/main/language-breakdown.svg))
- `task`: Can be either `transcribe` or `translate`. If `translate` is selected, the transcribed audio will be translated to English.
For example: For example:
``` ```python
text = model.autotranscribe("audio.wav", language="german", num_speakers = 2) text = model.autotranscribe("audio.wav", language="german", num_speakers = 2)
``` ```
`Scraibe` also contains the option to just do a transcription `Scraibe` also contains the option to just do a transcription:
```python ```python
transcription = model.transcribe("audio.wav") transcription = model.transcribe("audio.wav")
``` ```
or just do a diarization:
or just do a diarization:
```python ```python
diarization = model.diarize("audio.wav") diarization = model.diarization("audio.wav")
``` ```
Start exploring the powerful features of ScrAIbe and customize it to fit your specific transcription and diarization needs!
### Command-line usage ### Command-line usage
Next to the Pyhton interface, you can also run ScrAIbe using the command-line interface: Next to the Pyhton interface, you can also run ScrAIbe using the command-line interface:
scraibe -f "audio.wav" --hf-token "hf_yourhftoken" --language "german" --num_speakers 2 ```bash
scraibe -f "audio.wav" --language "german" --num_speakers 2
```
For the full list of options, run: For the full list of options, run:
scraibe -h ```bash
scraibe -h
### Gradio App
The Gradio App is a user-friendly interface for ScrAIbe. It enables you to run the model without any coding knowledge. Therefore, you can run the app in your browser and upload your audio file, or you can make the Framework avail on your network and run it on your local machine.
#### Running the Gradio App on your local machine
To run the Gradio App on your local machine, just use the following command:
```
scraibe --start_server --port 7860 --hf_token hf_yourhftoken
``` ```
- `--start_server`: Command to start the Gradio App. This will display a comprehensive list of all command-line options, allowing you to tailor ScrAIbes functionality to your specific needs.
- `--port`: Flag for connecting the container internal port to the port on your local machine.
- `--hf_token`: Flag for entering your personal HuggingFace token in the container.
When the app is running, it will show you at which address you can access it. ## Gradio App 🌐
The default address is: http://127.0.0.1:7860 or http://0.0.0.0:7860
After the app is running, you can upload your audio file and select the desired options. The Gradio App is now part of ScrAIbe-WebUI! This user-friendly interface enables you to run the model without any coding knowledge. You can easily run the app in your browser and upload your audio files, or make the framework available on your network and run it on your local machine. 🚀
An example is shown below:
![Gradio App](Pictures/gradio_app.png) All functionalities previously available in the Gradio App are now part of the ScrAIbe-WebUI. For more information and detailed instructions, visit the [ScrAIbe-WebUI GitHub repository](https://github.com/JSchmie/ScrAIbe-WebUI).
## Docker Container 🐳
ScrAIbe's Docker containers have also moved to ScrAIbe-WebUI! This option is especially useful if you want to run the model on a server or if you would like to use the GPU without dealing with CUDA.
All Docker container functionalities are now part of ScrAIbe-WebUI. For more information and detailed instructions on how to use the Docker containers, please visit the [ScrAIbe-WebUI GitHub repository](https://github.com/JSchmie/ScrAIbe-WebUI).
---
With these changes, ScrAIbe focuses on its core functionalities while the enhanced Gradio App and related Docker containers are now part of ScrAIbe-WebUI. Enjoy a more streamlined and powerful transcription experience! 🎉
## Documentation 📚
For comprehensive guides, detailed instructions, and advanced usage tips, visit our [documentation page](https://jschmie.github.io/ScrAIbe/). Here, you will find everything you need to make the most out of ScrAIbe.
### Contributions 🤝
We warmly welcome contributions from the community! Whether youre fixing bugs, adding new features, or improving documentation, your help is invaluable. Please see our [Contributing Guidelines](./CONTRIBUTING.md) for more information on how to get involved and make your mark on ScrAIbe-WebUI.
### Running a Docker container ### License 📜
Another option to run ScrAIbe is to use a Docker container. This option is especially useful if you want to run the model on a server or if you would like to use the GPU without dealing with CUDA. ScrAIbe-WebUI is proudly open source and licensed under the GPL-3.0 license. This promotes a collaborative and transparent development process. For more details, see the [LICENSE](./LICENSE) file in this repository.
After you have installed Docker, you can execute the following commands in the terminal.
First, you need to build the Docker image. Therefore, you need to enter your HuggingFace token and the image name.
```
docker build . --build-arg="hf_token=[enter your HuggingFace token] " -t scraibe
```
After the image is built, you can run the container with the following command:
```
sudo docker run -it -p 7860:7860 --name [container name][image name] --hf_token [enter your HuggingFace token] --start_server
```
- `-p`: Flag for connecting the container internal port to the port on your local machine.
- `--hf_token`: Flag for entering your personal HuggingFace token in the container.
- `--start_server`: Command to start the Gradio App.
Inside the container, the `cli` is used. Therefore, you can use the same commands as in the command-line interface.
#### Enabling GPU usage
To use the GPU, ensure your Docker installation supports GPU usage.
For further information, check: https://docs.nvidia.com/datacenter/cloud-native/container-toolkit/install-guide.html#docker
To enable GPU usage, you need to add the following flag to the `docker run` command:
```
docker run -it -p 7860:7860 --gpus 'all,capabilities=utility' --name [container name][image name] --hf_token [enter your HuggingFace token] --start_server
```
For further guidance, check: https://blog.roboflow.com/use-the-gpu-in-docker/
## Documentation
For further insights, check the [documentation page]().
## Contributions
We are happy to have any interest in contributing and about feedback: In order to do that, create an issue with your feedback or feel free to contact us.
## Roadmap
The following milestones are planned for further releases of ScrAIbe:
- Model quantization
Quantization to empower memory and computational efficiency.
- Model fine-tuning
In order to be able to cover a variety of linguistic phenomena.
For example, currently ScrAIbe is able to transcribe word by word, but ignores filler words or speech pauses.
These phenomena can be addressed by fine-tuning with the corresponding data.
- Implementation of LLMs
One example is the implementation of a summarization or extraction model, which enables ScrAIbe to automatically summarize or retrieve the key information out of a generated transcription, which could be the minutes of a meeting.
- Executable for Windows
## Contact
For queries contact [Jacob Schmieder](Jacob.Schmieder@dbfz.de)
## License
ScrAIbe is licensed under GNU General Public License.
## Acknowledgments ## Acknowledgments
Special thanks go to the KIDA project and the BMEL (Bundesministerium für Ernährung und Landwirtschaft), especially to the AI Consultancy Team. Special thanks go to the [KIDA](https://www.kida-bmel.de/) project and the [BMEL (Bundesministerium für Ernährung und Landwirtschaft)](https://www.bmel.de/EN/Home/home_node.html), especially to the AI Consultancy Team.
![KIDA](Pictures/kida_dark.png#gh-dark-mode-only) &nbsp; ![BMEL](Pictures/BMEL_dark.png#gh-dark-mode-only) &nbsp;&nbsp;&nbsp;&nbsp; ![DBFZ](Pictures/DBFZ_dark.png#gh-dark-mode-only) &nbsp; &nbsp;&nbsp;&nbsp; ![MRI](Pictures/MRI.png#gh-dark-mode-only) ---
![KIDA](Pictures/kida.png#gh-light-mode-only) &nbsp; ![BMEL](Pictures/BMEL.jpg#gh-light-mode-only) &nbsp;&nbsp;&nbsp;&nbsp; ![DBFZ](Pictures/DBFZ.png#gh-light-mode-only) &nbsp; &nbsp;&nbsp;&nbsp; ![MRI](Pictures/MRI.png#gh-light-mode-only) Join us in making ScrAIbe even better! 🚀
+73
View File
@@ -0,0 +1,73 @@
[build-system]
requires = ["poetry-core>=1.8.0", "poetry-dynamic-versioning>=1.0.0,<2.0.0"]
build-backend = "poetry_dynamic_versioning.backend"
[tool.poetry]
name = "scraibe"
version = "0.0.0"
description = "Transcription tool for audio files based on Whisper and Pyannote"
authors = ["Schmieder, Jacob <jacob.schmieder@dbfz.de>"]
license = "GPL-3.0-or-later"
readme = ["README.md", "LICENSE"]
repository = "https://github.com/JSchmie/ScAIbe"
documentation = "https://jschmie.github.io/ScrAIbe/"
keywords = ["transcription", "audio", "whisper", "pyannote", "speech-to-text", "speech-recognition"]
classifiers = [
'Development Status :: 4 - Beta',
'Intended Audience :: Developers',
'License :: OSI Approved :: GNU General Public License v3 (GPLv3)',
'Programming Language :: Python :: 3.8',
'Programming Language :: Python :: 3.9',
'Programming Language :: Python :: 3.10',
'Programming Language :: Python :: 3.11',
'Environment :: GPU :: NVIDIA CUDA :: 12 :: 12.1',
'Topic :: Scientific/Engineering :: Artificial Intelligence'
]
packages = [{include = "scraibe"}]
exclude =[
"__pycache__",
"*.pyc",
"test"
]
[tool.poetry.dependencies]
python = "^3.9"
tqdm = "^4.66.4"
numpy = "^1.26.4"
openai-whisper = "^20231117"
whisperx = "^3.1.3"
"pyannote.audio" = "^3.1.1"
torch = "^2.3.0"
[tool.poetry.group.dev.dependencies]
pytest = "^8.1.1"
[tool.poetry-dynamic-versioning]
enable = true
vcs = "git"
strict = true
format-jinja = """
{%- if distance == 0 -%}
{{ serialize_pep440(base) }}
{%- elif branch == 'develop' -%}
{{ serialize_pep440(bump_version(base), dev = distance) }}
{%- else -%}
{{ serialize_pep440(bump_version(base), dev=distance, metadata=[commit]) }}
{%- endif -%}
"""
[tool.poetry.group.docs.dependencies]
sphinx = "^7.3.7"
sphinx-rtd-theme = "^2.0.0"
markdown-it-py = {version = "~3.0.0", extras = ["plugins"]}
myst-parser = "^3.0.1"
mdit-py-plugins = "^0.4.1"
[tool.poetry.scripts]
scraibe = "scraibe.cli:cli"
[tool.poetry.extras]
app = ["scraibe-webui"]
[tool.ruff.lint.extend-per-file-ignores]
"__init__.py" = ["E402","F403",'F401']
"scraibe/misc.py" = ["E722"]
+9 -20
View File
@@ -1,25 +1,14 @@
openai-whisper==20230314
pyannote.audio~=2.1.1
pyannote.core~=4.5
pyannote.database~=4.1.3
pyannote.metrics~=3.2.1
pyannote.pipeline~=2.3
setuptools~=69.0.3
setuptools-rust~=1.8.1
tqdm>=4.65.0 tqdm>=4.65.0
numpy>=1.26.4
gradio~=3.36.1 openai-whisper==20231117
gradio-client~=0.2.7 whisperx~=3.1.3
# add pytorch to override the one installed by pyannote.audio pyannote.audio~=3.1.1
pyannote.core~=5.0.0
torch~=1.11.0 pyannote.database~=5.0.1
torchvision~=0.12.0 pyannote.metrics~=3.2.1
torchaudio~=0.11.0 pyannote.pipeline~=3.0.1
#optional:
#sphinx~=5.0.2
torch>=2.0.0
-1
View File
@@ -1 +0,0 @@
+2 -6
View File
@@ -4,12 +4,8 @@ from .audio import *
from .transcript_exporter import * from .transcript_exporter import *
from .diarisation import * from .diarisation import *
from .version import get_version as _get_version
from .misc import * from .misc import *
from .app.gradio_app import *
from .app.qtfaststart import *
from .cli import * from .cli import *
__version__ = _get_version() from ._version import __version__
+1
View File
@@ -0,0 +1 @@
__version__ = "0.0.0"
File diff suppressed because one or more lines are too long

Before

Width:  |  Height:  |  Size: 38 KiB

-2
View File
@@ -1,2 +0,0 @@
from .qtfaststart import *
from .gradio_app import *
-438
View File
@@ -1,438 +0,0 @@
"""
Gradio Audio Transcription App.
--------------------------------
This module provides an interface to transcribe audio files using the
Scraibe model. Users can either upload an audio file or record their speech
live for transcription. The application supports multiple languages and provides
options to specify the number of speakers and the language of the audio.
Attributes:
LANGUAGES (list): A list of supported languages for transcription.
Usage:
Run this script to start the Gradio web interface for audio transcription.
"""
"""
Gradio Audio Transcription App.
--------------------------------
This module provides an interface to transcribe audio files using the
Scraibe model. Users can either upload an audio file or record their speech
live for transcription. The application supports multiple languages and provides
options to specify the number of speakers and the language of the audio.
Attributes:
LANGUAGES (list): A list of supported languages for transcription.
Usage:
Run this script to start the Gradio web interface for audio transcription.
"""
import json
import os
import gradio as gr
from tqdm import tqdm
from scraibe import Scraibe, Transcript
theme = gr.themes.Soft(
primary_hue="green",
secondary_hue='orange',
neutral_hue="gray",
)
LANGUAGES = [
"Afrikaans", "Arabic", "Armenian", "Azerbaijani", "Belarusian",
"Bosnian", "Bulgarian", "Catalan", "Chinese", "Croatian",
"Czech", "Danish", "Dutch", "English", "Estonian",
"Finnish", "French", "Galician", "German", "Greek",
"Hebrew", "Hindi", "Hungarian", "Icelandic", "Indonesian",
"Italian", "Japanese", "Kannada", "Kazakh", "Korean",
"Latvian", "Lithuanian", "Macedonian", "Malay", "Marathi",
"Maori", "Nepali", "Norwegian", "Persian", "Polish",
"Portuguese", "Romanian", "Russian", "Serbian", "Slovak",
"Slovenian", "Spanish", "Swahili", "Swedish", "Tagalog",
"Tamil", "Thai", "Turkish", "Ukrainian", "Urdu",
"Vietnamese", "Welsh"
]
CURRENT_PATH = os.path.dirname(os.path.realpath(__file__))
class GradioTranscriptionInterface:
"""
Interface handling the interaction between Gradio UI and the Audio Transcription system.
"""
def __init__(self, model: Scraibe):
"""
Initializes the GradioTranscriptionInterface with a transcription model.
Args:
model (Scraibe): Model responsible for audio transcription tasks.
"""
self.model = model
def auto_transcribe(self, source,
num_speakers : int,
translation : bool,
language : str):
"""
Shortcut method for the Scraibe task.
Returns:
tuple: Transcribed text (str), JSON output (dict)
"""
kwargs = {
"num_speakers": num_speakers if num_speakers != 0 else None,
"language": language if language != "None" else None,
"task": 'translate' if translation else None
}
if isinstance(source, str):
try:
result = self.model.autotranscribe(source, **kwargs)
except ValueError:
raise gr.Error("Couldn't detect any speech in the provided audio. \
Please try again!")
return str(result), result.get_json()
elif isinstance(source, list):
source_names = [s.split("/")[-1] for s in source]
result = []
for s in tqdm(source, total=len(source),desc = "Transcribing audio files"):
try:
res = self.model.autotranscribe(s, **kwargs)
except ValueError:
_name = s.split("/")[-1]
res = f"NO TRANSCRIPT FOUND FOR {_name}"
gr.Warning(f"Couldn't detect any speech in {_name} will skip this file.")
result.append(res)
out = ''
out_dict = {}
for i, r in enumerate(result):
out += f"TRANSCRIPT {i} FOR ({source_names[i]}):\n\n"
out += str(r)
out += "\n\n"
if isinstance(r, str):
out_dict[source_names[i]] = r
else:
out_dict[source_names[i]] = r.get_dict()
return out, json.dumps(out_dict, indent=4)
else:
raise gr.Error("Please provide a valid audio file.")
def transcribe(self, source, translation, language):
"""
Shortcut method for the Transcribe task.
Returns:
str: Transcribed text.
"""
kwargs = {
"language": language if language != "None" else None,
"task": 'translate' if translation == "Yes" else None
}
if isinstance(source, str):
result = self.model.transcribe(source, **kwargs)
return str(result)
elif isinstance(source, list):
source_names = [s.split("/")[-1] for s in source]
result = []
for s in tqdm(source, total=len(source),desc = "Transcribing audio files"):
res = self.model.transcribe(s, **kwargs)
result.append(res)
out = ''
for i, res in enumerate(result):
out += f"TRANSCRIPT {i} FOR ({source_names[i]}):\n\n"
out += str(res)
out += "\n\n"
return out
else:
raise gr.Error("Please provide a valid audio file.")
def perform_diarisation(self, source, num_speakers):
"""
Shortcut method for the Diarisation task.
Returns:
str: JSON output of diarisation result.
"""
kwargs = {
"num_speakers": num_speakers if num_speakers != 0 else None,
}
if isinstance(source, str):
try:
result = self.model.diarization(source, **kwargs)
except ValueError:
raise gr.Error("Couldn't detect any speech in the provided audio. \
Please try again!")
return json.dumps(result, indent=2)
elif isinstance(source, list):
source_names = [s.split("/")[-1] for s in source]
result = []
for s in tqdm(source, total=len(source),desc = "Performing diarisation"):
try:
res = self.model.diarization(s, **kwargs)
except ValueError:
res = f"NO DIARISATION FOUND FOR {s}"
gr.Warning(f"Couldn't detect any speech in {s} will skip this file.")
result.append(res)
out = {}
for i, res in enumerate(result):
out[source_names[i]] = res
return json.dumps(out, indent=4)
else:
gr.Error("Please provide a valid audio file.")
####
# Gradio Interface
####
def gradio_Interface(model : Scraibe = None):
if model is None:
model = Scraibe()
pipe = GradioTranscriptionInterface(model)
def select_task(choice):
if choice == 'Auto Transcribe':
return (gr.update(visible = True),
gr.update(visible = True),
gr.update(visible = True))
elif choice == 'Transcribe':
return (gr.update(visible = False),
gr.update(visible = True),
gr.update(visible = True))
elif choice == 'Diarisation':
return (gr.update(visible = True),
gr.update(visible = False),
gr.update(visible = False))
def select_origin(choice):
if choice == "Upload Audio":
return (gr.update(visible = True),
gr.update(visible = False, value = None),
gr.update(visible = False, value = None),
gr.update(visible = False, value = None),
gr.update(visible = False, value = None))
elif choice == "Record Audio":
return (gr.update(visible = False, value = None),
gr.update(visible = True),
gr.update(visible = False, value = None),
gr.update(visible = False, value = None),
gr.update(visible = False, value = None))
elif choice == "Upload Video":
return (gr.update(visible = False, value = None),
gr.update(visible = False, value = None),
gr.update(visible = True),
gr.update(visible = False, value = None),
gr.update(visible = False, value = None))
elif choice == "Record Video":
return (gr.update(visible = False, value = None),
gr.update(visible = False, value = None),
gr.update(visible = False, value = None),
gr.update(visible = True),
gr.update(visible = False, value = None))
elif choice == "File or Files":
return (gr.update(visible = False, value = None),
gr.update(visible = False, value = None),
gr.update(visible = False, value = None),
gr.update(visible = False, value = None),
gr.update(visible = True))
def run_scribe(task,
num_speakers,
translate,
language,
audio1,
audio2,
video1,
video2,
file_in,
progress = gr.Progress(track_tqdm= True)):
# get *args which are not None
progress(0, desc='Starting task...')
source = audio1 or audio2 or video1 or video2 or file_in
if isinstance(source, list):
source = [s.name for s in source]
if len(source) == 1:
source = source[0]
if task == 'Auto Transcribe':
out_str , out_json = pipe.auto_transcribe(source = source,
num_speakers = num_speakers,
translation = translate,
language = language)
if isinstance(source, str):
return (gr.update(value = out_str, visible = True),
gr.update(value = out_json, visible = True),
gr.update(visible = True),
gr.update(visible = True))
else:
return (gr.update(value = out_str, visible = True),
gr.update(value = out_json, visible = True),
gr.update(visible = False),
gr.update(visible = False))
elif task == 'Transcribe':
out = pipe.transcribe(source = source,
translation = translate,
language = language)
return (gr.update(value = out, visible = True),
gr.update(value = None, visible = False),
gr.update(visible = False),
gr.update(visible = False))
elif task == 'Diarisation':
out = pipe.perform_diarisation(source = source,
num_speakers = num_speakers)
return (gr.update(value = None, visible = False),
gr.update(value = out, visible = True),
gr.update(visible = False),
gr.update(visible = False))
def annotate_output(annoation : str, out_json : dict):
# get *args which are not None
trans = Transcript.from_json(out_json)
trans = trans.annotate(*annoation.split(","))
return gr.update(value = str(trans)),gr.update(value = trans.get_json())
with gr.Blocks(theme=theme,title='ScrAIbe: Automatic Audio Transcription') as demo:
# Define components
hname = os.path.join(CURRENT_PATH, "header.html")
header = open(hname, "r").read()
gr.HTML(header, visible= True, show_label=False)
with gr.Row():
with gr.Column():
task = gr.Radio(["Auto Transcribe", "Transcribe", "Diarisation"], label="Task",
value= 'Auto Transcribe')
num_speakers = gr.Number(value=0, label= "Number of speakers (optional)",
info = "Number of speakers in the audio file. If you don't know,\
leave it at 0.", visible= True)
translate = gr.Checkbox(label="Translation", choices=[True, False], value = False,
info="Select 'Yes' to have the output translated into English.",
visible= True)
language = gr.Dropdown(LANGUAGES,
label="Language (optional)", value = "None",
info="Language of the audio file. If you don't know,\
leave it at None.", visible= True)
input = gr.Radio(["Upload Audio", "Record Audio", "Upload Video","Record Video"
,"File or Files"], label="Input Type", value="Upload Audio")
audio1 = gr.Audio(source="upload", type="filepath", label="Upload Audio",
interactive= True, visible= True)
audio2 = gr.Audio(source="microphone", label="Record Audio", type="filepath",
interactive= True, visible= False)
video1 = gr.Video(source="upload", type="filepath", label="Upload Video",
interactive= True, visible= False)
video2 = gr.Video(source="webcam", label="Record Video", type="filepath",
interactive= True, visible= False)
file_in = gr.Files(label="Upload File or Files", interactive= True, visible= False)
submit = gr.Button()
with gr.Column():
out_txt = gr.Textbox(label="Output",
visible= True, show_copy_button=True)
out_json = gr.JSON(label="JSON Output",
visible= False, show_copy_button=True)
annoation = gr.Textbox(label="Name your speaker's",
info= "Please provide a list of the speakers arranged \
in the order in which they appear in the input. Use comma ',' \
as a seperator. Be aware that the first name is given \
to SPEAKER_00 the second to SPEAKER_01 and so on.",
visible= False, interactive= True)
annotate = gr.Button(value="Annotate", visible= False, interactive= True)
# Define usage of components
input.change(fn=select_origin, inputs=[input],
outputs=[audio1, audio2, video1, video2, file_in])
task.change(fn=select_task, inputs=[task],
outputs=[num_speakers, translate, language])
translate.change(fn= lambda x : gr.update(value = x),
inputs=[translate], outputs=[translate])
num_speakers.change(fn= lambda x : gr.update(value = x),
inputs=[num_speakers], outputs=[num_speakers])
language.change(fn= lambda x : gr.update(value = x),
inputs=[language], outputs=[language])
submit.click(fn = run_scribe,
inputs=[task, num_speakers, translate, language, audio1,
audio2, video1, video2, file_in],
outputs=[out_txt, out_json, annoation, annotate])
annotate.click(fn = annotate_output, inputs=[annoation, out_json],
outputs=[out_txt, out_json])
return demo
if __name__ == "__main__":
gradio_Interface().queue().launch()
-66
View File
@@ -1,66 +0,0 @@
<!-- Importing Cormorant Garamond font from Google Fonts -->
<link href="https://fonts.googleapis.com/css2?family=Cormorant+Garamond:wght@400;700&display=swap" rel="stylesheet">
<style>
.header-container {
display: flex;
align-items: center;
justify-content: center;
position: relative;
padding-top: 30px;
}
.logo-container {
position: absolute;
top: 50%;
right: 20px;
transform: translateY(-50%);
width: 300px;
}
.logo {
width: 100%;
height: auto;
}
h1 {
font-family: 'Cormorant Garamond', serif;
font-size: 50px !important; /* Increased font size */
font-weight: bold;
color: #50AF31;
margin: 0;
position: relative;
padding: 0.5em 0;
}
h1::before, h1::after {
content: "";
position: absolute;
height: 2px;
width: 80%;
background-color: #50AF31;
left: 10%;
}
h1::before {
top: 0.5em;
}
h1::after {
bottom: 0.5em;
}
p, h2 {
font-size: 16px;
margin: 10px 0;
line-height: 1.4;
}
</style>
<div class="header-container">
<h1>ScrAIbe</h1>
<div class="logo-container">
<a href="https://www.kida-bmel.de/"> <!-- Replace with your actual URL -->
<img src="file/Logo_KIDA_bmel_green.svg" alt="KIDA Logo" class="logo">
</a>
</div>
</div>
<div style="text-align: center; padding: 20px 10%;">
<p>
Upload, record, or provide a video with audio for transcription. Our toolkit is designed to transcribe content from multiple languages accurately. The integrated speaker diarisation feature identifies different speakers, ensuring a smooth transcription experience. For optimal results, indicate the number of speakers and the original language of the content.
</p>
<h2 style="font-weight: bold; color: #50AF31;">What would you like to do next?</h2>
</div>
-319
View File
@@ -1,319 +0,0 @@
"""
This file contains a modified version of qtfaststart by qtfaststart
https://github.com/danielgtaylor/qtfaststart/tree/master
All credit goes to the original author.
Copyright (C) 2008 - 2013 Daniel G. Taylor <dan@programmer-art.org>
Permission is hereby granted, free of charge, to any person obtaining a copy of this
software and associated documentation files (the "Software"),
to deal in the Software without restriction, including without limitation the rights to
use, copy, modify, merge, publish, distribute, sublicense, and/or sell copies of the
Software, and to permit persons to whom the Software is furnished to do so,
subject to the following conditions:
The above copyright notice and this permission notice shall be included in all copies
or substantial portions of the Software.
THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR IMPLIED,
INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT.
IN NO EVENT SHALL THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM,
DAMAGES OR OTHER LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE,
ARISING FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS
IN THE SOFTWARE.
"""
import logging
import os
import struct
import collections
import io
# define error classes
class FastStartException(Exception):
"""
Raised when something bad happens during processing.
"""
pass
class FastStartSetupError(FastStartException):
"""
Rasised when asked to process a file that does not need processing
"""
pass
class MalformedFileError(FastStartException):
"""
Raised when the input file is setup in an unexpected way
"""
pass
class UnsupportedFormatError(FastStartException):
"""
Raised when a movie file is recognized as a format not supported.
"""
pass
# define constants
CHUNK_SIZE = 8192
log = logging.getLogger("qtfaststart")
# Older versions of Python require this to be defined
if not hasattr(os, 'SEEK_CUR'):
os.SEEK_CUR = 1
Atom = collections.namedtuple('Atom', 'name position size')
def read_atom(datastream):
"""
Read an atom and return a tuple of (size, type) where size is the size
in bytes (including the 8 bytes already read) and type is a "fourcc"
like "ftyp" or "moov".
"""
size, type = struct.unpack(">L4s", datastream.read(8))
type = type.decode('ascii')
return size, type
def _read_atom_ex(datastream):
"""
Read an Atom from datastream
"""
pos = datastream.tell()
atom_size, atom_type = read_atom(datastream)
if atom_size == 1:
atom_size, = struct.unpack(">Q", datastream.read(8))
return Atom(atom_type, pos, atom_size)
def get_index(datastream):
"""
Return an index of top level atoms, their absolute byte-position in the
file and their size in a list:
index = [
("ftyp", 0, 24),
("moov", 25, 2658),
("free", 2683, 8),
...
]
The tuple elements will be in the order that they appear in the file.
"""
log.debug("Getting index of top level atoms...")
index = list(_read_atoms(datastream))
_ensure_valid_index(index)
return index
def _read_atoms(datastream):
"""
Read atoms until an error occurs
"""
while datastream:
try:
atom = _read_atom_ex(datastream)
log.debug("%s: %s" % (atom.name, atom.size))
except:
break
yield atom
if atom.size == 0:
if atom.name == "mdat":
# Some files may end in mdat with no size set, which generally
# means to seek to the end of the file. We can just stop indexing
# as no more entries will be found!
break
else:
# Weird, but just continue to try to find more atoms
continue
datastream.seek(atom.position + atom.size)
def _ensure_valid_index(index):
"""
Ensure the minimum viable atoms are present in the index.
Raise FastStartException if not.
"""
top_level_atoms = set([item.name for item in index])
for key in ["moov", "mdat"]:
if key not in top_level_atoms:
log.error("%s atom not found, is this a valid MOV/MP4 file?" % key)
raise FastStartException()
def find_atoms(size, datastream):
"""
Compatibilty interface for _find_atoms_ex
"""
fake_parent = Atom('fake', datastream.tell()-8, size+8)
for atom in _find_atoms_ex(fake_parent, datastream):
yield atom.name
def _find_atoms_ex(parent_atom, datastream):
"""
Yield either "stco" or "co64" Atoms from datastream.
datastream will be 8 bytes into the stco or co64 atom when the value
is yielded.
It is assumed that datastream will be at the end of the atom after
the value has been yielded and processed.
parent_atom is the parent atom, a 'moov' or other ancestor of CO
atoms in the datastream.
"""
stop = parent_atom.position + parent_atom.size
while datastream.tell() < stop:
try:
atom = _read_atom_ex(datastream)
except:
log.exception("Error reading next atom!")
raise FastStartException()
if atom.name in ["trak", "mdia", "minf", "stbl"]:
# Known ancestor atom of stco or co64, search within it!
for res in _find_atoms_ex(atom, datastream):
yield res
elif atom.name in ["stco", "co64"]:
yield atom
else:
# Ignore this atom, seek to the end of it.
datastream.seek(atom.position + atom.size)
def process(infilename, limit=float('inf')):
"""
Convert a Quicktime/MP4 file for streaming by moving the metadata to
the front of the file. This method writes a new file.
If limit is set to something other than zero it will be used as the
number of bytes to write of the atoms following the moov atom. This
is very useful to create a small sample of a file with full headers,
which can then be used in bug reports and such.
"""
if isinstance(infilename, str):
datastream = open(infilename, "rb")
elif isinstance(infilename, bytes):
datastream = io.BytesIO(infilename)
else:
raise TypeError("infilename must be a filename, bytes or file-like object")
# Get the top level atom index
index = get_index(datastream)
mdat_pos = 999999
free_size = 0
# Make sure moov occurs AFTER mdat, otherwise no need to run!
for atom in index:
# The atoms are guaranteed to exist from get_index above!
if atom.name == "moov":
moov_atom = atom
moov_pos = atom.position
elif atom.name == "mdat":
mdat_pos = atom.position
elif atom.name == "free" and atom.position < mdat_pos:
# This free atom is before the mdat!
free_size += atom.size
log.info("Removing free atom at %d (%d bytes)" % (atom.position, atom.size))
elif atom.name == "\x00\x00\x00\x00" and atom.position < mdat_pos:
# This is some strange zero atom with incorrect size
free_size += 8
log.info("Removing strange zero atom at %s (8 bytes)" % atom.position)
# Offset to shift positions
offset = moov_atom.size - free_size
if moov_pos < mdat_pos:
# moov appears to be in the proper place, don't shift by moov size
offset -= moov_atom.size
if not free_size:
# No free atoms and moov is correct, we are done!
log.error("This file appears to already be setup for streaming!")
# Stupid hack to retrun the non-processed file:
if isinstance(infilename, str):
return open(infilename, "rb").read()
elif isinstance(infilename, bytes):
return io.BytesIO(infilename).read()
# Read and fix moov
moov = _patch_moov(datastream, moov_atom, offset)
log.info("Writing output...")
outfile = b''
# Write ftype
for atom in index:
if atom.name == "ftyp":
log.debug("Writing ftyp... (%d bytes)" % atom.size)
datastream.seek(atom.position)
outfile += datastream.read(atom.size)
# Write moov
_bytes = moov.getvalue()
log.debug("Writing moov... (%d bytes)" % len(_bytes))
outfile += _bytes
# Write the rest
atoms = [item for item in index if item.name not in ["ftyp", "moov", "free"]]
for atom in atoms:
log.debug("Writing %s... (%d bytes)" % (atom.name, atom.size))
datastream.seek(atom.position)
# for compatability, allow '0' to mean no limit
cur_limit = limit or float('inf')
cur_limit = min(cur_limit, atom.size)
for chunk in get_chunks(datastream, CHUNK_SIZE, cur_limit):
outfile += chunk
return outfile
def _patch_moov(datastream, atom, offset):
datastream.seek(atom.position)
moov = io.BytesIO(datastream.read(atom.size))
# reload the atom from the fixed stream
atom = _read_atom_ex(moov)
for atom in _find_atoms_ex(atom, moov):
# Read either 32-bit or 64-bit offsets
ctype, csize = dict(
stco=('L', 4),
co64=('Q', 8),
)[atom.name]
# Get number of entries
version, entry_count = struct.unpack(">2L", moov.read(8))
log.info("Patching %s with %d entries" % (atom.name, entry_count))
entries_pos = moov.tell()
struct_fmt = ">%(entry_count)s%(ctype)s" % vars()
# Read entries
entries = struct.unpack(struct_fmt, moov.read(csize * entry_count))
# Patch and write entries
offset_entries = [entry + offset for entry in entries]
moov.seek(entries_pos)
moov.write(struct.pack(struct_fmt, *offset_entries))
return moov
def get_chunks(stream, chunk_size, limit):
remaining = limit
while remaining:
chunk = stream.read(min(remaining, chunk_size))
if not chunk:
return
remaining -= len(chunk)
yield chunk
+23 -21
View File
@@ -28,6 +28,7 @@ import torch
SAMPLE_RATE = 16000 SAMPLE_RATE = 16000
NORMALIZATION_FACTOR = 32768.0 NORMALIZATION_FACTOR = 32768.0
class AudioProcessor: class AudioProcessor:
""" """
Audio Processor class that leverages PyTorchaudio to provide functionalities Audio Processor class that leverages PyTorchaudio to provide functionalities
@@ -39,10 +40,9 @@ class AudioProcessor:
sr: int sr: int
The sample rate of the audio. The sample rate of the audio.
""" """
def __init__(self, waveform: torch.Tensor, sr : int = SAMPLE_RATE, def __init__(self, waveform: torch.Tensor, sr: int = SAMPLE_RATE,
*args, **kwargs) -> None: *args, **kwargs) -> None:
""" """
Initialize the AudioProcessor object. Initialize the AudioProcessor object.
@@ -56,16 +56,17 @@ class AudioProcessor:
Raises: Raises:
ValueError: If the provided sample rate is not of type int. ValueError: If the provided sample rate is not of type int.
""" """
device = kwargs.get("device", "cuda" if torch.cuda.is_available() else "cpu") device = kwargs.get(
"device", "cuda" if torch.cuda.is_available() else "cpu")
self.waveform = waveform.to(device) self.waveform = waveform.to(device)
self.sr = sr self.sr = sr
if not isinstance(self.sr, int): if not isinstance(self.sr, int):
raise ValueError("Sample rate should be a single value of type int," \ raise ValueError("Sample rate should be a single value of type int,"
f"not {len(self.sr)} and type {type(self.sr)}") f"not {len(self.sr)} and type {type(self.sr)}")
@classmethod @classmethod
def from_file(cls, file: str, *args, **kwargs) -> 'AudioProcessor': def from_file(cls, file: str, *args, **kwargs) -> 'AudioProcessor':
""" """
@@ -77,14 +78,13 @@ class AudioProcessor:
Returns: Returns:
AudioProcessor: An instance of the AudioProcessor class containing the loaded audio. AudioProcessor: An instance of the AudioProcessor class containing the loaded audio.
""" """
audio, sr = cls.load_audio(file , *args, **kwargs) audio, sr = cls.load_audio(file, *args, **kwargs)
audio = torch.from_numpy(audio) audio = torch.from_numpy(audio)
return cls(audio, sr) return cls(audio, sr)
def cut(self, start: float, end: float) -> torch.Tensor: def cut(self, start: float, end: float) -> torch.Tensor:
""" """
Cut a segment from the audio waveform between the specified start and end times. Cut a segment from the audio waveform between the specified start and end times.
@@ -96,7 +96,7 @@ class AudioProcessor:
Returns: Returns:
torch.Tensor: The cut waveform segment. torch.Tensor: The cut waveform segment.
""" """
start = int(start * self.sr) start = int(start * self.sr)
if (isinstance(end, float) or isinstance(end, int)) and isinstance(self.sr, int): if (isinstance(end, float) or isinstance(end, int)) and isinstance(self.sr, int):
end = int(np.ceil(end * self.sr)) end = int(np.ceil(end * self.sr))
@@ -140,11 +140,13 @@ class AudioProcessor:
try: try:
out = run(cmd, capture_output=True, check=True).stdout out = run(cmd, capture_output=True, check=True).stdout
except CalledProcessError as e: except CalledProcessError as e:
raise RuntimeError(f"Failed to load audio: {e.stderr.decode()}") from e raise RuntimeError(
f"Failed to load audio: {e.stderr.decode()}") from e
out = np.frombuffer(out, np.int16).flatten().astype(
np.float32) / NORMALIZATION_FACTOR
return out, sr
out = np.frombuffer(out, np.int16).flatten().astype(np.float32) / NORMALIZATION_FACTOR
return out , sr
def __repr__(self) -> str: def __repr__(self) -> str:
return f'TorchAudioProcessor(waveform={len(self.waveform)}, sr={int(self.sr)})' return f'TorchAudioProcessor(waveform={len(self.waveform)}, sr={int(self.sr)})'
+146 -79
View File
@@ -38,7 +38,7 @@ from tqdm import trange
# Application-Specific Imports # Application-Specific Imports
from .audio import AudioProcessor from .audio import AudioProcessor
from .diarisation import Diariser from .diarisation import Diariser
from .transcriber import Transcriber, whisper from .transcriber import Transcriber, load_transcriber, whisper
from .transcript_exporter import Transcript from .transcript_exporter import Transcript
@@ -55,33 +55,43 @@ class Scraibe:
Attributes: Attributes:
transcriber (Transcriber): The transcriber object to handle transcription. transcriber (Transcriber): The transcriber object to handle transcription.
diariser (Diariser): The diariser object to handle diarization. diariser (Diariser): The diariser object to handle diarization.
Methods: Methods:
__init__: Initializes the Scraibe class with appropriate models. __init__: Initializes the Scraibe class with appropriate models.
transcribe: Transcribes an audio file using the whisper model and pyannote diarization model. transcribe: Transcribes an audio file using the whisper model and pyannote diarization model.
remove_audio_file: Removes the original audio file to avoid disk space issues or ensure data privacy. remove_audio_file: Removes the original audio file to avoid disk space issues or ensure data privacy.
get_audio_file: Gets an audio file as an AudioProcessor object. get_audio_file: Gets an audio file as an AudioProcessor object.
""" """
def __init__(self, def __init__(self,
whisper_model: Union[bool, str, whisper] = None, whisper_model: Union[bool, str, whisper] = None,
dia_model : Union[bool, str, DiarisationType] = None, whisper_type: str = "whisper",
**kwargs) -> None: dia_model: Union[bool, str, DiarisationType] = None,
**kwargs) -> None:
"""Initializes the Scraibe class. """Initializes the Scraibe class.
Args: Args:
whisper_model (Union[bool, str, whisper], optional): whisper_model (Union[bool, str, whisper], optional):
Path to whisper model or whisper model itself. Path to whisper model or whisper model itself.
whisper_type (str):
Type of whisper model to load. "whisper" or "whisperx".
diarisation_model (Union[bool, str, DiarisationType], optional): diarisation_model (Union[bool, str, DiarisationType], optional):
Path to pyannote diarization model or model itself. Path to pyannote diarization model or model itself.
**kwargs: Additional keyword arguments for whisper **kwargs: Additional keyword arguments for whisper
and pyannote diarization models. and pyannote diarization models.
e.g.:
- verbose: If True, the class will print additional information.
- save_kwargs: If True, the keyword arguments will be saved
for autotranscribe. So you can unload the class and reload it again.
""" """
if whisper_model is None: if whisper_model is None:
self.transcriber = Transcriber.load_model("medium", **kwargs) self.transcriber = load_transcriber(
"medium", whisper_type, **kwargs)
elif isinstance(whisper_model, str): elif isinstance(whisper_model, str):
self.transcriber = Transcriber.load_model(whisper_model, **kwargs) self.transcriber = load_transcriber(
whisper_model, whisper_type, **kwargs)
else: else:
self.transcriber = whisper_model self.transcriber = whisper_model
@@ -90,17 +100,25 @@ class Scraibe:
elif isinstance(dia_model, str): elif isinstance(dia_model, str):
self.diariser = Diariser.load_model(dia_model, **kwargs) self.diariser = Diariser.load_model(dia_model, **kwargs)
else: else:
self.diariser = dia_model self.diariser: Diariser = dia_model
if kwargs.get("verbose"): if kwargs.get("verbose"):
print("Scraibe initialized all models successfully loaded.") print("Scraibe initialized all models successfully loaded.")
self.verbose = True self.verbose = True
else: else:
self.verbose = False self.verbose = False
def autotranscribe(self, audio_file : Union[str, torch.Tensor, ndarray], # Save kwargs for autotranscribe if you want to unload the class and load it again.
remove_original : bool = False, if kwargs.get('save_setup'):
**kwargs) -> Transcript: self.params = dict(whisper_model=whisper_model,
dia_model=dia_model,
**kwargs)
else:
self.params = {}
def autotranscribe(self, audio_file: Union[str, torch.Tensor, ndarray],
remove_original: bool = False,
**kwargs) -> Transcript:
""" """
Transcribes an audio file using the whisper model and pyannote diarization model. Transcribes an audio file using the whisper model and pyannote diarization model.
@@ -119,60 +137,62 @@ class Scraibe:
if kwargs.get("verbose"): if kwargs.get("verbose"):
self.verbose = kwargs.get("verbose") self.verbose = kwargs.get("verbose")
# Get audio file as an AudioProcessor object # Get audio file as an AudioProcessor object
audio_file = self.get_audio_file(audio_file) audio_file: AudioProcessor = self.get_audio_file(audio_file)
# Prepare waveform and sample rate for diarization # Prepare waveform and sample rate for diarization
dia_audio = { dia_audio = {
"waveform" : audio_file.waveform.reshape(1,len(audio_file.waveform)), "waveform": audio_file.waveform.reshape(1, len(audio_file.waveform)),
"sample_rate": audio_file.sr "sample_rate": audio_file.sr
} }
if self.verbose: if self.verbose:
print("Starting diarisation.") print("Starting diarisation.")
diarisation = self.diariser.diarization(dia_audio, **kwargs) diarisation = self.diariser.diarization(dia_audio, **kwargs)
if not diarisation["segments"]: if not diarisation["segments"]:
print("No segments found. Try to run transcription without diarisation.") print("No segments found. Try to run transcription without diarisation.")
transcript = self.transcriber.transcribe(audio_file.waveform, **kwargs) transcript = self.transcriber.transcribe(
audio_file.waveform, **kwargs)
final_transcript= {0 : {"speakers" : 'SPEAKER_01',
"segments" : [0, len(audio_file.waveform)], final_transcript = {0: {"speakers": 'SPEAKER_01',
"text" : transcript}} "segments": [0, len(audio_file.waveform)],
"text": transcript}}
return Transcript(final_transcript) return Transcript(final_transcript)
if self.verbose: if self.verbose:
print("Diarisation finished. Starting transcription.") print("Diarisation finished. Starting transcription.")
audio_file.sr = torch.Tensor([audio_file.sr]).to(audio_file.waveform.device) audio_file.sr = torch.Tensor([audio_file.sr]).to(
audio_file.waveform.device)
# Transcribe each segment and store the results # Transcribe each segment and store the results
final_transcript = dict() final_transcript = dict()
for i in trange(len(diarisation["segments"]), desc= "Transcribing", disable = not self.verbose): for i in trange(len(diarisation["segments"]), desc="Transcribing", disable=not self.verbose):
seg = diarisation["segments"][i] seg = diarisation["segments"][i]
audio = audio_file.cut(seg[0], seg[1]) audio = audio_file.cut(seg[0], seg[1])
transcript = self.transcriber.transcribe(audio, **kwargs) transcript = self.transcriber.transcribe(audio, **kwargs)
final_transcript[i] = {"speakers" : diarisation["speakers"][i], final_transcript[i] = {"speakers": diarisation["speakers"][i],
"segments" : seg, "segments": seg,
"text" : transcript} "text": transcript}
# Remove original file if needed # Remove original file if needed
if remove_original: if remove_original:
if kwargs.get("shred") is True: if kwargs.get("shred") is True:
self.remove_audio_file(audio_file, shred=True) self.remove_audio_file(audio_file, shred=True)
else: else:
self.remove_audio_file(audio_file, shred=False) self.remove_audio_file(audio_file, shred=False)
return Transcript(final_transcript) return Transcript(final_transcript)
def diarization(self, audio_file : Union[str, torch.Tensor, ndarray], def diarization(self, audio_file: Union[str, torch.Tensor, ndarray],
**kwargs) -> dict: **kwargs) -> dict:
""" """
Perform diarization on an audio file using the pyannote diarization model. Perform diarization on an audio file using the pyannote diarization model.
@@ -187,24 +207,24 @@ class Scraibe:
dict: dict:
A dictionary containing the results of the diarization process. A dictionary containing the results of the diarization process.
""" """
# Get audio file as an AudioProcessor object # Get audio file as an AudioProcessor object
audio_file = self.get_audio_file(audio_file) audio_file: AudioProcessor = self.get_audio_file(audio_file)
# Prepare waveform and sample rate for diarization # Prepare waveform and sample rate for diarization
dia_audio = { dia_audio = {
"waveform" : audio_file.waveform.reshape(1,len(audio_file.waveform)), "waveform": audio_file.waveform.reshape(1, len(audio_file.waveform)),
"sample_rate": audio_file.sr "sample_rate": audio_file.sr
} }
print("Starting diarisation.") print("Starting diarisation.")
diarisation = self.diariser.diarization(dia_audio, **kwargs) diarisation = self.diariser.diarization(dia_audio, **kwargs)
return diarisation return diarisation
def transcribe(self, audio_file : Union[str, torch.Tensor, ndarray], def transcribe(self, audio_file: Union[str, torch.Tensor, ndarray],
**kwargs): **kwargs):
""" """
Transcribe the provided audio file. Transcribe the provided audio file.
@@ -218,12 +238,60 @@ class Scraibe:
str: str:
The transcribed text from the audio source. The transcribed text from the audio source.
""" """
audio_file = self.get_audio_file(audio_file) audio_file: AudioProcessor = self.get_audio_file(audio_file)
return self.transcriber.transcribe(audio_file.waveform, **kwargs) return self.transcriber.transcribe(audio_file.waveform, **kwargs)
def update_transcriber(self, whisper_model: Union[str, whisper], **kwargs) -> None:
"""
Update the transcriber model.
Args:
whisper_model (Union[str, whisper]):
The new whisper model to use for transcription.
**kwargs:
Additional keyword arguments for the transcriber model.
Returns:
None
"""
_old_model = self.transcriber.model_name
if isinstance(whisper_model, str):
self.transcriber = load_transcriber(whisper_model, **kwargs)
elif isinstance(whisper_model, Transcriber):
self.transcriber = whisper_model
else:
warn(
f"Invalid model type. Please provide a valid model. Fallback to old {_old_model} Model.", RuntimeWarning)
return None
def update_diariser(self, dia_model: Union[str, DiarisationType], **kwargs) -> None:
"""
Update the diariser model.
Args:
dia_model (Union[str, DiarisationType]):
The new diariser model to use for diarization.
**kwargs:
Additional keyword arguments for the diariser model.
Returns:
None
"""
if isinstance(dia_model, str):
self.diariser = Diariser.load_model(dia_model, **kwargs)
elif isinstance(dia_model, Diariser):
self.diariser = dia_model
else:
warn("Invalid model type. Please provide a valid model. Fallback to old Model.", RuntimeWarning)
return None
@staticmethod @staticmethod
def remove_audio_file(audio_file : str, def remove_audio_file(audio_file: str,
shred : bool = False) -> None: shred: bool = False) -> None:
""" """
Removes the original audio file to avoid disk space issues or ensure data privacy. Removes the original audio file to avoid disk space issues or ensure data privacy.
@@ -234,31 +302,29 @@ class Scraibe:
""" """
if not os.path.exists(audio_file): if not os.path.exists(audio_file):
raise ValueError(f"Audiofile {audio_file} does not exist.") raise ValueError(f"Audiofile {audio_file} does not exist.")
if shred: if shred:
warn("Shredding audiofile can take a long time.", RuntimeWarning) warn("Shredding audiofile can take a long time.", RuntimeWarning)
gen = iglob(f'{audio_file}', recursive=True) gen = iglob(f'{audio_file}', recursive=True)
cmd = ['shred', '-zvu', '-n', '10', f'{audio_file}'] cmd = ['shred', '-zvu', '-n', '10', f'{audio_file}']
if os.path.isdir(audio_file): if os.path.isdir(audio_file):
raise ValueError(f"Audiofile {audio_file} is a directory.") raise ValueError(f"Audiofile {audio_file} is a directory.")
for file in gen: for file in gen:
print(f'shredding {file} now\n') print(f'shredding {file} now\n')
run(cmd , check=True) run(cmd, check=True)
else: else:
os.remove(audio_file) os.remove(audio_file)
print(f"Audiofile {audio_file} removed.") print(f"Audiofile {audio_file} removed.")
@staticmethod @staticmethod
def get_audio_file(audio_file : Union[str, torch.Tensor, ndarray], def get_audio_file(audio_file: Union[str, torch.Tensor, ndarray],
*args, **kwargs) -> AudioProcessor: *args, **kwargs) -> AudioProcessor:
"""Gets an audio file as TorchAudioProcessor. """Gets an audio file as TorchAudioProcessor.
Args: Args:
@@ -271,19 +337,20 @@ class Scraibe:
AudioProcessor: An object containing the waveform and sample rate in AudioProcessor: An object containing the waveform and sample rate in
torch.Tensor format. torch.Tensor format.
""" """
if isinstance(audio_file, str): if isinstance(audio_file, str):
audio_file = AudioProcessor.from_file(audio_file) audio_file = AudioProcessor.from_file(audio_file)
elif isinstance(audio_file, torch.Tensor): elif isinstance(audio_file, torch.Tensor):
audio_file = AudioProcessor(audio_file[0], audio_file[1]) audio_file = AudioProcessor(audio_file[0], audio_file[1])
elif isinstance(audio_file, ndarray): elif isinstance(audio_file, ndarray):
audio_file = AudioProcessor(torch.Tensor(audio_file[0]), audio_file = AudioProcessor(torch.Tensor(audio_file[0]),
audio_file[1]) audio_file[1])
if not isinstance(audio_file, AudioProcessor): if not isinstance(audio_file, AudioProcessor):
raise ValueError(f'Audiofile must be of type AudioProcessor,' \ raise ValueError(f'Audiofile must be of type AudioProcessor,'
f'not {type(audio_file)}') f'not {type(audio_file)}')
return audio_file return audio_file
def __repr__(self): def __repr__(self):
+51 -64
View File
@@ -4,17 +4,13 @@ allowing for user interaction to transcribe and diarize audio files.
The function includes arguments for specifying the audio files, model paths, The function includes arguments for specifying the audio files, model paths,
output formats, and other options necessary for transcription. output formats, and other options necessary for transcription.
""" """
import os import os
from argparse import ArgumentParser, ArgumentDefaultsHelpFormatter
import json import json
from argparse import ArgumentParser, ArgumentDefaultsHelpFormatter
from .autotranscript import Scraibe from whisper.tokenizer import LANGUAGES, TO_LANGUAGE_CODE
from .app.gradio_app import gradio_Interface
from whisper.tokenizer import LANGUAGES , TO_LANGUAGE_CODE
from torch.cuda import is_available from torch.cuda import is_available
from torch import set_num_threads from torch import set_num_threads
from .autotranscript import Scraibe
def cli(): def cli():
""" """
@@ -25,40 +21,34 @@ def cli():
This function can be executed from the command line to perform transcription tasks, providing a This function can be executed from the command line to perform transcription tasks, providing a
user-friendly way to access the Scraibe class functionalities. user-friendly way to access the Scraibe class functionalities.
""" """
def str2bool(string): def str2bool(string):
str2val = {"True": True, "False": False} str2val = {"True": True, "False": False}
if string in str2val: if string in str2val:
return str2val[string] return str2val[string]
else: else:
raise ValueError(f"Expected one of {set(str2val.keys())}, got {string}") raise ValueError(
f"Expected one of {set(str2val.keys())}, got {string}")
parser = ArgumentParser(formatter_class = ArgumentDefaultsHelpFormatter) parser = ArgumentParser(formatter_class=ArgumentDefaultsHelpFormatter)
group = parser.add_mutually_exclusive_group() parser.add_argument("-f", "--audio-files", nargs="+", type=str, default=None,
parser.add_argument("-f","--audio-files", nargs="+", type=str, default=None,
help="List of audio files to transcribe.") help="List of audio files to transcribe.")
parser.add_argument("--whisper-type", type=str, default="whisper",
choices=["whisper", "whisperx"],
help="Type of Whisper model to use ('whisper' or 'whisperx').")
group.add_argument('--start-server', action='store_true',
help='Start the Gradio app.')
parser.add_argument("--port", type=int, default= None,
help="Port to run the Gradio app on. Defaults to 7860.")
parser.add_argument("--server-name", type=str, default= None,
help="Name of the Gradio app. If empty 127.0.0.1 or 0.0.0.0 will be used.")
parser.add_argument("--whisper-model-name", default="medium", parser.add_argument("--whisper-model-name", default="medium",
help="Name of the Whisper model to use.") help="Name of the Whisper model to use.")
parser.add_argument("--whisper-model-directory", type=str, default= None, parser.add_argument("--whisper-model-directory", type=str, default=None,
help="Path to save Whisper model files; defaults to ./models/whisper.") help="Path to save Whisper model files; defaults to ./models/whisper.")
parser.add_argument("--diarization-directory", type=str, default= None, parser.add_argument("--diarization-directory", type=str, default=None,
help="Path to the diarization model directory.") help="Path to the diarization model directory.")
parser.add_argument("--hf-token", default= None, type=str, parser.add_argument("--hf-token", default=None, type=str,
help="HuggingFace token for private model download.") help="HuggingFace token for private model download.")
parser.add_argument("--inference-device", parser.add_argument("--inference-device",
@@ -66,7 +56,8 @@ def cli():
help="Device to use for PyTorch inference.") help="Device to use for PyTorch inference.")
parser.add_argument("--num-threads", type=int, default=0, parser.add_argument("--num-threads", type=int, default=0,
help="Number of threads used by torch for CPU inference; overrides MKL_NUM_THREADS/OMP_NUM_THREADS.") help="Number of threads used by torch for CPU inference; '\
'overrides MKL_NUM_THREADS/OMP_NUM_THREADS.")
parser.add_argument("--output-directory", "-o", type=str, default=".", parser.add_argument("--output-directory", "-o", type=str, default=".",
help="Directory to save the transcription outputs.") help="Directory to save the transcription outputs.")
@@ -78,90 +69,86 @@ def cli():
parser.add_argument("--verbose-output", type=str2bool, default=True, parser.add_argument("--verbose-output", type=str2bool, default=True,
help="Enable or disable progress and debug messages.") help="Enable or disable progress and debug messages.")
parser.add_argument("--task", type=str, default= 'autotranscribe', # unifinished code parser.add_argument("--task", type=str, default='autotranscribe',
choices=["autotranscribe", "diarization", choices=["autotranscribe", "diarization",
"autotranscribe+translate", "translate", 'transcribe'], "autotranscribe+translate", "translate", 'transcribe'],
help="Choose to perform transcription, diarization, or translation. \ help="Choose to perform transcription, diarization, or translation. \
If set to translate, the output will be translated to English.") If set to translate, the output will be translated to English.")
parser.add_argument("--language", type=str, default=None, parser.add_argument("--language", type=str, default=None,
choices=sorted(LANGUAGES.keys()) + sorted([k.title() for k in TO_LANGUAGE_CODE.keys()]), choices=sorted(
LANGUAGES.keys()) + sorted([k.title() for k in TO_LANGUAGE_CODE.keys()]),
help="Language spoken in the audio. Specify None to perform language detection.") help="Language spoken in the audio. Specify None to perform language detection.")
args = parser.parse_args() args = parser.parse_args()
arg_dict = vars(args) arg_dict = vars(args)
# configure output # configure output
out_folder = arg_dict.pop("output_directory") out_folder = arg_dict.pop("output_directory")
os.makedirs(out_folder, exist_ok=True) os.makedirs(out_folder, exist_ok=True)
out_format = arg_dict.pop("output_format") out_format = arg_dict.pop("output_format")
# seup server arg:
start_server = arg_dict.pop("start_server")
task = arg_dict.pop("task") task = arg_dict.pop("task")
if args.num_threads > 0: if args.num_threads > 0:
set_num_threads(arg_dict.pop("num_threads")) set_num_threads(arg_dict.pop("num_threads"))
class_kwargs = {'whisper_model' : arg_dict.pop("whisper_model_name"), class_kwargs = {'whisper_model': arg_dict.pop("whisper_model_name"),
'whisper_type':arg_dict.pop("whisper_type"),
'dia_model': arg_dict.pop("diarization_directory"), 'dia_model': arg_dict.pop("diarization_directory"),
'use_auth_token' : arg_dict.pop("hf_token")} 'use_auth_token': arg_dict.pop("hf_token"),
}
if arg_dict["whisper_model_directory"]: if arg_dict["whisper_model_directory"]:
class_kwargs["download_root"] = arg_dict.pop("whisper_model_directory") class_kwargs["download_root"] = arg_dict.pop("whisper_model_directory")
model = Scraibe(**class_kwargs) model = Scraibe(**class_kwargs)
if arg_dict["audio_files"]: if arg_dict["audio_files"]:
audio_files = arg_dict.pop("audio_files") audio_files = arg_dict.pop("audio_files")
if task == "autotranscribe" or task == "autotranscribe+translate": if task == "autotranscribe" or task == "autotranscribe+translate":
for audio in audio_files: for audio in audio_files:
if task == "autotranscribe+translate": if task == "autotranscribe+translate":
task = "translate" task = "translate"
else: else:
task = "transcribe" task = "transcribe"
out = model.autotranscribe(audio,task = task, language=arg_dict.pop("language"), verbose = arg_dict.pop("verbose_output")) out = model.autotranscribe(audio, task=task, language=arg_dict.pop(
"language"), verbose=arg_dict.pop("verbose_output"))
basename = audio.split("/")[-1].split(".")[0] basename = audio.split("/")[-1].split(".")[0]
print(f'Saving {basename}.{out_format} to {out_folder}') print(f'Saving {basename}.{out_format} to {out_folder}')
out.save(os.path.join(out_folder, f"{basename}.{out_format}")) out.save(os.path.join(
out_folder, f"{basename}.{out_format}"))
elif task == "diarization": elif task == "diarization":
for audio in audio_files: for audio in audio_files:
if arg_dict.pop("verbose_output"): if arg_dict.pop("verbose_output"):
print(f"Verbose not implemented for diarization.") print("Verbose not implemented for diarization.")
out = model.diarization(audio) out = model.diarization(audio)
basename = audio.split("/")[-1].split(".")[0] basename = audio.split("/")[-1].split(".")[0]
path = os.path.join(out_folder, f"{basename}.{out_format}") path = os.path.join(out_folder, f"{basename}.{out_format}")
print(f'Saving {basename}.{out_format} to {out_folder}') print(f'Saving {basename}.{out_format} to {out_folder}')
with open(path, "w") as f: with open(path, "w") as f:
json.dump(json.dumps(out, indent= 1), f) json.dump(json.dumps(out, indent=1), f)
elif task == "transcribe" or task == "translate": elif task == "transcribe" or task == "translate":
for audio in audio_files: for audio in audio_files:
out = model.transcribe(audio, task = task, out = model.transcribe(audio, task=task,
language= arg_dict.pop("language"), language=arg_dict.pop("language"),
verbose = arg_dict.pop("verbose_output")) verbose=arg_dict.pop("verbose_output"))
basename = audio.split("/")[-1].split(".")[0] basename = audio.split("/")[-1].split(".")[0]
path = os.path.join(out_folder, f"{basename}.{out_format}") path = os.path.join(out_folder, f"{basename}.{out_format}")
with open(path, "w") as f: with open(path, "w") as f:
f.write(out) f.write(out)
if start_server: # unfinished code
gradio_Interface(model).queue().launch(server_port=args.port, server_name=args.server_name)
if __name__ == "__main__": if __name__ == "__main__":
cli() cli()
+128 -61
View File
@@ -27,19 +27,26 @@ Usage:
diarisation_output = model.diarization("path/to/audiofile.wav") diarisation_output = model.diarization("path/to/audiofile.wav")
""" """
import warnings
import os import os
import yaml
from pathlib import Path from pathlib import Path
from typing import TypeVar, Union from typing import TypeVar, Union
from pyannote.audio import Pipeline from pyannote.audio import Pipeline
from pyannote.audio.pipelines.speaker_diarization import SpeakerDiarization from pyannote.audio.pipelines.speaker_diarization import SpeakerDiarization
from torch import Tensor from torch import Tensor
from torch import device as torch_device
from torch.cuda import is_available
from huggingface_hub import HfApi
from huggingface_hub.utils import RepositoryNotFoundError
from .misc import PYANNOTE_DEFAULT_PATH, PYANNOTE_DEFAULT_CONFIG from .misc import PYANNOTE_DEFAULT_PATH, PYANNOTE_DEFAULT_CONFIG
Annotation = TypeVar('Annotation') Annotation = TypeVar('Annotation')
TOKEN_PATH = os.path.join(os.path.dirname( TOKEN_PATH = os.path.join(os.path.dirname(
os.path.realpath(__file__)), '.pyannotetoken') os.path.realpath(__file__)), '.pyannotetoken')
class Diariser: class Diariser:
""" """
@@ -49,12 +56,12 @@ class Diariser:
Args: Args:
model: The pretrained model to use for diarization. model: The pretrained model to use for diarization.
""" """
def __init__(self, model) -> None: def __init__(self, model) -> None:
self.model = model self.model = model
def diarization(self, audiofile : Union[str, Tensor, dict] , def diarization(self, audiofile: Union[str, Tensor, dict],
*args, **kwargs) -> Annotation: *args, **kwargs) -> Annotation:
""" """
Perform speaker diarization on the provided audio file, Perform speaker diarization on the provided audio file,
@@ -73,15 +80,15 @@ class Diariser:
to the diarization process. to the diarization process.
""" """
kwargs = self._get_diarisation_kwargs(**kwargs) kwargs = self._get_diarisation_kwargs(**kwargs)
diarization = self.model(audiofile,*args, **kwargs) diarization = self.model(audiofile, *args, **kwargs)
out = self.format_diarization_output(diarization) out = self.format_diarization_output(diarization)
return out return out
@staticmethod @staticmethod
def format_diarization_output(dia : Annotation) -> dict: def format_diarization_output(dia: Annotation) -> dict:
""" """
Formats the raw diarization output into a more usable structure for this project. Formats the raw diarization output into a more usable structure for this project.
@@ -93,14 +100,14 @@ class Diariser:
as keys and a list of tuples representing segments as values. as keys and a list of tuples representing segments as values.
""" """
dia_list = list(dia.itertracks(yield_label=True)) dia_list = list(dia.itertracks(yield_label=True))
diarization_output = {"speakers": [], "segments": []} diarization_output = {"speakers": [], "segments": []}
normalized_output = [] normalized_output = []
index_start_speaker = 0 index_start_speaker = 0
index_end_speaker = 0 index_end_speaker = 0
current_speaker = str() current_speaker = str()
### ###
# Sometimes two consecutive speakers are the same # Sometimes two consecutive speakers are the same
# This loop removes these duplicates # This loop removes these duplicates
@@ -109,40 +116,39 @@ class Diariser:
if len(dia_list) == 1: if len(dia_list) == 1:
normalized_output.append([0, 0, dia_list[0][2]]) normalized_output.append([0, 0, dia_list[0][2]])
else: else:
for i, (_, _, speaker) in enumerate(dia_list): for i, (_, _, speaker) in enumerate(dia_list):
if i == 0: if i == 0:
current_speaker = speaker current_speaker = speaker
if speaker != current_speaker: if speaker != current_speaker:
index_end_speaker = i - 1 index_end_speaker = i - 1
normalized_output.append([index_start_speaker, normalized_output.append([index_start_speaker,
index_end_speaker, index_end_speaker,
current_speaker]) current_speaker])
index_start_speaker = i index_start_speaker = i
current_speaker = speaker current_speaker = speaker
if i == len(dia_list) - 1: if i == len(dia_list) - 1:
index_end_speaker = i index_end_speaker = i
normalized_output.append([index_start_speaker, normalized_output.append([index_start_speaker,
index_end_speaker, index_end_speaker,
current_speaker]) current_speaker])
for outp in normalized_output: for outp in normalized_output:
start = dia_list[outp[0]][0].start start = dia_list[outp[0]][0].start
end = dia_list[outp[1]][0].end end = dia_list[outp[1]][0].end
diarization_output["segments"].append([start, end]) diarization_output["segments"].append([start, end])
diarization_output["speakers"].append(outp[2]) diarization_output["speakers"].append(outp[2])
return diarization_output return diarization_output
@staticmethod @staticmethod
def _get_token(): def _get_token():
""" """
@@ -155,14 +161,14 @@ class Diariser:
Returns: Returns:
str: The Huggingface token. str: The Huggingface token.
""" """
if os.path.exists(TOKEN_PATH): if os.path.exists(TOKEN_PATH):
with open(TOKEN_PATH, 'r', encoding="utf-8") as file: with open(TOKEN_PATH, 'r', encoding="utf-8") as file:
token = file.read() token = file.read()
else: else:
raise ValueError('No token found.' \ raise ValueError('No token found.'
'Please create a token at https://huggingface.co/settings/token' \ 'Please create a token at https://huggingface.co/settings/token'
f'and save it in a file called {TOKEN_PATH}') f'and save it in a file called {TOKEN_PATH}')
return token return token
@staticmethod @staticmethod
@@ -176,54 +182,114 @@ class Diariser:
""" """
with open(TOKEN_PATH, 'w', encoding="utf-8") as file: with open(TOKEN_PATH, 'w', encoding="utf-8") as file:
file.write(token) file.write(token)
@classmethod @classmethod
def load_model(cls, def load_model(cls,
model: str = PYANNOTE_DEFAULT_CONFIG, model: str = PYANNOTE_DEFAULT_CONFIG,
use_auth_token: str = None, use_auth_token: str = None,
cache_token: bool = True, cache_token: bool = False,
cache_dir: Union[Path, str] = PYANNOTE_DEFAULT_PATH, cache_dir: Union[Path, str] = PYANNOTE_DEFAULT_PATH,
hparams_file: Union[str, Path] = None, hparams_file: Union[str, Path] = None,
*args, **kwargs device: str = None,
) -> Pipeline: *args, **kwargs
) -> Pipeline:
""" """
Loads a pretrained model from pyannote.audio, Loads a pretrained model from pyannote.audio,
either from a local cache or online repository. either from a local cache or some online repository.
Args: Args:
model: Path or identifier for the pyannote model. model: Path or identifier for the pyannote model.
default: /models/pyannote/speaker_diarization/config.yaml default: '/home/[user]/.cache/torch/models/pyannote/config.yaml'
or one of 'jaikinator/scraibe', 'pyannote/speaker-diarization-3.1'
token: Optional HUGGINGFACE_TOKEN for authenticated access. token: Optional HUGGINGFACE_TOKEN for authenticated access.
cache_token: Whether to cache the token locally for future use. cache_token: Whether to cache the token locally for future use.
cache_dir: Directory for caching models. cache_dir: Directory for caching models.
hparams_file: Path to a YAML file containing hyperparameters. hparams_file: Path to a YAML file containing hyperparameters.
device: Device to load the model on.
args: Additional arguments only to avoid errors. args: Additional arguments only to avoid errors.
kwargs: Additional keyword arguments only to avoid errors. kwargs: Additional keyword arguments only to avoid errors.
Returns: Returns:
Pipeline: A pyannote.audio Pipeline object, encapsulating the loaded model. Pipeline: A pyannote.audio Pipeline object, encapsulating the loaded model.
""" """
if isinstance(model, str) and os.path.exists(model):
if cache_token and use_auth_token is not None: # check if model can be found locally nearby the config file
cls._save_token(use_auth_token) with open(model, 'r') as file:
config = yaml.safe_load(file)
if not os.path.exists(model) and use_auth_token is None:
use_auth_token = cls._get_token() path_to_model = config['pipeline']['params']['segmentation']
model = 'pyannote/speaker-diarization'
elif not os.path.exists(model) and use_auth_token is not None: if not os.path.exists(path_to_model):
model = 'pyannote/speaker-diarization' warnings.warn(f"Model not found at {path_to_model}. "
"Trying to find it nearby the config file.")
_model = Pipeline.from_pretrained(model,
use_auth_token = use_auth_token, pwd = model.split("/")[:-1]
cache_dir = cache_dir, pwd = "/".join(pwd)
hparams_file = hparams_file,)
path_to_model = os.path.join(pwd, "pytorch_model.bin")
if not os.path.exists(path_to_model):
warnings.warn(f"Model not found at {path_to_model}. \
'Trying to find it nearby .bin files instead.")
warnings.warn(
'Searching for nearby files in a folder path is '
'deprecated and will be removed in future versions.',
category=DeprecationWarning)
# list elementes with the ending .bin
bin_files = [f for f in os.listdir(
pwd) if f.endswith(".bin")]
if len(bin_files) == 1:
path_to_model = os.path.join(pwd, bin_files[0])
else:
warnings.warn("Found more than one .bin file. "
"or none. Please specify the path to the model "
"or setup a huggingface token.")
raise FileNotFoundError
warnings.warn(
f"Found model at {path_to_model} overwriting config file.")
config['pipeline']['params']['segmentation'] = path_to_model
with open(model, 'w') as file:
yaml.dump(config, file)
elif isinstance(model, tuple):
try:
_model = model[0]
HfApi().model_info(_model)
model = _model
use_auth_token = None
except RepositoryNotFoundError:
print(f'{model[0]} not found on Huggingface, \
trying {model[1]}')
_model = model[1]
HfApi().model_info(_model)
model = _model
if cache_token and use_auth_token is not None:
cls._save_token(use_auth_token)
if use_auth_token is None:
use_auth_token = cls._get_token()
else:
raise FileNotFoundError(
f'No local model or directory found at {model}.')
_model = Pipeline.from_pretrained(model,
use_auth_token=use_auth_token,
cache_dir=cache_dir,
hparams_file=hparams_file,)
if _model is None: if _model is None:
raise ValueError('Unable to load model either from local cache' \ raise ValueError('Unable to load model either from local cache'
'or from huggingface.co models. Please check your token' \ 'or from huggingface.co models. Please check your token'
'or your local model path') 'or your local model path')
# try to move the model to the device
if device is None:
device = "cuda" if is_available() else "cpu"
# torch_device is renamed from torch.device to avoid name conflict
_model = _model.to(torch_device(device))
return cls(_model) return cls(_model)
@staticmethod @staticmethod
@@ -239,9 +305,10 @@ class Diariser:
""" """
_possible_kwargs = SpeakerDiarization.apply.__code__.co_varnames _possible_kwargs = SpeakerDiarization.apply.__code__.co_varnames
diarisation_kwargs = {k: v for k, v in kwargs.items() if k in _possible_kwargs} diarisation_kwargs = {k: v for k,
v in kwargs.items() if k in _possible_kwargs}
return diarisation_kwargs return diarisation_kwargs
def __repr__(self): def __repr__(self):
return f"Diarisation(model={self.model})" return f"Diarisation(model={self.model})"
+95
View File
@@ -0,0 +1,95 @@
# List of known hallucinations - adapted from:
# https://github.com/openai/whisper/discussions/928
KNOWN_HALLUCINATIONS = [
# en
" www.mooji.org"
# nl
" Ondertitels ingediend door de Amara.org gemeenschap",
" Ondertiteld door de Amara.org gemeenschap",
" Ondertiteling door de Amara.org gemeenschap"
# de
" Untertitelung aufgrund der Amara.org-Community"
" Untertitelung im Auftrag des ZDF für funk, 2016",
" Untertitelung im Auftrag des ZDF f\u00fcr funk, 2016",
" Untertitel im Auftrag des ZDF für funk, 2017",
" Untertitel im Auftrag des ZDF f\u00fcr funk, 2017",
" Untertitel im Auftrag des ZDF für funk, 2018",
" Untertitel von Stephanie Geiges",
" Untertitel der Amara.org-Community",
" Untertitel im Auftrag des ZDF, 2017",
" Untertitel im Auftrag des ZDF, 2018",
" Untertitel im Auftrag des ZDF, 2019",
" Untertitel im Auftrag des ZDF, 2020",
" Untertitel im Auftrag des ZDF, 2021",
" Untertitelung im Auftrag des ZDF, 2021",
" Copyright WDR 2021",
" Copyright WDR 2020",
" Copyright WDR 2019",
" SWR 2021",
" SWR 2020",
# fr
" Sous-titres réalisés para la communauté d'Amara.org",
" Sous-titres réalisés par la communauté d'Amara.org",
" Sous-titres fait par Sous-titres par Amara.org",
" Sous-titres réalisés par les SousTitres d'Amara.org",
" Sous-titres par Amara.org",
" Sous-titres par la communauté d'Amara.org",
" Sous-titres réalisés pour la communauté d'Amara.org",
" Sous-titres réalisés par la communauté de l'Amara.org",
" Sous-Titres faits par la communauté d'Amara.org",
" Sous-titres par l'Amara.org",
" Sous-titres fait par la communauté d'Amara.org"
" Sous-titrage ST' 501",
" Sous-titrage ST'501",
" Cliquez-vous sur les sous-titres et abonnez-vous à la chaîne d'Amara.org",
" ❤️ par SousTitreur.com",
# it
" Sottotitoli creati dalla comunità Amara.org",
" Sottotitoli di Sottotitoli di Amara.org",
" Sottotitoli e revisione al canale di Amara.org",
" Sottotitoli e revisione a cura di Amara.org",
" Sottotitoli e revisione a cura di QTSS",
" Sottotitoli e revisione a cura di QTSS.",
" Sottotitoli a cura di QTSS",
" Subtítulos realizados por la comunidad de Amara.org",
" Subtitulado por la comunidad de Amara.org",
" Subtítulos por la comunidad de Amara.org",
" Subtítulos creados por la comunidad de Amara.org",
" Subtítulos en español de Amara.org",
" Subtítulos hechos por la comunidad de Amara.org",
" Subtitulos por la comunidad de Amara.org"
" Más información www.alimmenta.com",
" www.mooji.org",
# gl
" Subtítulos realizados por la comunidad de Amara.org"
# pt
" Legendas pela comunidade Amara.org",
" Legendas pela comunidade de Amara.org",
" Legendas pela comunidade do Amara.org",
" Legendas pela comunidade das Amara.org",
" Transcrição e Legendas pela comunidade de Amara.org"
# la
" Sottotitoli creati dalla comunità Amara.org",
" Sous-titres réalisés para la communauté d'Amara.org"
# ln
" Sous-titres réalisés para la communauté d'Amara.org"
# pl
" Napisy stworzone przez społeczność Amara.org",
" Napisy wykonane przez społeczność Amara.org",
" Zdjęcia i napisy stworzone przez społeczność Amara.org",
" napisy stworzone przez społeczność Amara.org",
" Tłumaczenie i napisy stworzone przez społeczność Amara.org",
" Napisy stworzone przez społeczności Amara.org",
" Tłumaczenie stworzone przez społeczność Amara.org",
" Napisy robione przez społeczność Amara.org"
" www.multi-moto.eu",
# ru
" Редактор субтитров А.Синецкая Корректор А.Егорова"
# tr
" Yorumlarınızıza abone olmayı unutmayın.",
# su
" Sottotitoli creati dalla comunità Amara.org"
# zh
"字幕由Amara.org社区提供",
"小編字幕由Amara.org社區提供"
]
+26 -3
View File
@@ -1,6 +1,8 @@
import os import os
import yaml import yaml
from pyannote.audio.core.model import CACHE_DIR as PYANNOTE_CACHE_DIR from pyannote.audio.core.model import CACHE_DIR as PYANNOTE_CACHE_DIR
from argparse import Action
from ast import literal_eval
CACHE_DIR = os.getenv( CACHE_DIR = os.getenv(
"AUTOT_CACHE", "AUTOT_CACHE",
@@ -12,7 +14,10 @@ if CACHE_DIR != PYANNOTE_CACHE_DIR:
WHISPER_DEFAULT_PATH = os.path.join(CACHE_DIR, "whisper") WHISPER_DEFAULT_PATH = os.path.join(CACHE_DIR, "whisper")
PYANNOTE_DEFAULT_PATH = os.path.join(CACHE_DIR, "pyannote") PYANNOTE_DEFAULT_PATH = os.path.join(CACHE_DIR, "pyannote")
PYANNOTE_DEFAULT_CONFIG = os.path.join(PYANNOTE_DEFAULT_PATH, "config.yaml") PYANNOTE_DEFAULT_CONFIG = os.path.join(PYANNOTE_DEFAULT_PATH, "config.yaml") \
if os.path.exists(os.path.join(PYANNOTE_DEFAULT_PATH, "config.yaml")) \
else ('jaikinator/scraibe', 'pyannote/speaker-diarization-3.1')
def config_diarization_yaml(file_path: str, path_to_segmentation: str = None) -> None: def config_diarization_yaml(file_path: str, path_to_segmentation: str = None) -> None:
"""Configure diarization pipeline from a YAML file. """Configure diarization pipeline from a YAML file.
@@ -30,11 +35,29 @@ def config_diarization_yaml(file_path: str, path_to_segmentation: str = None) ->
with open(file_path, "r") as stream: with open(file_path, "r") as stream:
yml = yaml.safe_load(stream) yml = yaml.safe_load(stream)
segmentation_path = path_to_segmentation or os.path.join(PYANNOTE_DEFAULT_PATH, "pytorch_model.bin") segmentation_path = path_to_segmentation or os.path.join(
PYANNOTE_DEFAULT_PATH, "pytorch_model.bin")
yml["pipeline"]["params"]["segmentation"] = segmentation_path yml["pipeline"]["params"]["segmentation"] = segmentation_path
if not os.path.exists(segmentation_path): if not os.path.exists(segmentation_path):
raise FileNotFoundError(f"Segmentation model not found at {segmentation_path}") raise FileNotFoundError(
f"Segmentation model not found at {segmentation_path}")
with open(file_path, "w") as stream: with open(file_path, "w") as stream:
yaml.dump(yml, stream) yaml.dump(yml, stream)
class ParseKwargs(Action):
"""
Custom argparse action to parse keyword arguments.
"""
def __call__(self, parser, namespace, values, option_string=None):
setattr(namespace, self.dest, dict())
for value in values:
key, value = value.split('=')
try:
value = literal_eval(value)
except:
pass
getattr(namespace, self.dest)[key] = value
+286 -38
View File
@@ -24,16 +24,20 @@ Usage:
>>> transcriber.save_transcript(transcript, "path/to/save.txt") >>> transcriber.save_transcript(transcript, "path/to/save.txt")
""" """
from whisper import Whisper, load_model from whisper import Whisper
from typing import TypeVar , Union , Optional from whisper import load_model as whisper_load_model
from whisperx.asr import WhisperModel
from whisperx import load_model as whisperx_load_model
from typing import TypeVar, Union, Optional
from torch import Tensor, device from torch import Tensor, device
from torch.cuda import is_available as cuda_is_available
from numpy import ndarray from numpy import ndarray
from inspect import signature
from abc import abstractmethod
import warnings
from .misc import WHISPER_DEFAULT_PATH from .misc import WHISPER_DEFAULT_PATH
whisper = TypeVar('whisper') whisper = TypeVar('whisper')
class Transcriber: class Transcriber:
@@ -64,16 +68,22 @@ class Transcriber:
The class supports various sizes and versions of Whisper models. Please refer to The class supports various sizes and versions of Whisper models. Please refer to
the load_model method for available options. the load_model method for available options.
""" """
def __init__(self, model: whisper ) -> None:
def __init__(self, model: whisper, model_name: str) -> None:
""" """
Initialize the Transcriber class with a Whisper model. Initialize the Transcriber class with a Whisper model.
Args: Args:
model (whisper): The Whisper model to use for transcription. model (whisper): The Whisper model to use for transcription.
model_name (str): The name of the model.
""" """
self.model = model self.model = model
def transcribe(self, audio : Union[str, Tensor, ndarray] , self.model_name = model_name
@abstractmethod
def transcribe(self, audio: Union[str, Tensor, ndarray],
*args, **kwargs) -> str: *args, **kwargs) -> str:
""" """
Transcribe an audio file. Transcribe an audio file.
@@ -87,17 +97,10 @@ class Transcriber:
Returns: Returns:
str: The transcript as a string. str: The transcript as a string.
""" """
pass
kwargs = self._get_whisper_kwargs(**kwargs)
if not kwargs.get("verbose"):
kwargs["verbose"] = None
result = self.model.transcribe(audio, *args, **kwargs)
return result["text"]
@staticmethod @staticmethod
def save_transcript(transcript : str , save_path : str) -> None: def save_transcript(transcript: str, save_path: str) -> None:
""" """
Save a transcript to a file. Save a transcript to a file.
@@ -111,17 +114,19 @@ class Transcriber:
with open(save_path, 'w') as f: with open(save_path, 'w') as f:
f.write(transcript) f.write(transcript)
print(f'Transcript saved to {save_path}') print(f'Transcript saved to {save_path}')
@classmethod @classmethod
@abstractmethod
def load_model(cls, def load_model(cls,
model: str = "medium", model: str = "medium",
download_root: str = WHISPER_DEFAULT_PATH, whisper_type: str = 'whisper',
device: Optional[Union[str, device]] = None, download_root: str = WHISPER_DEFAULT_PATH,
in_memory: bool = False, device: Optional[Union[str, device]] = None,
*args, **kwargs in_memory: bool = False,
) -> 'Transcriber': *args, **kwargs
) -> None:
""" """
Load whisper model. Load whisper model.
@@ -137,11 +142,94 @@ class Transcriber:
- 'medium' - 'medium'
- 'large-v1' - 'large-v1'
- 'large-v2' - 'large-v2'
- 'large-v3'
- 'large' - 'large'
whisper_type (str):
Type of whisper model to load. "whisper" or "whisperx".
download_root (str, optional): Path to download the model. download_root (str, optional): Path to download the model.
Defaults to WHISPER_DEFAULT_PATH. Defaults to WHISPER_DEFAULT_PATH.
device (Optional[Union[str, torch.device]], optional):
Device to load model on. Defaults to None.
in_memory (bool, optional): Whether to load model in memory.
Defaults to False.
args: Additional arguments only to avoid errors.
kwargs: Additional keyword arguments only to avoid errors.
Returns:
None: abscract method.
"""
pass
@staticmethod
def _get_whisper_kwargs(**kwargs) -> dict:
"""
Get kwargs for whisper model. Ensure that kwargs are valid.
Returns:
dict: Keyword arguments for whisper model.
"""
pass
def __repr__(self) -> str:
return f"Transcriber(model_name={self.model_name}, model={self.model})"
class WhisperTranscriber(Transcriber):
def __init__(self, model: whisper, model_name: str) -> None:
super().__init__(model, model_name)
def transcribe(self, audio: Union[str, Tensor, ndarray],
*args, **kwargs) -> str:
"""
Transcribe an audio file.
Args:
audio (Union[str, Tensor, nparray]): The audio file to transcribe.
*args: Additional arguments.
**kwargs: Additional keyword arguments,
such as the language of the audio file.
Returns:
str: The transcript as a string.
"""
kwargs = self._get_whisper_kwargs(**kwargs)
if not kwargs.get("verbose"):
kwargs["verbose"] = None
result = self.model.transcribe(audio, *args, **kwargs)
return result["text"]
@classmethod
def load_model(cls,
model: str = "medium",
download_root: str = WHISPER_DEFAULT_PATH,
device: Optional[Union[str, device]] = None,
in_memory: bool = False,
*args, **kwargs
) -> 'WhisperTranscriber':
"""
Load whisper model.
Args:
model (str): Whisper model. Available models include:
- 'tiny.en'
- 'tiny'
- 'base.en'
- 'base'
- 'small.en'
- 'small'
- 'medium.en'
- 'medium'
- 'large-v1'
- 'large-v2'
- 'large-v3'
- 'large'
download_root (str, optional): Path to download the model.
Defaults to WHISPER_DEFAULT_PATH.
device (Optional[Union[str, torch.device]], optional): device (Optional[Union[str, torch.device]], optional):
Device to load model on. Defaults to None. Device to load model on. Defaults to None.
in_memory (bool, optional): Whether to load model in memory. in_memory (bool, optional): Whether to load model in memory.
@@ -153,10 +241,10 @@ class Transcriber:
Transcriber: A Transcriber object initialized with the specified model. Transcriber: A Transcriber object initialized with the specified model.
""" """
_model = load_model(model, download_root=download_root, _model = whisper_load_model(model, download_root=download_root,
device=device, in_memory=in_memory) device=device, in_memory=in_memory)
return cls(_model) return cls(_model, model_name=model)
@staticmethod @staticmethod
def _get_whisper_kwargs(**kwargs) -> dict: def _get_whisper_kwargs(**kwargs) -> dict:
@@ -166,17 +254,177 @@ class Transcriber:
Returns: Returns:
dict: Keyword arguments for whisper model. dict: Keyword arguments for whisper model.
""" """
_possible_kwargs = Whisper.transcribe.__code__.co_varnames # _possible_kwargs = WhisperModel.transcribe.__code__.co_varnames
_possible_kwargs = signature(Whisper.transcribe).parameters.keys()
whisper_kwargs = {k: v for k, v in kwargs.items() if k in _possible_kwargs}
whisper_kwargs = {k: v for k,
v in kwargs.items() if k in _possible_kwargs}
if (task := kwargs.get("task")): if (task := kwargs.get("task")):
whisper_kwargs["task"] = task whisper_kwargs["task"] = task
if (language := kwargs.get("language")): if (language := kwargs.get("language")):
whisper_kwargs["language"] = language whisper_kwargs["language"] = language
return whisper_kwargs return whisper_kwargs
def __repr__(self) -> str: def __repr__(self) -> str:
return f"Transcriber(model={self.model})" return f"WhisperTranscriber(model_name={self.model_name}, model={self.model})"
class WhisperXTranscriber(Transcriber):
def __init__(self, model: whisper, model_name: str) -> None:
super().__init__(model, model_name)
def transcribe(self, audio: Union[str, Tensor, ndarray],
*args, **kwargs) -> str:
"""
Transcribe an audio file.
Args:
audio (Union[str, Tensor, nparray]): The audio file to transcribe.
*args: Additional arguments.
**kwargs: Additional keyword arguments,
such as the language of the audio file.
Returns:
str: The transcript as a string.
"""
kwargs = self._get_whisper_kwargs(**kwargs)
if isinstance(audio, Tensor):
audio = audio.cpu().numpy()
result = self.model.transcribe(audio, *args, **kwargs)
text = ""
for seg in result['segments']:
text += seg['text']
return text
@classmethod
def load_model(cls,
model: str = "medium",
download_root: str = WHISPER_DEFAULT_PATH,
device: Optional[Union[str, device]] = None,
*args, **kwargs
) -> 'WhisperXTranscriber':
"""
Load whisper model.
Args:
model (str): Whisper model. Available models include:
- 'tiny.en'
- 'tiny'
- 'base.en'
- 'base'
- 'small.en'
- 'small'
- 'medium.en'
- 'medium'
- 'large-v1'
- 'large-v2'
- 'large-v3'
- 'large'
download_root (str, optional): Path to download the model.
Defaults to WHISPER_DEFAULT_PATH.
device (Optional[Union[str, torch.device]], optional):
Device to load model on. Defaults to None.
in_memory (bool, optional): Whether to load model in memory.
Defaults to False.
args: Additional arguments only to avoid errors.
kwargs: Additional keyword arguments only to avoid errors.
Returns:
Transcriber: A Transcriber object initialized with the specified model.
"""
if device is None:
device = "cuda" if cuda_is_available() else "cpu"
if not isinstance(device, str):
device = str(device)
compute_type = kwargs.get('compute_type', 'float16')
if device == 'cpu' and compute_type == 'float16':
warnings.warn(f'Compute type {compute_type} not compatible with '
f'device {device}! Changing compute type to int8.')
compute_type = 'int8'
_model = whisperx_load_model(model, download_root=download_root,
device=device, compute_type=compute_type)
return cls(_model, model_name=model)
@staticmethod
def _get_whisper_kwargs(**kwargs) -> dict:
"""
Get kwargs for whisper model. Ensure that kwargs are valid.
Returns:
dict: Keyword arguments for whisper model.
"""
# _possible_kwargs = WhisperModel.transcribe.__code__.co_varnames
_possible_kwargs = signature(WhisperModel.transcribe).parameters.keys()
whisper_kwargs = {k: v for k,
v in kwargs.items() if k in _possible_kwargs}
if (task := kwargs.get("task")):
whisper_kwargs["task"] = task
if (language := kwargs.get("language")):
whisper_kwargs["language"] = language
return whisper_kwargs
def __repr__(self) -> str:
return f"WhisperXTranscriber(model_name={self.model_name}, model={self.model})"
def load_transcriber(model: str = "medium",
whisper_type: str = 'whisper',
download_root: str = WHISPER_DEFAULT_PATH,
device: Optional[Union[str, device]] = None,
in_memory: bool = False,
*args, **kwargs
) -> Union[WhisperTranscriber, WhisperXTranscriber]:
"""
Load whisper model.
Args:
model (str): Whisper model. Available models include:
- 'tiny.en'
- 'tiny'
- 'base.en'
- 'base'
- 'small.en'
- 'small'
- 'medium.en'
- 'medium'
- 'large-v1'
- 'large-v2'
- 'large-v3'
- 'large'
whisper_type (str):
Type of whisper model to load. "whisper" or "whisperx".
download_root (str, optional): Path to download the model.
Defaults to WHISPER_DEFAULT_PATH.
device (Optional[Union[str, torch.device]], optional):
Device to load model on. Defaults to None.
in_memory (bool, optional): Whether to load model in memory.
Defaults to False.
args: Additional arguments only to avoid errors.
kwargs: Additional keyword arguments only to avoid errors.
Returns:
Union[WhisperTranscriber, WhisperXTranscriber]:
One of the Whisper variants as Transcrbier object initialized with the specified model.
"""
if whisper_type.lower() == 'whisper':
_model = WhisperTranscriber.load_model(
model, download_root, device, in_memory, *args, **kwargs)
return _model
elif whisper_type.lower() == 'whisperx':
_model = WhisperXTranscriber.load_model(
model, download_root, device, *args, **kwargs)
return _model
else:
raise ValueError(f'Model type not recognized, exptected "whisper" '
f'or "whisperx", got {whisper_type}.')
+86 -68
View File
@@ -1,10 +1,11 @@
import json import json
import time import time
from traceback import print_stack from json.decoder import JSONDecodeError
from typing import Union from typing import Union
from .hallucinations import KNOWN_HALLUCINATIONS
ALPHABET = [*"abcdefghijklmnopqrstuvwxyz"] ALPHABET = [*"abcdefghijklmnopqrstuvwxyz"]
@@ -13,7 +14,7 @@ class Transcript:
Class for storing transcript data, including speaker information and text segments, Class for storing transcript data, including speaker information and text segments,
and exporting it to various file formats such as JSON, HTML, and LaTeX. and exporting it to various file formats such as JSON, HTML, and LaTeX.
""" """
def __init__(self, transcript: dict) -> None: def __init__(self, transcript: dict) -> None:
""" """
Initializes the Transcript object with the given transcript data. Initializes the Transcript object with the given transcript data.
@@ -25,10 +26,11 @@ class Transcript:
""" """
self.transcript = transcript self.transcript = transcript
self._remove_hallucinations()
self.speakers = self._extract_speakers() self.speakers = self._extract_speakers()
self.segments = self._extract_segments() self.segments = self._extract_segments()
self.annotation = {} self.annotation = {}
def annotate(self, *args, **kwargs) -> dict: def annotate(self, *args, **kwargs) -> dict:
""" """
Annotates the transcript to associate specific names with speakers. Annotates the transcript to associate specific names with speakers.
@@ -44,26 +46,45 @@ class Transcript:
ValueError: If the number of speaker names does not match the number ValueError: If the number of speaker names does not match the number
of speakers, or if an unknown speaker is found. of speakers, or if an unknown speaker is found.
""" """
annotations = {} annotations = {}
if args and len(args) != len(self.speakers): if args and len(args) != len(self.speakers):
raise ValueError("Number of speaker names does not match number of speakers") raise ValueError(
"Number of speaker names does not match number of speakers")
if args: if args:
for arg, speaker in zip(args, sorted(self.speakers)): for arg, speaker in zip(args, sorted(self.speakers)):
annotations[speaker] = arg annotations[speaker] = arg
invalid_speakers = set(kwargs.keys()) - set(self.speakers) invalid_speakers = set(kwargs.keys()) - set(self.speakers)
if invalid_speakers: if invalid_speakers:
raise ValueError(f"These keys are not speakers: {', '.join(invalid_speakers)}") raise ValueError(
f"These keys are not speakers: {', '.join(invalid_speakers)}")
annotations.update({key: kwargs[key] for key in self.speakers if key in kwargs}) annotations.update({key: kwargs[key]
for key in self.speakers if key in kwargs})
self.annotation = annotations self.annotation = annotations
return self return self
def _remove_hallucinations(self) -> None:
"""
Removes all occurances of known hallucinations from all segments of the transcript.
Segments that are identical to empty strings afterwards are removed from the transcript.
"""
segments_to_drop = []
for id in self.transcript:
for snippet in KNOWN_HALLUCINATIONS:
self.transcript[id]['text'] = self.transcript[id]['text'].replace(
snippet, '')
if self.transcript[id]['text'] == '':
segments_to_drop.append(id)
for id in segments_to_drop:
del self.transcript[id]
def _extract_speakers(self) -> list: def _extract_speakers(self) -> list:
""" """
Extracts the unique speaker names from the transcript. Extracts the unique speaker names from the transcript.
@@ -71,9 +92,9 @@ class Transcript:
Returns: Returns:
list: List of unique speaker names in the transcript. list: List of unique speaker names in the transcript.
""" """
return list(set([self.transcript[id]["speakers"] for id in self.transcript])) return list(set([self.transcript[id]["speakers"] for id in self.transcript]))
def _extract_segments(self) -> list: def _extract_segments(self) -> list:
""" """
Extracts all the text segments from the transcript. Extracts all the text segments from the transcript.
@@ -93,23 +114,23 @@ class Transcript:
time stamps for each segment. time stamps for each segment.
""" """
fstring = "" fstring = ""
for _id in self.transcript: for _id in self.transcript:
seq = self.transcript[_id] seq = self.transcript[_id]
if self.annotation: if self.annotation:
speaker = self.annotation[seq["speakers"]] speaker = self.annotation[seq["speakers"]]
else: else:
speaker = seq["speakers"] speaker = seq["speakers"]
segm = seq["segments"] segm = seq["segments"]
sseg = time.strftime("%H:%M:%S",time.gmtime(segm[0])) sseg = time.strftime("%H:%M:%S", time.gmtime(segm[0]))
eseg = time.strftime("%H:%M:%S",time.gmtime(segm[1])) eseg = time.strftime("%H:%M:%S", time.gmtime(segm[1]))
fstring += f"{speaker} ({sseg} ; {eseg}):\t{seq['text']}\n" fstring += f"{speaker} ({sseg} ; {eseg}):\t{seq['text']}\n"
return fstring return fstring
def __repr__(self) -> str: def __repr__(self) -> str:
"""Return a string representation of the Transcript object. """Return a string representation of the Transcript object.
@@ -117,8 +138,8 @@ class Transcript:
str: A string that provides an informative description of the object. str: A string that provides an informative description of the object.
""" """
return f"Transcript(speakers = {self.speakers},"\ return f"Transcript(speakers = {self.speakers},"\
f"segments = {self.segments}, annotation = {self.annotation})" f"segments = {self.segments}, annotation = {self.annotation})"
def get_dict(self) -> dict: def get_dict(self) -> dict:
""" """
Get transcript as dict Get transcript as dict
@@ -126,10 +147,10 @@ class Transcript:
:return: transcript as dict :return: transcript as dict
:rtype: dict :rtype: dict
""" """
return self.transcript return self.transcript
def get_json(self, *args, use_annotation : bool = True, **kwargs) -> str: def get_json(self, *args, use_annotation: bool = True, **kwargs) -> str:
""" """
Get transcript as json string Get transcript as json string
:return: transcript as json string :return: transcript as json string
@@ -137,14 +158,14 @@ class Transcript:
""" """
if "indent" not in kwargs: if "indent" not in kwargs:
kwargs["indent"] = 3 kwargs["indent"] = 3
if use_annotation and self.annotation: if use_annotation and self.annotation:
for _id in self.transcript: for _id in self.transcript:
seq = self.transcript[_id] seq = self.transcript[_id]
seq["speakers"] = self.annotation[seq["speakers"]] seq["speakers"] = self.annotation[seq["speakers"]]
return json.dumps(self.transcript, *args, **kwargs) return json.dumps(self.transcript, *args, **kwargs)
def get_html(self) -> str: def get_html(self) -> str:
""" """
Get transcript as html string Get transcript as html string
@@ -155,9 +176,9 @@ class Transcript:
html = "<p>" + self.__str__().replace("\n", "<br>") + "</p>" html = "<p>" + self.__str__().replace("\n", "<br>") + "</p>"
html = "<html><body>" + html + "</body></html>" html = "<html><body>" + html + "</body></html>"
html = html.replace("\t", "&nbsp;&nbsp;&nbsp;&nbsp;") html = html.replace("\t", "&nbsp;&nbsp;&nbsp;&nbsp;")
return html return html
def get_md(self) -> str: def get_md(self) -> str:
"""Get transcript as Markdown string, using HTML formatting. """Get transcript as Markdown string, using HTML formatting.
@@ -165,7 +186,7 @@ class Transcript:
str: Transcript as a Markdown string. str: Transcript as a Markdown string.
""" """
return self.get_html() return self.get_html()
def get_tex(self) -> str: def get_tex(self) -> str:
"""Get transcript as LaTeX string. If no annotations are present, the speakers will """Get transcript as LaTeX string. If no annotations are present, the speakers will
be annotated with the first letters of the alphabet. be annotated with the first letters of the alphabet.
@@ -176,43 +197,42 @@ class Transcript:
if not self.annotation: if not self.annotation:
self.annotate(*ALPHABET[:len(self.speakers)]) self.annotate(*ALPHABET[:len(self.speakers)])
fstring ="\\begin{drama}" fstring = "\\begin{drama}"
for speaker in self.speakers: for speaker in self.speakers:
fstring += "\n\t\\Character{"+ str(self.annotation[speaker]) + "}" \ fstring += "\n\t\\Character{" + str(self.annotation[speaker]) + "}" \
"{"+ str(self.annotation[speaker]) + "}" "{" + str(self.annotation[speaker]) + "}"
for id in self.transcript: for id in self.transcript:
seq = self.transcript[id] seq = self.transcript[id]
speaker = self.annotation[seq["speakers"]] speaker = self.annotation[seq["speakers"]]
fstring += f"\n\\{speaker}speaks:\n{seq['text']}" fstring += f"\n\\{speaker}speaks:\n{seq['text']}"
fstring += "\n\\end{drama}" fstring += "\n\\end{drama}"
return fstring return fstring
def to_json(self, path, *args, **kwargs) -> None:
def to_json(self,path, *args, **kwargs) -> None:
"""Save transcript as json file """Save transcript as json file
Args: Args:
path (str): path to save file path (str): path to save file
""" """
with open(path, "w") as f: with open(path, "w") as f:
json.dump(self.transcript, f, *args, **kwargs) json.dump(self.transcript, f, *args, **kwargs)
def to_txt(self, path: str) -> None: def to_txt(self, path: str) -> None:
"""Save transcript as a LaTeX file (placeholder function, implementation needed). """Save transcript as a LaTeX file (placeholder function, implementation needed).
Args: Args:
path (str): Path to save the LaTeX file. path (str): Path to save the LaTeX file.
""" """
with open(path, "w") as f: with open(path, "w") as f:
f.write(self.__str__()) f.write(self.__str__())
def to_md(self, path: str) -> None: def to_md(self, path: str) -> None:
"""Get transcript as Markdown string, using HTML formatting. """Get transcript as Markdown string, using HTML formatting.
@@ -220,7 +240,7 @@ class Transcript:
str: Transcript as a Markdown string. str: Transcript as a Markdown string.
""" """
return self.to_html(path) return self.to_html(path)
def to_html(self, path: str) -> None: def to_html(self, path: str) -> None:
""" """
Save transcript as html file Save transcript as html file
@@ -228,10 +248,10 @@ class Transcript:
:param path: path to save file :param path: path to save file
:type path: str :type path: str
""" """
with open(path, "w") as file: with open(path, "w") as file:
file.write(self.get_html()) file.write(self.get_html())
def to_tex(self, path: str) -> None: def to_tex(self, path: str) -> None:
"""Save transcript as a LaTeX file (placeholder function, implementation needed). """Save transcript as a LaTeX file (placeholder function, implementation needed).
@@ -239,7 +259,7 @@ class Transcript:
path (str): Path to save the LaTeX file. path (str): Path to save the LaTeX file.
""" """
pass pass
def to_pdf(self, path: str) -> None: def to_pdf(self, path: str) -> None:
"""Save transcript as a PDF file (placeholder function, implementation needed). """Save transcript as a PDF file (placeholder function, implementation needed).
@@ -247,7 +267,7 @@ class Transcript:
path (str): Path to save the PDF file. path (str): Path to save the PDF file.
""" """
pass pass
def save(self, path: str, *args, **kwargs) -> None: def save(self, path: str, *args, **kwargs) -> None:
"""Save transcript to file with the given path and file format. """Save transcript to file with the given path and file format.
@@ -263,7 +283,7 @@ class Transcript:
Raises: Raises:
ValueError: If the file format specified in the path is unknown. ValueError: If the file format specified in the path is unknown.
""" """
if path.endswith(".json"): if path.endswith(".json"):
self.to_json(path, *args, **kwargs) self.to_json(path, *args, **kwargs)
elif path.endswith(".txt"): elif path.endswith(".txt"):
@@ -278,9 +298,9 @@ class Transcript:
self.to_pdf(path, *args, **kwargs) self.to_pdf(path, *args, **kwargs)
else: else:
raise ValueError("Unknown file format") raise ValueError("Unknown file format")
@classmethod @classmethod
def from_json(cls, json: Union[dict, str]) -> "Transcript": def from_json(cls, _json: Union[dict, str]) -> "Transcript":
"""Load transcript from json file """Load transcript from json file
Args: Args:
@@ -289,15 +309,13 @@ class Transcript:
Returns: Returns:
Transcript: Transcript object Transcript: Transcript object
""" """
if isinstance(json, dict): if isinstance(_json, dict):
return cls(json) return cls(_json)
else: else:
try: try:
transcript = json.loads(json) transcript = json.loads(_json)
except: except (TypeError, JSONDecodeError):
with open(json, "r") as f: with open(_json, "r") as f:
transcript = json.load(f) transcript = json.load(f)
return cls(transcript)
return cls(transcript)
-69
View File
@@ -1,69 +0,0 @@
import os
import subprocess as sp
MAJOR = 0
MINOR = 1
MICRO = 0
MICRO_POST = 0
ISRELEASED = False
VERSION = '%d.%d.%d.%d' % (MAJOR, MINOR, MICRO, MICRO_POST)
# Return the git revision as a string
# taken from numpy/numpy
def git_version():
def _minimal_ext_cmd(cmd):
# construct minimal environment
env = {}
for k in ['SYSTEMROOT', 'PATH', 'HOME']:
v = os.environ.get(k)
if v is not None:
env[k] = v
# LANGUAGE is used on win32
env['LANGUAGE'] = 'C'
env['LANG'] = 'C'
env['LC_ALL'] = 'C'
out = sp.Popen(cmd, stdout=sp.PIPE, stderr=sp.PIPE, env=env).communicate()[0]
return out
try:
out = _minimal_ext_cmd(['git', 'rev-parse', 'HEAD'])
GIT_REVISION = out.strip().decode('ascii')
except OSError:
GIT_REVISION = "Unknown"
return GIT_REVISION
def _get_git_version():
cwd = os.getcwd()
# go to the main directory
fdir = os.path.dirname(os.path.abspath(__file__))
maindir = os.path.abspath(os.path.join(fdir, ".."))
# maindir = fdir # os.path.join(fdir, "..")
os.chdir(maindir)
# get git version
res = git_version()
# restore the cwd
os.chdir(cwd)
return res
def get_version(build_version=False):
if ISRELEASED:
return VERSION
# unreleased version
GIT_REVISION = _get_git_version()
if build_version:
import datetime as dt
date = dt.date.strftime(dt.datetime.now(), "%Y%m%d%H%M%S")
return VERSION + ".dev" + date
else:
return VERSION + ".dev0+" + GIT_REVISION[:7]
-31
View File
@@ -1,31 +0,0 @@
[metadata]
name = scraibe
version = attr: scraibe.__version__
author = Jacob Schmieder
author_email = Jacob.Schmieder@dbfz.de
description = My package description
long_description = file: README.md, LICENSE
platforms = Linux
keywords = transcription speech recognition whisper pyannote audio speech-to-text speech-to-text transcription speech-to-text recognition voice-to-speech
license = GPL-3.0
classifiers =
Development Status :: 3 - Alpha
Environment :: GPU :: NVIDIA CUDA :: 11.2
License :: OSI Approved :: Open Software License 3.0 (OSL-3.0)
Topic :: Scientific/Engineering :: Artificial Intelligence
Programming Language :: Python :: 3.8
Programming Language :: Python :: 3.9
Programming Language :: Python :: 3.10
[options]
zip_safe = False
include_package_data = True
packages = find:
python_requires = >=3.7
install_requires =
requests
importlib-metadata; python_version<"3.8"
[options.entry_points]
console_scripts =
executable-name = scraibe.cli:cli
-60
View File
@@ -1,60 +0,0 @@
from calendar import c
import pkg_resources
import os
from setuptools import setup, find_packages
module_name = "scraibe"
github_url = "https://github.com/JSchmie/ScAIbe"
file_dir = os.path.dirname(os.path.realpath(__file__))
absdir = lambda p: os.path.join(file_dir, p)
############### versioning ###############
verfile = os.path.abspath(os.path.join(module_name, "version.py"))
version = {"__file__": verfile}
with open(verfile, "r") as fp:
exec(fp.read(), version)
############### setup ###############
build_version = "SCRAIBE_BUILD" in os.environ
if __name__ == "__main__":
setup(
name=module_name,
version=version["get_version"](build_version),
packages=find_packages(),
python_requires=">=3.8",
readme="README.md",
install_requires = [str(r) for r in pkg_resources.parse_requirements(
open(os.path.join(os.path.dirname(__file__), "requirements.txt"))
)
],
dependency_links=[
'https://download.pytorch.org/whl/cu113',
],
url= github_url,
license='GPL-3',
author='Jacob Schmieder',
author_email='Jacob.Schmieder@dbfz.de',
description='Transcription tool for audio files based on Whisper and Pyannote',
classifiers=[
'Development Status :: 3 - Alpha',
'Environment :: GPU :: NVIDIA CUDA :: 11.2',
'License :: OSI Approved :: Open Software License 3.0 (OSL-3.0)',
'Topic :: Scientific/Engineering :: Artificial Intelligence',
'Programming Language :: Python :: 3.8',
'Programming Language :: Python :: 3.9',
'Programming Language :: Python :: 3.10'],
keywords = ['transcription', 'speech recognition', 'whisper', 'pyannote', 'audio',
'speech-to-text', 'speech-to-text transcription', 'speech-to-text recognition',
'voice-to-speech'],
package_data={'scraibe.app' : ["*.html", "*.svg"]},
entry_points={'console_scripts':
['scraibe = scraibe.cli:cli']}
)
+84
View File
@@ -0,0 +1,84 @@
# Configuration file for the Sphinx documentation builder.
#
# This file only contains a selection of the most common options. For a full
# list see the documentation:
# https://www.sphinx-doc.org/en/master/usage/configuration.html
# -- Path setup --------------------------------------------------------------
# If extensions (or modules to document with autodoc) are in another directory,
# add these directories to sys.path here. If the directory is relative to the
# documentation root, use os.path.abspath to make it absolute, like shown here.
#
import os
import sys
sys.path.insert(0, os.path.abspath('../'))
# -- Project information -----------------------------------------------------
project = 'ScrAIbe: Streamlined Conversation Recording with Automated Intelligence Based Environment'
copyright = '2023, Jacob Schmieder'
author = 'Jacob Schmieder'
# The full version, including alpha/beta/rc tags
release = '0.1.1'
# -- General configuration ---------------------------------------------------
# Add any Sphinx extension module names here, as strings. They can be
# extensions coming with Sphinx (named 'sphinx.ext.*') or your custom
# ones.
extensions = ['sphinx.ext.autodoc',
'sphinx.ext.doctest',
'sphinx.ext.intersphinx',
'sphinx.ext.todo',
'sphinx.ext.coverage',
'sphinx.ext.mathjax',
'sphinx.ext.ifconfig',
'sphinx.ext.viewcode',
'sphinx.ext.githubpages',
'sphinx.ext.napoleon',
'myst_parser']
# Napoleon settings
napoleon_google_docstring = True
napoleon_numpy_docstring = True
napoleon_include_init_with_doc = True
napoleon_include_private_with_doc = True
napoleon_include_special_with_doc = True
napoleon_use_admonition_for_examples = False
napoleon_use_admonition_for_notes = False
napoleon_use_admonition_for_references = False
napoleon_use_ivar = False
napoleon_use_param = True
napoleon_use_rtype = True
# Add any paths that contain templates here, relative to this directory.
templates_path = ['_templates']
# List of patterns, relative to source directory, that match files and
# directories to ignore when looking for source files.
# This pattern also affects html_static_path and html_extra_path.
exclude_patterns = []
# Add source file parsers
source_suffix = {
'.rst': 'restructuredtext',
'.txt': 'markdown',
'.md': 'markdown',
}
# -- Options for HTML output -------------------------------------------------
# The theme to use for HTML and HTML Help pages. See the documentation for
# a list of builtin themes.
#
html_theme = 'sphinx_rtd_theme'
# Add any paths that contain custom static files (such as style sheets) here,
# relative to this directory. They are copied after the builtin static files,
# so a file named "default.css" will overwrite the builtin "default.css".
html_static_path = ['_static']
+21
View File
@@ -0,0 +1,21 @@
Welcome to ScrAIbe: Streamlined Conversation Recording with Automated Intelligence Based Environment's documentation!
=====================================================================================================================
.. automodule:: scraibe
:members:
.. toctree::
:maxdepth: 2
:caption: Contents:
../README.md
modules
Indices and tables
==================
* :ref:`genindex`
* :ref:`modindex`
* :ref:`search`
+7
View File
@@ -0,0 +1,7 @@
scraibe
=======
.. toctree::
:maxdepth: 4
scraibe
Binary file not shown.
Binary file not shown.
+96
View File
@@ -0,0 +1,96 @@
import pytest
from scraibe.audio import AudioProcessor
import torch
DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu")
TEST_WAVEFORM = torch.sin(torch.randn(160000)).to(DEVICE)
TEST_SR = 16000
SAMPLE_RATE = 16000
NORMALIZATION_FACTOR = 32768
@pytest.fixture
def probe_audio_processor():
"""Fixture for creating an instance of the AudioProcessor class with test waveform and sample rate.
This fixture is used to create an instance of the AudioProcessor class with a predfined test waveform and sample rate (TEST_SR). It returns the instantiated AudioProcessor , which can bes used as a
dependency in other test functions.
Returns:
AudioProcessor (obj): An instance of the AudioProcessor class with the test waveform and sample rate.
"""
return AudioProcessor(TEST_WAVEFORM, TEST_SR)
def test_AudioProcessor_init(probe_audio_processor):
"""
Test the initialization of the AudioProcessor class.
This test verifies that the AUdioProcessor class is correctly initialized with the provided waveform and sample rate. It checks whether the instantiated AhdioProcessor object has the correct attributes
and whether the waveform and sample rate match the expected values.
Args:
probe_audio_processor (obj): An instance of the AudioProcessor class to be tested.
Returns:
None
"""
assert isinstance(probe_audio_processor, AudioProcessor)
assert probe_audio_processor.waveform.device == TEST_WAVEFORM.device
assert torch.equal(probe_audio_processor.waveform, TEST_WAVEFORM)
assert probe_audio_processor.sr == TEST_SR
def test_cut(probe_audio_processor):
"""Test the cut function of the AudioProcessor class.
This test verifies that the cut function correctly extracts a segment of audio data from
the waveform, given start and end indices. It checks whether the size of the extracted segment matches
the expected size based on the provided start and end indices and the sample rate.
Returns:
None
"""
start = 4
end = 7
trimmed_waveform = probe_audio_processor.cut(start, end)
expected_size = int((end - start) * TEST_SR)
real_size = trimmed_waveform.size(0)
assert real_size == expected_size
# assert AudioProcessor(TEST_WAVEFORM, TEST_SR).cut(start, end).size() == int((end - start) * TEST_SR)
def test_audio_processor_invalid_sr():
"""Test the behavior of AudioProcessor when an invalid smaple rate is provided.
This test verifies that the AudioProcessor constructor raises a ValueError when an invalid sample rate is provided. It uses the pytest.raises context manager to check if the ValueError is raised when initializing an
AudioProcessor object with an invalid sample rate.
Returns:
None
"""
with pytest.raises(ValueError):
AudioProcessor(TEST_WAVEFORM, [44100, 48000])
def test_audio_processor_SAMPLE_RATE():
"""Test the default sample rate of the AudioProcessor class.
This test verifies that the default sample rate of the AudioProcessor class matches the expected value defined by the constant SAMPLE_RATE. It instantiates an AudioProcessor object with a test waveform
and checks whether the sample rate attribute (sr) of the AudioProcessor object equals the predefined constant SAMPLE_RATE.
Returns:
None
"""
probe_audio_processor = AudioProcessor(TEST_WAVEFORM)
assert probe_audio_processor.sr == SAMPLE_RATE
+52
View File
@@ -0,0 +1,52 @@
import pytest
from scraibe import Scraibe, Diariser, Transcriber, Transcript
import os
@pytest.fixture
def create_scraibe_instance():
if "HF_TOKEN" in os.environ:
return Scraibe(use_auth_token=os.environ["HF_TOKEN"])
else:
return Scraibe()
def test_scraibe_init(create_scraibe_instance):
model = create_scraibe_instance
assert isinstance(model.transcriber, Transcriber)
assert isinstance(model.diariser, Diariser)
def test_scraibe_autotranscribe(create_scraibe_instance):
model = create_scraibe_instance
transcript = model.autotranscribe('test/audio_test_2.mp4')
assert isinstance(transcript, Transcript)
def test_scraibe_diarization(create_scraibe_instance):
model = create_scraibe_instance
diarisation_result = model.diarization('test/audio_test_2.mp4')
assert isinstance(diarisation_result, dict)
def test_scraibe_transcribe(create_scraibe_instance):
model = create_scraibe_instance
transcription_result = model.transcribe('test/audio_test_2.mp4')
assert isinstance(transcription_result, str)
""" def test_remove_audio_file(create_scraibe_instance):
model = create_scraibe_instance
with pytest.raises(ValueError):
model.remove_audio_file("non_existing_audio_file")
model.remove_audio_file("audio_test_2.mp4")
assert not os.path.exists("audio_test_2.mp4") """
""" def test_get_audio_file(create_scraibe_instance):
model = create_scraibe_instance
audio_file = os.path.exist("audio_test_2.mp4")
assert isinstance(audio_file, AudioProcessor)
assert isinstance(audio_file.waveform, torch.Tensor)
assert isinstance(audio_file.sr, torch.Tensor) """
+32
View File
@@ -0,0 +1,32 @@
import pytest
from scraibe import Diariser
@pytest.fixture
def diariser_instance():
"""Fixture for creating an instance of the Diariser class with mocked token.
This fixture is used to create an instance of the the Diariser class with a mocked token returned by the _get_token method. It patches the _get_token method of the Diariser class
using unit.test.mock.patch.object, ensuring that it returns a predetrmined value ('personal Hugging-Face token'). The mocked Diariser object is retunrned and can be used as a dependency in otehr tests.
Returns:
Diariser(Obj): An instance of the Diariser class with a mocked token.
"""
# with mock.patch.object(Diariser, '_get_token', return_value = 'HF_TOKEN' ):
return Diariser('pyannote')
def test_Diariser_init(diariser_instance):
"""Test the initialization of the Diariser class.
This test verifies that the Diariser class is correctly initialized with the specified model.
It checks whether the 'model' attribute of the instantiated Diariser object equals 'pyannote'.
Args:
diariser_instance (obj): instance of the Diariser class
Returns:
None
"""
assert diariser_instance.model == 'pyannote'
+80
View File
@@ -0,0 +1,80 @@
import pytest
from scraibe import (Transcriber, WhisperTranscriber,
WhisperXTranscriber, load_transcriber)
import torch
DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu")
TEST_WAVEFORM = "Hello World"
"""
@pytest.mark.parametrize("audio_file, expected_transcription",[("path_to_test_audiofile", "test_transcription")] )
@patch("scraibe.Transcriber.load_model")
def test_transcriber(mock_load_model, audio_file, expected_transcription):
Args:
mock_load_model (_type_): _description_
audio_file (_type_): _description_
expected_transcription (_type_): _description_
mock_model = mock_load_model.return_value
mock_model.transcribe.return_value ={"text": expected_transcription}
transcriber = Transcriber.load_model(model="medium")
transcription_result = transcriber.transcribe(audio=audio_file)
assert transcription_result == expected_transcription """
@pytest.fixture
def whisper_instance():
return load_transcriber('medium', whisper_type='whisper')
@pytest.fixture
def whisperx_instance():
return load_transcriber('medium', whisper_type='whisperx')
def test_whisper_base_initialization(whisper_instance):
assert isinstance(whisper_instance, Transcriber)
def test_whisperx_base_initialization(whisperx_instance):
assert isinstance(whisperx_instance, Transcriber)
def test_whisper_transcriber_initialization(whisper_instance):
assert isinstance(whisper_instance, WhisperTranscriber)
def test_whisperx_transcriber_initialization(whisperx_instance):
assert isinstance(whisperx_instance, WhisperXTranscriber)
def test_wrong_transcriber_initialization():
with pytest.raises(ValueError):
load_transcriber('medium', whisper_type='wrong_whisper')
def test_get_whisper_kwargs():
kwargs = {"arg1": 1, "arg3": 3}
valid_kwargs = Transcriber._get_whisper_kwargs(**kwargs)
assert not valid_kwargs == {"arg1": 1, "arg3": 3}
def test_whisper_transcribe(whisper_instance):
model = whisper_instance
# mocker.patch.object(transcriber_instance.model, 'transcribe', return_value={'Hello, World !'} )
transcript = model.transcribe('test/audio_test_2.mp4')
assert isinstance(transcript, str)
def test_whisperx_transcribe(whisperx_instance):
model = whisperx_instance
# mocker.patch.object(transcriber_instance.model, 'transcribe', return_value={'Hello, World !'} )
transcript = model.transcribe('test/audio_test_2.mp4')
assert isinstance(transcript, str)
-120
View File
@@ -1,120 +0,0 @@
import pytest
from scraibe import Transcriber
from unittest.mock import patch, mock_open
import os
def test_load_pyannote_model():
"""
Test load_pyannote_test
"""
from pyannote.audio.pipelines.speaker_diarization import SpeakerDiarization
from pyannote.audio import Pipeline
pipeline = Pipeline.from_pretrained("models/pyannote/speaker_diarization/config.yaml")
assert isinstance(pipeline, SpeakerDiarization)
# Test Transcribtion class
@pytest.fixture
def transcriber():
"""
Prepare Transcriber for testing
Returns: Transcriber Object
"""
return Transcriber.load_model("medium", local=True)
def test_Transcriber_init(transcriber):
"""
Test Transcriber initialization with a whisper model
"""
assert isinstance(transcriber, Transcriber)
def test_transcription(transcriber):
"""
Test transcription
"""
transcript = transcriber.transcribe("tests/test.wav")
assert isinstance(transcript, str)
def test_save_transcript_to_file(transcriber):
"""
Test save_transcript_to_file
"""
transcript = transcriber.transcribe("tests/test.wav")
Transcriber.save_transcript(transcript, "tests/output.txt")
assert os.path.exists("tests/output.txt")
os.remove("tests/output.txt")
# Test Diaraization class
from scraibe import Diariser
@pytest.fixture
def diarisation():
"""
Prepare Diarisation for testing
Returns: Diarisation Object
"""
return Diariser.load_model("models/pyannote/speaker_diarization/config.yaml", local=True)
def test_Diarisation_init(diarisation):
"""
Test Diarisation initialization with a pyannote model
"""
assert isinstance(diarisation, Diariser)
def test_diarisation(diarisation):
"""
Test diarisation
"""
diarisation = diarisation.diarization("tests/test.wav")
assert isinstance(diarisation, dict)
# Test AudioProcessor
from scraibe import AudioProcessor , TorchAudioProcessor
def test_AudioProcessor_init():
"""
Test AudioProcessor initialization
"""
audio = AudioProcessor("tests/test.wav")
assert isinstance(audio, AudioProcessor)
def test_AudioProcessor_convert():
"""
Test AudioProcessor convert
"""
audio = AudioProcessor("tests/test.wav")
audio.convert_audio("tests/test.mp3", format="mp3")
assert os.path.exists("tests/test.mp3")
def test_TorchAudioProcessor_from_file():
"""
Test TorchAudioProcessor initialization
"""
audio = TorchAudioProcessor.from_file("tests/test.wav")
assert isinstance(audio, TorchAudioProcessor)
os.remove("tests/test.mp3")
def test_TorchAudioProcessor_from_ffmpeg():
"""
Test TorchAudioProcessor initialization
"""
audio = TorchAudioProcessor.from_ffmpeg("tests/test.wav")
assert isinstance(audio, TorchAudioProcessor)