fixed kwargs confusion
and resolved path issues
This commit is contained in:
@@ -2,9 +2,10 @@ from pyannote.audio import Pipeline
|
|||||||
from pyannote.audio.pipelines.speaker_diarization import SpeakerDiarization
|
from pyannote.audio.pipelines.speaker_diarization import SpeakerDiarization
|
||||||
from torch import Tensor
|
from torch import Tensor
|
||||||
import os
|
import os
|
||||||
|
from pathlib import Path
|
||||||
from typing import TypeVar, Union
|
from typing import TypeVar, Union
|
||||||
import json
|
import json
|
||||||
from .misc import PYANNOTE_DEFAULT_PATH
|
from .misc import PYANNOTE_DEFAULT_CONFIG, PYANNOTE_DEFAULT_PATH
|
||||||
Annotation = TypeVar('Annotation')
|
Annotation = TypeVar('Annotation')
|
||||||
|
|
||||||
class Diariser:
|
class Diariser:
|
||||||
@@ -118,10 +119,12 @@ class Diariser:
|
|||||||
return token
|
return token
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
def load_model(cls, model: str = PYANNOTE_DEFAULT_PATH,
|
def load_model(cls,
|
||||||
token: str = "",
|
model: str = PYANNOTE_DEFAULT_CONFIG,
|
||||||
local : bool = True,
|
token: str = None,
|
||||||
*args, **kwargs) -> Pipeline:
|
cache_dir: Union[Path, str] = PYANNOTE_DEFAULT_PATH,
|
||||||
|
hparams_file: Union[str, Path] = None
|
||||||
|
) -> Pipeline:
|
||||||
"""
|
"""
|
||||||
Load modules from pyannote
|
Load modules from pyannote
|
||||||
|
|
||||||
@@ -139,17 +142,15 @@ class Diariser:
|
|||||||
-------
|
-------
|
||||||
Pipeline Object
|
Pipeline Object
|
||||||
"""
|
"""
|
||||||
|
if not os.path.exists(model) and token is None:
|
||||||
if local:
|
|
||||||
diarization_model = Pipeline.from_pretrained(model,*args, **kwargs)
|
|
||||||
else:
|
|
||||||
print("Loading model from HuggingFace")
|
|
||||||
if token == "":
|
|
||||||
token = cls._get_token()
|
token = cls._get_token()
|
||||||
diarization_model = Pipeline.from_pretrained(model, use_auth_token = token,
|
|
||||||
*args, **kwargs)
|
|
||||||
|
|
||||||
return cls(diarization_model)
|
_model = Pipeline.from_pretrained(model,
|
||||||
|
use_auth_token = token,
|
||||||
|
cache_dir = cache_dir,
|
||||||
|
hparams_file = hparams_file,)
|
||||||
|
|
||||||
|
return cls(_model)
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
def _get_diarisation_kwargs(**kwargs) -> dict:
|
def _get_diarisation_kwargs(**kwargs) -> dict:
|
||||||
|
|||||||
Reference in New Issue
Block a user