From 465700e88ba2bfffcae721c0b43ba8f5f960d8e7 Mon Sep 17 00:00:00 2001 From: Jaikinator Date: Thu, 29 Dec 2022 15:56:03 +0100 Subject: [PATCH] added diarisation option --- autotranscript/__main__.py | 560 ++++++++++++++++++++++++++++--------- transcribe.py | 5 +- 2 files changed, 426 insertions(+), 139 deletions(-) diff --git a/autotranscript/__main__.py b/autotranscript/__main__.py index 1b817d6..0baf46d 100644 --- a/autotranscript/__main__.py +++ b/autotranscript/__main__.py @@ -2,61 +2,445 @@ import whisper from time import time, sleep import os +import glob +import re +import shutil from typing import Union from pydub import AudioSegment -class Transcribe: - def __init__(self, audiofile : Union[bool, str, list] = None, model : str = "medium", language :str = "German"): - """ - Class to autotranscript audio and video files with the Whisper model - :param audiofile: audio file or list of audio files - :param model: model to use for transcription - :param language: language of the audio file - """ +from pyannote.audio import Pipeline - self.audiofile = audiofile +class AudioProcessor: + def __init__(self, audio_file:str): + self.audio_file_path = audio_file + self.audio_file = AudioSegment.from_file(audio_file, format=audio_file.split('.')[-1]) + self.audiofilename = audio_file.split('/')[-1][:-4] + self.coreaudiofile = audio_file.split('/')[-1][:-4] + self.audiofilefolder = os.path.dirname(audio_file) + self.audio_file_type = audio_file.split('.')[-1] + + + + def convert_audio(self, savefolder: str = "", savename: str = "", type: str = "wav", remove_orginal: bool = True): + """ + Convert video file or other audio files to mp3 file, ensures that the audio file is in the correct format for the + Whisper model + :param file: path to audio or video file + :param remove_orginal: remove original file + :return: mp3 file path + """ + print(f'Converting {self.audiofilename} to .{type} file') + + if savefolder == "": + savefolder = self.audiofilefolder + + if savename == "": + savename = self.coreaudiofile + f'.{type}' + else: + savename = savename + f'.{type}' + print(savefolder, savename) + savepath = os.path.join(savefolder, savename) + + self.audio_file.export(savepath, format=type) + + print(f'Converted {self.audiofilename} to {type}') + + if remove_orginal: + os.remove(self.audio_file_path) + print(f'File {self.audio_file_path} removed') + + self.audio_file_path = savepath + self.audio_file = AudioSegment.from_file(savepath, format=type) + + return self + + def to_mp3(self, savefolder: str = "", savename: str = "", remove_orginal: bool = True): + """ + Convert audio file to mp3 file + :param file: audio file + :param remove_orginal: remove original file + :return: mp3 file path + """ + return self.convert_audio(savefolder = savefolder, savename = savename, type="mp3", remove_orginal=remove_orginal) + + def to_wav(self, savefolder: str = "", savename: str = "", remove_orginal: bool = True): + """ + Convert audio file to wav file + :param file: audio file + :param remove_orginal: remove original file + :return: wav file path + """ + return self.convert_audio(savefolder = savefolder, savename = savename,type="wav", remove_orginal=remove_orginal) + + def slower_mp3(self, savefolder: str = "", savename: str = "", speed: float = 0.75, type: str = "mp3"): + """ + Slow down mp3 file + :param file: mp3 file + :param speed: speed + :return: None + """ + if savefolder == "": + savefolder = self.audiofilefolder + else: + savefolder = savefolder + + sound = self.audio_file + slow_sound = sound._spawn(sound.raw_data, overrides={ + "frame_rate": int(sound.frame_rate * speed) + }) + + speedstr = str(speed).replace('.', '') + + file_out = self.coreaudiofile + f'_{speedstr}.{type}' + + save_path = os.path.join(savefolder, file_out) + + slow_sound.export(save_path, format=type) + + return slow_sound + +class WhisperTranscription: + def __init__(self, audio_file: str , model, language: str = "German"): + + self.audio_file = audio_file + self.model = model self.language = language + def transcribe(self, language:str = "German"): """ - Create folder structure + Transcribe audio file + + language: language of the audio file + :return: transcript as string """ - self.currentpath,\ - self.audiopath,\ - self.transcriptionpath,\ - self.audiofiles = self.create_folder_structure() # create folder structure + audiofilename = self.audio_file.split('/')[-1] + print(f'Start transcribing Audio file: {audiofilename}') - print("loading model") - self.model = whisper.load_model(model) # load model - print("model loaded") + _stime = time() + result = self.model.transcribe(self.audio_file, verbose=True, language=self.language) + + print(f'Transcription finished in {time() - _stime} seconds') + + self.transcript = result + + return result["text"] + + def save_transcript(self, transcript:str = "", savefolder : str = "", savename: str = ""): + """ + Save transcript to file + :param transcript: transcript as string + :param savefolder: folder to save transcript + :param savename: name of the transcript file + :return: None + """ + if savefolder == "": + savefolder = os.path.dirname(self.audio_file) + else: + savefolder = savefolder + + if savename == "": + savename = self.audio_file.split('/')[-1][:-4] + '.txt' + else: + savename = savename + + if transcript == "": + transcript = self.transcript["text"] + + savepath = os.path.join(savefolder, savename) + + with open(savepath, 'w') as f: + f.write(transcript) + + print(f'Transcript saved to {savepath}') + +class Diarisation(AudioProcessor): + def __init__(self, audio_file: str, model,**kwargs): + + super().__init__(audio_file=audio_file) + + self.model = model + def diarization(self, *args, **kwargs): - def create_folder_structure(self): + if "num_speakers" in kwargs: + num_speakers = kwargs['num_speakers'] + else: + num_speakers = 2 + + audiofilename = self.coreaudiofile + + print(f'Start diarization of audio file: {self.audiofilename}') + + _stime = time() + + diarization = self.model(self.audio_file_path, num_speakers=num_speakers) + + print(f'Diarization finished in {time() - _stime} seconds') + self.diarization = diarization + + return diarization + + def format_diarization_output(self, *args, **kwargs): + """ + Format diarization output to a list of tuples + :param args: + :param kwargs: + :return: dict with speaker names as keys and list of tuples as values and list of different speakers + """ + + diarization_output = {"speakers": [], "segments": []} + + if not hasattr(self, 'diarization'): + # ensure diarization is run before formatting + self.diarization = self.diarization() + + + for segment, _, speaker in self.diarization.itertracks(yield_label=True): + diarization_output["speakers"].append(speaker) + diarization_output["segments"].append(segment) + + normalized_output = [] + index_start_speaker = 0 + index_end_speaker = 0 + current_speaker = str() + + for i, speaker in enumerate(diarization_output["speakers"]): + print(i, speaker) + if i == 0: + current_speaker = speaker + + if speaker != current_speaker: + print("Speaker change") + + index_end_speaker = i - 1 + + normalized_output.append([index_start_speaker, index_end_speaker, current_speaker]) + + index_start_speaker = i + current_speaker = speaker + + if i == len(diarization_output["speakers"]) - 1: + + index_end_speaker = i + normalized_output.append([index_start_speaker, index_end_speaker, current_speaker]) + + + self.normalized_output = normalized_output + self.diarization_output = diarization_output + + return diarization_output,normalized_output + + def create_temporary_wav(self,savefolder: str = "", savename: str = "", *args, **kwargs): + """ + Create temporary wav file for diarization + :param savefolder: folder to save the temporary wav file + :param savename: name of the temporary wav file prefix + :param audiofile: audio file + :return: temporary wav file + """ + + + if savefolder == "": + folder = '.temp' + if not os.path.exists(folder): + os.makedirs(folder) + else: + folder = savefolder + + folder = os.path.realpath(folder) + + if savename == "": + savename = self.coreaudiofile + '.wav' + else: + savename = savename + + + if not os.path.exists(folder): + os.makedirs(folder) + + if not hasattr(self, 'normalized_output') or not hasattr(self, 'diarization_output'): + self.format_diarization_output() + + print("jkvndhjfvndfhjvndfijhvndvijkdvndfjklvndkvjl") + + speaker = set(self.diarization_output["speakers"]) + num_speak_iter = [0 for _ in range(len(speaker))] + + for count, outp in enumerate(self.normalized_output): + start = self.diarization_output["segments"][outp[0]].start + end = self.diarization_output["segments"][outp[1]].end + + print("start: ", start) + print("end: ", end) + + start_milliseconds = start * 1000 + end_milliseconds = end * 1000 + + print("start_milliseconds: ", start_milliseconds) + print("end_milliseconds: ", end_milliseconds) + + print("cut audio") + + cut_audio = self.audio_file[start_milliseconds:end_milliseconds] + + print("save audio") + print(f".temp/{count}_speaker_" + str(outp[2]) + ".wav") + cut_audio.export(f".temp/{count}_speaker_" + str(outp[2]) + ".wav", format="wav") + + return os.path.realpath(folder) + + def __repr__(self): + return f"Diarization(audiofile={self.audiofile}, model={self.model}, language={self.language})" + def __str__(self): + return f"Diarization(audiofile={self.audiofile}, model={self.model}, language={self.language})" + + +class AutoTranscribe: + def __init__(self, audiofile: Union[str, bool, list] = None, + model: str = "medium", + language: str = "German", + diarisation: bool = False, + audioinput: str = "audiofiles", + transcriptionout: str = "transcriptions", + *args, **kwargs): + """ + AutoTranscribe + :param audiofile: audio file or list of audio files to transcribe + :param model: model name (default: medium) + :param language: language (default: German) + :param diarisation: diarisation (default: False) + """ + if audiofile is None: + audiofile = os.listdir(audioinput) # get all audio files in audioinput folder + + self.audiofile = os.path.realpath(audiofile) + self.language = language + self.diarisation = diarisation + if diarisation: + print("Diarisation is enabled") + print("Load Diarisation model") + self.diarisation_model = Pipeline.from_pretrained("pyannote/speaker-diarization", + use_auth_token = self._get_token()) + print("Load Diarisation model done") + + print(f"Load Whisper model {model}") + self.model = whisper.load_model(model) + print(f"Load Whisper model {model} done") + + self.currentpath, \ + self.audiopath, \ + self.transcriptionpath, \ + self.audiofiles = self.create_folder_structure(audioinput, transcriptionout) # create folder structure + + + def transcribe(self, *args, **kwargs): + + if isinstance(self.audiofile, str): + audiolist= [self.audiofile] # convert to list + elif isinstance(self.audiofile, list): + audiolist = self.audiofile + else: + audiolist = self.audiofiles + + print("Start transcribing audio files") + + if not set(audiolist).issubset(set(self.audiofiles)): + raise ValueError(f"Audio file {self.audiofile} not found in {self.audiopath}") + + + for audiofile in audiolist: + _start = time() + if not "/" in audiofile: + audiofile = os.path.join(self.audiopath, audiofile) + + if not self.check_if_allready_transcribed(audiofile): + + audio = AudioProcessor(audiofile) + + if not audiofile.endswith('wav'): + audio = audio.to_wav() + self.audiofile = audio.audio_file_path + + if "speed" in kwargs: + speed = kwargs['speed'] + print('Creating slower version of the audio file with speed {}'.format(speed)) + audio.slower_mp3(speed=speed) + + if not self.diarisation: + WhisperTranscription(audiofile, self.model, self.language + ).save_transcript(savefolder = self.transcriptionpath) + + else: + print("Start diarisation") + dia = Diarisation(audiofile, self.diarisation_model) + dia.diarization() + temppath = dia.create_temporary_wav() + + for file in sorted(os.listdir(temppath)): + print(file ) + fstring = "\\begin{drama}" \ + "\n\t\Character{F}{Frage}" \ + "\n\t\Character{A1}{Antwort}\n" \ + + files = glob.glob(temppath + "/*.wav") + + # Sort files according to the digits included in the filename + files = sorted(files, key=lambda x: float(re.findall("(\d+)", x)[0])) + + for file in files: + print("Start Whisper") + Whisper = WhisperTranscription(file, self.model, self.language).transcribe() + + if "SPEAKER_00" in file: + fstring += f"\n\Fragespeaks: \n {Whisper}" + + elif "SPEAKER_01" in file: + fstring += f"\n\Antwortspeaks: \n {Whisper}" + + fstring += "\n\end{drama}" + + print(fstring) + + with open(os.path.join(self.transcriptionpath, + os.path.basename(audiofile).split('.')[0] + '.tex'), 'w') as f: + f.write(fstring) + + print("Remove temporary files") + shutil.rmtree(temppath) + + print(f"Transcription of {audiofile} done in total of {time() - _start} seconds") + + def create_folder_structure(self, audiopath: str, transcriptionout: str): """ Create folder structure for audio and transcription files :return: currentpath, audiopath, transcriptionpath, audiofiles """ - currentpath = os.getcwd() # get current path + currentpath = os.getcwd() # get current path - if not os.path.exists(os.path.join(currentpath, 'audiofiles')): + if not os.path.exists(os.path.join(currentpath, audiopath)): print('Creating audiofiles folder') - os.makedirs(os.path.join(currentpath, 'audiofiles')) - if not os.path.exists(os.path.join(currentpath, 'transcription')): + os.makedirs(os.path.join(currentpath, audiopath)) + if not os.path.exists(os.path.join(currentpath, transcriptionout)): print('Creating transcription folder') - os.makedirs(os.path.join(currentpath, 'transcription')) + os.makedirs(os.path.join(currentpath, transcriptionout)) - audiopath = os.path.join(currentpath, 'audiofiles') # path to audio files - transcriptionpath = os.path.join(currentpath, 'transcription') # path to transcription files + audiopath = os.path.join(currentpath, audiopath) # path to audio files + transcriptionpath = os.path.join(currentpath, transcriptionout) # path to transcription files - audiofiles = os.listdir(audiopath) # list of audio files + _audiofiles = os.listdir(audiopath) # list of audio files + audiofiles = [] + for i in _audiofiles: + audiofiles.append(os.path.join(audiopath, i)) return currentpath, audiopath, transcriptionpath, audiofiles - def check_if_allready_transcribed(self, filename): + + def check_if_allready_transcribed(self, filename: str): """ Check if all audio files are already transcribed :param filename: audio file name @@ -68,115 +452,19 @@ class Transcribe: return True else: return False - def to_mp3(self,file, remove_orginal=True): - """ - Convert video file or other audio files to mp3 file, ensures that the audio file is in the correct format for the - Whisper model - :param file: audio or video file - :param remove_orginal: remove original file - :return: mp3 file path - """ - print(f'Converting {file} to mp3') - AudioSegment.from_file(file, format=file.split('.')[-1]).export(file[:-4] + '.mp3', format='mp3') - print(f'Converted {file} to mp3') - if remove_orginal: - os.remove(file) - print(f'File {file} removed') - return os.path.join(file[:-4] + '.mp3') - - def slower_mp3(self, file, speed=0.5): - """ - Slow down mp3 file - :param file: mp3 file - :param speed: speed - :return: None - """ - if not os.path.exists(os.path.join(self.transcriptionpath, 'slower_version')): - print('Creating slower_version folder') - os.makedirs(os.path.join(self.transcriptionpath, 'slower_version')) - - path = os.path.join(self.transcriptionpath, 'slower_version') - - sound = AudioSegment.from_file(file, format="mp3") - slow_sound = sound._spawn(sound.raw_data, overrides={ - "frame_rate": int(sound.frame_rate * speed) - }) - speedstr = str(speed).replace('.', '') - file_out = file.split('/')[-1][:-4] + f'_{speedstr}.mp3' - save_path = os.path.join(path, file_out) - slow_sound.export(save_path, format="mp3") - - def transcribe(self, speed = 0.75): - - if self.audiofile is not None: - if self.audiofile in self.audiofiles: - audiofile = os.path.join(self.audiopath, self.audiofile) - else: - raise ValueError('Audio file not found') - - if not self.check_if_allready_transcribed(self.audiofile): - - if not audiofile.endswith('.mp3'): - print('Converting video to audio') - audiofile = self.to_mp3(audiofile) - if speed != 1: - print('Creating slower version of the audio file with speed {}'.format(speed)) - self.slower_mp3(audiofile, speed=speed) - - print(f'Start transcribing Audio file: {audiofile}') - _stime = time() - result = self.model.transcribe(audiofile, verbose=True, language= self.language) - - print(f'Transcription finished in {time() - _stime} seconds') - - txtfilename = str(audiofile.split('/')[-1][:-4]) + '.txt' - - savepath = os.path.join(self.transcriptionpath, txtfilename) - - with open(savepath, 'w') as f: - f.write(result["text"]) - - elif self.audiofile is None or isinstance(self.audiofile, list): - print('No audio file specified or list of audio files') - print(f"{len(self.audiofiles)} audio files found in {self.audiopath}") - print("Start transcribing all audio files") - i = 0 - for audiofile in self.audiofiles: - - audiofile = os.path.join(self.audiopath, audiofile) - - if not self.check_if_allready_transcribed(audiofile): - - if not audiofile.endswith('.mp3'): - audiofile = self.to_mp3(audiofile) - if speed != 1: - print('Creating slower version of the audio file with speed {}'.format(speed)) - self.slower_mp3(audiofile, speed=speed) - - print(f'Start transcribing Audio file: {audiofile}') - _stime = time() - result = self.model.transcribe(audiofile, verbose=True, language=self.language) - print(f'Transcription finished in {time() - _stime} seconds') - - txtfilename = str(audiofile.split('/')[-1][:-4]) + '.txt' - - savepath = os.path.join(self.transcriptionpath, txtfilename) - - with open(savepath, 'w') as f: - f.write(result["text"]) - - i += 1 - print(f'{i} of {len(self.audiofiles)} files transcribed') - + @classmethod + def _get_token(self): + # check ig .pyannotetoken.txt exists + path = os.path.join(os.path.dirname(os.path.realpath(__file__)), '.pyannotetoken') + if os.path.exists(path): + with open(path, 'r') as f: + token = f.read() else: - raise ValueError('Audio file not found') + raise ValueError('No token found. Please create a token at https://huggingface.co/settings/token' + ' and save it in a file called .pyannotetoken.txt') + return token - print('Transcription finished') - - def __call__(self): - return self.transcribe() def __repr__(self): - return f"Transcribe(audiofile={self.audiofile}, model={self.model}, language={self.language})" - def __str__(self): - return f"Transcribe(audiofile={self.audiofile}, model={self.model}, language={self.language})" - + return f"AutoTranscribe(audiofile={self.audiofile}, model={self.model}, language={self.language}, diarisation={self.diarisation})" + def __call__(self, *args, **kwargs): + return self.transcribe(*args, **kwargs) diff --git a/transcribe.py b/transcribe.py index 6be0c17..e7c62fa 100644 --- a/transcribe.py +++ b/transcribe.py @@ -1,4 +1,3 @@ -from autotranscript import Transcribe - -Transcribe().transcribe() +from autotranscript import AutoTranscribe +AutoTranscribe(diarisation=True).transcribe()