removed kwargs confusion

This commit is contained in:
Jaikinator
2023-06-30 18:44:10 +02:00
parent cd35ad8903
commit 38d1f8f668
+7 -16
View File
@@ -1,6 +1,7 @@
import os
from whisper import Whisper, load_model
from typing import TypeVar , Union
from typing import TypeVar , Union , Optional
import torch
from glob import glob
from .misc import WHISPER_DEFAULT_PATH
whisper = TypeVar('whisper')
@@ -55,9 +56,10 @@ class Transcriber:
@classmethod
def load_model(cls,
model: str = "medium",
local : bool = True,
download_root: str = WHISPER_DEFAULT_PATH,
*args, **kwargs) -> 'Transcriber':
device: Optional[Union[str, torch.device]] = None,
in_memory: bool = False,
) -> 'Transcriber':
"""
Load whisper module
@@ -92,20 +94,9 @@ class Transcriber:
Whisper Object
"""
if local:
available_models = [os.path.basename(x) for x in
glob(os.path.join(download_root, "*"))]
for i, module in enumerate(available_models):
available_models[i] = module.split(".")[0]
if model not in available_models:
raise RuntimeError("Model not found. Consider downloading the "/
"model first. By deactivating the local flag, " /
"the model will be downloaded automatically.")
_model = load_model(model, download_root=download_root, *args, **kwargs)
_model = load_model(model, download_root=download_root,
device=device, in_memory=in_memory)
return cls(_model)