diff --git a/autotranscript/transcriber.py b/autotranscript/transcriber.py new file mode 100644 index 0000000..a3927f1 --- /dev/null +++ b/autotranscript/transcriber.py @@ -0,0 +1,112 @@ + +import os +from typing import TypeVar +from whisper import load_model +from glob import glob + +whisper = TypeVar('whisper') +Transcriber = TypeVar('Transcriber') + +def get_whisper_default_path() -> str: + """ + Get default path for whisper models + + Returns + ------- + str + path + """ + _path = os.path.dirname(os.path.dirname(__file__)) + return os.path.join(_path, "models", "whisper") + +WHISPER_DEFAULT_PATH = get_whisper_default_path() + +class Transcriber: + def __init__(self, model: whisper ) -> None: + """ + Initialize Transcriber class with a whisper model + :param model: whisper model + """ + self.model = model + + + def transcribe(self, file : str, language:str = "German"): + """ + transcribe audio file + :param file: audio file to transcribe + :param language: language of the audio file + :return: transcript as string + """ + result = self.model.transcribe(file, language = language) + + return result["text"] + + @staticmethod + def save_transcript(transcript:str , save_path : str) -> None: + """ + Save transcript to file + :param transcript: transcript as string + :param savepath: path to save the transcript + :return: None + """ + + with open(save_path, 'w') as f: + f.write(transcript) + f.close() + + print(f'Transcript saved to {save_path}') + + @classmethod + def load_whisper_model(cls, + model: str = "medium", + local : bool = True, + download_root: str = WHISPER_DEFAULT_PATH) -> Transcriber: + """ + Load whisper module + + Parameters + ---------- + whisper : str + whisper model + available models: + + - 'tiny.en' + - 'tiny' + - 'base.en' + - 'base' + - 'small.en' + - 'small' + - 'medium.en' + - 'medium' + - 'large-v1' + - 'large-v2' + - 'large' + + local : bool + If true, load from local cache + + download_root : str + Path to download the model + + default: /models/whisper + + Returns + ------- + Whisper Object + """ + + if local: + + available_models = [os.path.basename(x) for x in glob(os.path.join(download_root, "*"))] + + for i, module in enumerate(available_models): + available_models[i] = module.split(".")[0] + + if model not in available_models: + raise RuntimeError("Model not found. Consider downloading the "/ + "model first. By deactivating the local flag, " / + "the model will be downloaded automatically.") + + _model = load_model(model, download_root=download_root) + + return cls(_model)