From f12c1396d6aab5d2920a9ad91f9fd1b449f28ed6 Mon Sep 17 00:00:00 2001 From: Jaikinator Date: Fri, 1 Sep 2023 14:30:17 +0200 Subject: [PATCH] fixed bugs with hf tokens --- autotranscript/cli.py | 19 +++++++++---------- autotranscript/diarisation.py | 2 ++ 2 files changed, 11 insertions(+), 10 deletions(-) diff --git a/autotranscript/cli.py b/autotranscript/cli.py index 5fa0774..183f6c5 100644 --- a/autotranscript/cli.py +++ b/autotranscript/cli.py @@ -8,6 +8,8 @@ import os from argparse import ArgumentParser, ArgumentDefaultsHelpFormatter import json +from sympy import use + from .autotranscript import AutoTranscribe from whisper.tokenizer import LANGUAGES , TO_LANGUAGE_CODE @@ -57,12 +59,9 @@ def cli(): parser.add_argument("--diarization_directory", type=str, default= None, help="Path to the diarization model directory.") - parser.add_argument("--huggingface_token", default= None, type=str, + parser.add_argument("--hf_token", default= None, type=str, help="HuggingFace token for private model download.") - parser.add_argument("--allow_download", type=str2bool, default=True, - help="Allow model download if not found locally.") - parser.add_argument("--inference_device", default="cuda" if is_available() else "cpu", help="Device to use for PyTorch inference.") @@ -107,13 +106,13 @@ def cli(): if args.num_threads > 0: set_num_threads(arg_dict.pop("num_threads")) - - class_kwargs = dict() - for k, v in arg_dict.items(): - if v is not None: - class_kwargs[k] = v - + class_kwargs = {'whisper_model' : arg_dict.pop("whisper_model_name"), + 'dia_model': arg_dict.pop("diarization_directory"), + 'use_auth_token' : arg_dict.pop("hf_token")} + + if arg_dict["whisper_model_directory"]: + class_kwargs["download_root"] = arg_dict.pop("whisper_model_directory") model = AutoTranscribe(**class_kwargs) diff --git a/autotranscript/diarisation.py b/autotranscript/diarisation.py index 682c145..8b476bc 100644 --- a/autotranscript/diarisation.py +++ b/autotranscript/diarisation.py @@ -208,6 +208,8 @@ class Diariser: if not os.path.exists(model) and use_auth_token is None: use_auth_token = cls._get_token() model = 'pyannote/speaker-diarization' + elif not os.path.exists(model) and use_auth_token is not None: + model = 'pyannote/speaker-diarization' _model = Pipeline.from_pretrained(model, use_auth_token = use_auth_token,