fixed bugs with hf tokens

This commit is contained in:
Jaikinator
2023-09-01 14:30:17 +02:00
parent d3e4c2dc75
commit f12c1396d6
2 changed files with 11 additions and 10 deletions
+9 -10
View File
@@ -8,6 +8,8 @@ import os
from argparse import ArgumentParser, ArgumentDefaultsHelpFormatter
import json
from sympy import use
from .autotranscript import AutoTranscribe
from whisper.tokenizer import LANGUAGES , TO_LANGUAGE_CODE
@@ -57,12 +59,9 @@ def cli():
parser.add_argument("--diarization_directory", type=str, default= None,
help="Path to the diarization model directory.")
parser.add_argument("--huggingface_token", default= None, type=str,
parser.add_argument("--hf_token", default= None, type=str,
help="HuggingFace token for private model download.")
parser.add_argument("--allow_download", type=str2bool, default=True,
help="Allow model download if not found locally.")
parser.add_argument("--inference_device",
default="cuda" if is_available() else "cpu",
help="Device to use for PyTorch inference.")
@@ -107,13 +106,13 @@ def cli():
if args.num_threads > 0:
set_num_threads(arg_dict.pop("num_threads"))
class_kwargs = dict()
for k, v in arg_dict.items():
if v is not None:
class_kwargs[k] = v
class_kwargs = {'whisper_model' : arg_dict.pop("whisper_model_name"),
'dia_model': arg_dict.pop("diarization_directory"),
'use_auth_token' : arg_dict.pop("hf_token")}
if arg_dict["whisper_model_directory"]:
class_kwargs["download_root"] = arg_dict.pop("whisper_model_directory")
model = AutoTranscribe(**class_kwargs)
+2
View File
@@ -208,6 +208,8 @@ class Diariser:
if not os.path.exists(model) and use_auth_token is None:
use_auth_token = cls._get_token()
model = 'pyannote/speaker-diarization'
elif not os.path.exists(model) and use_auth_token is not None:
model = 'pyannote/speaker-diarization'
_model = Pipeline.from_pretrained(model,
use_auth_token = use_auth_token,