updated cli
This commit is contained in:
+50
-22
@@ -6,7 +6,7 @@ output formats, and other options necessary for transcription.
|
|||||||
"""
|
"""
|
||||||
import os
|
import os
|
||||||
from argparse import ArgumentParser, ArgumentDefaultsHelpFormatter
|
from argparse import ArgumentParser, ArgumentDefaultsHelpFormatter
|
||||||
from turtle import st
|
import json
|
||||||
|
|
||||||
from .transcriber import WHISPER_DEFAULT_PATH
|
from .transcriber import WHISPER_DEFAULT_PATH
|
||||||
from .diarisation import PYANNOTE_DEFAULT_PATH
|
from .diarisation import PYANNOTE_DEFAULT_PATH
|
||||||
@@ -85,7 +85,7 @@ def cli():
|
|||||||
help="Enable or disable progress and debug messages.")
|
help="Enable or disable progress and debug messages.")
|
||||||
|
|
||||||
parser.add_argument("--task", type=str, default= None, # unifinished code
|
parser.add_argument("--task", type=str, default= None, # unifinished code
|
||||||
choices=["autoranscribe", "diarize", "autotranscribe+translate", "translate"],
|
choices=["autotranscribe", "diarization", "autotranscribe+translate", "translate"],
|
||||||
help="Choose to perform transcription, diarization, or translation. \
|
help="Choose to perform transcription, diarization, or translation. \
|
||||||
If set to translate, the language argument must be specified.")
|
If set to translate, the language argument must be specified.")
|
||||||
|
|
||||||
@@ -98,14 +98,15 @@ def cli():
|
|||||||
arg_dict = vars(args)
|
arg_dict = vars(args)
|
||||||
|
|
||||||
# configure output
|
# configure output
|
||||||
|
out_folder = arg_dict.pop("output_directory")
|
||||||
os.makedirs(arg_dict.pop("output_directory"), exist_ok=True)
|
os.makedirs(out_folder, exist_ok=True)
|
||||||
|
|
||||||
out_format = arg_dict.pop("output_format")
|
out_format = arg_dict.pop("output_format")
|
||||||
|
|
||||||
# seup server arg:
|
# seup server arg:
|
||||||
start_server = arg_dict.pop("start_server")
|
start_server = arg_dict.pop("start_server")
|
||||||
|
|
||||||
|
task = arg_dict.pop("task")
|
||||||
|
|
||||||
if args.num_threads > 0:
|
if args.num_threads > 0:
|
||||||
set_num_threads(arg_dict.pop("num_threads"))
|
set_num_threads(arg_dict.pop("num_threads"))
|
||||||
@@ -115,29 +116,56 @@ def cli():
|
|||||||
for k, v in arg_dict.items():
|
for k, v in arg_dict.items():
|
||||||
if v is not None:
|
if v is not None:
|
||||||
class_kwargs[k] = v
|
class_kwargs[k] = v
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
model = AutoTranscribe(**class_kwargs)
|
model = AutoTranscribe(**class_kwargs)
|
||||||
|
|
||||||
# if transcription_task == "transcribe":
|
if arg_dict["audio_files"]:
|
||||||
# for audio in audio_files:
|
audio_files = args.pop("audio_files")
|
||||||
# out = model.transcribe(audio, language=spoken_language)
|
|
||||||
# basename = audio.split("/")[-1].split(".")[0]
|
if task == "autotranscribe" or task == "autotranscribe+translate":
|
||||||
# spath = f"{output_directory}/{basename}.{output_format}"
|
for audio in audio_files:
|
||||||
# out.save(spath)
|
if task == "autotranscribe+translate":
|
||||||
|
task = "translate"
|
||||||
# # ... include other tasks here ...
|
else:
|
||||||
# elif transcription_task == "diarize":
|
task = "transcribe"
|
||||||
# # diarize code here
|
|
||||||
# pass
|
out = model.autotranscribe(audio,task = task, language=arg_dict.pop("language"), verbose = arg_dict.pop("verbose_output"))
|
||||||
# elif transcription_task == "wtranscribe":
|
basename = audio.split("/")[-1].split(".")[0]
|
||||||
# # wtranscribe code here
|
out.save(os.path.join(out_folder, f"{basename}.{out_format}"))
|
||||||
# pass
|
|
||||||
|
elif task == "diarization":
|
||||||
|
for audio in audio_files:
|
||||||
|
if arg_dict.pop("verbose_output"):
|
||||||
|
print(f"Verbose not implemented for diarization.")
|
||||||
|
|
||||||
|
out = model.diarization(audio)
|
||||||
|
basename = audio.split("/")[-1].split(".")[0]
|
||||||
|
path = os.path.join(out_folder, f"{basename}.{out_format}")
|
||||||
|
if out_format == "txt":
|
||||||
|
with open(path, "w") as f:
|
||||||
|
f.write(out)
|
||||||
|
elif out_format == "json":
|
||||||
|
with open(path, "w") as f:
|
||||||
|
json.dump(json.dumps(out, indent= 3), f)
|
||||||
|
else:
|
||||||
|
raise ValueError(f"Unsupported output format for diarization{out_format}.")
|
||||||
|
elif task == "transcribe" or task == "translate":
|
||||||
|
|
||||||
|
for audio in audio_files:
|
||||||
|
|
||||||
|
out = model.transcribe(audio, task = task,
|
||||||
|
language=arg_dict.pop("language"),
|
||||||
|
verbose = arg_dict.pop("verbose_output"))
|
||||||
|
basename = audio.split("/")[-1].split(".")[0]
|
||||||
|
path = os.path.join(out_folder, f"{basename}.{out_format}")
|
||||||
|
with open(path, "w") as f:
|
||||||
|
f.write(out)
|
||||||
|
|
||||||
|
|
||||||
# if start_server: # unfinished code
|
if start_server: # unfinished code
|
||||||
# from .gradio_app import gradio_app
|
from .gradio_app import gradio_app
|
||||||
# gradio_app(model)
|
gradio_app(model)
|
||||||
|
|
||||||
if __name__ == "__main__":
|
if __name__ == "__main__":
|
||||||
cli()
|
cli()
|
||||||
Reference in New Issue
Block a user