diff --git a/autotranscript/autotranscipt.py b/autotranscript/autotranscipt.py new file mode 100644 index 0000000..c1225af --- /dev/null +++ b/autotranscript/autotranscipt.py @@ -0,0 +1,125 @@ +from audio import AudioProcessor , TorchAudioProcessor + +from diarisation import Diariser +from transcriber import Transcriber, whisper +from whisper import Whisper +from transcript_exporter import Transcript +from typing import Union , TypeVar +from tqdm import trange +from pprint import pprint +import torch +diarisation = TypeVar('diarisation') + + +class AutoTranscribe: + def __init__(self, + whisper_model: Union[bool, str, whisper] = None, + dia_model : Union[bool, str, diarisation] = None, + dia_kwargs : dict = {}, + whisper_kwargs : dict = {}) -> None: + """ + AutoTranscribe class + + This class is the core Api Class of the autotranscript package. + It allows to transcribe audio files with a whisper model and + pyannote diarization model. + + Therefore it is do a fully automatic transcription of audio files. + + :param whisper_model: path to whisper model or whisper model + :param dia_model: path to pyannote diarization model + :param dia_kwargs: kwargs for pyannote diarization model + :param whisper_kwargs: kwargs for whisper model + + """ + + if whisper_model is None: + self.transcriber = Transcriber.load_model("medium", local=True) + elif isinstance(whisper_model, str): + self.transcriber = Transcriber.load_model(whisper_model, **whisper_kwargs) + else: + self.transcriber = whisper_model + + if dia_model is None: + self.diariser = Diariser.load_model() + elif isinstance(dia_model, str): + self.diariser = Diariser.load_model(dia_model, **dia_kwargs) + else: + self.diariser = dia_model + + print("AutoTranscribe initialized all models successfully loaded.") + + def transcribe(self, audiofile : Union[str, torch.Tensor], + *args, **kwargs) -> Transcript: + """ + Transcribe audiofile with whisper model and pyannote diarization model + + :param audiofile: path to audiofile or torch.Tensor + :return: Transcript object + """ + + audiofile = self.get_audiofile(audiofile) + + final_transcript = dict() + + dia_audio = {"waveform" : + audiofile.waveform.reshape(1,len(audiofile.waveform)), + "sample_rate": audiofile.sr} + + print("Starting diarisation.") + + diarisation = self.diariser.diarization( dia_audio, + *args , **kwargs) + + print("Diarisation finished. Starting transcription.") + + for i in trange(len(diarisation["segments"]), desc= "Transcribing"): + + seg = diarisation["segments"][i] + + audio = audiofile.cut(seg[0], seg[1]) + + transcript = self.transcriber.transcribe(audio, *args , **kwargs) + + final_transcript[i] = {"speaker" : diarisation["speakers"][i], + "text" : transcript} + + pprint(final_transcript) + #return Transcript(transcript, diarisation) + + @staticmethod + def get_audiofile(audiofile : Union[str, torch.Tensor], + *args, **kwargs) -> TorchAudioProcessor: + """ + Get audiofile as TorchAudioProcessor + + :param audiofile: path to audiofile or torch.Tensor + :type audiofile: Union[str, torch.Tensor] + :return: object of audiofile containes + waveform and sample_rate in torch.Tensor format. + :rtype: TorchAudioProcessor + """ + if isinstance(audiofile, str): + try: + audiofile = TorchAudioProcessor.from_file(audiofile) + except: + print("Could not load audiofile with torch audio." \ + "Trying ffmpeg. using pydub.") + audiofile = TorchAudioProcessor.from_ffmpeg(audiofile) + + if isinstance(audiofile, torch.Tensor): + audiofile = TorchAudioProcessor(audiofile[0], audiofile[1]) + + if isinstance(audiofile, AudioProcessor): + audiofile = TorchAudioProcessor.from_audio_processor(audiofile) + + if not isinstance(audiofile, TorchAudioProcessor): + raise ValueError(f'Audiofile must be of type TorchAudioProcessor,' \ + f'not {type(audiofile)}') + return audiofile + + +if __name__ == "__main__": + + AudioTranscriber = AutoTranscribe() + AudioTranscriber.transcribe("/home/jacob/PycharmProjects/autotranscript/tests/Kathi_interview.mp3" , num_speaker=2) \ No newline at end of file