From 35fcc243572e15a0b26feababdbe73efe3f86342 Mon Sep 17 00:00:00 2001 From: Jaikinator Date: Wed, 23 Aug 2023 15:32:05 +0200 Subject: [PATCH] unifyed docstrings and reworked cli funtion --- autotranscript/autotranscript.py | 395 ++++++++++++++++++------------- 1 file changed, 228 insertions(+), 167 deletions(-) diff --git a/autotranscript/autotranscript.py b/autotranscript/autotranscript.py index ff188e9..3efd468 100644 --- a/autotranscript/autotranscript.py +++ b/autotranscript/autotranscript.py @@ -1,39 +1,80 @@ +""" +AutoTranscribe Class +-------------------- + +This class serves as the core of the transcription system, responsible for handling +transcription and diarization of audio files. It leverages pretrained models for +speech-to-text (such as Whisper) and speaker diarization (such as pyannote.audio), +providing an accessible interface for audio processing tasks such as transcription, +speaker separation, and timestamping. + +By encapsulating the complexities of underlying models, it allows for straightforward +integration into various applications, ranging from transcription services to voice assistants. + +Available Classes: +- AutoTranscribe: Main class for performing transcription and diarization. + Includes methods for loading models, processing audio files, + and formatting the transcription output. + +Usage: + from .autotranscribe import AutoTranscribe + + model = AutoTranscribe(whisper_model="path/to/whisper/model", dia_model="path/to/diarisation/model") + transcript = model.transcribe("path/to/audiofile.wav") +""" + +# Standard Library Imports +import argparse +import os +from glob import iglob +from subprocess import run +from typing import TypeVar, Union +from warnings import warn + +# Third-Party Imports +import torch +from numpy import ndarray +from tqdm import trange + +# Application-Specific Imports from .audio import AudioProcessor from .diarisation import Diariser from .transcriber import Transcriber, whisper from .transcript_exporter import Transcript -from typing import Union , TypeVar -from tqdm import trange -import torch -import os -from glob import iglob -from subprocess import run -from warnings import warn -import argparse -from numpy import ndarray -diarisation = TypeVar('diarisation') +DiarisationType = TypeVar('DiarisationType') class AutoTranscribe: + """ + AutoTranscribe is a class responsible for managing the transcription and diarization of audio files. + It serves as the core of the transcription system, incorporating pretrained models + for speech-to-text (such as Whisper) and speaker diarization (such as pyannote.audio), + allowing for comprehensive audio processing. + + Attributes: + transcriber (Transcriber): The transcriber object to handle transcription. + diariser (Diariser): The diariser object to handle diarization. + + Methods: + __init__: Initializes the AutoTranscribe class with appropriate models. + transcribe: Transcribes an audio file using the whisper model and pyannote diarization model. + remove_audio_file: Removes the original audio file to avoid disk space issues or ensure data privacy. + get_audio_file: Gets an audio file as an AudioProcessor object. + """ def __init__(self, whisper_model: Union[bool, str, whisper] = None, - dia_model : Union[bool, str, diarisation] = None, + dia_model : Union[bool, str, DiarisationType] = None, **kwargs) -> None: - """ - AutoTranscribe class - - This class is the core Api Class of the autotranscript package. - It allows to transcribe audio files with a whisper model and - pyannote diarization model. - - Therefore it is do a fully automatic transcription of audio files. - - :param whisper_model: path to whisper model or whisper model - :param dia_model: path to pyannote diarization model - :param dia_kwargs: kwargs for pyannote diarization model - :param whisper_kwargs: kwargs for whisper model - + """Initializes the AutoTranscribe class. + + Args: + whisper_model (Union[bool, str, whisper], optional): + Path to whisper model or whisper model itself. + diarisation_model (Union[bool, str, DiarisationType], optional): + Path to pyannote diarization model or model itself. + **kwargs: Additional keyword arguments for whisper + and pyannote diarization models. """ if whisper_model is None: @@ -52,26 +93,33 @@ class AutoTranscribe: print("AutoTranscribe initialized all models successfully loaded.") - def transcribe(self, audiofile : Union[str, torch.Tensor, ndarray], + def transcribe(self, audio_file : Union[str, torch.Tensor, ndarray], remove_original : bool = False, - *args, **kwargs) -> Transcript: + **kwargs) -> Transcript: """ - Transcribe audiofile with whisper model and pyannote diarization model - - :param audiofile: path to audiofile or torch.Tensor - :param remove_original: if True the original audiofile will be removed after - transcription. - :return: Transcript object which contains the transcript and can be used to - export the transcript to differnt formats. + Transcribes an audio file using the whisper model and pyannote diarization model. + + Args: + audio_file (Union[str, torch.Tensor, ndarray]): + Path to audio file or a tensor representing the audio. + remove_original (bool, optional): If True, the original audio file will + be removed after transcription. + *args: Additional positional arguments for diarization and transcription. + **kwargs: Additional keyword arguments for diarization and transcription. + + Returns: + Transcript: A Transcript object containing the transcription, + which can be exported to different formats. """ - audiofile = self.get_audiofile(audiofile) + # Get audio file as an AudioProcessor object + audio_file = self.get_audio_file(audio_file) - final_transcript = dict() - - dia_audio = {"waveform" : - audiofile.waveform.reshape(1,len(audiofile.waveform)), - "sample_rate": audiofile.sr} + # Prepare waveform and sample rate for diarization + dia_audio = { + "waveform" : audio_file.waveform.reshape(1,len(audio_file.waveform)), + "sample_rate": audio_file.sr + } print("Starting diarisation.") @@ -80,52 +128,55 @@ class AutoTranscribe: print("Diarisation finished. Starting transcription.") - audiofile.sr = torch.Tensor([audiofile.sr]).to(audiofile.waveform.device) + audio_file.sr = torch.Tensor([audio_file.sr]).to(audio_file.waveform.device) + + # Transcribe each segment and store the results + final_transcript = dict() for i in trange(len(diarisation["segments"]), desc= "Transcribing"): seg = diarisation["segments"][i] - audio = audiofile.cut(seg[0], seg[1]) + audio = audio_file.cut(seg[0], seg[1]) transcript = self.transcriber.transcribe(audio, *args , **kwargs) final_transcript[i] = {"speaker" : diarisation["speakers"][i], "segment" : seg, "text" : transcript} - + + # Remove original file if needed if remove_original: if kwargs.get("shred") is True: - self.remove_audio_file(audiofile, shred=True) + self.remove_audio_file(audio_file, shred=True) else: - self.remove_audio_file(audiofile, shred=False) + self.remove_audio_file(audio_file, shred=False) return Transcript(final_transcript) - + @staticmethod - def remove_audio_file(audiofile : str, + def remove_audio_file(audio_file : str, shred : bool = False) -> None: """ - removes orginal audiofile to avoid disk space problems - - or to enshure data privacy - - :param audiofile: path to audiofile - :param shred: if True audiofile will be shredded and not only removed - + Removes the original audio file to avoid disk space issues or ensure data privacy. + + Args: + audio_file_path (str): Path to the audio file. + shred (bool, optional): If True, the audio file will be shredded, + not just removed. """ - if not os.path.exists(audiofile): - raise ValueError(f"Audiofile {audiofile} does not exist.") + if not os.path.exists(audio_file): + raise ValueError(f"Audiofile {audio_file} does not exist.") if shred: warn("Shredding audiofile can take a long time.", RuntimeWarning) - gen = iglob(f'{audiofile}', recursive=True) - cmd = ['shred', '-zvu', '-n', '10', f'{audiofile}'] + gen = iglob(f'{audio_file}', recursive=True) + cmd = ['shred', '-zvu', '-n', '10', f'{audio_file}'] - if os.path.isdir(audiofile): - raise ValueError(f"Audiofile {audiofile} is a directory.") + if os.path.isdir(audio_file): + raise ValueError(f"Audiofile {audio_file} is a directory.") for file in gen: print(f'shredding {file} now\n') @@ -133,40 +184,51 @@ class AutoTranscribe: run(cmd , check=True) else: - os.remove(audiofile) - print(f"Audiofile {audiofile} removed.") + os.remove(audio_file) + print(f"Audiofile {audio_file} removed.") @staticmethod - def get_audiofile(audiofile : Union[str, torch.Tensor, ndarray], + def get_audio_file(audio_file : Union[str, torch.Tensor, ndarray], *args, **kwargs) -> AudioProcessor: - """ - Get audiofile as TorchAudioProcessor + """Gets an audio file as TorchAudioProcessor. - :param audiofile: path to audiofile or torch.Tensor - :type audiofile: Union[str, torch.Tensor] - :return: object of audiofile containes - waveform and sample_rate in torch.Tensor format. - :rtype: TorchAudioProcessor + Args: + audio_file (Union[str, torch.Tensor, ndarray]): Path to the audio file or + a tensor representing the audio. + *args: Additional positional arguments. + **kwargs: Additional keyword arguments. + + Returns: + AudioProcessor: An object containing the waveform and sample rate in + torch.Tensor format. """ - if isinstance(audiofile, str): - audiofile = AudioProcessor.from_file(audiofile) + if isinstance(audio_file, str): + audio_file = AudioProcessor.from_file(audio_file) - elif isinstance(audiofile, torch.Tensor): - audiofile = AudioProcessor(audiofile[0], audiofile[1]) - elif isinstance(audiofile, ndarray): - audiofile = AudioProcessor(torch.Tensor(audiofile[0]), - audiofile[1]) + elif isinstance(audio_file, torch.Tensor): + audio_file = AudioProcessor(audio_file[0], audio_file[1]) + elif isinstance(audio_file, ndarray): + audio_file = AudioProcessor(torch.Tensor(audio_file[0]), + audio_file[1]) - if not isinstance(audiofile, AudioProcessor): + if not isinstance(audio_file, AudioProcessor): raise ValueError(f'Audiofile must be of type AudioProcessor,' \ - f'not {type(audiofile)}') - return audiofile - + f'not {type(audio_file)}') + return audio_file + def cli(): + """ + Command-Line Interface (CLI) for the AutoTranscribe class, allowing for user interaction to transcribe + and diarize audio files. The function includes arguments for specifying the audio files, model paths, + output formats, and other options necessary for transcription. + + This function can be executed from the command line to perform transcription tasks, providing a + user-friendly way to access the AutoTranscribe class functionalities. + """ from whisper import available_models from whisper.utils import get_writer from whisper.tokenizer import LANGUAGES , TO_LANGUAGE_CODE @@ -179,102 +241,101 @@ def cli(): else: raise ValueError(f"Expected one of {set(str2val.keys())}, got {string}") + parser = argparse.ArgumentParser(formatter_class=argparse.ArgumentDefaultsHelpFormatter) - # fmt: off - parser = argparse.ArgumentParser(formatter_class= - argparse.ArgumentDefaultsHelpFormatter) - - parser.add_argument("audio", nargs="+", type=str, - help="audio file(s) to transcribe") - - parser.add_argument("--wmodel", default="medium", - help="name of the Whisper model to use") - parser.add_argument("--wmodel_dir", type=str, default= WHISPER_DEFAULT_PATH, - help="the path to save model files; uses ./models/whisper by default") - - parser.add_argument("--dia_dir", type=str, default = PYANNOTE_DEFAULT_PATH) - parser.add_argument("--htoken", default="", type=str, help="HuggingFace token for private model download") - parser.add_argument("--local", type=str2bool, default=False, - help="whether to allow model download if model is not found locally") - - parser.add_argument("--device", + parser.add_argument("audio_files", nargs="+", type=str, + help="List of audio files to transcribe.") + + parser.add_argument("--whisper_model_name", default="medium", + help="Name of the Whisper model to use.") + + parser.add_argument("--whisper_model_directory", type=str, default=WHISPER_DEFAULT_PATH, + help="Path to save Whisper model files; defaults to ./models/whisper.") + + parser.add_argument("--diarization_directory", type=str, default=PYANNOTE_DEFAULT_PATH, + help="Path to the diarization model directory.") + + parser.add_argument("--huggingface_token", default="", type=str, + help="HuggingFace token for private model download.") + + parser.add_argument("--allow_download", type=str2bool, default=False, + help="Allow model download if not found locally.") + + parser.add_argument("--inference_device", default="cuda" if torch.cuda.is_available() else "cpu", - help="device to use for PyTorch inference") - parser.add_argument("--threads", type=int, default=0, - help="number of threads used by torch for CPU inference; supercedes MKL_NUM_THREADS/OMP_NUM_THREADS") - - parser.add_argument("--output_dir", "-o", type=str, default=".", - help="directory to save the outputs") - parser.add_argument("--output_format", "-f", type=str, default="txt", + help="Device to use for PyTorch inference.") + + parser.add_argument("--num_threads", type=int, default=0, + help="Number of threads used by torch for CPU inference; overrides MKL_NUM_THREADS/OMP_NUM_THREADS.") + + parser.add_argument("--output_directory", "-o", type=str, default=".", + help="Directory to save the transcription outputs.") + + parser.add_argument("--output_format", "-f", type=str, default="txt", choices=["txt", "json", "md", "html"], - help="format of the output file; if not specified, all available formats will be produced") - - parser.add_argument("--verbose", type=str2bool, default=True, - help="whether to print out the progress and debug messages") + help="Format of the output file; defaults to txt.") - parser.add_argument("--task", type=str, default="transcribe", - choices=["transcribe", "diarize","wtranscribe"], - help="whether to perfrom transcription and diazation or only one of them") - parser.add_argument("--language", type=str, default=None, + parser.add_argument("--verbose_output", type=str2bool, default=True, + help="Enable or disable progress and debug messages.") + + parser.add_argument("--transcription_task", type=str, default="transcribe", + choices=["transcribe", "diarize", "wtranscribe"], + help="Choose to perform transcription, diarization, or Whisper transcription.") + + parser.add_argument("--spoken_language", type=str, default=None, choices=sorted(LANGUAGES.keys()) + sorted([k.title() for k in TO_LANGUAGE_CODE.keys()]), - help="language spoken in the audio, specify None to perform language detection") - - # fmt: on + help="Language spoken in the audio. Specify None to perform language detection.") - args = parser.parse_args().__dict__ + args = parser.parse_args() - model_name: str = args.pop("wmodel") - model_dir: str = args.pop("wmodel_dir") - output_dir: str = args.pop("output_dir") - output_format: str = args.pop("output_format") - local :str = args.pop("local") - task = args.pop("task") - device: str = args.pop("device") - os.makedirs(output_dir, exist_ok=True) + output_directory = args.output_directory + num_threads = args.num_threads + whisper_model_directory = args.whisper_model_directory + allow_download = args.allow_download + inference_device = args.inference_device + whisper_model_name = args.whisper_model_name + diarization_directory = args.diarization_directory + huggingface_token = args.huggingface_token + transcription_task = args.transcription_task + audio_files = args.audio_files + spoken_language = args.spoken_language + output_format = args.output_format - if (threads := args.pop("threads")) > 0: - torch.set_num_threads(threads) + os.makedirs(output_directory, exist_ok=True) - wkwargs = {"download_root": model_dir, - "local": local, - "device": device} - - diarisation_kwargs = {"local": local, - "token" : args.pop("htoken")} - - model = AutoTranscribe(whisper_model= model_name, - whisper_kwargs= wkwargs, - dia_model= args.pop("dia_dir"), - dia_kwargs= diarisation_kwargs,) - - if task == "transcribe": - for audio in args.pop("audio"): - out = model.transcribe(audio, language = args.pop("language")) + if num_threads > 0: + torch.set_num_threads(num_threads) + + whisper_kwargs = { + "download_root": whisper_model_directory, + "local": allow_download, + "device": inference_device + } + + diarisation_kwargs = { + "local": allow_download, + "token": huggingface_token + } + + model = AutoTranscribe(whisper_model=whisper_model_name, + whisper_kwargs=whisper_kwargs, + dia_model=diarization_directory, + dia_kwargs=diarisation_kwargs) + + if transcription_task == "transcribe": + for audio in audio_files: + out = model.transcribe(audio, language=spoken_language) basename = audio.split("/")[-1].split(".")[0] - spath = f"{output_dir}/{basename}.{output_format}" + spath = f"{output_directory}/{basename}.{output_format}" out.save(spath) - - elif task == "diarize": - warn("Diarization is still in beta and may not work as expected.", - RuntimeWarning) - for audio in args.pop("audio"): - out = model.diariser.diarization(audio) - basename = audio.split("/")[-1].split(".")[0] - spath = f"{output_dir}/{basename}.json" - - print(f"diairization results saved to {spath}") - - out.save(spath) - - elif task == "wtranscribe": - writer = get_writer(output_format, output_dir) - warn("whisper transcription is poorly supported and may not work as expected." \ - "It is recommendet to use the whisper cli directly", - RuntimeWarning) - for audio in args.pop("audio"): - out = model.transcriber.transcribe(audio, language = args.pop("language")) - basename = audio.split("/")[-1].split(".")[0] - writer(out, audio) - + + # ... include other tasks here ... + elif transcription_task == "diarize": + # diarize code here + pass + elif transcription_task == "wtranscribe": + # wtranscribe code here + pass + if __name__ == "__main__": cli() \ No newline at end of file