diff --git a/autotranscript/diarisation.py b/autotranscript/diarisation.py index ea36b93..1c2e4fb 100644 --- a/autotranscript/diarisation.py +++ b/autotranscript/diarisation.py @@ -4,14 +4,16 @@ from torch import Tensor import os from typing import TypeVar, Union import json - +from .misc import PYANNOTE_DEFAULT_PATH Annotation = TypeVar('Annotation') -PYANNOTE_DEFAULT_PATH = os.path.join(os.path.dirname(os.path.dirname(__file__)), - "models", "pyannote", - "speaker_diarization", "config.yaml") - class Diariser: + """ + Diarisation class + This class is used to diarize an audio file using a pretrained model + from pyannote.audio. + :param model: model to use for diarization + """ def __init__(self, model,*args,**kwargs) -> None: self.model = model @@ -137,10 +139,11 @@ class Diariser: ------- Pipeline Object """ - + if local: diarization_model = Pipeline.from_pretrained(model,*args, **kwargs) else: + print("Loading model from HuggingFace") if token == "": token = cls._get_token() diarization_model = Pipeline.from_pretrained(model, use_auth_token = token, diff --git a/autotranscript/misc.py b/autotranscript/misc.py index 065e45d..716852e 100644 --- a/autotranscript/misc.py +++ b/autotranscript/misc.py @@ -4,83 +4,22 @@ from whisper import Whisper, load_model import os import glob from warnings import warn +import yaml -WHISPER_DEFAULT_PATH = os.path.join(os.path.dirname(os.path.dirname(__file__)), +WHISPER_DEFAULT_PATH = os.path.join(os.path.dirname(__file__), "models", "whisper") -PYANNOTE_DEFAULT_PATH = os.path.join(os.path.dirname(os.path.dirname(__file__)), +PYANNOTE_DEFAULT_PATH = os.path.join(os.path.dirname(__file__), "models", "pyannote", "speaker_diarization", "config.yaml") -def load_whisper_model(model: str ="medium", local : bool = False, download_root: str = WHISPER_DEFAULT_PATH) -> Whisper: + +def config_diarization_yaml(file): """ - Load modules from whisper - - Parameters - ---------- - whisper : str - whisper model - available models: - - - 'tiny.en' - - 'tiny' - - 'base.en' - - 'base' - - 'small.en' - - 'small' - - 'medium.en' - - 'medium' - - 'large-v1' - - 'large-v2' - - 'large' - - local : bool - If true, load from local cache - - download_root : str - Path to download the model - - default: /models/whisper + Configure diarization pipeline from yaml file to use the model offline + and avoid manuel file manipulation. - Returns - ------- - Whisper Object + :param file: yaml file + :type file: yaml """ - warn("load_whisper_model is deprecated. Use Transcriptor.load_model() instead.", DeprecationWarning) - if local: - available_models = [os.path.basename(x) for x in glob.glob(os.path.join(download_root, "*"))] - - for i, module in enumerate(available_models): - available_models[i] = module.split(".")[0] - - if model not in available_models: - raise RuntimeError("Model not found. Consider downloading the model first. By deactivating the local flag, the model will be downloaded automatically.") - - return load_model(model, download_root=download_root) - -def load_pyannote_model(model: str = PYANNOTE_DEFAULT_PATH, - token: str = "", - local : bool = True, - *args, **kwargs) -> Pipeline: - """ - Load modules from pyannote - - Parameters - ---------- - model : str - pyannote model - default: /models/pyannote/speaker_diarization/config.yaml - token : str - HUGGINGFACE_TOKEN - local : bool - If true, load from local cache - - Returns - ------- - Pipeline Object - """ - warn("load_pyannote_model is deprecated. Use Diarisation.load_model() instead.", DeprecationWarning) - if local: - return Pipeline.from_pretrained(model,*args, **kwargs) - else: - return Pipeline.from_pretrained(model, use_auth_token = token, *args, **kwargs) + \ No newline at end of file diff --git a/autotranscript/transcriber.py b/autotranscript/transcriber.py index 39c0842..82156cf 100644 --- a/autotranscript/transcriber.py +++ b/autotranscript/transcriber.py @@ -2,24 +2,12 @@ import os from whisper import Whisper, load_model from typing import TypeVar , Union from glob import glob - +from .misc import WHISPER_DEFAULT_PATH whisper = TypeVar('whisper') Tensor = TypeVar('Tensor') nparray = TypeVar('nparray') -def get_whisper_default_path() -> str: - """ - Get default path for whisper models - Returns - ------- - str - path - """ - _path = os.path.dirname(os.path.dirname(__file__)) - return os.path.join(_path, "models", "whisper") - -WHISPER_DEFAULT_PATH = get_whisper_default_path() class Transcriber: def __init__(self, model: whisper ) -> None: @@ -68,7 +56,7 @@ class Transcriber: def load_model(cls, model: str = "medium", local : bool = True, - download_root: str = WHISPER_DEFAULT_PATH , + download_root: str = WHISPER_DEFAULT_PATH, *args, **kwargs) -> 'Transcriber': """ Load whisper module