fixed bugs with hf tokens
This commit is contained in:
@@ -8,6 +8,8 @@ import os
|
|||||||
from argparse import ArgumentParser, ArgumentDefaultsHelpFormatter
|
from argparse import ArgumentParser, ArgumentDefaultsHelpFormatter
|
||||||
import json
|
import json
|
||||||
|
|
||||||
|
from sympy import use
|
||||||
|
|
||||||
from .autotranscript import AutoTranscribe
|
from .autotranscript import AutoTranscribe
|
||||||
|
|
||||||
from whisper.tokenizer import LANGUAGES , TO_LANGUAGE_CODE
|
from whisper.tokenizer import LANGUAGES , TO_LANGUAGE_CODE
|
||||||
@@ -57,12 +59,9 @@ def cli():
|
|||||||
parser.add_argument("--diarization_directory", type=str, default= None,
|
parser.add_argument("--diarization_directory", type=str, default= None,
|
||||||
help="Path to the diarization model directory.")
|
help="Path to the diarization model directory.")
|
||||||
|
|
||||||
parser.add_argument("--huggingface_token", default= None, type=str,
|
parser.add_argument("--hf_token", default= None, type=str,
|
||||||
help="HuggingFace token for private model download.")
|
help="HuggingFace token for private model download.")
|
||||||
|
|
||||||
parser.add_argument("--allow_download", type=str2bool, default=True,
|
|
||||||
help="Allow model download if not found locally.")
|
|
||||||
|
|
||||||
parser.add_argument("--inference_device",
|
parser.add_argument("--inference_device",
|
||||||
default="cuda" if is_available() else "cpu",
|
default="cuda" if is_available() else "cpu",
|
||||||
help="Device to use for PyTorch inference.")
|
help="Device to use for PyTorch inference.")
|
||||||
@@ -108,12 +107,12 @@ def cli():
|
|||||||
if args.num_threads > 0:
|
if args.num_threads > 0:
|
||||||
set_num_threads(arg_dict.pop("num_threads"))
|
set_num_threads(arg_dict.pop("num_threads"))
|
||||||
|
|
||||||
class_kwargs = dict()
|
class_kwargs = {'whisper_model' : arg_dict.pop("whisper_model_name"),
|
||||||
|
'dia_model': arg_dict.pop("diarization_directory"),
|
||||||
for k, v in arg_dict.items():
|
'use_auth_token' : arg_dict.pop("hf_token")}
|
||||||
if v is not None:
|
|
||||||
class_kwargs[k] = v
|
|
||||||
|
|
||||||
|
if arg_dict["whisper_model_directory"]:
|
||||||
|
class_kwargs["download_root"] = arg_dict.pop("whisper_model_directory")
|
||||||
|
|
||||||
model = AutoTranscribe(**class_kwargs)
|
model = AutoTranscribe(**class_kwargs)
|
||||||
|
|
||||||
|
|||||||
@@ -208,6 +208,8 @@ class Diariser:
|
|||||||
if not os.path.exists(model) and use_auth_token is None:
|
if not os.path.exists(model) and use_auth_token is None:
|
||||||
use_auth_token = cls._get_token()
|
use_auth_token = cls._get_token()
|
||||||
model = 'pyannote/speaker-diarization'
|
model = 'pyannote/speaker-diarization'
|
||||||
|
elif not os.path.exists(model) and use_auth_token is not None:
|
||||||
|
model = 'pyannote/speaker-diarization'
|
||||||
|
|
||||||
_model = Pipeline.from_pretrained(model,
|
_model = Pipeline.from_pretrained(model,
|
||||||
use_auth_token = use_auth_token,
|
use_auth_token = use_auth_token,
|
||||||
|
|||||||
Reference in New Issue
Block a user