updated cli
This commit is contained in:
+50
-22
@@ -6,7 +6,7 @@ output formats, and other options necessary for transcription.
|
||||
"""
|
||||
import os
|
||||
from argparse import ArgumentParser, ArgumentDefaultsHelpFormatter
|
||||
from turtle import st
|
||||
import json
|
||||
|
||||
from .transcriber import WHISPER_DEFAULT_PATH
|
||||
from .diarisation import PYANNOTE_DEFAULT_PATH
|
||||
@@ -85,7 +85,7 @@ def cli():
|
||||
help="Enable or disable progress and debug messages.")
|
||||
|
||||
parser.add_argument("--task", type=str, default= None, # unifinished code
|
||||
choices=["autoranscribe", "diarize", "autotranscribe+translate", "translate"],
|
||||
choices=["autotranscribe", "diarization", "autotranscribe+translate", "translate"],
|
||||
help="Choose to perform transcription, diarization, or translation. \
|
||||
If set to translate, the language argument must be specified.")
|
||||
|
||||
@@ -98,14 +98,15 @@ def cli():
|
||||
arg_dict = vars(args)
|
||||
|
||||
# configure output
|
||||
|
||||
os.makedirs(arg_dict.pop("output_directory"), exist_ok=True)
|
||||
out_folder = arg_dict.pop("output_directory")
|
||||
os.makedirs(out_folder, exist_ok=True)
|
||||
|
||||
out_format = arg_dict.pop("output_format")
|
||||
|
||||
# seup server arg:
|
||||
start_server = arg_dict.pop("start_server")
|
||||
|
||||
task = arg_dict.pop("task")
|
||||
|
||||
if args.num_threads > 0:
|
||||
set_num_threads(arg_dict.pop("num_threads"))
|
||||
@@ -115,29 +116,56 @@ def cli():
|
||||
for k, v in arg_dict.items():
|
||||
if v is not None:
|
||||
class_kwargs[k] = v
|
||||
|
||||
|
||||
|
||||
model = AutoTranscribe(**class_kwargs)
|
||||
|
||||
# if transcription_task == "transcribe":
|
||||
# for audio in audio_files:
|
||||
# out = model.transcribe(audio, language=spoken_language)
|
||||
# basename = audio.split("/")[-1].split(".")[0]
|
||||
# spath = f"{output_directory}/{basename}.{output_format}"
|
||||
# out.save(spath)
|
||||
|
||||
# # ... include other tasks here ...
|
||||
# elif transcription_task == "diarize":
|
||||
# # diarize code here
|
||||
# pass
|
||||
# elif transcription_task == "wtranscribe":
|
||||
# # wtranscribe code here
|
||||
# pass
|
||||
if arg_dict["audio_files"]:
|
||||
audio_files = args.pop("audio_files")
|
||||
|
||||
if task == "autotranscribe" or task == "autotranscribe+translate":
|
||||
for audio in audio_files:
|
||||
if task == "autotranscribe+translate":
|
||||
task = "translate"
|
||||
else:
|
||||
task = "transcribe"
|
||||
|
||||
out = model.autotranscribe(audio,task = task, language=arg_dict.pop("language"), verbose = arg_dict.pop("verbose_output"))
|
||||
basename = audio.split("/")[-1].split(".")[0]
|
||||
out.save(os.path.join(out_folder, f"{basename}.{out_format}"))
|
||||
|
||||
elif task == "diarization":
|
||||
for audio in audio_files:
|
||||
if arg_dict.pop("verbose_output"):
|
||||
print(f"Verbose not implemented for diarization.")
|
||||
|
||||
out = model.diarization(audio)
|
||||
basename = audio.split("/")[-1].split(".")[0]
|
||||
path = os.path.join(out_folder, f"{basename}.{out_format}")
|
||||
if out_format == "txt":
|
||||
with open(path, "w") as f:
|
||||
f.write(out)
|
||||
elif out_format == "json":
|
||||
with open(path, "w") as f:
|
||||
json.dump(json.dumps(out, indent= 3), f)
|
||||
else:
|
||||
raise ValueError(f"Unsupported output format for diarization{out_format}.")
|
||||
elif task == "transcribe" or task == "translate":
|
||||
|
||||
for audio in audio_files:
|
||||
|
||||
out = model.transcribe(audio, task = task,
|
||||
language=arg_dict.pop("language"),
|
||||
verbose = arg_dict.pop("verbose_output"))
|
||||
basename = audio.split("/")[-1].split(".")[0]
|
||||
path = os.path.join(out_folder, f"{basename}.{out_format}")
|
||||
with open(path, "w") as f:
|
||||
f.write(out)
|
||||
|
||||
|
||||
# if start_server: # unfinished code
|
||||
# from .gradio_app import gradio_app
|
||||
# gradio_app(model)
|
||||
if start_server: # unfinished code
|
||||
from .gradio_app import gradio_app
|
||||
gradio_app(model)
|
||||
|
||||
if __name__ == "__main__":
|
||||
cli()
|
||||
Reference in New Issue
Block a user