auto transcript
This commit is contained in:
@@ -0,0 +1,125 @@
|
||||
from audio import AudioProcessor , TorchAudioProcessor
|
||||
|
||||
from diarisation import Diariser
|
||||
from transcriber import Transcriber, whisper
|
||||
from whisper import Whisper
|
||||
from transcript_exporter import Transcript
|
||||
from typing import Union , TypeVar
|
||||
from tqdm import trange
|
||||
from pprint import pprint
|
||||
import torch
|
||||
diarisation = TypeVar('diarisation')
|
||||
|
||||
|
||||
class AutoTranscribe:
|
||||
def __init__(self,
|
||||
whisper_model: Union[bool, str, whisper] = None,
|
||||
dia_model : Union[bool, str, diarisation] = None,
|
||||
dia_kwargs : dict = {},
|
||||
whisper_kwargs : dict = {}) -> None:
|
||||
"""
|
||||
AutoTranscribe class
|
||||
|
||||
This class is the core Api Class of the autotranscript package.
|
||||
It allows to transcribe audio files with a whisper model and
|
||||
pyannote diarization model.
|
||||
|
||||
Therefore it is do a fully automatic transcription of audio files.
|
||||
|
||||
:param whisper_model: path to whisper model or whisper model
|
||||
:param dia_model: path to pyannote diarization model
|
||||
:param dia_kwargs: kwargs for pyannote diarization model
|
||||
:param whisper_kwargs: kwargs for whisper model
|
||||
|
||||
"""
|
||||
|
||||
if whisper_model is None:
|
||||
self.transcriber = Transcriber.load_model("medium", local=True)
|
||||
elif isinstance(whisper_model, str):
|
||||
self.transcriber = Transcriber.load_model(whisper_model, **whisper_kwargs)
|
||||
else:
|
||||
self.transcriber = whisper_model
|
||||
|
||||
if dia_model is None:
|
||||
self.diariser = Diariser.load_model()
|
||||
elif isinstance(dia_model, str):
|
||||
self.diariser = Diariser.load_model(dia_model, **dia_kwargs)
|
||||
else:
|
||||
self.diariser = dia_model
|
||||
|
||||
print("AutoTranscribe initialized all models successfully loaded.")
|
||||
|
||||
def transcribe(self, audiofile : Union[str, torch.Tensor],
|
||||
*args, **kwargs) -> Transcript:
|
||||
"""
|
||||
Transcribe audiofile with whisper model and pyannote diarization model
|
||||
|
||||
:param audiofile: path to audiofile or torch.Tensor
|
||||
:return: Transcript object
|
||||
"""
|
||||
|
||||
audiofile = self.get_audiofile(audiofile)
|
||||
|
||||
final_transcript = dict()
|
||||
|
||||
dia_audio = {"waveform" :
|
||||
audiofile.waveform.reshape(1,len(audiofile.waveform)),
|
||||
"sample_rate": audiofile.sr}
|
||||
|
||||
print("Starting diarisation.")
|
||||
|
||||
diarisation = self.diariser.diarization( dia_audio,
|
||||
*args , **kwargs)
|
||||
|
||||
print("Diarisation finished. Starting transcription.")
|
||||
|
||||
for i in trange(len(diarisation["segments"]), desc= "Transcribing"):
|
||||
|
||||
seg = diarisation["segments"][i]
|
||||
|
||||
audio = audiofile.cut(seg[0], seg[1])
|
||||
|
||||
transcript = self.transcriber.transcribe(audio, *args , **kwargs)
|
||||
|
||||
final_transcript[i] = {"speaker" : diarisation["speakers"][i],
|
||||
"text" : transcript}
|
||||
|
||||
pprint(final_transcript)
|
||||
#return Transcript(transcript, diarisation)
|
||||
|
||||
@staticmethod
|
||||
def get_audiofile(audiofile : Union[str, torch.Tensor],
|
||||
*args, **kwargs) -> TorchAudioProcessor:
|
||||
"""
|
||||
Get audiofile as TorchAudioProcessor
|
||||
|
||||
:param audiofile: path to audiofile or torch.Tensor
|
||||
:type audiofile: Union[str, torch.Tensor]
|
||||
:return: object of audiofile containes
|
||||
waveform and sample_rate in torch.Tensor format.
|
||||
:rtype: TorchAudioProcessor
|
||||
"""
|
||||
if isinstance(audiofile, str):
|
||||
try:
|
||||
audiofile = TorchAudioProcessor.from_file(audiofile)
|
||||
except:
|
||||
print("Could not load audiofile with torch audio." \
|
||||
"Trying ffmpeg. using pydub.")
|
||||
audiofile = TorchAudioProcessor.from_ffmpeg(audiofile)
|
||||
|
||||
if isinstance(audiofile, torch.Tensor):
|
||||
audiofile = TorchAudioProcessor(audiofile[0], audiofile[1])
|
||||
|
||||
if isinstance(audiofile, AudioProcessor):
|
||||
audiofile = TorchAudioProcessor.from_audio_processor(audiofile)
|
||||
|
||||
if not isinstance(audiofile, TorchAudioProcessor):
|
||||
raise ValueError(f'Audiofile must be of type TorchAudioProcessor,' \
|
||||
f'not {type(audiofile)}')
|
||||
return audiofile
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
|
||||
AudioTranscriber = AutoTranscribe()
|
||||
AudioTranscriber.transcribe("/home/jacob/PycharmProjects/autotranscript/tests/Kathi_interview.mp3" , num_speaker=2)
|
||||
Reference in New Issue
Block a user