removed kwargs confusion

This commit is contained in:
Jaikinator
2023-06-30 18:44:10 +02:00
parent cd35ad8903
commit 38d1f8f668
+8 -17
View File
@@ -1,6 +1,7 @@
import os import os
from whisper import Whisper, load_model from whisper import Whisper, load_model
from typing import TypeVar , Union from typing import TypeVar , Union , Optional
import torch
from glob import glob from glob import glob
from .misc import WHISPER_DEFAULT_PATH from .misc import WHISPER_DEFAULT_PATH
whisper = TypeVar('whisper') whisper = TypeVar('whisper')
@@ -17,7 +18,7 @@ class Transcriber:
""" """
self.model = model self.model = model
def transcribe(self, audio : Union[str, Tensor, nparray] , def transcribe(self, audio : Union[str, Tensor, nparray] ,
*args, **kwargs) -> str: *args, **kwargs) -> str:
""" """
transcribe audio file transcribe audio file
@@ -55,9 +56,10 @@ class Transcriber:
@classmethod @classmethod
def load_model(cls, def load_model(cls,
model: str = "medium", model: str = "medium",
local : bool = True,
download_root: str = WHISPER_DEFAULT_PATH, download_root: str = WHISPER_DEFAULT_PATH,
*args, **kwargs) -> 'Transcriber': device: Optional[Union[str, torch.device]] = None,
in_memory: bool = False,
) -> 'Transcriber':
""" """
Load whisper module Load whisper module
@@ -92,20 +94,9 @@ class Transcriber:
Whisper Object Whisper Object
""" """
if local:
available_models = [os.path.basename(x) for x in _model = load_model(model, download_root=download_root,
glob(os.path.join(download_root, "*"))] device=device, in_memory=in_memory)
for i, module in enumerate(available_models):
available_models[i] = module.split(".")[0]
if model not in available_models:
raise RuntimeError("Model not found. Consider downloading the "/
"model first. By deactivating the local flag, " /
"the model will be downloaded automatically.")
_model = load_model(model, download_root=download_root, *args, **kwargs)
return cls(_model) return cls(_model)