fixed kwargs confusion

and resolved path issues
This commit is contained in:
Jaikinator
2023-06-30 18:44:39 +02:00
parent 38d1f8f668
commit 907913f2bf
+17 -16
View File
@@ -2,9 +2,10 @@ from pyannote.audio import Pipeline
from pyannote.audio.pipelines.speaker_diarization import SpeakerDiarization from pyannote.audio.pipelines.speaker_diarization import SpeakerDiarization
from torch import Tensor from torch import Tensor
import os import os
from pathlib import Path
from typing import TypeVar, Union from typing import TypeVar, Union
import json import json
from .misc import PYANNOTE_DEFAULT_PATH from .misc import PYANNOTE_DEFAULT_CONFIG, PYANNOTE_DEFAULT_PATH
Annotation = TypeVar('Annotation') Annotation = TypeVar('Annotation')
class Diariser: class Diariser:
@@ -118,10 +119,12 @@ class Diariser:
return token return token
@classmethod @classmethod
def load_model(cls, model: str = PYANNOTE_DEFAULT_PATH, def load_model(cls,
token: str = "", model: str = PYANNOTE_DEFAULT_CONFIG,
local : bool = True, token: str = None,
*args, **kwargs) -> Pipeline: cache_dir: Union[Path, str] = PYANNOTE_DEFAULT_PATH,
hparams_file: Union[str, Path] = None
) -> Pipeline:
""" """
Load modules from pyannote Load modules from pyannote
@@ -139,17 +142,15 @@ class Diariser:
------- -------
Pipeline Object Pipeline Object
""" """
if not os.path.exists(model) and token is None:
if local: token = cls._get_token()
diarization_model = Pipeline.from_pretrained(model,*args, **kwargs)
else: _model = Pipeline.from_pretrained(model,
print("Loading model from HuggingFace") use_auth_token = token,
if token == "": cache_dir = cache_dir,
token = cls._get_token() hparams_file = hparams_file,)
diarization_model = Pipeline.from_pretrained(model, use_auth_token = token,
*args, **kwargs) return cls(_model)
return cls(diarization_model)
@staticmethod @staticmethod
def _get_diarisation_kwargs(**kwargs) -> dict: def _get_diarisation_kwargs(**kwargs) -> dict: