diff --git a/autotranscript/autotranscript.py b/autotranscript/autotranscript.py index 0a29528..2097f2f 100644 --- a/autotranscript/autotranscript.py +++ b/autotranscript/autotranscript.py @@ -38,8 +38,7 @@ class AutoTranscribe: """ if whisper_model is None: - self.transcriber = Transcriber.load_model("medium", local=True) - + self.transcriber = Transcriber.load_model("medium", local=True) elif isinstance(whisper_model, str): self.transcriber = Transcriber.load_model(whisper_model, **whisper_kwargs) else: @@ -170,6 +169,7 @@ def cli(): from whisper.utils import get_writer from whisper.tokenizer import LANGUAGES , TO_LANGUAGE_CODE from .transcriber import WHISPER_DEFAULT_PATH + from .diarisation import PYANNOTE_DEFAULT_PATH def str2bool(string): str2val = {"True": True, "False": False} if string in str2val: @@ -190,6 +190,10 @@ def cli(): parser.add_argument("--wmodel_dir", type=str, default= WHISPER_DEFAULT_PATH, help="the path to save model files; uses ./models/whisper by default") + parser.add_argument("--dia_model", type=str, default = PYANNOTE_DEFAULT_PATH) + + parser.add_argument("--allow_download", type= bool, default=True, + help="whether to allow model download if model is not found locally") parser.add_argument("--device", default="cuda" if torch.cuda.is_available() else "cpu", help="device to use for PyTorch inference") @@ -219,6 +223,7 @@ def cli(): model_dir: str = args.pop("wmodel_dir") output_dir: str = args.pop("output_dir") output_format: str = args.pop("output_format") + local :str = args.pop("allow_download") task = args.pop("task") device: str = args.pop("device") os.makedirs(output_dir, exist_ok=True) @@ -227,14 +232,17 @@ def cli(): torch.set_num_threads(threads) wkwargs = {"download_root": model_dir, - "device": device, - "language" : args.pop("language")} - - model = AutoTranscribe(whisper_model= model_name, whisper_kwargs= wkwargs) + "local": local, + "device": device} + diarisation_kwargs = {"local": local} + model = AutoTranscribe(whisper_model= model_name, + whisper_kwargs= wkwargs, + dia_model= args.pop("dia_model"), + dia_kwargs_kwargs= diarisation_kwargs,) if task == "transcribe": for audio in args.pop("audio"): - out = model.transcribe(audio) + out = model.transcribe(audio, language = args.pop("language")) basename = audio.split("/")[-1].split(".")[0] spath = f"{output_dir}/{basename}.{output_format}" out.save(spath) @@ -257,7 +265,7 @@ def cli(): "It is recommendet to use the whisper cli directly", RuntimeWarning) for audio in args.pop("audio"): - out = model.transcriber.transcribe(audio, diarisation=True) + out = model.transcriber.transcribe(audio, language = args.pop("language")) basename = audio.split("/")[-1].split(".")[0] writer(out, audio)