changed function name and added addional function for easier use

This commit is contained in:
Jaikinator
2023-08-28 15:16:53 +02:00
parent b2f332a4d2
commit f162b480d3
+52 -132
View File
@@ -24,9 +24,9 @@ Usage:
""" """
# Standard Library Imports # Standard Library Imports
import argparse
import os import os
from glob import iglob from glob import iglob
import re
from subprocess import run from subprocess import run
from typing import TypeVar, Union from typing import TypeVar, Union
from warnings import warn from warnings import warn
@@ -93,7 +93,7 @@ class AutoTranscribe:
print("AutoTranscribe initialized all models successfully loaded.") print("AutoTranscribe initialized all models successfully loaded.")
def transcribe(self, audio_file : Union[str, torch.Tensor, ndarray], def autotranscribe(self, audio_file : Union[str, torch.Tensor, ndarray],
remove_original : bool = False, remove_original : bool = False,
**kwargs) -> Transcript: **kwargs) -> Transcript:
""" """
@@ -164,6 +164,55 @@ class AutoTranscribe:
return Transcript(final_transcript) return Transcript(final_transcript)
def diarization(self, audio_file : Union[str, torch.Tensor, ndarray],
**kwargs) -> dict:
"""
Perform diarization on an audio file using the pyannote diarization model.
Args:
audio_file (Union[str, torch.Tensor, ndarray]):
The audio source which can either be a path to the audio file or a tensor representation.
**kwargs:
Additional keyword arguments for diarization.
Returns:
dict:
A dictionary containing the results of the diarization process.
"""
# Get audio file as an AudioProcessor object
audio_file = self.get_audio_file(audio_file)
# Prepare waveform and sample rate for diarization
dia_audio = {
"waveform" : audio_file.waveform.reshape(1,len(audio_file.waveform)),
"sample_rate": audio_file.sr
}
print("Starting diarisation.")
diarisation = self.diariser.diarization(dia_audio, **kwargs)
return diarisation
def transcribe(self, audio_file : Union[str, torch.Tensor, ndarray],
**kwargs):
"""
Transcribe the provided audio file.
Args:
audio_file (Union[str, torch.Tensor, ndarray]):
The audio source, which can either be a path or a tensor representation.
**kwargs:
Additional keyword arguments for transcription.
Returns:
str:
The transcribed text from the audio source.
"""
audio_file = self.get_audio_file(audio_file)
return self.transcriber.transcribe(audio_file.waveform, **kwargs)
@staticmethod @staticmethod
def remove_audio_file(audio_file : str, def remove_audio_file(audio_file : str,
shred : bool = False) -> None: shred : bool = False) -> None:
@@ -228,133 +277,4 @@ class AutoTranscribe:
raise ValueError(f'Audiofile must be of type AudioProcessor,' \ raise ValueError(f'Audiofile must be of type AudioProcessor,' \
f'not {type(audio_file)}') f'not {type(audio_file)}')
return audio_file return audio_file
def cli():
"""
Command-Line Interface (CLI) for the AutoTranscribe class, allowing for user interaction to transcribe
and diarize audio files. The function includes arguments for specifying the audio files, model paths,
output formats, and other options necessary for transcription.
This function can be executed from the command line to perform transcription tasks, providing a
user-friendly way to access the AutoTranscribe class functionalities.
"""
from whisper import available_models
from whisper.utils import get_writer
from whisper.tokenizer import LANGUAGES , TO_LANGUAGE_CODE
from .transcriber import WHISPER_DEFAULT_PATH
from .diarisation import PYANNOTE_DEFAULT_PATH
def str2bool(string):
str2val = {"True": True, "False": False}
if string in str2val:
return str2val[string]
else:
raise ValueError(f"Expected one of {set(str2val.keys())}, got {string}")
parser = argparse.ArgumentParser(formatter_class=argparse.ArgumentDefaultsHelpFormatter)
parser.add_argument("-f","--audio_files", nargs="+", type=str,
help="List of audio files to transcribe.")
parser.add_argument('--start_server', action='store_true',
help='Start the Gradio app.')
parser.add_argument("--whisper_model_name", default="medium",
help="Name of the Whisper model to use.")
parser.add_argument("--whisper_model_directory", type=str, default=WHISPER_DEFAULT_PATH,
help="Path to save Whisper model files; defaults to ./models/whisper.")
parser.add_argument("--diarization_directory", type=str, default=PYANNOTE_DEFAULT_PATH,
help="Path to the diarization model directory.")
parser.add_argument("--huggingface_token", default="", type=str,
help="HuggingFace token for private model download.")
parser.add_argument("--allow_download", type=str2bool, default=False,
help="Allow model download if not found locally.")
parser.add_argument("--inference_device",
default="cuda" if torch.cuda.is_available() else "cpu",
help="Device to use for PyTorch inference.")
parser.add_argument("--num_threads", type=int, default=0,
help="Number of threads used by torch for CPU inference; overrides MKL_NUM_THREADS/OMP_NUM_THREADS.")
parser.add_argument("--output_directory", "-o", type=str, default=".",
help="Directory to save the transcription outputs.")
parser.add_argument("--output_format", "-of", type=str, default="txt",
choices=["txt", "json", "md", "html"],
help="Format of the output file; defaults to txt.")
parser.add_argument("--verbose_output", type=str2bool, default=True,
help="Enable or disable progress and debug messages.")
parser.add_argument("--transcription_task", type=str, default="transcribe",
choices=["transcribe", "diarize", "wtranscribe"],
help="Choose to perform transcription, diarization, or Whisper transcription.")
parser.add_argument("--spoken_language", type=str, default=None,
choices=sorted(LANGUAGES.keys()) + sorted([k.title() for k in TO_LANGUAGE_CODE.keys()]),
help="Language spoken in the audio. Specify None to perform language detection.")
args = parser.parse_args()
output_directory = args.output_directory
num_threads = args.num_threads
whisper_model_directory = args.whisper_model_directory
allow_download = args.allow_download
inference_device = args.inference_device
whisper_model_name = args.whisper_model_name
diarization_directory = args.diarization_directory
huggingface_token = args.huggingface_token
transcription_task = args.transcription_task
audio_files = args.audio_files
spoken_language = args.spoken_language
output_format = args.output_format
start_server = args.start_server
os.makedirs(output_directory, exist_ok=True)
if num_threads > 0:
torch.set_num_threads(num_threads)
whisper_kwargs = {
"download_root": whisper_model_directory,
"local": allow_download,
"device": inference_device
}
diarisation_kwargs = {
"local": allow_download,
"token": huggingface_token
}
model = AutoTranscribe(whisper_model=whisper_model_name,
whisper_kwargs=whisper_kwargs,
dia_model=diarization_directory,
dia_kwargs=diarisation_kwargs)
if transcription_task == "transcribe":
for audio in audio_files:
out = model.transcribe(audio, language=spoken_language)
basename = audio.split("/")[-1].split(".")[0]
spath = f"{output_directory}/{basename}.{output_format}"
out.save(spath)
# ... include other tasks here ...
elif transcription_task == "diarize":
# diarize code here
pass
elif transcription_task == "wtranscribe":
# wtranscribe code here
pass
if start_server:
from .gradio_app import gradio_app
gradio_app(model)
if __name__ == "__main__":
cli()