From 907913f2bfa1cc342642db2fa90e9c65c55ecfd1 Mon Sep 17 00:00:00 2001 From: Jaikinator Date: Fri, 30 Jun 2023 18:44:39 +0200 Subject: [PATCH] fixed kwargs confusion and resolved path issues --- autotranscript/diarisation.py | 33 +++++++++++++++++---------------- 1 file changed, 17 insertions(+), 16 deletions(-) diff --git a/autotranscript/diarisation.py b/autotranscript/diarisation.py index 1c2e4fb..bb364e9 100644 --- a/autotranscript/diarisation.py +++ b/autotranscript/diarisation.py @@ -2,9 +2,10 @@ from pyannote.audio import Pipeline from pyannote.audio.pipelines.speaker_diarization import SpeakerDiarization from torch import Tensor import os +from pathlib import Path from typing import TypeVar, Union import json -from .misc import PYANNOTE_DEFAULT_PATH +from .misc import PYANNOTE_DEFAULT_CONFIG, PYANNOTE_DEFAULT_PATH Annotation = TypeVar('Annotation') class Diariser: @@ -118,10 +119,12 @@ class Diariser: return token @classmethod - def load_model(cls, model: str = PYANNOTE_DEFAULT_PATH, - token: str = "", - local : bool = True, - *args, **kwargs) -> Pipeline: + def load_model(cls, + model: str = PYANNOTE_DEFAULT_CONFIG, + token: str = None, + cache_dir: Union[Path, str] = PYANNOTE_DEFAULT_PATH, + hparams_file: Union[str, Path] = None + ) -> Pipeline: """ Load modules from pyannote @@ -139,17 +142,15 @@ class Diariser: ------- Pipeline Object """ - - if local: - diarization_model = Pipeline.from_pretrained(model,*args, **kwargs) - else: - print("Loading model from HuggingFace") - if token == "": - token = cls._get_token() - diarization_model = Pipeline.from_pretrained(model, use_auth_token = token, - *args, **kwargs) - - return cls(diarization_model) + if not os.path.exists(model) and token is None: + token = cls._get_token() + + _model = Pipeline.from_pretrained(model, + use_auth_token = token, + cache_dir = cache_dir, + hparams_file = hparams_file,) + + return cls(_model) @staticmethod def _get_diarisation_kwargs(**kwargs) -> dict: