From f162b480d36e987bbd48814d6e2932d832cb2d0f Mon Sep 17 00:00:00 2001 From: Jaikinator Date: Mon, 28 Aug 2023 15:16:53 +0200 Subject: [PATCH] changed function name and added addional function for easier use --- autotranscript/autotranscript.py | 184 +++++++++---------------------- 1 file changed, 52 insertions(+), 132 deletions(-) diff --git a/autotranscript/autotranscript.py b/autotranscript/autotranscript.py index a8e23aa..44bf2d4 100644 --- a/autotranscript/autotranscript.py +++ b/autotranscript/autotranscript.py @@ -24,9 +24,9 @@ Usage: """ # Standard Library Imports -import argparse import os from glob import iglob +import re from subprocess import run from typing import TypeVar, Union from warnings import warn @@ -93,7 +93,7 @@ class AutoTranscribe: print("AutoTranscribe initialized all models successfully loaded.") - def transcribe(self, audio_file : Union[str, torch.Tensor, ndarray], + def autotranscribe(self, audio_file : Union[str, torch.Tensor, ndarray], remove_original : bool = False, **kwargs) -> Transcript: """ @@ -164,6 +164,55 @@ class AutoTranscribe: return Transcript(final_transcript) + def diarization(self, audio_file : Union[str, torch.Tensor, ndarray], + **kwargs) -> dict: + """ + Perform diarization on an audio file using the pyannote diarization model. + + Args: + audio_file (Union[str, torch.Tensor, ndarray]): + The audio source which can either be a path to the audio file or a tensor representation. + **kwargs: + Additional keyword arguments for diarization. + + Returns: + dict: + A dictionary containing the results of the diarization process. + """ + + # Get audio file as an AudioProcessor object + audio_file = self.get_audio_file(audio_file) + + # Prepare waveform and sample rate for diarization + dia_audio = { + "waveform" : audio_file.waveform.reshape(1,len(audio_file.waveform)), + "sample_rate": audio_file.sr + } + + print("Starting diarisation.") + + diarisation = self.diariser.diarization(dia_audio, **kwargs) + + return diarisation + + def transcribe(self, audio_file : Union[str, torch.Tensor, ndarray], + **kwargs): + """ + Transcribe the provided audio file. + + Args: + audio_file (Union[str, torch.Tensor, ndarray]): + The audio source, which can either be a path or a tensor representation. + **kwargs: + Additional keyword arguments for transcription. + + Returns: + str: + The transcribed text from the audio source. + """ + audio_file = self.get_audio_file(audio_file) + + return self.transcriber.transcribe(audio_file.waveform, **kwargs) @staticmethod def remove_audio_file(audio_file : str, shred : bool = False) -> None: @@ -228,133 +277,4 @@ class AutoTranscribe: raise ValueError(f'Audiofile must be of type AudioProcessor,' \ f'not {type(audio_file)}') return audio_file - - -def cli(): - """ - Command-Line Interface (CLI) for the AutoTranscribe class, allowing for user interaction to transcribe - and diarize audio files. The function includes arguments for specifying the audio files, model paths, - output formats, and other options necessary for transcription. - - This function can be executed from the command line to perform transcription tasks, providing a - user-friendly way to access the AutoTranscribe class functionalities. - """ - from whisper import available_models - from whisper.utils import get_writer - from whisper.tokenizer import LANGUAGES , TO_LANGUAGE_CODE - from .transcriber import WHISPER_DEFAULT_PATH - from .diarisation import PYANNOTE_DEFAULT_PATH - - def str2bool(string): - str2val = {"True": True, "False": False} - if string in str2val: - return str2val[string] - else: - raise ValueError(f"Expected one of {set(str2val.keys())}, got {string}") - - parser = argparse.ArgumentParser(formatter_class=argparse.ArgumentDefaultsHelpFormatter) - - parser.add_argument("-f","--audio_files", nargs="+", type=str, - help="List of audio files to transcribe.") - - parser.add_argument('--start_server', action='store_true', - help='Start the Gradio app.') - - parser.add_argument("--whisper_model_name", default="medium", - help="Name of the Whisper model to use.") - - parser.add_argument("--whisper_model_directory", type=str, default=WHISPER_DEFAULT_PATH, - help="Path to save Whisper model files; defaults to ./models/whisper.") - - parser.add_argument("--diarization_directory", type=str, default=PYANNOTE_DEFAULT_PATH, - help="Path to the diarization model directory.") - - parser.add_argument("--huggingface_token", default="", type=str, - help="HuggingFace token for private model download.") - - parser.add_argument("--allow_download", type=str2bool, default=False, - help="Allow model download if not found locally.") - - parser.add_argument("--inference_device", - default="cuda" if torch.cuda.is_available() else "cpu", - help="Device to use for PyTorch inference.") - - parser.add_argument("--num_threads", type=int, default=0, - help="Number of threads used by torch for CPU inference; overrides MKL_NUM_THREADS/OMP_NUM_THREADS.") - - parser.add_argument("--output_directory", "-o", type=str, default=".", - help="Directory to save the transcription outputs.") - - parser.add_argument("--output_format", "-of", type=str, default="txt", - choices=["txt", "json", "md", "html"], - help="Format of the output file; defaults to txt.") - - parser.add_argument("--verbose_output", type=str2bool, default=True, - help="Enable or disable progress and debug messages.") - - parser.add_argument("--transcription_task", type=str, default="transcribe", - choices=["transcribe", "diarize", "wtranscribe"], - help="Choose to perform transcription, diarization, or Whisper transcription.") - - parser.add_argument("--spoken_language", type=str, default=None, - choices=sorted(LANGUAGES.keys()) + sorted([k.title() for k in TO_LANGUAGE_CODE.keys()]), - help="Language spoken in the audio. Specify None to perform language detection.") - - args = parser.parse_args() - - output_directory = args.output_directory - num_threads = args.num_threads - whisper_model_directory = args.whisper_model_directory - allow_download = args.allow_download - inference_device = args.inference_device - whisper_model_name = args.whisper_model_name - diarization_directory = args.diarization_directory - huggingface_token = args.huggingface_token - transcription_task = args.transcription_task - audio_files = args.audio_files - spoken_language = args.spoken_language - output_format = args.output_format - start_server = args.start_server - - os.makedirs(output_directory, exist_ok=True) - - if num_threads > 0: - torch.set_num_threads(num_threads) - - whisper_kwargs = { - "download_root": whisper_model_directory, - "local": allow_download, - "device": inference_device - } - - diarisation_kwargs = { - "local": allow_download, - "token": huggingface_token - } - - model = AutoTranscribe(whisper_model=whisper_model_name, - whisper_kwargs=whisper_kwargs, - dia_model=diarization_directory, - dia_kwargs=diarisation_kwargs) - - if transcription_task == "transcribe": - for audio in audio_files: - out = model.transcribe(audio, language=spoken_language) - basename = audio.split("/")[-1].split(".")[0] - spath = f"{output_directory}/{basename}.{output_format}" - out.save(spath) - - # ... include other tasks here ... - elif transcription_task == "diarize": - # diarize code here - pass - elif transcription_task == "wtranscribe": - # wtranscribe code here - pass - - if start_server: - from .gradio_app import gradio_app - gradio_app(model) - -if __name__ == "__main__": - cli() \ No newline at end of file + \ No newline at end of file