diff --git a/autotranscript/autotranscript.py b/autotranscript/autotranscript.py index 6f00888..9f14886 100644 --- a/autotranscript/autotranscript.py +++ b/autotranscript/autotranscript.py @@ -19,8 +19,7 @@ class AutoTranscribe: def __init__(self, whisper_model: Union[bool, str, whisper] = None, dia_model : Union[bool, str, diarisation] = None, - dia_kwargs : dict = {}, - whisper_kwargs : dict = {}) -> None: + **kwargs) -> None: """ AutoTranscribe class @@ -38,16 +37,16 @@ class AutoTranscribe: """ if whisper_model is None: - self.transcriber = Transcriber.load_model("medium", local=True) + self.transcriber = Transcriber.load_model("medium") elif isinstance(whisper_model, str): - self.transcriber = Transcriber.load_model(whisper_model, **whisper_kwargs) + self.transcriber = Transcriber.load_model(whisper_model, **kwargs) else: self.transcriber = whisper_model if dia_model is None: self.diariser = Diariser.load_model() elif isinstance(dia_model, str): - self.diariser = Diariser.load_model(dia_model, **dia_kwargs) + self.diariser = Diariser.load_model(dia_model, **kwargs) else: self.diariser = dia_model