Reworking the hf wrapper, now without blank except block (wow)!

This commit is contained in:
Marko Henning
2024-04-23 14:39:18 +02:00
parent f7927fd524
commit 7d8da3b81c
2 changed files with 62 additions and 61 deletions
+52 -50
View File
@@ -19,7 +19,6 @@ Constants:
- TOKEN_PATH (str): Path to the Pyannote token.
- PYANNOTE_DEFAULT_PATH (str): Default path to Pyannote models.
- PYANNOTE_DEFAULT_CONFIG (str): Default configuration for Pyannote models.
- PYANNOTE_FALLBACK_CONFIG (str): Fallback config for Pyannote models if default config does not work.
Usage:
from .diarisation import Diariser
@@ -39,8 +38,10 @@ from pyannote.audio.pipelines.speaker_diarization import SpeakerDiarization
from torch import Tensor
from torch import device as torch_device
from torch.cuda import is_available, current_device
from huggingface_hub import HfApi
from huggingface_hub.utils import RepositoryNotFoundError
from .misc import PYANNOTE_DEFAULT_PATH, PYANNOTE_DEFAULT_CONFIG, PYANNOTE_FALLBACK_CONFIG
from .misc import PYANNOTE_DEFAULT_PATH, PYANNOTE_DEFAULT_CONFIG
Annotation = TypeVar('Annotation')
TOKEN_PATH = os.path.join(os.path.dirname(
@@ -184,7 +185,7 @@ class Diariser:
@classmethod
def load_model(cls,
model: str = PYANNOTE_FALLBACK_CONFIG,
model: str = PYANNOTE_DEFAULT_CONFIG,
use_auth_token: str = None,
cache_token: bool = True,
cache_dir: Union[Path, str] = PYANNOTE_DEFAULT_PATH,
@@ -211,64 +212,65 @@ class Diariser:
Returns:
Pipeline: A pyannote.audio Pipeline object, encapsulating the loaded model.
"""
try:
hf_model = PYANNOTE_DEFAULT_CONFIG
# if not use_auth_token:
# use_auth_token = cls._get_token()
_model = Pipeline.from_pretrained(
hf_model, use_auth_token=use_auth_token,
cache_dir=cache_dir, hparams_file=hparams_file
)
except:
print(f'Trying fallback to config file.. ')
_model = None
if _model is None:
if isinstance(model, str) and os.path.exists(model):
# check if model can be found locally nearby the config file
with open(model, 'r') as file:
config = yaml.safe_load(file)
if cache_token and use_auth_token is not None:
cls._save_token(use_auth_token)
path_to_model = config['pipeline']['params']['segmentation']
if not os.path.exists(model) and use_auth_token is None:
use_auth_token = cls._get_token()
if not os.path.exists(path_to_model):
warnings.warn(f"Model not found at {path_to_model}. "
"Trying to find it nearby the config file.")
elif os.path.exists(model) and not use_auth_token:
# check if model can be found locally nearby the config file
with open(model, 'r') as file:
config = yaml.safe_load(file)
pwd = model.split("/")[:-1]
pwd = "/".join(pwd)
path_to_model = config['pipeline']['params']['segmentation']
path_to_model = os.path.join(pwd, "pytorch_model.bin")
if not os.path.exists(path_to_model):
warnings.warn(f"Model not found at {path_to_model}. "\
"Trying to find it nearby the config file.")
warnings.warn(f"Model not found at {path_to_model}. \
'Trying to find it nearby .bin files instead.")
# list elementes with the ending .bin
bin_files = [f for f in os.listdir(pwd) if f.endswith(".bin")]
if len(bin_files) == 1:
path_to_model = os.path.join(pwd, bin_files[0])
else:
warnings.warn("Found more than one .bin file. "\
"or none. Please specify the path to the model " \
"or setup a huggingface token.")
raise FileNotFoundError
pwd = model.split("/")[:-1]
pwd = "/".join(pwd)
warnings.warn(f"Found model at {path_to_model} overwriting config file.")
path_to_model = os.path.join(pwd, "pytorch_model.bin")
config['pipeline']['params']['segmentation'] = path_to_model
if not os.path.exists(path_to_model):
warnings.warn(f"Model not found at {path_to_model}. \
'Trying to find it nearby .bin files instead.")
# list elementes with the ending .bin
bin_files = [f for f in os.listdir(pwd) if f.endswith(".bin")]
if len(bin_files) == 1:
path_to_model = os.path.join(pwd, bin_files[0])
else:
warnings.warn("Found more than one .bin file. "\
"or none. Please specify the path to the model " \
"or setup a huggingface token.")
with open(model, 'w') as file:
yaml.dump(config, file)
elif isinstance(model, tuple):
try:
_model = model[0]
HfApi().model_info(_model)
model = _model
use_auth_token = None
except RepositoryNotFoundError:
print(f'{model[0]} not found on Huggingface, \
trying {model[1]}')
_model = model[1]
HfApi().model_info(_model)
model = _model
if cache_token and use_auth_token is not None:
cls._save_token(use_auth_token)
warnings.warn(f"Found model at {path_to_model} overwriting config file.")
if not os.path.exists(model) and use_auth_token is None:
use_auth_token = cls._get_token()
else:
raise FileNotFoundError(f'No local model or directory found at {model}.')
config['pipeline']['params']['segmentation'] = path_to_model
with open(model, 'w') as file:
yaml.dump(config, file)
_model = Pipeline.from_pretrained(model,
use_auth_token = use_auth_token,
cache_dir = cache_dir,
hparams_file = hparams_file,)
_model = Pipeline.from_pretrained(model,
use_auth_token=use_auth_token,
cache_dir=cache_dir,
hparams_file=hparams_file,)
# try to move the model to the device
if device is None:
+2 -3
View File
@@ -13,10 +13,9 @@ if CACHE_DIR != PYANNOTE_CACHE_DIR:
WHISPER_DEFAULT_PATH = os.path.join(CACHE_DIR, "whisper")
PYANNOTE_DEFAULT_PATH = os.path.join(CACHE_DIR, "pyannote")
PYANNOTE_DEFAULT_CONFIG = 'jaikinator/scraibe'
PYANNOTE_FALLBACK_CONFIG = os.path.join(PYANNOTE_DEFAULT_PATH, "config.yaml") \
PYANNOTE_DEFAULT_CONFIG = os.path.join(PYANNOTE_DEFAULT_PATH, "config.yaml") \
if os.path.exists(os.path.join(PYANNOTE_DEFAULT_PATH, "config.yaml")) \
else 'pyannote/speaker-diarization-3.1'
else ('jaikinator/scraibe', 'pyannote/speaker-diarization-3.1')
def config_diarization_yaml(file_path: str, path_to_segmentation: str = None) -> None:
"""Configure diarization pipeline from a YAML file.