bug fixes
This commit is contained in:
+20
-22
@@ -8,12 +8,8 @@ import os
|
|||||||
from argparse import ArgumentParser, ArgumentDefaultsHelpFormatter
|
from argparse import ArgumentParser, ArgumentDefaultsHelpFormatter
|
||||||
import json
|
import json
|
||||||
|
|
||||||
from .transcriber import WHISPER_DEFAULT_PATH
|
|
||||||
from .diarisation import PYANNOTE_DEFAULT_PATH
|
|
||||||
from .autotranscript import AutoTranscribe
|
from .autotranscript import AutoTranscribe
|
||||||
|
|
||||||
from whisper import available_models
|
|
||||||
from whisper.utils import get_writer
|
|
||||||
from whisper.tokenizer import LANGUAGES , TO_LANGUAGE_CODE
|
from whisper.tokenizer import LANGUAGES , TO_LANGUAGE_CODE
|
||||||
from torch.cuda import is_available
|
from torch.cuda import is_available
|
||||||
from torch import set_num_threads
|
from torch import set_num_threads
|
||||||
@@ -47,10 +43,10 @@ def cli():
|
|||||||
help='Start the Gradio app.')
|
help='Start the Gradio app.')
|
||||||
|
|
||||||
parser.add_argument("--port", type=int, default= None,
|
parser.add_argument("--port", type=int, default= None,
|
||||||
help="Port to run the Gradio app on.")
|
help="Port to run the Gradio app on. Defaults to 7860.")
|
||||||
|
|
||||||
parser.add_argument("--server_name", type=str, default= "autotranscript",
|
parser.add_argument("--server_name", type=str, default= None,
|
||||||
help="Name of the Gradio app.")
|
help="Name of the Gradio app. If empty 127.0.0.1 or 0.0.0.0 will be used.")
|
||||||
|
|
||||||
parser.add_argument("--whisper_model_name", default="medium",
|
parser.add_argument("--whisper_model_name", default="medium",
|
||||||
help="Name of the Whisper model to use.")
|
help="Name of the Whisper model to use.")
|
||||||
@@ -84,10 +80,11 @@ def cli():
|
|||||||
parser.add_argument("--verbose_output", type=str2bool, default=True,
|
parser.add_argument("--verbose_output", type=str2bool, default=True,
|
||||||
help="Enable or disable progress and debug messages.")
|
help="Enable or disable progress and debug messages.")
|
||||||
|
|
||||||
parser.add_argument("--task", type=str, default= None, # unifinished code
|
parser.add_argument("--task", type=str, default= 'autotranscribe', # unifinished code
|
||||||
choices=["autotranscribe", "diarization", "autotranscribe+translate", "translate"],
|
choices=["autotranscribe", "diarization",
|
||||||
|
"autotranscribe+translate", "translate", 'transcribe'],
|
||||||
help="Choose to perform transcription, diarization, or translation. \
|
help="Choose to perform transcription, diarization, or translation. \
|
||||||
If set to translate, the language argument must be specified.")
|
If set to translate, the output will be translated to English.")
|
||||||
|
|
||||||
parser.add_argument("--language", type=str, default=None,
|
parser.add_argument("--language", type=str, default=None,
|
||||||
choices=sorted(LANGUAGES.keys()) + sorted([k.title() for k in TO_LANGUAGE_CODE.keys()]),
|
choices=sorted(LANGUAGES.keys()) + sorted([k.title() for k in TO_LANGUAGE_CODE.keys()]),
|
||||||
@@ -120,8 +117,9 @@ def cli():
|
|||||||
|
|
||||||
model = AutoTranscribe(**class_kwargs)
|
model = AutoTranscribe(**class_kwargs)
|
||||||
|
|
||||||
|
|
||||||
if arg_dict["audio_files"]:
|
if arg_dict["audio_files"]:
|
||||||
audio_files = args.pop("audio_files")
|
audio_files = arg_dict.pop("audio_files")
|
||||||
|
|
||||||
if task == "autotranscribe" or task == "autotranscribe+translate":
|
if task == "autotranscribe" or task == "autotranscribe+translate":
|
||||||
for audio in audio_files:
|
for audio in audio_files:
|
||||||
@@ -132,6 +130,7 @@ def cli():
|
|||||||
|
|
||||||
out = model.autotranscribe(audio,task = task, language=arg_dict.pop("language"), verbose = arg_dict.pop("verbose_output"))
|
out = model.autotranscribe(audio,task = task, language=arg_dict.pop("language"), verbose = arg_dict.pop("verbose_output"))
|
||||||
basename = audio.split("/")[-1].split(".")[0]
|
basename = audio.split("/")[-1].split(".")[0]
|
||||||
|
print(f'Saving {basename}.{out_format} to {out_folder}')
|
||||||
out.save(os.path.join(out_folder, f"{basename}.{out_format}"))
|
out.save(os.path.join(out_folder, f"{basename}.{out_format}"))
|
||||||
|
|
||||||
elif task == "diarization":
|
elif task == "diarization":
|
||||||
@@ -142,20 +141,18 @@ def cli():
|
|||||||
out = model.diarization(audio)
|
out = model.diarization(audio)
|
||||||
basename = audio.split("/")[-1].split(".")[0]
|
basename = audio.split("/")[-1].split(".")[0]
|
||||||
path = os.path.join(out_folder, f"{basename}.{out_format}")
|
path = os.path.join(out_folder, f"{basename}.{out_format}")
|
||||||
if out_format == "txt":
|
|
||||||
with open(path, "w") as f:
|
print(f'Saving {basename}.{out_format} to {out_folder}')
|
||||||
f.write(out)
|
|
||||||
elif out_format == "json":
|
with open(path, "w") as f:
|
||||||
with open(path, "w") as f:
|
json.dump(json.dumps(out, indent= 1), f)
|
||||||
json.dump(json.dumps(out, indent= 3), f)
|
|
||||||
else:
|
|
||||||
raise ValueError(f"Unsupported output format for diarization{out_format}.")
|
|
||||||
elif task == "transcribe" or task == "translate":
|
elif task == "transcribe" or task == "translate":
|
||||||
|
|
||||||
for audio in audio_files:
|
for audio in audio_files:
|
||||||
|
|
||||||
out = model.transcribe(audio, task = task,
|
out = model.transcribe(audio, task = task,
|
||||||
language=arg_dict.pop("language"),
|
language= arg_dict.pop("language"),
|
||||||
verbose = arg_dict.pop("verbose_output"))
|
verbose = arg_dict.pop("verbose_output"))
|
||||||
basename = audio.split("/")[-1].split(".")[0]
|
basename = audio.split("/")[-1].split(".")[0]
|
||||||
path = os.path.join(out_folder, f"{basename}.{out_format}")
|
path = os.path.join(out_folder, f"{basename}.{out_format}")
|
||||||
@@ -164,8 +161,9 @@ def cli():
|
|||||||
|
|
||||||
|
|
||||||
if start_server: # unfinished code
|
if start_server: # unfinished code
|
||||||
from .gradio_app import gradio_app
|
from .app.gradio_app import gradio_Interface
|
||||||
gradio_app(model)
|
gradio_Interface(model).queue().launch(server_port=args.port, server_name=args.server_name)
|
||||||
|
|
||||||
|
|
||||||
if __name__ == "__main__":
|
if __name__ == "__main__":
|
||||||
cli()
|
cli()
|
||||||
Reference in New Issue
Block a user