autotrancript works

This commit is contained in:
Jaikinator
2023-06-16 12:09:18 +02:00
parent 8ecc66cf29
commit 29e8a229dc
+16 -22
View File
@@ -1,13 +1,11 @@
from audio import AudioProcessor , TorchAudioProcessor from audio import AudioProcessor
from diarisation import Diariser from diarisation import Diariser
from transcriber import Transcriber, whisper from transcriber import Transcriber, whisper
from whisper import Whisper
from transcript_exporter import Transcript from transcript_exporter import Transcript
from typing import Union , TypeVar from typing import Union , TypeVar
from tqdm import trange from tqdm import trange
from pprint import pprint
import torch import torch
diarisation = TypeVar('diarisation') diarisation = TypeVar('diarisation')
@@ -35,6 +33,7 @@ class AutoTranscribe:
if whisper_model is None: if whisper_model is None:
self.transcriber = Transcriber.load_model("medium", local=True) self.transcriber = Transcriber.load_model("medium", local=True)
elif isinstance(whisper_model, str): elif isinstance(whisper_model, str):
self.transcriber = Transcriber.load_model(whisper_model, **whisper_kwargs) self.transcriber = Transcriber.load_model(whisper_model, **whisper_kwargs)
else: else:
@@ -55,7 +54,8 @@ class AutoTranscribe:
Transcribe audiofile with whisper model and pyannote diarization model Transcribe audiofile with whisper model and pyannote diarization model
:param audiofile: path to audiofile or torch.Tensor :param audiofile: path to audiofile or torch.Tensor
:return: Transcript object :return: Transcript object which contains the transcript and can be used to
export the transcript to differnt formats.
""" """
audiofile = self.get_audiofile(audiofile) audiofile = self.get_audiofile(audiofile)
@@ -68,11 +68,13 @@ class AutoTranscribe:
print("Starting diarisation.") print("Starting diarisation.")
diarisation = self.diariser.diarization( dia_audio, diarisation = self.diariser.diarization(dia_audio,
*args , **kwargs) *args , **kwargs)
print("Diarisation finished. Starting transcription.") print("Diarisation finished. Starting transcription.")
audiofile.sr = torch.Tensor([audiofile.sr]).to(audiofile.waveform.device)
for i in trange(len(diarisation["segments"]), desc= "Transcribing"): for i in trange(len(diarisation["segments"]), desc= "Transcribing"):
seg = diarisation["segments"][i] seg = diarisation["segments"][i]
@@ -84,12 +86,11 @@ class AutoTranscribe:
final_transcript[i] = {"speaker" : diarisation["speakers"][i], final_transcript[i] = {"speaker" : diarisation["speakers"][i],
"text" : transcript} "text" : transcript}
pprint(final_transcript) return Transcript(transcript, diarisation)
#return Transcript(transcript, diarisation)
@staticmethod @staticmethod
def get_audiofile(audiofile : Union[str, torch.Tensor], def get_audiofile(audiofile : Union[str, torch.Tensor],
*args, **kwargs) -> TorchAudioProcessor: *args, **kwargs) -> AudioProcessor:
""" """
Get audiofile as TorchAudioProcessor Get audiofile as TorchAudioProcessor
@@ -99,22 +100,15 @@ class AutoTranscribe:
waveform and sample_rate in torch.Tensor format. waveform and sample_rate in torch.Tensor format.
:rtype: TorchAudioProcessor :rtype: TorchAudioProcessor
""" """
if isinstance(audiofile, str): if isinstance(audiofile, str):
try: audiofile = AudioProcessor.from_file(audiofile)
audiofile = TorchAudioProcessor.from_file(audiofile)
except:
print("Could not load audiofile with torch audio." \
"Trying ffmpeg. using pydub.")
audiofile = TorchAudioProcessor.from_ffmpeg(audiofile)
if isinstance(audiofile, torch.Tensor): if isinstance(audiofile, torch.Tensor):
audiofile = TorchAudioProcessor(audiofile[0], audiofile[1]) audiofile = AudioProcessor(audiofile[0], audiofile[1])
if isinstance(audiofile, AudioProcessor): if not isinstance(audiofile, AudioProcessor):
audiofile = TorchAudioProcessor.from_audio_processor(audiofile) raise ValueError(f'Audiofile must be of type AudioProcessor,' \
if not isinstance(audiofile, TorchAudioProcessor):
raise ValueError(f'Audiofile must be of type TorchAudioProcessor,' \
f'not {type(audiofile)}') f'not {type(audiofile)}')
return audiofile return audiofile
@@ -122,4 +116,4 @@ class AutoTranscribe:
if __name__ == "__main__": if __name__ == "__main__":
AudioTranscriber = AutoTranscribe() AudioTranscriber = AutoTranscribe()
AudioTranscriber.transcribe("/home/jacob/PycharmProjects/autotranscript/tests/Kathi_interview.mp3" , num_speaker=2) AudioTranscriber.transcribe("tests/test.wav")