removed kwargs confusion
This commit is contained in:
@@ -1,6 +1,7 @@
|
||||
import os
|
||||
from whisper import Whisper, load_model
|
||||
from typing import TypeVar , Union
|
||||
from typing import TypeVar , Union , Optional
|
||||
import torch
|
||||
from glob import glob
|
||||
from .misc import WHISPER_DEFAULT_PATH
|
||||
whisper = TypeVar('whisper')
|
||||
@@ -55,9 +56,10 @@ class Transcriber:
|
||||
@classmethod
|
||||
def load_model(cls,
|
||||
model: str = "medium",
|
||||
local : bool = True,
|
||||
download_root: str = WHISPER_DEFAULT_PATH,
|
||||
*args, **kwargs) -> 'Transcriber':
|
||||
device: Optional[Union[str, torch.device]] = None,
|
||||
in_memory: bool = False,
|
||||
) -> 'Transcriber':
|
||||
"""
|
||||
Load whisper module
|
||||
|
||||
@@ -92,20 +94,9 @@ class Transcriber:
|
||||
Whisper Object
|
||||
"""
|
||||
|
||||
if local:
|
||||
|
||||
available_models = [os.path.basename(x) for x in
|
||||
glob(os.path.join(download_root, "*"))]
|
||||
|
||||
for i, module in enumerate(available_models):
|
||||
available_models[i] = module.split(".")[0]
|
||||
|
||||
if model not in available_models:
|
||||
raise RuntimeError("Model not found. Consider downloading the "/
|
||||
"model first. By deactivating the local flag, " /
|
||||
"the model will be downloaded automatically.")
|
||||
|
||||
_model = load_model(model, download_root=download_root, *args, **kwargs)
|
||||
_model = load_model(model, download_root=download_root,
|
||||
device=device, in_memory=in_memory)
|
||||
|
||||
return cls(_model)
|
||||
|
||||
|
||||
Reference in New Issue
Block a user