reworked diarization feature
This commit is contained in:
+140
-92
@@ -1,61 +1,63 @@
|
|||||||
from audio_processor import AudioProcessor
|
from pyannote.audio import Pipeline
|
||||||
from time import time
|
from time import time
|
||||||
import os
|
import os
|
||||||
|
from typing import TypeVar
|
||||||
|
|
||||||
class Diarisation(AudioProcessor):
|
Annotation = TypeVar('Annotation')
|
||||||
def __init__(self, audio_file: str, model,**kwargs) -> None:
|
|
||||||
|
|
||||||
super().__init__(audio_file=audio_file)
|
PYANNOTE_DEFAULT_PATH = os.path.join(os.path.dirname(os.path.dirname(__file__)),
|
||||||
|
"models", "pyannote",
|
||||||
|
"speaker_diarization", "config.yaml")
|
||||||
|
|
||||||
|
class Diarisation:
|
||||||
|
def __init__(self, model,*args,**kwargs) -> None:
|
||||||
|
|
||||||
self.model = model
|
self.model = model
|
||||||
|
|
||||||
|
|
||||||
def diarization(self, *args, **kwargs):
|
def diarization(self, audiofile : str , *args, **kwargs) -> Annotation:
|
||||||
|
"""
|
||||||
|
Diarization of audio file
|
||||||
|
:param audiofile: path to audio file
|
||||||
|
:param args: args for diarization model
|
||||||
|
:param kwargs: kwargs for diarization model
|
||||||
|
:return: diarization
|
||||||
|
"""
|
||||||
|
|
||||||
if "num_speakers" in kwargs:
|
print(f'Start diarization of audio file: {audiofile}')
|
||||||
num_speakers = kwargs['num_speakers']
|
|
||||||
kwargs.pop('num_speakers')
|
|
||||||
else:
|
|
||||||
num_speakers = 2
|
|
||||||
|
|
||||||
audiofilename = self.coreaudiofile
|
diarization = self.model(audiofile,*args, **kwargs)
|
||||||
|
|
||||||
print(f'Start diarization of audio file: {self.audiofilename}')
|
print('Diarization finished')
|
||||||
|
|
||||||
_stime = time()
|
out = self.format_diarization_output(diarization)
|
||||||
|
|
||||||
diarization = self.model(self.audio_file_path, num_speakers=num_speakers)
|
return out
|
||||||
|
|
||||||
print(f'Diarization finished in {time() - _stime} seconds')
|
@staticmethod
|
||||||
self.diarization = diarization
|
def format_diarization_output(dia : Annotation) -> dict:
|
||||||
|
|
||||||
return diarization
|
|
||||||
|
|
||||||
def format_diarization_output(self, *args, **kwargs):
|
|
||||||
"""
|
"""
|
||||||
Format diarization output to a list of tuples
|
Format diarization output to a list of tuples
|
||||||
:param args:
|
:param dia: diarization output
|
||||||
:param kwargs:
|
:return: dict with speaker names as keys and list of tuples
|
||||||
:return: dict with speaker names as keys and list of tuples as values and list of different speakers
|
as values and list of different speakers
|
||||||
"""
|
"""
|
||||||
|
|
||||||
|
dia_list = list(dia.itertracks(yield_label=True))
|
||||||
diarization_output = {"speakers": [], "segments": []}
|
diarization_output = {"speakers": [], "segments": []}
|
||||||
|
|
||||||
if not hasattr(self, 'diarization'):
|
|
||||||
# ensure diarization is run before formatting
|
|
||||||
self.diarization = self.diarization()
|
|
||||||
|
|
||||||
|
|
||||||
for segment, _, speaker in self.diarization.itertracks(yield_label=True):
|
|
||||||
diarization_output["speakers"].append(speaker)
|
|
||||||
diarization_output["segments"].append(segment)
|
|
||||||
|
|
||||||
normalized_output = []
|
normalized_output = []
|
||||||
index_start_speaker = 0
|
index_start_speaker = 0
|
||||||
index_end_speaker = 0
|
index_end_speaker = 0
|
||||||
current_speaker = str()
|
current_speaker = str()
|
||||||
|
|
||||||
for i, speaker in enumerate(diarization_output["speakers"]):
|
###
|
||||||
|
# Sometimes two consecutive speakers are the same
|
||||||
|
# This loop removes these duplicates
|
||||||
|
###
|
||||||
|
|
||||||
|
|
||||||
|
for i, (_, _, speaker) in enumerate(dia_list):
|
||||||
|
|
||||||
if i == 0:
|
if i == 0:
|
||||||
current_speaker = speaker
|
current_speaker = speaker
|
||||||
@@ -64,7 +66,9 @@ class Diarisation(AudioProcessor):
|
|||||||
|
|
||||||
index_end_speaker = i - 1
|
index_end_speaker = i - 1
|
||||||
|
|
||||||
normalized_output.append([index_start_speaker, index_end_speaker, current_speaker])
|
normalized_output.append([index_start_speaker,
|
||||||
|
index_end_speaker,
|
||||||
|
current_speaker])
|
||||||
|
|
||||||
index_start_speaker = i
|
index_start_speaker = i
|
||||||
current_speaker = speaker
|
current_speaker = speaker
|
||||||
@@ -72,73 +76,117 @@ class Diarisation(AudioProcessor):
|
|||||||
if i == len(diarization_output["speakers"]) - 1:
|
if i == len(diarization_output["speakers"]) - 1:
|
||||||
|
|
||||||
index_end_speaker = i
|
index_end_speaker = i
|
||||||
normalized_output.append([index_start_speaker, index_end_speaker, current_speaker])
|
normalized_output.append([index_start_speaker,
|
||||||
|
index_end_speaker,
|
||||||
|
current_speaker])
|
||||||
|
|
||||||
|
for outp in normalized_output:
|
||||||
|
#convert in milliseconds
|
||||||
|
start = dia_list[outp[0]][0].start * 1000
|
||||||
|
end = dia_list[outp[1]][0].end * 1000
|
||||||
|
|
||||||
self.normalized_output = normalized_output
|
diarization_output["segments"].append([start, end])
|
||||||
self.diarization_output = diarization_output
|
diarization_output["speakers"].append(outp[2])
|
||||||
|
|
||||||
return diarization_output,normalized_output
|
return diarization_output
|
||||||
|
|
||||||
def create_temporary_wav(self,savefolder: str = "", savename: str = "", *args, **kwargs):
|
@classmethod
|
||||||
|
def load_model(cls, model: str = PYANNOTE_DEFAULT_PATH,
|
||||||
|
token: str = "",
|
||||||
|
local : bool = True,
|
||||||
|
*args, **kwargs) -> Pipeline:
|
||||||
"""
|
"""
|
||||||
Create temporary wav file for diarization
|
Load modules from pyannote
|
||||||
:param savefolder: folder to save the temporary wav file
|
|
||||||
:param savename: name of the temporary wav file prefix
|
Parameters
|
||||||
:param audiofile: audio file
|
----------
|
||||||
:return: temporary wav file
|
model : str
|
||||||
|
pyannote model
|
||||||
|
default: /models/pyannote/speaker_diarization/config.yaml
|
||||||
|
token : str
|
||||||
|
HUGGINGFACE_TOKEN
|
||||||
|
local : bool
|
||||||
|
If true, load from local cache
|
||||||
|
|
||||||
|
Returns
|
||||||
|
-------
|
||||||
|
Pipeline Object
|
||||||
"""
|
"""
|
||||||
|
|
||||||
|
if local:
|
||||||
if savefolder == "":
|
diarization_model = Pipeline.from_pretrained(model,*args, **kwargs)
|
||||||
folder = '.temp'
|
|
||||||
if not os.path.exists(folder):
|
|
||||||
os.makedirs(folder)
|
|
||||||
else:
|
else:
|
||||||
folder = savefolder
|
diarization_model = Pipeline.from_pretrained(model, use_auth_token = token,
|
||||||
|
*args, **kwargs)
|
||||||
|
|
||||||
folder = os.path.realpath(folder)
|
return cls(diarization_model)
|
||||||
|
|
||||||
if savename == "":
|
|
||||||
savename = self.coreaudiofile + '.wav'
|
|
||||||
else:
|
|
||||||
savename = savename
|
|
||||||
|
|
||||||
|
|
||||||
if not os.path.exists(folder):
|
|
||||||
os.makedirs(folder)
|
|
||||||
|
|
||||||
if not hasattr(self, 'normalized_output') or not hasattr(self, 'diarization_output'):
|
|
||||||
self.format_diarization_output()
|
|
||||||
|
|
||||||
|
|
||||||
speaker = set(self.diarization_output["speakers"])
|
|
||||||
num_speak_iter = [0 for _ in range(len(speaker))]
|
|
||||||
|
|
||||||
for count, outp in enumerate(self.normalized_output):
|
|
||||||
start = self.diarization_output["segments"][outp[0]].start
|
|
||||||
end = self.diarization_output["segments"][outp[1]].end
|
|
||||||
|
|
||||||
print("start: ", start)
|
|
||||||
print("end: ", end)
|
|
||||||
|
|
||||||
start_milliseconds = start * 1000
|
|
||||||
end_milliseconds = end * 1000
|
|
||||||
|
|
||||||
print("start_milliseconds: ", start_milliseconds)
|
|
||||||
print("end_milliseconds: ", end_milliseconds)
|
|
||||||
|
|
||||||
print("cut audio")
|
|
||||||
|
|
||||||
cut_audio = self.audio_file[start_milliseconds:end_milliseconds]
|
|
||||||
|
|
||||||
print("save audio")
|
|
||||||
print(f".temp/{count}_speaker_" + str(outp[2]) + ".wav")
|
|
||||||
cut_audio.export(f".temp/{count}_speaker_" + str(outp[2]) + ".wav", format="wav")
|
|
||||||
|
|
||||||
return os.path.realpath(folder)
|
|
||||||
|
|
||||||
def __repr__(self):
|
def __repr__(self):
|
||||||
return f"Diarization(audiofile={self.audiofile}, model={self.model}, language={self.language})"
|
return f"Diarisation(model={self.model})"
|
||||||
def __str__(self):
|
def __str__(self):
|
||||||
return f"Diarization(audiofile={self.audiofile}, model={self.model}, language={self.language})"
|
return f"Diarisation(model={self.model})"
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == '__main__':
|
||||||
|
|
||||||
|
model = Diarisation.load_model()
|
||||||
|
print(model)
|
||||||
|
audiofile = "/home/jacob/PycharmProjects/autotranscript/tests/test.wav"
|
||||||
|
out = model.diarization(audiofile)
|
||||||
|
print(out)
|
||||||
|
|
||||||
|
# # deprecated
|
||||||
|
# def create_temporary_wav(self, location_of_temp_folder : str = '.temp'):
|
||||||
|
# """
|
||||||
|
# Create temporary wav file for diarization
|
||||||
|
# :param location_of_temp_folder: folder to save the temporary wav file
|
||||||
|
# default: .temp
|
||||||
|
# :param savename: name of the temporary wav file prefix
|
||||||
|
# :param audiofile: audio file
|
||||||
|
# :return: temporary wav file
|
||||||
|
# """
|
||||||
|
# print("Linne 84 Diarisation.py create_temporary_wav :" /
|
||||||
|
# "location_of_temp_folder.split('/')[-1]",location_of_temp_folder.split('/')[-1])
|
||||||
|
|
||||||
|
# if location_of_temp_folder.split('/')[-1] != '.temp':
|
||||||
|
# folder =os.path.join(location_of_temp_folder, '.temp')
|
||||||
|
# else:
|
||||||
|
# folder = location_of_temp_folder
|
||||||
|
|
||||||
|
# if not os.path.exists(folder):
|
||||||
|
# os.makedirs(folder)
|
||||||
|
|
||||||
|
# folder = os.path.realpath(folder)
|
||||||
|
|
||||||
|
# if not hasattr(self, 'normalized_output') or not hasattr(self, 'diarization_output'):
|
||||||
|
# raise AttributeError("You need to run the diarization first")
|
||||||
|
|
||||||
|
# speaker = set(self.diarization_output["speakers"])
|
||||||
|
# num_speak_iter = [0 for _ in range(len(speaker))]
|
||||||
|
|
||||||
|
# for count, outp in enumerate(self.normalized_output):
|
||||||
|
# print(outp)
|
||||||
|
# print(self.diarization_output["segments"][outp[0]])
|
||||||
|
# print(self.diarization_output["segments"][outp[1]])
|
||||||
|
|
||||||
|
# start = self.diarization_output["segments"][outp[0]].start
|
||||||
|
# end = self.diarization_output["segments"][outp[1]].end
|
||||||
|
|
||||||
|
# print("start: ", start)
|
||||||
|
# print("end: ", end)
|
||||||
|
|
||||||
|
# start_milliseconds = start * 1000
|
||||||
|
# end_milliseconds = end * 1000
|
||||||
|
|
||||||
|
# print("start_milliseconds: ", start_milliseconds)
|
||||||
|
# print("end_milliseconds: ", end_milliseconds)
|
||||||
|
|
||||||
|
# print("cut audio")
|
||||||
|
|
||||||
|
# cut_audio = self.audio_file[start_milliseconds:end_milliseconds]
|
||||||
|
|
||||||
|
# print("save audio")
|
||||||
|
# print(f".temp/{count}_speaker_" + str(outp[2]) + ".wav")
|
||||||
|
# cut_audio.export(f".temp/{count}_speaker_" + str(outp[2]) + ".wav", format="wav")
|
||||||
|
|
||||||
|
# return os.path.realpath(folder)
|
||||||
Reference in New Issue
Block a user