169 lines
7.0 KiB
Python
169 lines
7.0 KiB
Python
"""
|
|
Command-Line Interface (CLI) for the AutoTranscribe class,
|
|
allowing for user interaction to transcribe and diarize audio files.
|
|
The function includes arguments for specifying the audio files, model paths,
|
|
output formats, and other options necessary for transcription.
|
|
"""
|
|
import os
|
|
from argparse import ArgumentParser, ArgumentDefaultsHelpFormatter
|
|
import json
|
|
|
|
from .autotranscript import AutoTranscribe
|
|
|
|
from whisper.tokenizer import LANGUAGES , TO_LANGUAGE_CODE
|
|
from torch.cuda import is_available
|
|
from torch import set_num_threads
|
|
|
|
|
|
def cli():
|
|
"""
|
|
Command-Line Interface (CLI) for the AutoTranscribe class, allowing for user interaction to transcribe
|
|
and diarize audio files. The function includes arguments for specifying the audio files, model paths,
|
|
output formats, and other options necessary for transcription.
|
|
|
|
This function can be executed from the command line to perform transcription tasks, providing a
|
|
user-friendly way to access the AutoTranscribe class functionalities.
|
|
"""
|
|
|
|
def str2bool(string):
|
|
str2val = {"True": True, "False": False}
|
|
if string in str2val:
|
|
return str2val[string]
|
|
else:
|
|
raise ValueError(f"Expected one of {set(str2val.keys())}, got {string}")
|
|
|
|
parser = ArgumentParser(formatter_class = ArgumentDefaultsHelpFormatter)
|
|
|
|
group = parser.add_mutually_exclusive_group()
|
|
|
|
parser.add_argument("-f","--audio_files", nargs="+", type=str, default=None,
|
|
help="List of audio files to transcribe.")
|
|
|
|
group.add_argument('--start_server', action='store_true',
|
|
help='Start the Gradio app.')
|
|
|
|
parser.add_argument("--port", type=int, default= None,
|
|
help="Port to run the Gradio app on. Defaults to 7860.")
|
|
|
|
parser.add_argument("--server_name", type=str, default= None,
|
|
help="Name of the Gradio app. If empty 127.0.0.1 or 0.0.0.0 will be used.")
|
|
|
|
parser.add_argument("--whisper_model_name", default="medium",
|
|
help="Name of the Whisper model to use.")
|
|
|
|
parser.add_argument("--whisper_model_directory", type=str, default= None,
|
|
help="Path to save Whisper model files; defaults to ./models/whisper.")
|
|
|
|
parser.add_argument("--diarization_directory", type=str, default= None,
|
|
help="Path to the diarization model directory.")
|
|
|
|
parser.add_argument("--huggingface_token", default= None, type=str,
|
|
help="HuggingFace token for private model download.")
|
|
|
|
parser.add_argument("--allow_download", type=str2bool, default=True,
|
|
help="Allow model download if not found locally.")
|
|
|
|
parser.add_argument("--inference_device",
|
|
default="cuda" if is_available() else "cpu",
|
|
help="Device to use for PyTorch inference.")
|
|
|
|
parser.add_argument("--num_threads", type=int, default=0,
|
|
help="Number of threads used by torch for CPU inference; overrides MKL_NUM_THREADS/OMP_NUM_THREADS.")
|
|
|
|
parser.add_argument("--output_directory", "-o", type=str, default=".",
|
|
help="Directory to save the transcription outputs.")
|
|
|
|
parser.add_argument("--output_format", "-of", type=str, default="txt",
|
|
choices=["txt", "json", "md", "html"],
|
|
help="Format of the output file; defaults to txt.")
|
|
|
|
parser.add_argument("--verbose_output", type=str2bool, default=True,
|
|
help="Enable or disable progress and debug messages.")
|
|
|
|
parser.add_argument("--task", type=str, default= 'autotranscribe', # unifinished code
|
|
choices=["autotranscribe", "diarization",
|
|
"autotranscribe+translate", "translate", 'transcribe'],
|
|
help="Choose to perform transcription, diarization, or translation. \
|
|
If set to translate, the output will be translated to English.")
|
|
|
|
parser.add_argument("--language", type=str, default=None,
|
|
choices=sorted(LANGUAGES.keys()) + sorted([k.title() for k in TO_LANGUAGE_CODE.keys()]),
|
|
help="Language spoken in the audio. Specify None to perform language detection.")
|
|
|
|
args = parser.parse_args()
|
|
|
|
arg_dict = vars(args)
|
|
|
|
# configure output
|
|
out_folder = arg_dict.pop("output_directory")
|
|
os.makedirs(out_folder, exist_ok=True)
|
|
|
|
out_format = arg_dict.pop("output_format")
|
|
|
|
# seup server arg:
|
|
start_server = arg_dict.pop("start_server")
|
|
|
|
task = arg_dict.pop("task")
|
|
|
|
if args.num_threads > 0:
|
|
set_num_threads(arg_dict.pop("num_threads"))
|
|
|
|
class_kwargs = dict()
|
|
|
|
for k, v in arg_dict.items():
|
|
if v is not None:
|
|
class_kwargs[k] = v
|
|
|
|
|
|
model = AutoTranscribe(**class_kwargs)
|
|
|
|
|
|
if arg_dict["audio_files"]:
|
|
audio_files = arg_dict.pop("audio_files")
|
|
|
|
if task == "autotranscribe" or task == "autotranscribe+translate":
|
|
for audio in audio_files:
|
|
if task == "autotranscribe+translate":
|
|
task = "translate"
|
|
else:
|
|
task = "transcribe"
|
|
|
|
out = model.autotranscribe(audio,task = task, language=arg_dict.pop("language"), verbose = arg_dict.pop("verbose_output"))
|
|
basename = audio.split("/")[-1].split(".")[0]
|
|
print(f'Saving {basename}.{out_format} to {out_folder}')
|
|
out.save(os.path.join(out_folder, f"{basename}.{out_format}"))
|
|
|
|
elif task == "diarization":
|
|
for audio in audio_files:
|
|
if arg_dict.pop("verbose_output"):
|
|
print(f"Verbose not implemented for diarization.")
|
|
|
|
out = model.diarization(audio)
|
|
basename = audio.split("/")[-1].split(".")[0]
|
|
path = os.path.join(out_folder, f"{basename}.{out_format}")
|
|
|
|
print(f'Saving {basename}.{out_format} to {out_folder}')
|
|
|
|
with open(path, "w") as f:
|
|
json.dump(json.dumps(out, indent= 1), f)
|
|
|
|
elif task == "transcribe" or task == "translate":
|
|
|
|
for audio in audio_files:
|
|
|
|
out = model.transcribe(audio, task = task,
|
|
language= arg_dict.pop("language"),
|
|
verbose = arg_dict.pop("verbose_output"))
|
|
basename = audio.split("/")[-1].split(".")[0]
|
|
path = os.path.join(out_folder, f"{basename}.{out_format}")
|
|
with open(path, "w") as f:
|
|
f.write(out)
|
|
|
|
|
|
if start_server: # unfinished code
|
|
from .app.gradio_app import gradio_Interface
|
|
gradio_Interface(model).queue().launch(server_port=args.port, server_name=args.server_name)
|
|
|
|
|
|
if __name__ == "__main__":
|
|
cli() |