fixed kwargs confusion

and resolved path issues
This commit is contained in:
Jaikinator
2023-06-30 18:44:39 +02:00
parent 38d1f8f668
commit 907913f2bf
+15 -14
View File
@@ -2,9 +2,10 @@ from pyannote.audio import Pipeline
from pyannote.audio.pipelines.speaker_diarization import SpeakerDiarization
from torch import Tensor
import os
from pathlib import Path
from typing import TypeVar, Union
import json
from .misc import PYANNOTE_DEFAULT_PATH
from .misc import PYANNOTE_DEFAULT_CONFIG, PYANNOTE_DEFAULT_PATH
Annotation = TypeVar('Annotation')
class Diariser:
@@ -118,10 +119,12 @@ class Diariser:
return token
@classmethod
def load_model(cls, model: str = PYANNOTE_DEFAULT_PATH,
token: str = "",
local : bool = True,
*args, **kwargs) -> Pipeline:
def load_model(cls,
model: str = PYANNOTE_DEFAULT_CONFIG,
token: str = None,
cache_dir: Union[Path, str] = PYANNOTE_DEFAULT_PATH,
hparams_file: Union[str, Path] = None
) -> Pipeline:
"""
Load modules from pyannote
@@ -139,17 +142,15 @@ class Diariser:
-------
Pipeline Object
"""
if local:
diarization_model = Pipeline.from_pretrained(model,*args, **kwargs)
else:
print("Loading model from HuggingFace")
if token == "":
if not os.path.exists(model) and token is None:
token = cls._get_token()
diarization_model = Pipeline.from_pretrained(model, use_auth_token = token,
*args, **kwargs)
return cls(diarization_model)
_model = Pipeline.from_pretrained(model,
use_auth_token = token,
cache_dir = cache_dir,
hparams_file = hparams_file,)
return cls(_model)
@staticmethod
def _get_diarisation_kwargs(**kwargs) -> dict: