fixed bugs with hf tokens

This commit is contained in:
Jaikinator
2023-09-01 14:30:17 +02:00
parent d3e4c2dc75
commit f12c1396d6
2 changed files with 11 additions and 10 deletions
+9 -10
View File
@@ -8,6 +8,8 @@ import os
from argparse import ArgumentParser, ArgumentDefaultsHelpFormatter from argparse import ArgumentParser, ArgumentDefaultsHelpFormatter
import json import json
from sympy import use
from .autotranscript import AutoTranscribe from .autotranscript import AutoTranscribe
from whisper.tokenizer import LANGUAGES , TO_LANGUAGE_CODE from whisper.tokenizer import LANGUAGES , TO_LANGUAGE_CODE
@@ -57,12 +59,9 @@ def cli():
parser.add_argument("--diarization_directory", type=str, default= None, parser.add_argument("--diarization_directory", type=str, default= None,
help="Path to the diarization model directory.") help="Path to the diarization model directory.")
parser.add_argument("--huggingface_token", default= None, type=str, parser.add_argument("--hf_token", default= None, type=str,
help="HuggingFace token for private model download.") help="HuggingFace token for private model download.")
parser.add_argument("--allow_download", type=str2bool, default=True,
help="Allow model download if not found locally.")
parser.add_argument("--inference_device", parser.add_argument("--inference_device",
default="cuda" if is_available() else "cpu", default="cuda" if is_available() else "cpu",
help="Device to use for PyTorch inference.") help="Device to use for PyTorch inference.")
@@ -107,13 +106,13 @@ def cli():
if args.num_threads > 0: if args.num_threads > 0:
set_num_threads(arg_dict.pop("num_threads")) set_num_threads(arg_dict.pop("num_threads"))
class_kwargs = dict()
for k, v in arg_dict.items(): class_kwargs = {'whisper_model' : arg_dict.pop("whisper_model_name"),
if v is not None: 'dia_model': arg_dict.pop("diarization_directory"),
class_kwargs[k] = v 'use_auth_token' : arg_dict.pop("hf_token")}
if arg_dict["whisper_model_directory"]:
class_kwargs["download_root"] = arg_dict.pop("whisper_model_directory")
model = AutoTranscribe(**class_kwargs) model = AutoTranscribe(**class_kwargs)
+2
View File
@@ -208,6 +208,8 @@ class Diariser:
if not os.path.exists(model) and use_auth_token is None: if not os.path.exists(model) and use_auth_token is None:
use_auth_token = cls._get_token() use_auth_token = cls._get_token()
model = 'pyannote/speaker-diarization' model = 'pyannote/speaker-diarization'
elif not os.path.exists(model) and use_auth_token is not None:
model = 'pyannote/speaker-diarization'
_model = Pipeline.from_pretrained(model, _model = Pipeline.from_pretrained(model,
use_auth_token = use_auth_token, use_auth_token = use_auth_token,