From 671c67415f6b0da6feca9ab9ff4e24bfa31187da Mon Sep 17 00:00:00 2001 From: Jaikinator Date: Mon, 12 Jun 2023 11:29:28 +0200 Subject: [PATCH] reworked diarization feature --- autotranscript/diarisation.py | 238 ++++++++++++++++++++-------------- 1 file changed, 143 insertions(+), 95 deletions(-) diff --git a/autotranscript/diarisation.py b/autotranscript/diarisation.py index b7ee848..b0c9e84 100644 --- a/autotranscript/diarisation.py +++ b/autotranscript/diarisation.py @@ -1,62 +1,64 @@ -from audio_processor import AudioProcessor +from pyannote.audio import Pipeline from time import time import os +from typing import TypeVar -class Diarisation(AudioProcessor): - def __init__(self, audio_file: str, model,**kwargs) -> None: +Annotation = TypeVar('Annotation') - super().__init__(audio_file=audio_file) +PYANNOTE_DEFAULT_PATH = os.path.join(os.path.dirname(os.path.dirname(__file__)), + "models", "pyannote", + "speaker_diarization", "config.yaml") + +class Diarisation: + def __init__(self, model,*args,**kwargs) -> None: self.model = model - def diarization(self, *args, **kwargs): + def diarization(self, audiofile : str , *args, **kwargs) -> Annotation: + """ + Diarization of audio file + :param audiofile: path to audio file + :param args: args for diarization model + :param kwargs: kwargs for diarization model + :return: diarization + """ - if "num_speakers" in kwargs: - num_speakers = kwargs['num_speakers'] - kwargs.pop('num_speakers') - else: - num_speakers = 2 + print(f'Start diarization of audio file: {audiofile}') - audiofilename = self.coreaudiofile + diarization = self.model(audiofile,*args, **kwargs) - print(f'Start diarization of audio file: {self.audiofilename}') + print('Diarization finished') - _stime = time() + out = self.format_diarization_output(diarization) - diarization = self.model(self.audio_file_path, num_speakers=num_speakers) + return out - print(f'Diarization finished in {time() - _stime} seconds') - self.diarization = diarization - - return diarization - - def format_diarization_output(self, *args, **kwargs): + @staticmethod + def format_diarization_output(dia : Annotation) -> dict: """ Format diarization output to a list of tuples - :param args: - :param kwargs: - :return: dict with speaker names as keys and list of tuples as values and list of different speakers + :param dia: diarization output + :return: dict with speaker names as keys and list of tuples + as values and list of different speakers """ + dia_list = list(dia.itertracks(yield_label=True)) diarization_output = {"speakers": [], "segments": []} - if not hasattr(self, 'diarization'): - # ensure diarization is run before formatting - self.diarization = self.diarization() - - - for segment, _, speaker in self.diarization.itertracks(yield_label=True): - diarization_output["speakers"].append(speaker) - diarization_output["segments"].append(segment) - normalized_output = [] index_start_speaker = 0 index_end_speaker = 0 current_speaker = str() + + ### + # Sometimes two consecutive speakers are the same + # This loop removes these duplicates + ### - for i, speaker in enumerate(diarization_output["speakers"]): + for i, (_, _, speaker) in enumerate(dia_list): + if i == 0: current_speaker = speaker @@ -64,7 +66,9 @@ class Diarisation(AudioProcessor): index_end_speaker = i - 1 - normalized_output.append([index_start_speaker, index_end_speaker, current_speaker]) + normalized_output.append([index_start_speaker, + index_end_speaker, + current_speaker]) index_start_speaker = i current_speaker = speaker @@ -72,73 +76,117 @@ class Diarisation(AudioProcessor): if i == len(diarization_output["speakers"]) - 1: index_end_speaker = i - normalized_output.append([index_start_speaker, index_end_speaker, current_speaker]) + normalized_output.append([index_start_speaker, + index_end_speaker, + current_speaker]) + + for outp in normalized_output: + #convert in milliseconds + start = dia_list[outp[0]][0].start * 1000 + end = dia_list[outp[1]][0].end * 1000 + diarization_output["segments"].append([start, end]) + diarization_output["speakers"].append(outp[2]) - self.normalized_output = normalized_output - self.diarization_output = diarization_output - - return diarization_output,normalized_output - - def create_temporary_wav(self,savefolder: str = "", savename: str = "", *args, **kwargs): + return diarization_output + + @classmethod + def load_model(cls, model: str = PYANNOTE_DEFAULT_PATH, + token: str = "", + local : bool = True, + *args, **kwargs) -> Pipeline: """ - Create temporary wav file for diarization - :param savefolder: folder to save the temporary wav file - :param savename: name of the temporary wav file prefix - :param audiofile: audio file - :return: temporary wav file + Load modules from pyannote + + Parameters + ---------- + model : str + pyannote model + default: /models/pyannote/speaker_diarization/config.yaml + token : str + HUGGINGFACE_TOKEN + local : bool + If true, load from local cache + + Returns + ------- + Pipeline Object """ - - if savefolder == "": - folder = '.temp' - if not os.path.exists(folder): - os.makedirs(folder) + if local: + diarization_model = Pipeline.from_pretrained(model,*args, **kwargs) else: - folder = savefolder - - folder = os.path.realpath(folder) - - if savename == "": - savename = self.coreaudiofile + '.wav' - else: - savename = savename - - - if not os.path.exists(folder): - os.makedirs(folder) - - if not hasattr(self, 'normalized_output') or not hasattr(self, 'diarization_output'): - self.format_diarization_output() - - - speaker = set(self.diarization_output["speakers"]) - num_speak_iter = [0 for _ in range(len(speaker))] - - for count, outp in enumerate(self.normalized_output): - start = self.diarization_output["segments"][outp[0]].start - end = self.diarization_output["segments"][outp[1]].end - - print("start: ", start) - print("end: ", end) - - start_milliseconds = start * 1000 - end_milliseconds = end * 1000 - - print("start_milliseconds: ", start_milliseconds) - print("end_milliseconds: ", end_milliseconds) - - print("cut audio") - - cut_audio = self.audio_file[start_milliseconds:end_milliseconds] - - print("save audio") - print(f".temp/{count}_speaker_" + str(outp[2]) + ".wav") - cut_audio.export(f".temp/{count}_speaker_" + str(outp[2]) + ".wav", format="wav") - - return os.path.realpath(folder) + diarization_model = Pipeline.from_pretrained(model, use_auth_token = token, + *args, **kwargs) + + return cls(diarization_model) def __repr__(self): - return f"Diarization(audiofile={self.audiofile}, model={self.model}, language={self.language})" + return f"Diarisation(model={self.model})" def __str__(self): - return f"Diarization(audiofile={self.audiofile}, model={self.model}, language={self.language})" \ No newline at end of file + return f"Diarisation(model={self.model})" + + +if __name__ == '__main__': + + model = Diarisation.load_model() + print(model) + audiofile = "/home/jacob/PycharmProjects/autotranscript/tests/test.wav" + out = model.diarization(audiofile) + print(out) + + # # deprecated + # def create_temporary_wav(self, location_of_temp_folder : str = '.temp'): + # """ + # Create temporary wav file for diarization + # :param location_of_temp_folder: folder to save the temporary wav file + # default: .temp + # :param savename: name of the temporary wav file prefix + # :param audiofile: audio file + # :return: temporary wav file + # """ + # print("Linne 84 Diarisation.py create_temporary_wav :" / + # "location_of_temp_folder.split('/')[-1]",location_of_temp_folder.split('/')[-1]) + + # if location_of_temp_folder.split('/')[-1] != '.temp': + # folder =os.path.join(location_of_temp_folder, '.temp') + # else: + # folder = location_of_temp_folder + + # if not os.path.exists(folder): + # os.makedirs(folder) + + # folder = os.path.realpath(folder) + + # if not hasattr(self, 'normalized_output') or not hasattr(self, 'diarization_output'): + # raise AttributeError("You need to run the diarization first") + + # speaker = set(self.diarization_output["speakers"]) + # num_speak_iter = [0 for _ in range(len(speaker))] + + # for count, outp in enumerate(self.normalized_output): + # print(outp) + # print(self.diarization_output["segments"][outp[0]]) + # print(self.diarization_output["segments"][outp[1]]) + + # start = self.diarization_output["segments"][outp[0]].start + # end = self.diarization_output["segments"][outp[1]].end + + # print("start: ", start) + # print("end: ", end) + + # start_milliseconds = start * 1000 + # end_milliseconds = end * 1000 + + # print("start_milliseconds: ", start_milliseconds) + # print("end_milliseconds: ", end_milliseconds) + + # print("cut audio") + + # cut_audio = self.audio_file[start_milliseconds:end_milliseconds] + + # print("save audio") + # print(f".temp/{count}_speaker_" + str(outp[2]) + ".wav") + # cut_audio.export(f".temp/{count}_speaker_" + str(outp[2]) + ".wav", format="wav") + + # return os.path.realpath(folder) \ No newline at end of file