Add default path to pyannote model with fallback option.

This commit is contained in:
Marko Henning
2024-04-19 17:36:34 +02:00
parent b075271b89
commit f7927fd524
2 changed files with 58 additions and 45 deletions
+56 -44
View File
@@ -19,6 +19,7 @@ Constants:
- TOKEN_PATH (str): Path to the Pyannote token. - TOKEN_PATH (str): Path to the Pyannote token.
- PYANNOTE_DEFAULT_PATH (str): Default path to Pyannote models. - PYANNOTE_DEFAULT_PATH (str): Default path to Pyannote models.
- PYANNOTE_DEFAULT_CONFIG (str): Default configuration for Pyannote models. - PYANNOTE_DEFAULT_CONFIG (str): Default configuration for Pyannote models.
- PYANNOTE_FALLBACK_CONFIG (str): Fallback config for Pyannote models if default config does not work.
Usage: Usage:
from .diarisation import Diariser from .diarisation import Diariser
@@ -39,7 +40,7 @@ from torch import Tensor
from torch import device as torch_device from torch import device as torch_device
from torch.cuda import is_available, current_device from torch.cuda import is_available, current_device
from .misc import PYANNOTE_DEFAULT_PATH, PYANNOTE_DEFAULT_CONFIG from .misc import PYANNOTE_DEFAULT_PATH, PYANNOTE_DEFAULT_CONFIG, PYANNOTE_FALLBACK_CONFIG
Annotation = TypeVar('Annotation') Annotation = TypeVar('Annotation')
TOKEN_PATH = os.path.join(os.path.dirname( TOKEN_PATH = os.path.join(os.path.dirname(
@@ -183,7 +184,7 @@ class Diariser:
@classmethod @classmethod
def load_model(cls, def load_model(cls,
model: str = PYANNOTE_DEFAULT_CONFIG, model: str = PYANNOTE_FALLBACK_CONFIG,
use_auth_token: str = None, use_auth_token: str = None,
cache_token: bool = True, cache_token: bool = True,
cache_dir: Union[Path, str] = PYANNOTE_DEFAULT_PATH, cache_dir: Union[Path, str] = PYANNOTE_DEFAULT_PATH,
@@ -210,53 +211,64 @@ class Diariser:
Returns: Returns:
Pipeline: A pyannote.audio Pipeline object, encapsulating the loaded model. Pipeline: A pyannote.audio Pipeline object, encapsulating the loaded model.
""" """
try:
hf_model = PYANNOTE_DEFAULT_CONFIG
# if not use_auth_token:
# use_auth_token = cls._get_token()
_model = Pipeline.from_pretrained(
hf_model, use_auth_token=use_auth_token,
cache_dir=cache_dir, hparams_file=hparams_file
)
except:
print(f'Trying fallback to config file.. ')
_model = None
if _model is None:
if cache_token and use_auth_token is not None:
if cache_token and use_auth_token is not None: cls._save_token(use_auth_token)
cls._save_token(use_auth_token)
if not os.path.exists(model) and use_auth_token is None:
use_auth_token = cls._get_token()
elif os.path.exists(model) and not use_auth_token:
# check if model can be found locally nearby the config file
with open(model, 'r') as file:
config = yaml.safe_load(file)
path_to_model = config['pipeline']['params']['segmentation']
if not os.path.exists(path_to_model):
warnings.warn(f"Model not found at {path_to_model}. "\
"Trying to find it nearby the config file.")
pwd = model.split("/")[:-1] if not os.path.exists(model) and use_auth_token is None:
pwd = "/".join(pwd) use_auth_token = cls._get_token()
path_to_model = os.path.join(pwd, "pytorch_model.bin") elif os.path.exists(model) and not use_auth_token:
# check if model can be found locally nearby the config file
with open(model, 'r') as file:
config = yaml.safe_load(file)
path_to_model = config['pipeline']['params']['segmentation']
if not os.path.exists(path_to_model): if not os.path.exists(path_to_model):
warnings.warn(f"Model not found at {path_to_model}. \ warnings.warn(f"Model not found at {path_to_model}. "\
'Trying to find it nearby .bin files instead.") "Trying to find it nearby the config file.")
# list elementes with the ending .bin
bin_files = [f for f in os.listdir(pwd) if f.endswith(".bin")] pwd = model.split("/")[:-1]
if len(bin_files) == 1: pwd = "/".join(pwd)
path_to_model = os.path.join(pwd, bin_files[0])
else: path_to_model = os.path.join(pwd, "pytorch_model.bin")
warnings.warn("Found more than one .bin file. "\
"or none. Please specify the path to the model " \ if not os.path.exists(path_to_model):
"or setup a huggingface token.") warnings.warn(f"Model not found at {path_to_model}. \
'Trying to find it nearby .bin files instead.")
warnings.warn(f"Found model at {path_to_model} overwriting config file.") # list elementes with the ending .bin
bin_files = [f for f in os.listdir(pwd) if f.endswith(".bin")]
config['pipeline']['params']['segmentation'] = path_to_model if len(bin_files) == 1:
path_to_model = os.path.join(pwd, bin_files[0])
with open(model, 'w') as file: else:
yaml.dump(config, file) warnings.warn("Found more than one .bin file. "\
"or none. Please specify the path to the model " \
_model = Pipeline.from_pretrained(model, "or setup a huggingface token.")
use_auth_token = use_auth_token,
cache_dir = cache_dir, warnings.warn(f"Found model at {path_to_model} overwriting config file.")
hparams_file = hparams_file,)
config['pipeline']['params']['segmentation'] = path_to_model
with open(model, 'w') as file:
yaml.dump(config, file)
_model = Pipeline.from_pretrained(model,
use_auth_token = use_auth_token,
cache_dir = cache_dir,
hparams_file = hparams_file,)
# try to move the model to the device # try to move the model to the device
if device is None: if device is None:
+2 -1
View File
@@ -13,7 +13,8 @@ if CACHE_DIR != PYANNOTE_CACHE_DIR:
WHISPER_DEFAULT_PATH = os.path.join(CACHE_DIR, "whisper") WHISPER_DEFAULT_PATH = os.path.join(CACHE_DIR, "whisper")
PYANNOTE_DEFAULT_PATH = os.path.join(CACHE_DIR, "pyannote") PYANNOTE_DEFAULT_PATH = os.path.join(CACHE_DIR, "pyannote")
PYANNOTE_DEFAULT_CONFIG = os.path.join(PYANNOTE_DEFAULT_PATH, "config.yaml") \ PYANNOTE_DEFAULT_CONFIG = 'jaikinator/scraibe'
PYANNOTE_FALLBACK_CONFIG = os.path.join(PYANNOTE_DEFAULT_PATH, "config.yaml") \
if os.path.exists(os.path.join(PYANNOTE_DEFAULT_PATH, "config.yaml")) \ if os.path.exists(os.path.join(PYANNOTE_DEFAULT_PATH, "config.yaml")) \
else 'pyannote/speaker-diarization-3.1' else 'pyannote/speaker-diarization-3.1'