updated cli

This commit is contained in:
Jaikinator
2023-08-28 17:01:53 +02:00
parent 5be187998e
commit 5937e81e31
+49 -21
View File
@@ -6,7 +6,7 @@ output formats, and other options necessary for transcription.
""" """
import os import os
from argparse import ArgumentParser, ArgumentDefaultsHelpFormatter from argparse import ArgumentParser, ArgumentDefaultsHelpFormatter
from turtle import st import json
from .transcriber import WHISPER_DEFAULT_PATH from .transcriber import WHISPER_DEFAULT_PATH
from .diarisation import PYANNOTE_DEFAULT_PATH from .diarisation import PYANNOTE_DEFAULT_PATH
@@ -85,7 +85,7 @@ def cli():
help="Enable or disable progress and debug messages.") help="Enable or disable progress and debug messages.")
parser.add_argument("--task", type=str, default= None, # unifinished code parser.add_argument("--task", type=str, default= None, # unifinished code
choices=["autoranscribe", "diarize", "autotranscribe+translate", "translate"], choices=["autotranscribe", "diarization", "autotranscribe+translate", "translate"],
help="Choose to perform transcription, diarization, or translation. \ help="Choose to perform transcription, diarization, or translation. \
If set to translate, the language argument must be specified.") If set to translate, the language argument must be specified.")
@@ -98,14 +98,15 @@ def cli():
arg_dict = vars(args) arg_dict = vars(args)
# configure output # configure output
out_folder = arg_dict.pop("output_directory")
os.makedirs(arg_dict.pop("output_directory"), exist_ok=True) os.makedirs(out_folder, exist_ok=True)
out_format = arg_dict.pop("output_format") out_format = arg_dict.pop("output_format")
# seup server arg: # seup server arg:
start_server = arg_dict.pop("start_server") start_server = arg_dict.pop("start_server")
task = arg_dict.pop("task")
if args.num_threads > 0: if args.num_threads > 0:
set_num_threads(arg_dict.pop("num_threads")) set_num_threads(arg_dict.pop("num_threads"))
@@ -117,27 +118,54 @@ def cli():
class_kwargs[k] = v class_kwargs[k] = v
model = AutoTranscribe(**class_kwargs) model = AutoTranscribe(**class_kwargs)
# if transcription_task == "transcribe": if arg_dict["audio_files"]:
# for audio in audio_files: audio_files = args.pop("audio_files")
# out = model.transcribe(audio, language=spoken_language)
# basename = audio.split("/")[-1].split(".")[0]
# spath = f"{output_directory}/{basename}.{output_format}"
# out.save(spath)
# # ... include other tasks here ... if task == "autotranscribe" or task == "autotranscribe+translate":
# elif transcription_task == "diarize": for audio in audio_files:
# # diarize code here if task == "autotranscribe+translate":
# pass task = "translate"
# elif transcription_task == "wtranscribe": else:
# # wtranscribe code here task = "transcribe"
# pass
# if start_server: # unfinished code out = model.autotranscribe(audio,task = task, language=arg_dict.pop("language"), verbose = arg_dict.pop("verbose_output"))
# from .gradio_app import gradio_app basename = audio.split("/")[-1].split(".")[0]
# gradio_app(model) out.save(os.path.join(out_folder, f"{basename}.{out_format}"))
elif task == "diarization":
for audio in audio_files:
if arg_dict.pop("verbose_output"):
print(f"Verbose not implemented for diarization.")
out = model.diarization(audio)
basename = audio.split("/")[-1].split(".")[0]
path = os.path.join(out_folder, f"{basename}.{out_format}")
if out_format == "txt":
with open(path, "w") as f:
f.write(out)
elif out_format == "json":
with open(path, "w") as f:
json.dump(json.dumps(out, indent= 3), f)
else:
raise ValueError(f"Unsupported output format for diarization{out_format}.")
elif task == "transcribe" or task == "translate":
for audio in audio_files:
out = model.transcribe(audio, task = task,
language=arg_dict.pop("language"),
verbose = arg_dict.pop("verbose_output"))
basename = audio.split("/")[-1].split(".")[0]
path = os.path.join(out_folder, f"{basename}.{out_format}")
with open(path, "w") as f:
f.write(out)
if start_server: # unfinished code
from .gradio_app import gradio_app
gradio_app(model)
if __name__ == "__main__": if __name__ == "__main__":
cli() cli()