diff --git a/autotranscript/transcriber.py b/autotranscript/transcriber.py index a3927f1..069866a 100644 --- a/autotranscript/transcriber.py +++ b/autotranscript/transcriber.py @@ -1,10 +1,12 @@ import os -from typing import TypeVar +from typing import TypeVar , Union from whisper import load_model from glob import glob whisper = TypeVar('whisper') +Tensor = TypeVar('Tensor') +nparray = TypeVar('nparray') Transcriber = TypeVar('Transcriber') def get_whisper_default_path() -> str: @@ -29,20 +31,24 @@ class Transcriber: """ self.model = model - - def transcribe(self, file : str, language:str = "German"): + def transcribe(self, audio : Union[str, Tensor, nparray] , + *args, **kwargs) -> str: """ transcribe audio file :param file: audio file to transcribe - :param language: language of the audio file + :param args: additional arguments + :param kwargs: additional keyword arguments + example: + - language: language of the audio file :return: transcript as string """ - result = self.model.transcribe(file, language = language) + + result = self.model.transcribe(audio, *args, **kwargs) return result["text"] @staticmethod - def save_transcript(transcript:str , save_path : str) -> None: + def save_transcript(transcript : str , save_path : str) -> None: """ Save transcript to file :param transcript: transcript as string @@ -57,10 +63,10 @@ class Transcriber: print(f'Transcript saved to {save_path}') @classmethod - def load_whisper_model(cls, - model: str = "medium", - local : bool = True, - download_root: str = WHISPER_DEFAULT_PATH) -> Transcriber: + def load_model(cls, + model: str = "medium", + local : bool = True, + download_root: str = WHISPER_DEFAULT_PATH) -> Transcriber: """ Load whisper module @@ -97,7 +103,8 @@ class Transcriber: if local: - available_models = [os.path.basename(x) for x in glob(os.path.join(download_root, "*"))] + available_models = [os.path.basename(x) for x in + glob(os.path.join(download_root, "*"))] for i, module in enumerate(available_models): available_models[i] = module.split(".")[0]