fixed kwargs confusion
and resolved path issues
This commit is contained in:
@@ -2,9 +2,10 @@ from pyannote.audio import Pipeline
|
||||
from pyannote.audio.pipelines.speaker_diarization import SpeakerDiarization
|
||||
from torch import Tensor
|
||||
import os
|
||||
from pathlib import Path
|
||||
from typing import TypeVar, Union
|
||||
import json
|
||||
from .misc import PYANNOTE_DEFAULT_PATH
|
||||
from .misc import PYANNOTE_DEFAULT_CONFIG, PYANNOTE_DEFAULT_PATH
|
||||
Annotation = TypeVar('Annotation')
|
||||
|
||||
class Diariser:
|
||||
@@ -118,10 +119,12 @@ class Diariser:
|
||||
return token
|
||||
|
||||
@classmethod
|
||||
def load_model(cls, model: str = PYANNOTE_DEFAULT_PATH,
|
||||
token: str = "",
|
||||
local : bool = True,
|
||||
*args, **kwargs) -> Pipeline:
|
||||
def load_model(cls,
|
||||
model: str = PYANNOTE_DEFAULT_CONFIG,
|
||||
token: str = None,
|
||||
cache_dir: Union[Path, str] = PYANNOTE_DEFAULT_PATH,
|
||||
hparams_file: Union[str, Path] = None
|
||||
) -> Pipeline:
|
||||
"""
|
||||
Load modules from pyannote
|
||||
|
||||
@@ -139,17 +142,15 @@ class Diariser:
|
||||
-------
|
||||
Pipeline Object
|
||||
"""
|
||||
|
||||
if local:
|
||||
diarization_model = Pipeline.from_pretrained(model,*args, **kwargs)
|
||||
else:
|
||||
print("Loading model from HuggingFace")
|
||||
if token == "":
|
||||
if not os.path.exists(model) and token is None:
|
||||
token = cls._get_token()
|
||||
diarization_model = Pipeline.from_pretrained(model, use_auth_token = token,
|
||||
*args, **kwargs)
|
||||
|
||||
return cls(diarization_model)
|
||||
_model = Pipeline.from_pretrained(model,
|
||||
use_auth_token = token,
|
||||
cache_dir = cache_dir,
|
||||
hparams_file = hparams_file,)
|
||||
|
||||
return cls(_model)
|
||||
|
||||
@staticmethod
|
||||
def _get_diarisation_kwargs(**kwargs) -> dict:
|
||||
|
||||
Reference in New Issue
Block a user