changed function name and added addional function for easier use
This commit is contained in:
@@ -24,9 +24,9 @@ Usage:
|
||||
"""
|
||||
|
||||
# Standard Library Imports
|
||||
import argparse
|
||||
import os
|
||||
from glob import iglob
|
||||
import re
|
||||
from subprocess import run
|
||||
from typing import TypeVar, Union
|
||||
from warnings import warn
|
||||
@@ -93,7 +93,7 @@ class AutoTranscribe:
|
||||
|
||||
print("AutoTranscribe initialized all models successfully loaded.")
|
||||
|
||||
def transcribe(self, audio_file : Union[str, torch.Tensor, ndarray],
|
||||
def autotranscribe(self, audio_file : Union[str, torch.Tensor, ndarray],
|
||||
remove_original : bool = False,
|
||||
**kwargs) -> Transcript:
|
||||
"""
|
||||
@@ -164,6 +164,55 @@ class AutoTranscribe:
|
||||
|
||||
return Transcript(final_transcript)
|
||||
|
||||
def diarization(self, audio_file : Union[str, torch.Tensor, ndarray],
|
||||
**kwargs) -> dict:
|
||||
"""
|
||||
Perform diarization on an audio file using the pyannote diarization model.
|
||||
|
||||
Args:
|
||||
audio_file (Union[str, torch.Tensor, ndarray]):
|
||||
The audio source which can either be a path to the audio file or a tensor representation.
|
||||
**kwargs:
|
||||
Additional keyword arguments for diarization.
|
||||
|
||||
Returns:
|
||||
dict:
|
||||
A dictionary containing the results of the diarization process.
|
||||
"""
|
||||
|
||||
# Get audio file as an AudioProcessor object
|
||||
audio_file = self.get_audio_file(audio_file)
|
||||
|
||||
# Prepare waveform and sample rate for diarization
|
||||
dia_audio = {
|
||||
"waveform" : audio_file.waveform.reshape(1,len(audio_file.waveform)),
|
||||
"sample_rate": audio_file.sr
|
||||
}
|
||||
|
||||
print("Starting diarisation.")
|
||||
|
||||
diarisation = self.diariser.diarization(dia_audio, **kwargs)
|
||||
|
||||
return diarisation
|
||||
|
||||
def transcribe(self, audio_file : Union[str, torch.Tensor, ndarray],
|
||||
**kwargs):
|
||||
"""
|
||||
Transcribe the provided audio file.
|
||||
|
||||
Args:
|
||||
audio_file (Union[str, torch.Tensor, ndarray]):
|
||||
The audio source, which can either be a path or a tensor representation.
|
||||
**kwargs:
|
||||
Additional keyword arguments for transcription.
|
||||
|
||||
Returns:
|
||||
str:
|
||||
The transcribed text from the audio source.
|
||||
"""
|
||||
audio_file = self.get_audio_file(audio_file)
|
||||
|
||||
return self.transcriber.transcribe(audio_file.waveform, **kwargs)
|
||||
@staticmethod
|
||||
def remove_audio_file(audio_file : str,
|
||||
shred : bool = False) -> None:
|
||||
@@ -229,132 +278,3 @@ class AutoTranscribe:
|
||||
f'not {type(audio_file)}')
|
||||
return audio_file
|
||||
|
||||
|
||||
def cli():
|
||||
"""
|
||||
Command-Line Interface (CLI) for the AutoTranscribe class, allowing for user interaction to transcribe
|
||||
and diarize audio files. The function includes arguments for specifying the audio files, model paths,
|
||||
output formats, and other options necessary for transcription.
|
||||
|
||||
This function can be executed from the command line to perform transcription tasks, providing a
|
||||
user-friendly way to access the AutoTranscribe class functionalities.
|
||||
"""
|
||||
from whisper import available_models
|
||||
from whisper.utils import get_writer
|
||||
from whisper.tokenizer import LANGUAGES , TO_LANGUAGE_CODE
|
||||
from .transcriber import WHISPER_DEFAULT_PATH
|
||||
from .diarisation import PYANNOTE_DEFAULT_PATH
|
||||
|
||||
def str2bool(string):
|
||||
str2val = {"True": True, "False": False}
|
||||
if string in str2val:
|
||||
return str2val[string]
|
||||
else:
|
||||
raise ValueError(f"Expected one of {set(str2val.keys())}, got {string}")
|
||||
|
||||
parser = argparse.ArgumentParser(formatter_class=argparse.ArgumentDefaultsHelpFormatter)
|
||||
|
||||
parser.add_argument("-f","--audio_files", nargs="+", type=str,
|
||||
help="List of audio files to transcribe.")
|
||||
|
||||
parser.add_argument('--start_server', action='store_true',
|
||||
help='Start the Gradio app.')
|
||||
|
||||
parser.add_argument("--whisper_model_name", default="medium",
|
||||
help="Name of the Whisper model to use.")
|
||||
|
||||
parser.add_argument("--whisper_model_directory", type=str, default=WHISPER_DEFAULT_PATH,
|
||||
help="Path to save Whisper model files; defaults to ./models/whisper.")
|
||||
|
||||
parser.add_argument("--diarization_directory", type=str, default=PYANNOTE_DEFAULT_PATH,
|
||||
help="Path to the diarization model directory.")
|
||||
|
||||
parser.add_argument("--huggingface_token", default="", type=str,
|
||||
help="HuggingFace token for private model download.")
|
||||
|
||||
parser.add_argument("--allow_download", type=str2bool, default=False,
|
||||
help="Allow model download if not found locally.")
|
||||
|
||||
parser.add_argument("--inference_device",
|
||||
default="cuda" if torch.cuda.is_available() else "cpu",
|
||||
help="Device to use for PyTorch inference.")
|
||||
|
||||
parser.add_argument("--num_threads", type=int, default=0,
|
||||
help="Number of threads used by torch for CPU inference; overrides MKL_NUM_THREADS/OMP_NUM_THREADS.")
|
||||
|
||||
parser.add_argument("--output_directory", "-o", type=str, default=".",
|
||||
help="Directory to save the transcription outputs.")
|
||||
|
||||
parser.add_argument("--output_format", "-of", type=str, default="txt",
|
||||
choices=["txt", "json", "md", "html"],
|
||||
help="Format of the output file; defaults to txt.")
|
||||
|
||||
parser.add_argument("--verbose_output", type=str2bool, default=True,
|
||||
help="Enable or disable progress and debug messages.")
|
||||
|
||||
parser.add_argument("--transcription_task", type=str, default="transcribe",
|
||||
choices=["transcribe", "diarize", "wtranscribe"],
|
||||
help="Choose to perform transcription, diarization, or Whisper transcription.")
|
||||
|
||||
parser.add_argument("--spoken_language", type=str, default=None,
|
||||
choices=sorted(LANGUAGES.keys()) + sorted([k.title() for k in TO_LANGUAGE_CODE.keys()]),
|
||||
help="Language spoken in the audio. Specify None to perform language detection.")
|
||||
|
||||
args = parser.parse_args()
|
||||
|
||||
output_directory = args.output_directory
|
||||
num_threads = args.num_threads
|
||||
whisper_model_directory = args.whisper_model_directory
|
||||
allow_download = args.allow_download
|
||||
inference_device = args.inference_device
|
||||
whisper_model_name = args.whisper_model_name
|
||||
diarization_directory = args.diarization_directory
|
||||
huggingface_token = args.huggingface_token
|
||||
transcription_task = args.transcription_task
|
||||
audio_files = args.audio_files
|
||||
spoken_language = args.spoken_language
|
||||
output_format = args.output_format
|
||||
start_server = args.start_server
|
||||
|
||||
os.makedirs(output_directory, exist_ok=True)
|
||||
|
||||
if num_threads > 0:
|
||||
torch.set_num_threads(num_threads)
|
||||
|
||||
whisper_kwargs = {
|
||||
"download_root": whisper_model_directory,
|
||||
"local": allow_download,
|
||||
"device": inference_device
|
||||
}
|
||||
|
||||
diarisation_kwargs = {
|
||||
"local": allow_download,
|
||||
"token": huggingface_token
|
||||
}
|
||||
|
||||
model = AutoTranscribe(whisper_model=whisper_model_name,
|
||||
whisper_kwargs=whisper_kwargs,
|
||||
dia_model=diarization_directory,
|
||||
dia_kwargs=diarisation_kwargs)
|
||||
|
||||
if transcription_task == "transcribe":
|
||||
for audio in audio_files:
|
||||
out = model.transcribe(audio, language=spoken_language)
|
||||
basename = audio.split("/")[-1].split(".")[0]
|
||||
spath = f"{output_directory}/{basename}.{output_format}"
|
||||
out.save(spath)
|
||||
|
||||
# ... include other tasks here ...
|
||||
elif transcription_task == "diarize":
|
||||
# diarize code here
|
||||
pass
|
||||
elif transcription_task == "wtranscribe":
|
||||
# wtranscribe code here
|
||||
pass
|
||||
|
||||
if start_server:
|
||||
from .gradio_app import gradio_app
|
||||
gradio_app(model)
|
||||
|
||||
if __name__ == "__main__":
|
||||
cli()
|
||||
Reference in New Issue
Block a user