fixed bugs with hf tokens
This commit is contained in:
+9
-10
@@ -8,6 +8,8 @@ import os
|
||||
from argparse import ArgumentParser, ArgumentDefaultsHelpFormatter
|
||||
import json
|
||||
|
||||
from sympy import use
|
||||
|
||||
from .autotranscript import AutoTranscribe
|
||||
|
||||
from whisper.tokenizer import LANGUAGES , TO_LANGUAGE_CODE
|
||||
@@ -57,12 +59,9 @@ def cli():
|
||||
parser.add_argument("--diarization_directory", type=str, default= None,
|
||||
help="Path to the diarization model directory.")
|
||||
|
||||
parser.add_argument("--huggingface_token", default= None, type=str,
|
||||
parser.add_argument("--hf_token", default= None, type=str,
|
||||
help="HuggingFace token for private model download.")
|
||||
|
||||
parser.add_argument("--allow_download", type=str2bool, default=True,
|
||||
help="Allow model download if not found locally.")
|
||||
|
||||
parser.add_argument("--inference_device",
|
||||
default="cuda" if is_available() else "cpu",
|
||||
help="Device to use for PyTorch inference.")
|
||||
@@ -107,13 +106,13 @@ def cli():
|
||||
|
||||
if args.num_threads > 0:
|
||||
set_num_threads(arg_dict.pop("num_threads"))
|
||||
|
||||
class_kwargs = dict()
|
||||
|
||||
for k, v in arg_dict.items():
|
||||
if v is not None:
|
||||
class_kwargs[k] = v
|
||||
|
||||
class_kwargs = {'whisper_model' : arg_dict.pop("whisper_model_name"),
|
||||
'dia_model': arg_dict.pop("diarization_directory"),
|
||||
'use_auth_token' : arg_dict.pop("hf_token")}
|
||||
|
||||
if arg_dict["whisper_model_directory"]:
|
||||
class_kwargs["download_root"] = arg_dict.pop("whisper_model_directory")
|
||||
|
||||
model = AutoTranscribe(**class_kwargs)
|
||||
|
||||
|
||||
Reference in New Issue
Block a user