From e499c987038d3f137fd8f1ed3fecb2444b79a45e Mon Sep 17 00:00:00 2001 From: Jaikinator Date: Fri, 25 Aug 2023 14:30:06 +0200 Subject: [PATCH 01/10] configured setup.py to handle pytorch --- setup.py | 5 ++++- 1 file changed, 4 insertions(+), 1 deletion(-) diff --git a/setup.py b/setup.py index e7da608..bf6a912 100644 --- a/setup.py +++ b/setup.py @@ -32,10 +32,13 @@ if __name__ == "__main__": open(os.path.join(os.path.dirname(__file__), "requirements.txt")) ) ], + dependency_links=[ + 'https://download.pytorch.org/whl/cu113', + ], url= github_url, license='', author='Jacob Schmieder', - author_email='', + author_email='Jacob.Schmieder@dbfz.de', description='Transcription tool for audio files based on Whisper and Pyannote', entry_points={'console_scripts': ['autotranscript = autotranscript.autotranscript:cli']} From 685fdfcfac6d639caf3aeb55687851cd8ab74f32 Mon Sep 17 00:00:00 2001 From: Jaikinator Date: Fri, 25 Aug 2023 14:32:14 +0200 Subject: [PATCH 02/10] resolved same argument name --- autotranscript/autotranscript.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/autotranscript/autotranscript.py b/autotranscript/autotranscript.py index e053d6a..a8e23aa 100644 --- a/autotranscript/autotranscript.py +++ b/autotranscript/autotranscript.py @@ -285,7 +285,7 @@ def cli(): parser.add_argument("--output_directory", "-o", type=str, default=".", help="Directory to save the transcription outputs.") - parser.add_argument("--output_format", "-f", type=str, default="txt", + parser.add_argument("--output_format", "-of", type=str, default="txt", choices=["txt", "json", "md", "html"], help="Format of the output file; defaults to txt.") From a3b65d22aac2672d244b3c5307ba3a47bddd36d5 Mon Sep 17 00:00:00 2001 From: Jaikinator Date: Fri, 25 Aug 2023 14:32:34 +0200 Subject: [PATCH 03/10] added gradio and pytorch --- requirements.txt | 8 ++++++++ 1 file changed, 8 insertions(+) diff --git a/requirements.txt b/requirements.txt index b81b23c..6375a2c 100644 --- a/requirements.txt +++ b/requirements.txt @@ -11,6 +11,14 @@ setuptools-rust~=1.5.2 tqdm>=4.65.0 +gradio~=3.36.1 +gradio-client~=0.2.7 + +# add pytorch to override the one installed by pyannote.audio + +torch~=1.11.0 +torchvision~=0.12.0 +torchaudio~=0.11.0 #optional: #dash~=2.10.2 From b2f332a4d25e70a76c33572973c86d0bd6027bf2 Mon Sep 17 00:00:00 2001 From: Jaikinator Date: Fri, 25 Aug 2023 14:33:35 +0200 Subject: [PATCH 04/10] updated enviroment --- environment.yml | 69 ++++++++++++++++++++++++++++++++++++++++++------- 1 file changed, 59 insertions(+), 10 deletions(-) diff --git a/environment.yml b/environment.yml index aeb907b..7913480 100644 --- a/environment.yml +++ b/environment.yml @@ -1,4 +1,3 @@ -name: whisper channels: - pytorch - defaults @@ -11,7 +10,7 @@ dependencies: - ca-certificates=2023.05.30=h06a4308_0 - certifi=2023.5.7=py39h06a4308_0 - cffi=1.15.1=py39h5eee18b_3 - - cryptography=39.0.1=py39h9ce1e76_0 + - cryptography=39.0.1=py39h9ce1e76_2 - cudatoolkit=11.3.1=h2bc3f7f_2 - ffmpeg=4.2.2=h20bf706_0 - flit-core=3.8.0=py39h06a4308_0 @@ -51,36 +50,40 @@ dependencies: - numpy=1.23.5=py39h14f4228_0 - numpy-base=1.23.5=py39h31eccc5_0 - openh264=2.1.1=h4ff587b_0 - - openssl=1.1.1t=h7f8727e_0 + - openssl=3.0.9=h7f8727e_0 - pillow=9.4.0=py39h6a678d5_0 - pip=23.0.1=py39h06a4308_0 - pycparser=2.21=pyhd3eb1b0_0 - pyopenssl=23.0.0=py39h06a4308_0 - pysocks=1.7.1=py39h06a4308_0 - - python=3.9.16=h7a1cb2a_2 + - python=3.9.16=h955ad1f_3 - pytorch=1.11.0=py3.9_cuda11.3_cudnn8.2.0_0 - pytorch-mutex=1.0=cuda - readline=8.2=h5eee18b_0 - requests=2.28.1=py39h06a4308_1 - setuptools=65.6.3=py39h06a4308_0 - six=1.16.0=pyhd3eb1b0_1 - - sqlite=3.41.1=h5eee18b_0 + - sqlite=3.41.2=h5eee18b_0 - tk=8.6.12=h1ccaba5_0 - torchaudio=0.11.0=py39_cu113 - torchvision=0.12.0=py39_cu113 - - typing_extensions=4.4.0=py39h06a4308_0 - tzdata=2023c=h04d1e81_0 - wheel=0.38.4=py39h06a4308_0 - x264=1!157.20191217=h7b6447c_0 - - xz=5.2.10=h5eee18b_1 + - xz=5.4.2=h5eee18b_0 - zlib=1.2.13=h5eee18b_0 - zstd=1.5.4=hc292b87_0 - pip: - absl-py==1.3.0 + - aiofiles==23.1.0 - aiohttp==3.8.3 - aiosignal==1.3.1 - alembic==1.9.1 + - altair==5.0.1 + - annotated-types==0.5.0 + - ansi2html==1.8.0 - antlr4-python3-runtime==4.9.3 + - anyio==3.7.1 - appdirs==1.4.4 - asteroid-filterbanks==0.4.0 - async-timeout==4.0.2 @@ -100,48 +103,76 @@ dependencies: - commonmark==0.9.1 - contourpy==1.0.6 - cycler==0.11.0 + - dash==2.12.1 + - dash-core-components==2.0.0 + - dash-html-components==2.0.0 + - dash-table==5.0.0 - decorator==4.4.2 - docopt==0.6.2 - einops==0.3.2 + - exceptiongroup==1.1.1 + - fastapi==0.100.0 - ffmpeg-python==0.2.0 + - ffmpy==0.3.0 - filelock==3.8.0 + - flask==2.2.5 - fonttools==4.38.0 - frozenlist==1.3.3 - fsspec==2022.11.0 - future==0.18.2 - google-auth==2.15.0 - google-auth-oauthlib==0.4.6 + - gradio==3.36.1 + - gradio-client==0.2.7 - greenlet==2.0.1 - grpcio==1.51.1 + - h11==0.14.0 - hmmlearn==0.2.8 - - huggingface-hub==0.11.0 + - httpcore==0.17.3 + - httpx==0.24.1 + - huggingface-hub==0.16.4 + - humanize==4.7.0 - hyperpyyaml==1.1.0 - imageio==2.23.0 - imageio-ffmpeg==0.4.7 - importlib-metadata==4.13.0 + - importlib-resources==5.12.0 + - iniconfig==2.0.0 + - itsdangerous==2.1.2 + - jinja2==3.1.2 - joblib==1.2.0 + - jsonschema==4.18.0 + - jsonschema-specifications==2023.6.1 - julius==0.2.7 - kiwisolver==1.4.4 - librosa==0.9.2 + - linkify-it-py==2.0.2 - lit==16.0.5.post0 - llvmlite==0.39.1 - mako==1.2.4 - markdown==3.4.1 + - markdown-it-py==2.2.0 - markupsafe==2.1.1 - - matplotlib==3.6.2 + - matplotlib==3.7.1 + - mdit-py-plugins==0.3.3 + - mdurl==0.1.2 - more-itertools==9.0.0 - moviepy==1.0.3 - mpmath==1.2.1 - multidict==6.0.4 + - nest-asyncio==1.5.7 - networkx==2.8.8 - numba==0.56.4 - oauthlib==3.2.2 - omegaconf==2.3.0 - openai-whisper==20230314 - optuna==3.0.5 + - orjson==3.9.2 - packaging==21.3 - pandas==1.5.2 - pbr==5.11.0 + - plotly==5.15.0 + - pluggy==1.0.0 - pooch==1.6.0 - prettytable==3.5.0 - primepy==1.3 @@ -154,23 +185,32 @@ dependencies: - pyannote-pipeline==2.3 - pyasn1==0.4.8 - pyasn1-modules==0.2.8 + - pydantic==2.0.2 + - pydantic-core==2.1.2 - pydeprecate==0.3.2 - pydub==0.25.1 - pygments==2.13.0 - pyparsing==3.0.9 - pyperclip==1.8.2 + - pytest==7.3.1 - python-dateutil==2.8.2 + - python-multipart==0.0.6 - pytorch-lightning==1.6.5 - pytorch-metric-learning==1.6.3 - pytz==2022.7 - pyyaml==6.0 + - qtfaststart==1.8 + - referencing==0.29.1 - regex==2022.10.31 - requests-oauthlib==1.3.1 - resampy==0.4.2 + - retrying==1.3.4 - rich==12.6.0 + - rpds-py==0.8.10 - rsa==4.9 - ruamel-yaml==0.17.21 - ruamel-yaml-clib==0.2.7 + - ruff==0.0.272 - scikit-learn==1.2.0 - scipy==1.8.1 - semantic-version==2.10.0 @@ -180,19 +220,24 @@ dependencies: - shellingham==1.5.0 - simplejson==3.18.0 - singledispatchmethod==1.0 + - sniffio==1.3.0 - sortedcontainers==2.4.0 - soundfile==0.10.3.post1 - - speechbrain==0.5.13 + - speechbrain==0.5.14 - sqlalchemy==1.4.45 + - starlette==0.27.0 - stevedore==4.1.1 - sympy==1.11.1 - tabulate==0.9.0 + - tenacity==8.2.2 - tensorboard==2.11.0 - tensorboard-data-server==0.6.1 - tensorboard-plugin-wit==1.8.1 - threadpoolctl==3.1.0 - tiktoken==0.3.1 - tokenizers==0.13.2 + - tomli==2.0.1 + - toolz==0.12.0 - torch-audiomentations==0.11.0 - torch-pitch-shift==1.2.2 - torchmetrics==0.11.0 @@ -200,8 +245,12 @@ dependencies: - transformers==4.24.0 - triton==2.0.0 - typer==0.7.0 + - typing-extensions==4.7.1 + - uc-micro-py==1.0.2 - urllib3==1.26.12 + - uvicorn==0.22.0 - wcwidth==0.2.5 + - websockets==11.0.3 - werkzeug==2.2.2 - yarl==1.8.2 - zipp==3.11.0 From f162b480d36e987bbd48814d6e2932d832cb2d0f Mon Sep 17 00:00:00 2001 From: Jaikinator Date: Mon, 28 Aug 2023 15:16:53 +0200 Subject: [PATCH 05/10] changed function name and added addional function for easier use --- autotranscript/autotranscript.py | 184 +++++++++---------------------- 1 file changed, 52 insertions(+), 132 deletions(-) diff --git a/autotranscript/autotranscript.py b/autotranscript/autotranscript.py index a8e23aa..44bf2d4 100644 --- a/autotranscript/autotranscript.py +++ b/autotranscript/autotranscript.py @@ -24,9 +24,9 @@ Usage: """ # Standard Library Imports -import argparse import os from glob import iglob +import re from subprocess import run from typing import TypeVar, Union from warnings import warn @@ -93,7 +93,7 @@ class AutoTranscribe: print("AutoTranscribe initialized all models successfully loaded.") - def transcribe(self, audio_file : Union[str, torch.Tensor, ndarray], + def autotranscribe(self, audio_file : Union[str, torch.Tensor, ndarray], remove_original : bool = False, **kwargs) -> Transcript: """ @@ -164,6 +164,55 @@ class AutoTranscribe: return Transcript(final_transcript) + def diarization(self, audio_file : Union[str, torch.Tensor, ndarray], + **kwargs) -> dict: + """ + Perform diarization on an audio file using the pyannote diarization model. + + Args: + audio_file (Union[str, torch.Tensor, ndarray]): + The audio source which can either be a path to the audio file or a tensor representation. + **kwargs: + Additional keyword arguments for diarization. + + Returns: + dict: + A dictionary containing the results of the diarization process. + """ + + # Get audio file as an AudioProcessor object + audio_file = self.get_audio_file(audio_file) + + # Prepare waveform and sample rate for diarization + dia_audio = { + "waveform" : audio_file.waveform.reshape(1,len(audio_file.waveform)), + "sample_rate": audio_file.sr + } + + print("Starting diarisation.") + + diarisation = self.diariser.diarization(dia_audio, **kwargs) + + return diarisation + + def transcribe(self, audio_file : Union[str, torch.Tensor, ndarray], + **kwargs): + """ + Transcribe the provided audio file. + + Args: + audio_file (Union[str, torch.Tensor, ndarray]): + The audio source, which can either be a path or a tensor representation. + **kwargs: + Additional keyword arguments for transcription. + + Returns: + str: + The transcribed text from the audio source. + """ + audio_file = self.get_audio_file(audio_file) + + return self.transcriber.transcribe(audio_file.waveform, **kwargs) @staticmethod def remove_audio_file(audio_file : str, shred : bool = False) -> None: @@ -228,133 +277,4 @@ class AutoTranscribe: raise ValueError(f'Audiofile must be of type AudioProcessor,' \ f'not {type(audio_file)}') return audio_file - - -def cli(): - """ - Command-Line Interface (CLI) for the AutoTranscribe class, allowing for user interaction to transcribe - and diarize audio files. The function includes arguments for specifying the audio files, model paths, - output formats, and other options necessary for transcription. - - This function can be executed from the command line to perform transcription tasks, providing a - user-friendly way to access the AutoTranscribe class functionalities. - """ - from whisper import available_models - from whisper.utils import get_writer - from whisper.tokenizer import LANGUAGES , TO_LANGUAGE_CODE - from .transcriber import WHISPER_DEFAULT_PATH - from .diarisation import PYANNOTE_DEFAULT_PATH - - def str2bool(string): - str2val = {"True": True, "False": False} - if string in str2val: - return str2val[string] - else: - raise ValueError(f"Expected one of {set(str2val.keys())}, got {string}") - - parser = argparse.ArgumentParser(formatter_class=argparse.ArgumentDefaultsHelpFormatter) - - parser.add_argument("-f","--audio_files", nargs="+", type=str, - help="List of audio files to transcribe.") - - parser.add_argument('--start_server', action='store_true', - help='Start the Gradio app.') - - parser.add_argument("--whisper_model_name", default="medium", - help="Name of the Whisper model to use.") - - parser.add_argument("--whisper_model_directory", type=str, default=WHISPER_DEFAULT_PATH, - help="Path to save Whisper model files; defaults to ./models/whisper.") - - parser.add_argument("--diarization_directory", type=str, default=PYANNOTE_DEFAULT_PATH, - help="Path to the diarization model directory.") - - parser.add_argument("--huggingface_token", default="", type=str, - help="HuggingFace token for private model download.") - - parser.add_argument("--allow_download", type=str2bool, default=False, - help="Allow model download if not found locally.") - - parser.add_argument("--inference_device", - default="cuda" if torch.cuda.is_available() else "cpu", - help="Device to use for PyTorch inference.") - - parser.add_argument("--num_threads", type=int, default=0, - help="Number of threads used by torch for CPU inference; overrides MKL_NUM_THREADS/OMP_NUM_THREADS.") - - parser.add_argument("--output_directory", "-o", type=str, default=".", - help="Directory to save the transcription outputs.") - - parser.add_argument("--output_format", "-of", type=str, default="txt", - choices=["txt", "json", "md", "html"], - help="Format of the output file; defaults to txt.") - - parser.add_argument("--verbose_output", type=str2bool, default=True, - help="Enable or disable progress and debug messages.") - - parser.add_argument("--transcription_task", type=str, default="transcribe", - choices=["transcribe", "diarize", "wtranscribe"], - help="Choose to perform transcription, diarization, or Whisper transcription.") - - parser.add_argument("--spoken_language", type=str, default=None, - choices=sorted(LANGUAGES.keys()) + sorted([k.title() for k in TO_LANGUAGE_CODE.keys()]), - help="Language spoken in the audio. Specify None to perform language detection.") - - args = parser.parse_args() - - output_directory = args.output_directory - num_threads = args.num_threads - whisper_model_directory = args.whisper_model_directory - allow_download = args.allow_download - inference_device = args.inference_device - whisper_model_name = args.whisper_model_name - diarization_directory = args.diarization_directory - huggingface_token = args.huggingface_token - transcription_task = args.transcription_task - audio_files = args.audio_files - spoken_language = args.spoken_language - output_format = args.output_format - start_server = args.start_server - - os.makedirs(output_directory, exist_ok=True) - - if num_threads > 0: - torch.set_num_threads(num_threads) - - whisper_kwargs = { - "download_root": whisper_model_directory, - "local": allow_download, - "device": inference_device - } - - diarisation_kwargs = { - "local": allow_download, - "token": huggingface_token - } - - model = AutoTranscribe(whisper_model=whisper_model_name, - whisper_kwargs=whisper_kwargs, - dia_model=diarization_directory, - dia_kwargs=diarisation_kwargs) - - if transcription_task == "transcribe": - for audio in audio_files: - out = model.transcribe(audio, language=spoken_language) - basename = audio.split("/")[-1].split(".")[0] - spath = f"{output_directory}/{basename}.{output_format}" - out.save(spath) - - # ... include other tasks here ... - elif transcription_task == "diarize": - # diarize code here - pass - elif transcription_task == "wtranscribe": - # wtranscribe code here - pass - - if start_server: - from .gradio_app import gradio_app - gradio_app(model) - -if __name__ == "__main__": - cli() \ No newline at end of file + \ No newline at end of file From 76310a8d1c48b3ed348ac258e245861de742c09a Mon Sep 17 00:00:00 2001 From: Jaikinator Date: Mon, 28 Aug 2023 15:17:14 +0200 Subject: [PATCH 06/10] moved cli into extra file --- autotranscript/cli.py | 143 ++++++++++++++++++++++++++++++++++++++++++ 1 file changed, 143 insertions(+) create mode 100644 autotranscript/cli.py diff --git a/autotranscript/cli.py b/autotranscript/cli.py new file mode 100644 index 0000000..1507f3a --- /dev/null +++ b/autotranscript/cli.py @@ -0,0 +1,143 @@ +""" +Command-Line Interface (CLI) for the AutoTranscribe class, +allowing for user interaction to transcribe and diarize audio files. +The function includes arguments for specifying the audio files, model paths, +output formats, and other options necessary for transcription. +""" +import os +from argparse import ArgumentParser, ArgumentDefaultsHelpFormatter +from turtle import st + +from .transcriber import WHISPER_DEFAULT_PATH +from .diarisation import PYANNOTE_DEFAULT_PATH +from .autotranscript import AutoTranscribe + +from whisper import available_models +from whisper.utils import get_writer +from whisper.tokenizer import LANGUAGES , TO_LANGUAGE_CODE +from torch.cuda import is_available +from torch import set_num_threads + + +def cli(): + """ + Command-Line Interface (CLI) for the AutoTranscribe class, allowing for user interaction to transcribe + and diarize audio files. The function includes arguments for specifying the audio files, model paths, + output formats, and other options necessary for transcription. + + This function can be executed from the command line to perform transcription tasks, providing a + user-friendly way to access the AutoTranscribe class functionalities. + """ + + def str2bool(string): + str2val = {"True": True, "False": False} + if string in str2val: + return str2val[string] + else: + raise ValueError(f"Expected one of {set(str2val.keys())}, got {string}") + + parser = ArgumentParser(formatter_class = ArgumentDefaultsHelpFormatter) + + group = parser.add_mutually_exclusive_group() + + parser.add_argument("-f","--audio_files", nargs="+", type=str, default=None, + help="List of audio files to transcribe.") + + group.add_argument('--start_server', action='store_true', + help='Start the Gradio app.') + + parser.add_argument("--port", type=int, default= None, + help="Port to run the Gradio app on.") + + parser.add_argument("--server_name", type=str, default= "autotranscript", + help="Name of the Gradio app.") + + parser.add_argument("--whisper_model_name", default="medium", + help="Name of the Whisper model to use.") + + parser.add_argument("--whisper_model_directory", type=str, default= None, + help="Path to save Whisper model files; defaults to ./models/whisper.") + + parser.add_argument("--diarization_directory", type=str, default= None, + help="Path to the diarization model directory.") + + parser.add_argument("--huggingface_token", default= None, type=str, + help="HuggingFace token for private model download.") + + parser.add_argument("--allow_download", type=str2bool, default=True, + help="Allow model download if not found locally.") + + parser.add_argument("--inference_device", + default="cuda" if is_available() else "cpu", + help="Device to use for PyTorch inference.") + + parser.add_argument("--num_threads", type=int, default=0, + help="Number of threads used by torch for CPU inference; overrides MKL_NUM_THREADS/OMP_NUM_THREADS.") + + parser.add_argument("--output_directory", "-o", type=str, default=".", + help="Directory to save the transcription outputs.") + + parser.add_argument("--output_format", "-of", type=str, default="txt", + choices=["txt", "json", "md", "html"], + help="Format of the output file; defaults to txt.") + + parser.add_argument("--verbose_output", type=str2bool, default=True, + help="Enable or disable progress and debug messages.") + + parser.add_argument("--task", type=str, default= None, # unifinished code + choices=["autoranscribe", "diarize", "autotranscribe+translate", "translate"], + help="Choose to perform transcription, diarization, or translation. \ + If set to translate, the language argument must be specified.") + + parser.add_argument("--language", type=str, default=None, + choices=sorted(LANGUAGES.keys()) + sorted([k.title() for k in TO_LANGUAGE_CODE.keys()]), + help="Language spoken in the audio. Specify None to perform language detection.") + + args = parser.parse_args() + + arg_dict = vars(args) + + # configure output + + os.makedirs(arg_dict.pop("output_directory"), exist_ok=True) + + out_format = arg_dict.pop("output_format") + + # seup server arg: + start_server = arg_dict.pop("start_server") + + + if args.num_threads > 0: + set_num_threads(arg_dict.pop("num_threads")) + + class_kwargs = dict() + + for k, v in arg_dict.items(): + if v is not None: + class_kwargs[k] = v + + + + model = AutoTranscribe(**class_kwargs) + + # if transcription_task == "transcribe": + # for audio in audio_files: + # out = model.transcribe(audio, language=spoken_language) + # basename = audio.split("/")[-1].split(".")[0] + # spath = f"{output_directory}/{basename}.{output_format}" + # out.save(spath) + + # # ... include other tasks here ... + # elif transcription_task == "diarize": + # # diarize code here + # pass + # elif transcription_task == "wtranscribe": + # # wtranscribe code here + # pass + + # if start_server: # unfinished code + # from .gradio_app import gradio_app + # gradio_app(model) + +if __name__ == "__main__": + cli() \ No newline at end of file From bf8ee9accaeeca823d39a9eef11c461df9804704 Mon Sep 17 00:00:00 2001 From: Jaikinator Date: Mon, 28 Aug 2023 15:17:27 +0200 Subject: [PATCH 07/10] added cli to init --- autotranscript/__init__.py | 1 + 1 file changed, 1 insertion(+) diff --git a/autotranscript/__init__.py b/autotranscript/__init__.py index 20bcc93..55b5bc3 100644 --- a/autotranscript/__init__.py +++ b/autotranscript/__init__.py @@ -6,5 +6,6 @@ from .transcript_exporter import * from .diarisation import * from .version import get_version as _get_version from .misc import * +from .cli import * __version__ = _get_version() From 5be187998e5bd9fd7e0c2ef944d25c9b1aa83a3d Mon Sep 17 00:00:00 2001 From: Jaikinator Date: Mon, 28 Aug 2023 15:18:57 +0200 Subject: [PATCH 08/10] added functionality to translate --- autotranscript/transcriber.py | 3 +++ 1 file changed, 3 insertions(+) diff --git a/autotranscript/transcriber.py b/autotranscript/transcriber.py index 81787da..e319372 100644 --- a/autotranscript/transcriber.py +++ b/autotranscript/transcriber.py @@ -166,6 +166,9 @@ class Transcriber: _possible_kwargs = Whisper.transcribe.__code__.co_varnames whisper_kwargs = {k: v for k, v in kwargs.items() if k in _possible_kwargs} + + if (task := kwargs.get("task")): + whisper_kwargs["task"] = task return whisper_kwargs From 5937e81e3139db1332e55b3d2ecb4e96e8451235 Mon Sep 17 00:00:00 2001 From: Jaikinator Date: Mon, 28 Aug 2023 17:01:53 +0200 Subject: [PATCH 09/10] updated cli --- autotranscript/cli.py | 72 ++++++++++++++++++++++++++++++------------- 1 file changed, 50 insertions(+), 22 deletions(-) diff --git a/autotranscript/cli.py b/autotranscript/cli.py index 1507f3a..e4c8e45 100644 --- a/autotranscript/cli.py +++ b/autotranscript/cli.py @@ -6,7 +6,7 @@ output formats, and other options necessary for transcription. """ import os from argparse import ArgumentParser, ArgumentDefaultsHelpFormatter -from turtle import st +import json from .transcriber import WHISPER_DEFAULT_PATH from .diarisation import PYANNOTE_DEFAULT_PATH @@ -85,7 +85,7 @@ def cli(): help="Enable or disable progress and debug messages.") parser.add_argument("--task", type=str, default= None, # unifinished code - choices=["autoranscribe", "diarize", "autotranscribe+translate", "translate"], + choices=["autotranscribe", "diarization", "autotranscribe+translate", "translate"], help="Choose to perform transcription, diarization, or translation. \ If set to translate, the language argument must be specified.") @@ -98,14 +98,15 @@ def cli(): arg_dict = vars(args) # configure output - - os.makedirs(arg_dict.pop("output_directory"), exist_ok=True) + out_folder = arg_dict.pop("output_directory") + os.makedirs(out_folder, exist_ok=True) out_format = arg_dict.pop("output_format") # seup server arg: start_server = arg_dict.pop("start_server") + task = arg_dict.pop("task") if args.num_threads > 0: set_num_threads(arg_dict.pop("num_threads")) @@ -115,29 +116,56 @@ def cli(): for k, v in arg_dict.items(): if v is not None: class_kwargs[k] = v - model = AutoTranscribe(**class_kwargs) - # if transcription_task == "transcribe": - # for audio in audio_files: - # out = model.transcribe(audio, language=spoken_language) - # basename = audio.split("/")[-1].split(".")[0] - # spath = f"{output_directory}/{basename}.{output_format}" - # out.save(spath) - - # # ... include other tasks here ... - # elif transcription_task == "diarize": - # # diarize code here - # pass - # elif transcription_task == "wtranscribe": - # # wtranscribe code here - # pass + if arg_dict["audio_files"]: + audio_files = args.pop("audio_files") + + if task == "autotranscribe" or task == "autotranscribe+translate": + for audio in audio_files: + if task == "autotranscribe+translate": + task = "translate" + else: + task = "transcribe" + + out = model.autotranscribe(audio,task = task, language=arg_dict.pop("language"), verbose = arg_dict.pop("verbose_output")) + basename = audio.split("/")[-1].split(".")[0] + out.save(os.path.join(out_folder, f"{basename}.{out_format}")) + + elif task == "diarization": + for audio in audio_files: + if arg_dict.pop("verbose_output"): + print(f"Verbose not implemented for diarization.") + + out = model.diarization(audio) + basename = audio.split("/")[-1].split(".")[0] + path = os.path.join(out_folder, f"{basename}.{out_format}") + if out_format == "txt": + with open(path, "w") as f: + f.write(out) + elif out_format == "json": + with open(path, "w") as f: + json.dump(json.dumps(out, indent= 3), f) + else: + raise ValueError(f"Unsupported output format for diarization{out_format}.") + elif task == "transcribe" or task == "translate": + + for audio in audio_files: + + out = model.transcribe(audio, task = task, + language=arg_dict.pop("language"), + verbose = arg_dict.pop("verbose_output")) + basename = audio.split("/")[-1].split(".")[0] + path = os.path.join(out_folder, f"{basename}.{out_format}") + with open(path, "w") as f: + f.write(out) + - # if start_server: # unfinished code - # from .gradio_app import gradio_app - # gradio_app(model) + if start_server: # unfinished code + from .gradio_app import gradio_app + gradio_app(model) if __name__ == "__main__": cli() \ No newline at end of file From 064a169b52d0e8b89c65ecb8c413eb7ea70938ac Mon Sep 17 00:00:00 2001 From: Jaikinator Date: Wed, 30 Aug 2023 10:28:33 +0200 Subject: [PATCH 10/10] changed python version --- setup.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/setup.py b/setup.py index bf6a912..f5a4351 100644 --- a/setup.py +++ b/setup.py @@ -26,7 +26,7 @@ if __name__ == "__main__": name=module_name, version=version["get_version"](build_version), packages=find_packages(), - python_requires="~=3.9", + python_requires=">=3.8", readme="README.md", install_requires = [str(r) for r in pkg_resources.parse_requirements( open(os.path.join(os.path.dirname(__file__), "requirements.txt"))