Add default path to pyannote model with fallback option.
This commit is contained in:
+15
-3
@@ -19,6 +19,7 @@ Constants:
|
|||||||
- TOKEN_PATH (str): Path to the Pyannote token.
|
- TOKEN_PATH (str): Path to the Pyannote token.
|
||||||
- PYANNOTE_DEFAULT_PATH (str): Default path to Pyannote models.
|
- PYANNOTE_DEFAULT_PATH (str): Default path to Pyannote models.
|
||||||
- PYANNOTE_DEFAULT_CONFIG (str): Default configuration for Pyannote models.
|
- PYANNOTE_DEFAULT_CONFIG (str): Default configuration for Pyannote models.
|
||||||
|
- PYANNOTE_FALLBACK_CONFIG (str): Fallback config for Pyannote models if default config does not work.
|
||||||
|
|
||||||
Usage:
|
Usage:
|
||||||
from .diarisation import Diariser
|
from .diarisation import Diariser
|
||||||
@@ -39,7 +40,7 @@ from torch import Tensor
|
|||||||
from torch import device as torch_device
|
from torch import device as torch_device
|
||||||
from torch.cuda import is_available, current_device
|
from torch.cuda import is_available, current_device
|
||||||
|
|
||||||
from .misc import PYANNOTE_DEFAULT_PATH, PYANNOTE_DEFAULT_CONFIG
|
from .misc import PYANNOTE_DEFAULT_PATH, PYANNOTE_DEFAULT_CONFIG, PYANNOTE_FALLBACK_CONFIG
|
||||||
Annotation = TypeVar('Annotation')
|
Annotation = TypeVar('Annotation')
|
||||||
|
|
||||||
TOKEN_PATH = os.path.join(os.path.dirname(
|
TOKEN_PATH = os.path.join(os.path.dirname(
|
||||||
@@ -183,7 +184,7 @@ class Diariser:
|
|||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
def load_model(cls,
|
def load_model(cls,
|
||||||
model: str = PYANNOTE_DEFAULT_CONFIG,
|
model: str = PYANNOTE_FALLBACK_CONFIG,
|
||||||
use_auth_token: str = None,
|
use_auth_token: str = None,
|
||||||
cache_token: bool = True,
|
cache_token: bool = True,
|
||||||
cache_dir: Union[Path, str] = PYANNOTE_DEFAULT_PATH,
|
cache_dir: Union[Path, str] = PYANNOTE_DEFAULT_PATH,
|
||||||
@@ -210,7 +211,18 @@ class Diariser:
|
|||||||
Returns:
|
Returns:
|
||||||
Pipeline: A pyannote.audio Pipeline object, encapsulating the loaded model.
|
Pipeline: A pyannote.audio Pipeline object, encapsulating the loaded model.
|
||||||
"""
|
"""
|
||||||
|
try:
|
||||||
|
hf_model = PYANNOTE_DEFAULT_CONFIG
|
||||||
|
# if not use_auth_token:
|
||||||
|
# use_auth_token = cls._get_token()
|
||||||
|
_model = Pipeline.from_pretrained(
|
||||||
|
hf_model, use_auth_token=use_auth_token,
|
||||||
|
cache_dir=cache_dir, hparams_file=hparams_file
|
||||||
|
)
|
||||||
|
except:
|
||||||
|
print(f'Trying fallback to config file.. ')
|
||||||
|
_model = None
|
||||||
|
if _model is None:
|
||||||
|
|
||||||
if cache_token and use_auth_token is not None:
|
if cache_token and use_auth_token is not None:
|
||||||
cls._save_token(use_auth_token)
|
cls._save_token(use_auth_token)
|
||||||
|
|||||||
+2
-1
@@ -13,7 +13,8 @@ if CACHE_DIR != PYANNOTE_CACHE_DIR:
|
|||||||
|
|
||||||
WHISPER_DEFAULT_PATH = os.path.join(CACHE_DIR, "whisper")
|
WHISPER_DEFAULT_PATH = os.path.join(CACHE_DIR, "whisper")
|
||||||
PYANNOTE_DEFAULT_PATH = os.path.join(CACHE_DIR, "pyannote")
|
PYANNOTE_DEFAULT_PATH = os.path.join(CACHE_DIR, "pyannote")
|
||||||
PYANNOTE_DEFAULT_CONFIG = os.path.join(PYANNOTE_DEFAULT_PATH, "config.yaml") \
|
PYANNOTE_DEFAULT_CONFIG = 'jaikinator/scraibe'
|
||||||
|
PYANNOTE_FALLBACK_CONFIG = os.path.join(PYANNOTE_DEFAULT_PATH, "config.yaml") \
|
||||||
if os.path.exists(os.path.join(PYANNOTE_DEFAULT_PATH, "config.yaml")) \
|
if os.path.exists(os.path.join(PYANNOTE_DEFAULT_PATH, "config.yaml")) \
|
||||||
else 'pyannote/speaker-diarization-3.1'
|
else 'pyannote/speaker-diarization-3.1'
|
||||||
|
|
||||||
|
|||||||
Reference in New Issue
Block a user