added cli
This commit is contained in:
@@ -1,7 +1,7 @@
|
||||
from autotranscript.audio import AudioProcessor
|
||||
from autotranscript.diarisation import Diariser
|
||||
from autotranscript.transcriber import Transcriber, whisper
|
||||
from autotranscript.transcript_exporter import Transcript
|
||||
from .audio import AudioProcessor
|
||||
from .diarisation import Diariser
|
||||
from .transcriber import Transcriber, whisper
|
||||
from .transcript_exporter import Transcript
|
||||
from typing import Union , TypeVar
|
||||
from tqdm import trange
|
||||
import torch
|
||||
@@ -9,6 +9,8 @@ import os
|
||||
from glob import iglob
|
||||
from subprocess import run
|
||||
from warnings import warn
|
||||
import argparse
|
||||
|
||||
|
||||
diarisation = TypeVar('diarisation')
|
||||
|
||||
@@ -160,4 +162,104 @@ class AutoTranscribe:
|
||||
if not isinstance(audiofile, AudioProcessor):
|
||||
raise ValueError(f'Audiofile must be of type AudioProcessor,' \
|
||||
f'not {type(audiofile)}')
|
||||
return audiofile
|
||||
return audiofile
|
||||
|
||||
|
||||
def cli():
|
||||
from whisper import available_models
|
||||
from whisper.utils import get_writer
|
||||
from whisper.tokenizer import LANGUAGES , TO_LANGUAGE_CODE
|
||||
from .transcriber import WHISPER_DEFAULT_PATH
|
||||
def str2bool(string):
|
||||
str2val = {"True": True, "False": False}
|
||||
if string in str2val:
|
||||
return str2val[string]
|
||||
else:
|
||||
raise ValueError(f"Expected one of {set(str2val.keys())}, got {string}")
|
||||
|
||||
|
||||
# fmt: off
|
||||
parser = argparse.ArgumentParser(formatter_class=
|
||||
argparse.ArgumentDefaultsHelpFormatter)
|
||||
|
||||
parser.add_argument("audio", nargs="+", type=str,
|
||||
help="audio file(s) to transcribe")
|
||||
|
||||
parser.add_argument("--wmodel", default="medium",
|
||||
help="name of the Whisper model to use")
|
||||
parser.add_argument("--wmodel_dir", type=str, default= WHISPER_DEFAULT_PATH,
|
||||
help="the path to save model files; uses ./models/whisper by default")
|
||||
|
||||
parser.add_argument("--device",
|
||||
default="cuda" if torch.cuda.is_available() else "cpu",
|
||||
help="device to use for PyTorch inference")
|
||||
parser.add_argument("--threads", type=int, default=0,
|
||||
help="number of threads used by torch for CPU inference; supercedes MKL_NUM_THREADS/OMP_NUM_THREADS")
|
||||
|
||||
parser.add_argument("--output_dir", "-o", type=str, default=".",
|
||||
help="directory to save the outputs")
|
||||
parser.add_argument("--output_format", "-f", type=str, default="txt",
|
||||
choices=["txt", "json", "md", "html"],
|
||||
help="format of the output file; if not specified, all available formats will be produced")
|
||||
|
||||
parser.add_argument("--verbose", type=str2bool, default=True,
|
||||
help="whether to print out the progress and debug messages")
|
||||
|
||||
parser.add_argument("--task", type=str, default="transcribe",
|
||||
choices=["transcribe", "diarize","wtranscribe"],
|
||||
help="whether to perfrom transcription and diazation or only one of them")
|
||||
parser.add_argument("--language", type=str, default=None,
|
||||
choices=sorted(LANGUAGES.keys()) + sorted([k.title() for k in TO_LANGUAGE_CODE.keys()]),
|
||||
help="language spoken in the audio, specify None to perform language detection")
|
||||
|
||||
# fmt: on
|
||||
|
||||
args = parser.parse_args().__dict__
|
||||
model_name: str = args.pop("wmodel")
|
||||
model_dir: str = args.pop("wmodel_dir")
|
||||
output_dir: str = args.pop("output_dir")
|
||||
output_format: str = args.pop("output_format")
|
||||
task = args.pop("task")
|
||||
device: str = args.pop("device")
|
||||
os.makedirs(output_dir, exist_ok=True)
|
||||
|
||||
if (threads := args.pop("threads")) > 0:
|
||||
torch.set_num_threads(threads)
|
||||
|
||||
wkwargs = {"download_root": model_dir,
|
||||
"device": device,
|
||||
"language" : args.pop("language")}
|
||||
|
||||
model = AutoTranscribe(whisper_model= model_name, whisper_kwargs= wkwargs)
|
||||
|
||||
if task == "transcribe":
|
||||
for audio in args.pop("audio"):
|
||||
out = model.transcribe(audio)
|
||||
basename = audio.split("/")[-1].split(".")[0]
|
||||
spath = f"{output_dir}/{basename}.{output_format}"
|
||||
out.save(spath)
|
||||
|
||||
elif task == "diarize":
|
||||
warn("Diarization is still in beta and may not work as expected.",
|
||||
RuntimeWarning)
|
||||
for audio in args.pop("audio"):
|
||||
out = model.diariser.diarization(audio)
|
||||
basename = audio.split("/")[-1].split(".")[0]
|
||||
spath = f"{output_dir}/{basename}.json"
|
||||
|
||||
print(f"diairization results saved to {spath}")
|
||||
|
||||
out.save(spath)
|
||||
|
||||
elif task == "wtranscribe":
|
||||
writer = get_writer(output_format, output_dir)
|
||||
warn("whisper transcription is poorly supported and may not work as expected." \
|
||||
"It is recommendet to use the whisper cli directly",
|
||||
RuntimeWarning)
|
||||
for audio in args.pop("audio"):
|
||||
out = model.transcriber.transcribe(audio, diarisation=True)
|
||||
basename = audio.split("/")[-1].split(".")[0]
|
||||
writer(out, audio)
|
||||
|
||||
if __name__ == "__main__":
|
||||
cli()
|
||||
Reference in New Issue
Block a user