diff --git a/autotranscript/transcriber.py b/autotranscript/transcriber.py index 82156cf..0cd42bf 100644 --- a/autotranscript/transcriber.py +++ b/autotranscript/transcriber.py @@ -1,6 +1,7 @@ import os from whisper import Whisper, load_model -from typing import TypeVar , Union +from typing import TypeVar , Union , Optional +import torch from glob import glob from .misc import WHISPER_DEFAULT_PATH whisper = TypeVar('whisper') @@ -17,7 +18,7 @@ class Transcriber: """ self.model = model - def transcribe(self, audio : Union[str, Tensor, nparray] , + def transcribe(self, audio : Union[str, Tensor, nparray] , *args, **kwargs) -> str: """ transcribe audio file @@ -55,9 +56,10 @@ class Transcriber: @classmethod def load_model(cls, model: str = "medium", - local : bool = True, download_root: str = WHISPER_DEFAULT_PATH, - *args, **kwargs) -> 'Transcriber': + device: Optional[Union[str, torch.device]] = None, + in_memory: bool = False, + ) -> 'Transcriber': """ Load whisper module @@ -92,20 +94,9 @@ class Transcriber: Whisper Object """ - if local: - - available_models = [os.path.basename(x) for x in - glob(os.path.join(download_root, "*"))] - - for i, module in enumerate(available_models): - available_models[i] = module.split(".")[0] - - if model not in available_models: - raise RuntimeError("Model not found. Consider downloading the "/ - "model first. By deactivating the local flag, " / - "the model will be downloaded automatically.") - _model = load_model(model, download_root=download_root, *args, **kwargs) + _model = load_model(model, download_root=download_root, + device=device, in_memory=in_memory) return cls(_model)