Reworking the hf wrapper, now without blank except block (wow)!
This commit is contained in:
+60
-58
@@ -19,7 +19,6 @@ Constants:
|
||||
- TOKEN_PATH (str): Path to the Pyannote token.
|
||||
- PYANNOTE_DEFAULT_PATH (str): Default path to Pyannote models.
|
||||
- PYANNOTE_DEFAULT_CONFIG (str): Default configuration for Pyannote models.
|
||||
- PYANNOTE_FALLBACK_CONFIG (str): Fallback config for Pyannote models if default config does not work.
|
||||
|
||||
Usage:
|
||||
from .diarisation import Diariser
|
||||
@@ -39,8 +38,10 @@ from pyannote.audio.pipelines.speaker_diarization import SpeakerDiarization
|
||||
from torch import Tensor
|
||||
from torch import device as torch_device
|
||||
from torch.cuda import is_available, current_device
|
||||
from huggingface_hub import HfApi
|
||||
from huggingface_hub.utils import RepositoryNotFoundError
|
||||
|
||||
from .misc import PYANNOTE_DEFAULT_PATH, PYANNOTE_DEFAULT_CONFIG, PYANNOTE_FALLBACK_CONFIG
|
||||
from .misc import PYANNOTE_DEFAULT_PATH, PYANNOTE_DEFAULT_CONFIG
|
||||
Annotation = TypeVar('Annotation')
|
||||
|
||||
TOKEN_PATH = os.path.join(os.path.dirname(
|
||||
@@ -184,7 +185,7 @@ class Diariser:
|
||||
|
||||
@classmethod
|
||||
def load_model(cls,
|
||||
model: str = PYANNOTE_FALLBACK_CONFIG,
|
||||
model: str = PYANNOTE_DEFAULT_CONFIG,
|
||||
use_auth_token: str = None,
|
||||
cache_token: bool = True,
|
||||
cache_dir: Union[Path, str] = PYANNOTE_DEFAULT_PATH,
|
||||
@@ -211,64 +212,65 @@ class Diariser:
|
||||
Returns:
|
||||
Pipeline: A pyannote.audio Pipeline object, encapsulating the loaded model.
|
||||
"""
|
||||
try:
|
||||
hf_model = PYANNOTE_DEFAULT_CONFIG
|
||||
# if not use_auth_token:
|
||||
# use_auth_token = cls._get_token()
|
||||
_model = Pipeline.from_pretrained(
|
||||
hf_model, use_auth_token=use_auth_token,
|
||||
cache_dir=cache_dir, hparams_file=hparams_file
|
||||
)
|
||||
except:
|
||||
print(f'Trying fallback to config file.. ')
|
||||
_model = None
|
||||
if _model is None:
|
||||
|
||||
if cache_token and use_auth_token is not None:
|
||||
cls._save_token(use_auth_token)
|
||||
|
||||
if not os.path.exists(model) and use_auth_token is None:
|
||||
use_auth_token = cls._get_token()
|
||||
|
||||
elif os.path.exists(model) and not use_auth_token:
|
||||
# check if model can be found locally nearby the config file
|
||||
with open(model, 'r') as file:
|
||||
config = yaml.safe_load(file)
|
||||
|
||||
path_to_model = config['pipeline']['params']['segmentation']
|
||||
if isinstance(model, str) and os.path.exists(model):
|
||||
# check if model can be found locally nearby the config file
|
||||
with open(model, 'r') as file:
|
||||
config = yaml.safe_load(file)
|
||||
|
||||
path_to_model = config['pipeline']['params']['segmentation']
|
||||
|
||||
if not os.path.exists(path_to_model):
|
||||
warnings.warn(f"Model not found at {path_to_model}. "
|
||||
"Trying to find it nearby the config file.")
|
||||
|
||||
pwd = model.split("/")[:-1]
|
||||
pwd = "/".join(pwd)
|
||||
|
||||
path_to_model = os.path.join(pwd, "pytorch_model.bin")
|
||||
|
||||
if not os.path.exists(path_to_model):
|
||||
warnings.warn(f"Model not found at {path_to_model}. "\
|
||||
"Trying to find it nearby the config file.")
|
||||
|
||||
pwd = model.split("/")[:-1]
|
||||
pwd = "/".join(pwd)
|
||||
|
||||
path_to_model = os.path.join(pwd, "pytorch_model.bin")
|
||||
warnings.warn(f"Model not found at {path_to_model}. \
|
||||
'Trying to find it nearby .bin files instead.")
|
||||
# list elementes with the ending .bin
|
||||
bin_files = [f for f in os.listdir(pwd) if f.endswith(".bin")]
|
||||
if len(bin_files) == 1:
|
||||
path_to_model = os.path.join(pwd, bin_files[0])
|
||||
else:
|
||||
warnings.warn("Found more than one .bin file. "\
|
||||
"or none. Please specify the path to the model " \
|
||||
"or setup a huggingface token.")
|
||||
raise FileNotFoundError
|
||||
|
||||
if not os.path.exists(path_to_model):
|
||||
warnings.warn(f"Model not found at {path_to_model}. \
|
||||
'Trying to find it nearby .bin files instead.")
|
||||
# list elementes with the ending .bin
|
||||
bin_files = [f for f in os.listdir(pwd) if f.endswith(".bin")]
|
||||
if len(bin_files) == 1:
|
||||
path_to_model = os.path.join(pwd, bin_files[0])
|
||||
else:
|
||||
warnings.warn("Found more than one .bin file. "\
|
||||
"or none. Please specify the path to the model " \
|
||||
"or setup a huggingface token.")
|
||||
|
||||
warnings.warn(f"Found model at {path_to_model} overwriting config file.")
|
||||
|
||||
config['pipeline']['params']['segmentation'] = path_to_model
|
||||
|
||||
with open(model, 'w') as file:
|
||||
yaml.dump(config, file)
|
||||
|
||||
_model = Pipeline.from_pretrained(model,
|
||||
use_auth_token = use_auth_token,
|
||||
cache_dir = cache_dir,
|
||||
hparams_file = hparams_file,)
|
||||
warnings.warn(f"Found model at {path_to_model} overwriting config file.")
|
||||
|
||||
config['pipeline']['params']['segmentation'] = path_to_model
|
||||
|
||||
with open(model, 'w') as file:
|
||||
yaml.dump(config, file)
|
||||
elif isinstance(model, tuple):
|
||||
try:
|
||||
_model = model[0]
|
||||
HfApi().model_info(_model)
|
||||
model = _model
|
||||
use_auth_token = None
|
||||
except RepositoryNotFoundError:
|
||||
print(f'{model[0]} not found on Huggingface, \
|
||||
trying {model[1]}')
|
||||
_model = model[1]
|
||||
HfApi().model_info(_model)
|
||||
model = _model
|
||||
if cache_token and use_auth_token is not None:
|
||||
cls._save_token(use_auth_token)
|
||||
|
||||
if not os.path.exists(model) and use_auth_token is None:
|
||||
use_auth_token = cls._get_token()
|
||||
else:
|
||||
raise FileNotFoundError(f'No local model or directory found at {model}.')
|
||||
|
||||
_model = Pipeline.from_pretrained(model,
|
||||
use_auth_token=use_auth_token,
|
||||
cache_dir=cache_dir,
|
||||
hparams_file=hparams_file,)
|
||||
|
||||
# try to move the model to the device
|
||||
if device is None:
|
||||
|
||||
Reference in New Issue
Block a user