reworked diarization feature

This commit is contained in:
Jaikinator
2023-06-12 11:29:28 +02:00
parent 6710f05eaf
commit 671c67415f
+140 -92
View File
@@ -1,61 +1,63 @@
from audio_processor import AudioProcessor from pyannote.audio import Pipeline
from time import time from time import time
import os import os
from typing import TypeVar
class Diarisation(AudioProcessor): Annotation = TypeVar('Annotation')
def __init__(self, audio_file: str, model,**kwargs) -> None:
super().__init__(audio_file=audio_file) PYANNOTE_DEFAULT_PATH = os.path.join(os.path.dirname(os.path.dirname(__file__)),
"models", "pyannote",
"speaker_diarization", "config.yaml")
class Diarisation:
def __init__(self, model,*args,**kwargs) -> None:
self.model = model self.model = model
def diarization(self, *args, **kwargs): def diarization(self, audiofile : str , *args, **kwargs) -> Annotation:
"""
Diarization of audio file
:param audiofile: path to audio file
:param args: args for diarization model
:param kwargs: kwargs for diarization model
:return: diarization
"""
if "num_speakers" in kwargs: print(f'Start diarization of audio file: {audiofile}')
num_speakers = kwargs['num_speakers']
kwargs.pop('num_speakers')
else:
num_speakers = 2
audiofilename = self.coreaudiofile diarization = self.model(audiofile,*args, **kwargs)
print(f'Start diarization of audio file: {self.audiofilename}') print('Diarization finished')
_stime = time() out = self.format_diarization_output(diarization)
diarization = self.model(self.audio_file_path, num_speakers=num_speakers) return out
print(f'Diarization finished in {time() - _stime} seconds') @staticmethod
self.diarization = diarization def format_diarization_output(dia : Annotation) -> dict:
return diarization
def format_diarization_output(self, *args, **kwargs):
""" """
Format diarization output to a list of tuples Format diarization output to a list of tuples
:param args: :param dia: diarization output
:param kwargs: :return: dict with speaker names as keys and list of tuples
:return: dict with speaker names as keys and list of tuples as values and list of different speakers as values and list of different speakers
""" """
dia_list = list(dia.itertracks(yield_label=True))
diarization_output = {"speakers": [], "segments": []} diarization_output = {"speakers": [], "segments": []}
if not hasattr(self, 'diarization'):
# ensure diarization is run before formatting
self.diarization = self.diarization()
for segment, _, speaker in self.diarization.itertracks(yield_label=True):
diarization_output["speakers"].append(speaker)
diarization_output["segments"].append(segment)
normalized_output = [] normalized_output = []
index_start_speaker = 0 index_start_speaker = 0
index_end_speaker = 0 index_end_speaker = 0
current_speaker = str() current_speaker = str()
for i, speaker in enumerate(diarization_output["speakers"]): ###
# Sometimes two consecutive speakers are the same
# This loop removes these duplicates
###
for i, (_, _, speaker) in enumerate(dia_list):
if i == 0: if i == 0:
current_speaker = speaker current_speaker = speaker
@@ -64,7 +66,9 @@ class Diarisation(AudioProcessor):
index_end_speaker = i - 1 index_end_speaker = i - 1
normalized_output.append([index_start_speaker, index_end_speaker, current_speaker]) normalized_output.append([index_start_speaker,
index_end_speaker,
current_speaker])
index_start_speaker = i index_start_speaker = i
current_speaker = speaker current_speaker = speaker
@@ -72,73 +76,117 @@ class Diarisation(AudioProcessor):
if i == len(diarization_output["speakers"]) - 1: if i == len(diarization_output["speakers"]) - 1:
index_end_speaker = i index_end_speaker = i
normalized_output.append([index_start_speaker, index_end_speaker, current_speaker]) normalized_output.append([index_start_speaker,
index_end_speaker,
current_speaker])
for outp in normalized_output:
#convert in milliseconds
start = dia_list[outp[0]][0].start * 1000
end = dia_list[outp[1]][0].end * 1000
self.normalized_output = normalized_output diarization_output["segments"].append([start, end])
self.diarization_output = diarization_output diarization_output["speakers"].append(outp[2])
return diarization_output,normalized_output return diarization_output
def create_temporary_wav(self,savefolder: str = "", savename: str = "", *args, **kwargs): @classmethod
def load_model(cls, model: str = PYANNOTE_DEFAULT_PATH,
token: str = "",
local : bool = True,
*args, **kwargs) -> Pipeline:
""" """
Create temporary wav file for diarization Load modules from pyannote
:param savefolder: folder to save the temporary wav file
:param savename: name of the temporary wav file prefix Parameters
:param audiofile: audio file ----------
:return: temporary wav file model : str
pyannote model
default: /models/pyannote/speaker_diarization/config.yaml
token : str
HUGGINGFACE_TOKEN
local : bool
If true, load from local cache
Returns
-------
Pipeline Object
""" """
if local:
if savefolder == "": diarization_model = Pipeline.from_pretrained(model,*args, **kwargs)
folder = '.temp'
if not os.path.exists(folder):
os.makedirs(folder)
else: else:
folder = savefolder diarization_model = Pipeline.from_pretrained(model, use_auth_token = token,
*args, **kwargs)
folder = os.path.realpath(folder) return cls(diarization_model)
if savename == "":
savename = self.coreaudiofile + '.wav'
else:
savename = savename
if not os.path.exists(folder):
os.makedirs(folder)
if not hasattr(self, 'normalized_output') or not hasattr(self, 'diarization_output'):
self.format_diarization_output()
speaker = set(self.diarization_output["speakers"])
num_speak_iter = [0 for _ in range(len(speaker))]
for count, outp in enumerate(self.normalized_output):
start = self.diarization_output["segments"][outp[0]].start
end = self.diarization_output["segments"][outp[1]].end
print("start: ", start)
print("end: ", end)
start_milliseconds = start * 1000
end_milliseconds = end * 1000
print("start_milliseconds: ", start_milliseconds)
print("end_milliseconds: ", end_milliseconds)
print("cut audio")
cut_audio = self.audio_file[start_milliseconds:end_milliseconds]
print("save audio")
print(f".temp/{count}_speaker_" + str(outp[2]) + ".wav")
cut_audio.export(f".temp/{count}_speaker_" + str(outp[2]) + ".wav", format="wav")
return os.path.realpath(folder)
def __repr__(self): def __repr__(self):
return f"Diarization(audiofile={self.audiofile}, model={self.model}, language={self.language})" return f"Diarisation(model={self.model})"
def __str__(self): def __str__(self):
return f"Diarization(audiofile={self.audiofile}, model={self.model}, language={self.language})" return f"Diarisation(model={self.model})"
if __name__ == '__main__':
model = Diarisation.load_model()
print(model)
audiofile = "/home/jacob/PycharmProjects/autotranscript/tests/test.wav"
out = model.diarization(audiofile)
print(out)
# # deprecated
# def create_temporary_wav(self, location_of_temp_folder : str = '.temp'):
# """
# Create temporary wav file for diarization
# :param location_of_temp_folder: folder to save the temporary wav file
# default: .temp
# :param savename: name of the temporary wav file prefix
# :param audiofile: audio file
# :return: temporary wav file
# """
# print("Linne 84 Diarisation.py create_temporary_wav :" /
# "location_of_temp_folder.split('/')[-1]",location_of_temp_folder.split('/')[-1])
# if location_of_temp_folder.split('/')[-1] != '.temp':
# folder =os.path.join(location_of_temp_folder, '.temp')
# else:
# folder = location_of_temp_folder
# if not os.path.exists(folder):
# os.makedirs(folder)
# folder = os.path.realpath(folder)
# if not hasattr(self, 'normalized_output') or not hasattr(self, 'diarization_output'):
# raise AttributeError("You need to run the diarization first")
# speaker = set(self.diarization_output["speakers"])
# num_speak_iter = [0 for _ in range(len(speaker))]
# for count, outp in enumerate(self.normalized_output):
# print(outp)
# print(self.diarization_output["segments"][outp[0]])
# print(self.diarization_output["segments"][outp[1]])
# start = self.diarization_output["segments"][outp[0]].start
# end = self.diarization_output["segments"][outp[1]].end
# print("start: ", start)
# print("end: ", end)
# start_milliseconds = start * 1000
# end_milliseconds = end * 1000
# print("start_milliseconds: ", start_milliseconds)
# print("end_milliseconds: ", end_milliseconds)
# print("cut audio")
# cut_audio = self.audio_file[start_milliseconds:end_milliseconds]
# print("save audio")
# print(f".temp/{count}_speaker_" + str(outp[2]) + ".wav")
# cut_audio.export(f".temp/{count}_speaker_" + str(outp[2]) + ".wav", format="wav")
# return os.path.realpath(folder)