added kwargs support for load model

This commit is contained in:
Jaikinator
2023-06-19 15:56:46 +02:00
parent ae9a125d12
commit 66e73e1c6b
+3 -3
View File
@@ -6,7 +6,6 @@ from glob import glob
whisper = TypeVar('whisper') whisper = TypeVar('whisper')
Tensor = TypeVar('Tensor') Tensor = TypeVar('Tensor')
nparray = TypeVar('nparray') nparray = TypeVar('nparray')
Transcriber = TypeVar('Transcriber')
def get_whisper_default_path() -> str: def get_whisper_default_path() -> str:
""" """
@@ -69,7 +68,8 @@ class Transcriber:
def load_model(cls, def load_model(cls,
model: str = "medium", model: str = "medium",
local : bool = True, local : bool = True,
download_root: str = WHISPER_DEFAULT_PATH) -> Transcriber: download_root: str = WHISPER_DEFAULT_PATH ,
*args, **kwargs) -> 'Transcriber':
""" """
Load whisper module Load whisper module
@@ -117,7 +117,7 @@ class Transcriber:
"model first. By deactivating the local flag, " / "model first. By deactivating the local flag, " /
"the model will be downloaded automatically.") "the model will be downloaded automatically.")
_model = load_model(model, download_root=download_root) _model = load_model(model, download_root=download_root, *args, **kwargs)
return cls(_model) return cls(_model)