removed kwargs confusion
This commit is contained in:
@@ -1,6 +1,7 @@
|
|||||||
import os
|
import os
|
||||||
from whisper import Whisper, load_model
|
from whisper import Whisper, load_model
|
||||||
from typing import TypeVar , Union
|
from typing import TypeVar , Union , Optional
|
||||||
|
import torch
|
||||||
from glob import glob
|
from glob import glob
|
||||||
from .misc import WHISPER_DEFAULT_PATH
|
from .misc import WHISPER_DEFAULT_PATH
|
||||||
whisper = TypeVar('whisper')
|
whisper = TypeVar('whisper')
|
||||||
@@ -55,9 +56,10 @@ class Transcriber:
|
|||||||
@classmethod
|
@classmethod
|
||||||
def load_model(cls,
|
def load_model(cls,
|
||||||
model: str = "medium",
|
model: str = "medium",
|
||||||
local : bool = True,
|
|
||||||
download_root: str = WHISPER_DEFAULT_PATH,
|
download_root: str = WHISPER_DEFAULT_PATH,
|
||||||
*args, **kwargs) -> 'Transcriber':
|
device: Optional[Union[str, torch.device]] = None,
|
||||||
|
in_memory: bool = False,
|
||||||
|
) -> 'Transcriber':
|
||||||
"""
|
"""
|
||||||
Load whisper module
|
Load whisper module
|
||||||
|
|
||||||
@@ -92,20 +94,9 @@ class Transcriber:
|
|||||||
Whisper Object
|
Whisper Object
|
||||||
"""
|
"""
|
||||||
|
|
||||||
if local:
|
|
||||||
|
|
||||||
available_models = [os.path.basename(x) for x in
|
_model = load_model(model, download_root=download_root,
|
||||||
glob(os.path.join(download_root, "*"))]
|
device=device, in_memory=in_memory)
|
||||||
|
|
||||||
for i, module in enumerate(available_models):
|
|
||||||
available_models[i] = module.split(".")[0]
|
|
||||||
|
|
||||||
if model not in available_models:
|
|
||||||
raise RuntimeError("Model not found. Consider downloading the "/
|
|
||||||
"model first. By deactivating the local flag, " /
|
|
||||||
"the model will be downloaded automatically.")
|
|
||||||
|
|
||||||
_model = load_model(model, download_root=download_root, *args, **kwargs)
|
|
||||||
|
|
||||||
return cls(_model)
|
return cls(_model)
|
||||||
|
|
||||||
|
|||||||
Reference in New Issue
Block a user