diff --git a/autotranscript/autotranscipt.py b/autotranscript/autotranscipt.py index c1225af..cbf2c9d 100644 --- a/autotranscript/autotranscipt.py +++ b/autotranscript/autotranscipt.py @@ -1,13 +1,11 @@ -from audio import AudioProcessor , TorchAudioProcessor - +from audio import AudioProcessor from diarisation import Diariser from transcriber import Transcriber, whisper -from whisper import Whisper from transcript_exporter import Transcript from typing import Union , TypeVar from tqdm import trange -from pprint import pprint import torch + diarisation = TypeVar('diarisation') @@ -35,6 +33,7 @@ class AutoTranscribe: if whisper_model is None: self.transcriber = Transcriber.load_model("medium", local=True) + elif isinstance(whisper_model, str): self.transcriber = Transcriber.load_model(whisper_model, **whisper_kwargs) else: @@ -55,7 +54,8 @@ class AutoTranscribe: Transcribe audiofile with whisper model and pyannote diarization model :param audiofile: path to audiofile or torch.Tensor - :return: Transcript object + :return: Transcript object which contains the transcript and can be used to + export the transcript to differnt formats. """ audiofile = self.get_audiofile(audiofile) @@ -68,11 +68,13 @@ class AutoTranscribe: print("Starting diarisation.") - diarisation = self.diariser.diarization( dia_audio, + diarisation = self.diariser.diarization(dia_audio, *args , **kwargs) print("Diarisation finished. Starting transcription.") + audiofile.sr = torch.Tensor([audiofile.sr]).to(audiofile.waveform.device) + for i in trange(len(diarisation["segments"]), desc= "Transcribing"): seg = diarisation["segments"][i] @@ -84,12 +86,11 @@ class AutoTranscribe: final_transcript[i] = {"speaker" : diarisation["speakers"][i], "text" : transcript} - pprint(final_transcript) - #return Transcript(transcript, diarisation) + return Transcript(transcript, diarisation) @staticmethod def get_audiofile(audiofile : Union[str, torch.Tensor], - *args, **kwargs) -> TorchAudioProcessor: + *args, **kwargs) -> AudioProcessor: """ Get audiofile as TorchAudioProcessor @@ -99,22 +100,15 @@ class AutoTranscribe: waveform and sample_rate in torch.Tensor format. :rtype: TorchAudioProcessor """ + if isinstance(audiofile, str): - try: - audiofile = TorchAudioProcessor.from_file(audiofile) - except: - print("Could not load audiofile with torch audio." \ - "Trying ffmpeg. using pydub.") - audiofile = TorchAudioProcessor.from_ffmpeg(audiofile) + audiofile = AudioProcessor.from_file(audiofile) if isinstance(audiofile, torch.Tensor): - audiofile = TorchAudioProcessor(audiofile[0], audiofile[1]) + audiofile = AudioProcessor(audiofile[0], audiofile[1]) - if isinstance(audiofile, AudioProcessor): - audiofile = TorchAudioProcessor.from_audio_processor(audiofile) - - if not isinstance(audiofile, TorchAudioProcessor): - raise ValueError(f'Audiofile must be of type TorchAudioProcessor,' \ + if not isinstance(audiofile, AudioProcessor): + raise ValueError(f'Audiofile must be of type AudioProcessor,' \ f'not {type(audiofile)}') return audiofile @@ -122,4 +116,4 @@ class AutoTranscribe: if __name__ == "__main__": AudioTranscriber = AutoTranscribe() - AudioTranscriber.transcribe("/home/jacob/PycharmProjects/autotranscript/tests/Kathi_interview.mp3" , num_speaker=2) \ No newline at end of file + AudioTranscriber.transcribe("tests/test.wav") \ No newline at end of file