reworked diarization feature
This commit is contained in:
+143
-95
@@ -1,62 +1,64 @@
|
||||
from audio_processor import AudioProcessor
|
||||
from pyannote.audio import Pipeline
|
||||
from time import time
|
||||
import os
|
||||
from typing import TypeVar
|
||||
|
||||
class Diarisation(AudioProcessor):
|
||||
def __init__(self, audio_file: str, model,**kwargs) -> None:
|
||||
Annotation = TypeVar('Annotation')
|
||||
|
||||
super().__init__(audio_file=audio_file)
|
||||
PYANNOTE_DEFAULT_PATH = os.path.join(os.path.dirname(os.path.dirname(__file__)),
|
||||
"models", "pyannote",
|
||||
"speaker_diarization", "config.yaml")
|
||||
|
||||
class Diarisation:
|
||||
def __init__(self, model,*args,**kwargs) -> None:
|
||||
|
||||
self.model = model
|
||||
|
||||
|
||||
def diarization(self, *args, **kwargs):
|
||||
def diarization(self, audiofile : str , *args, **kwargs) -> Annotation:
|
||||
"""
|
||||
Diarization of audio file
|
||||
:param audiofile: path to audio file
|
||||
:param args: args for diarization model
|
||||
:param kwargs: kwargs for diarization model
|
||||
:return: diarization
|
||||
"""
|
||||
|
||||
if "num_speakers" in kwargs:
|
||||
num_speakers = kwargs['num_speakers']
|
||||
kwargs.pop('num_speakers')
|
||||
else:
|
||||
num_speakers = 2
|
||||
print(f'Start diarization of audio file: {audiofile}')
|
||||
|
||||
audiofilename = self.coreaudiofile
|
||||
diarization = self.model(audiofile,*args, **kwargs)
|
||||
|
||||
print(f'Start diarization of audio file: {self.audiofilename}')
|
||||
print('Diarization finished')
|
||||
|
||||
_stime = time()
|
||||
out = self.format_diarization_output(diarization)
|
||||
|
||||
diarization = self.model(self.audio_file_path, num_speakers=num_speakers)
|
||||
return out
|
||||
|
||||
print(f'Diarization finished in {time() - _stime} seconds')
|
||||
self.diarization = diarization
|
||||
|
||||
return diarization
|
||||
|
||||
def format_diarization_output(self, *args, **kwargs):
|
||||
@staticmethod
|
||||
def format_diarization_output(dia : Annotation) -> dict:
|
||||
"""
|
||||
Format diarization output to a list of tuples
|
||||
:param args:
|
||||
:param kwargs:
|
||||
:return: dict with speaker names as keys and list of tuples as values and list of different speakers
|
||||
:param dia: diarization output
|
||||
:return: dict with speaker names as keys and list of tuples
|
||||
as values and list of different speakers
|
||||
"""
|
||||
|
||||
dia_list = list(dia.itertracks(yield_label=True))
|
||||
diarization_output = {"speakers": [], "segments": []}
|
||||
|
||||
if not hasattr(self, 'diarization'):
|
||||
# ensure diarization is run before formatting
|
||||
self.diarization = self.diarization()
|
||||
|
||||
|
||||
for segment, _, speaker in self.diarization.itertracks(yield_label=True):
|
||||
diarization_output["speakers"].append(speaker)
|
||||
diarization_output["segments"].append(segment)
|
||||
|
||||
normalized_output = []
|
||||
index_start_speaker = 0
|
||||
index_end_speaker = 0
|
||||
current_speaker = str()
|
||||
|
||||
###
|
||||
# Sometimes two consecutive speakers are the same
|
||||
# This loop removes these duplicates
|
||||
###
|
||||
|
||||
for i, speaker in enumerate(diarization_output["speakers"]):
|
||||
|
||||
for i, (_, _, speaker) in enumerate(dia_list):
|
||||
|
||||
if i == 0:
|
||||
current_speaker = speaker
|
||||
|
||||
@@ -64,7 +66,9 @@ class Diarisation(AudioProcessor):
|
||||
|
||||
index_end_speaker = i - 1
|
||||
|
||||
normalized_output.append([index_start_speaker, index_end_speaker, current_speaker])
|
||||
normalized_output.append([index_start_speaker,
|
||||
index_end_speaker,
|
||||
current_speaker])
|
||||
|
||||
index_start_speaker = i
|
||||
current_speaker = speaker
|
||||
@@ -72,73 +76,117 @@ class Diarisation(AudioProcessor):
|
||||
if i == len(diarization_output["speakers"]) - 1:
|
||||
|
||||
index_end_speaker = i
|
||||
normalized_output.append([index_start_speaker, index_end_speaker, current_speaker])
|
||||
normalized_output.append([index_start_speaker,
|
||||
index_end_speaker,
|
||||
current_speaker])
|
||||
|
||||
for outp in normalized_output:
|
||||
#convert in milliseconds
|
||||
start = dia_list[outp[0]][0].start * 1000
|
||||
end = dia_list[outp[1]][0].end * 1000
|
||||
|
||||
diarization_output["segments"].append([start, end])
|
||||
diarization_output["speakers"].append(outp[2])
|
||||
|
||||
self.normalized_output = normalized_output
|
||||
self.diarization_output = diarization_output
|
||||
|
||||
return diarization_output,normalized_output
|
||||
|
||||
def create_temporary_wav(self,savefolder: str = "", savename: str = "", *args, **kwargs):
|
||||
return diarization_output
|
||||
|
||||
@classmethod
|
||||
def load_model(cls, model: str = PYANNOTE_DEFAULT_PATH,
|
||||
token: str = "",
|
||||
local : bool = True,
|
||||
*args, **kwargs) -> Pipeline:
|
||||
"""
|
||||
Create temporary wav file for diarization
|
||||
:param savefolder: folder to save the temporary wav file
|
||||
:param savename: name of the temporary wav file prefix
|
||||
:param audiofile: audio file
|
||||
:return: temporary wav file
|
||||
Load modules from pyannote
|
||||
|
||||
Parameters
|
||||
----------
|
||||
model : str
|
||||
pyannote model
|
||||
default: /models/pyannote/speaker_diarization/config.yaml
|
||||
token : str
|
||||
HUGGINGFACE_TOKEN
|
||||
local : bool
|
||||
If true, load from local cache
|
||||
|
||||
Returns
|
||||
-------
|
||||
Pipeline Object
|
||||
"""
|
||||
|
||||
|
||||
if savefolder == "":
|
||||
folder = '.temp'
|
||||
if not os.path.exists(folder):
|
||||
os.makedirs(folder)
|
||||
if local:
|
||||
diarization_model = Pipeline.from_pretrained(model,*args, **kwargs)
|
||||
else:
|
||||
folder = savefolder
|
||||
|
||||
folder = os.path.realpath(folder)
|
||||
|
||||
if savename == "":
|
||||
savename = self.coreaudiofile + '.wav'
|
||||
else:
|
||||
savename = savename
|
||||
|
||||
|
||||
if not os.path.exists(folder):
|
||||
os.makedirs(folder)
|
||||
|
||||
if not hasattr(self, 'normalized_output') or not hasattr(self, 'diarization_output'):
|
||||
self.format_diarization_output()
|
||||
|
||||
|
||||
speaker = set(self.diarization_output["speakers"])
|
||||
num_speak_iter = [0 for _ in range(len(speaker))]
|
||||
|
||||
for count, outp in enumerate(self.normalized_output):
|
||||
start = self.diarization_output["segments"][outp[0]].start
|
||||
end = self.diarization_output["segments"][outp[1]].end
|
||||
|
||||
print("start: ", start)
|
||||
print("end: ", end)
|
||||
|
||||
start_milliseconds = start * 1000
|
||||
end_milliseconds = end * 1000
|
||||
|
||||
print("start_milliseconds: ", start_milliseconds)
|
||||
print("end_milliseconds: ", end_milliseconds)
|
||||
|
||||
print("cut audio")
|
||||
|
||||
cut_audio = self.audio_file[start_milliseconds:end_milliseconds]
|
||||
|
||||
print("save audio")
|
||||
print(f".temp/{count}_speaker_" + str(outp[2]) + ".wav")
|
||||
cut_audio.export(f".temp/{count}_speaker_" + str(outp[2]) + ".wav", format="wav")
|
||||
|
||||
return os.path.realpath(folder)
|
||||
diarization_model = Pipeline.from_pretrained(model, use_auth_token = token,
|
||||
*args, **kwargs)
|
||||
|
||||
return cls(diarization_model)
|
||||
|
||||
def __repr__(self):
|
||||
return f"Diarization(audiofile={self.audiofile}, model={self.model}, language={self.language})"
|
||||
return f"Diarisation(model={self.model})"
|
||||
def __str__(self):
|
||||
return f"Diarization(audiofile={self.audiofile}, model={self.model}, language={self.language})"
|
||||
return f"Diarisation(model={self.model})"
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
|
||||
model = Diarisation.load_model()
|
||||
print(model)
|
||||
audiofile = "/home/jacob/PycharmProjects/autotranscript/tests/test.wav"
|
||||
out = model.diarization(audiofile)
|
||||
print(out)
|
||||
|
||||
# # deprecated
|
||||
# def create_temporary_wav(self, location_of_temp_folder : str = '.temp'):
|
||||
# """
|
||||
# Create temporary wav file for diarization
|
||||
# :param location_of_temp_folder: folder to save the temporary wav file
|
||||
# default: .temp
|
||||
# :param savename: name of the temporary wav file prefix
|
||||
# :param audiofile: audio file
|
||||
# :return: temporary wav file
|
||||
# """
|
||||
# print("Linne 84 Diarisation.py create_temporary_wav :" /
|
||||
# "location_of_temp_folder.split('/')[-1]",location_of_temp_folder.split('/')[-1])
|
||||
|
||||
# if location_of_temp_folder.split('/')[-1] != '.temp':
|
||||
# folder =os.path.join(location_of_temp_folder, '.temp')
|
||||
# else:
|
||||
# folder = location_of_temp_folder
|
||||
|
||||
# if not os.path.exists(folder):
|
||||
# os.makedirs(folder)
|
||||
|
||||
# folder = os.path.realpath(folder)
|
||||
|
||||
# if not hasattr(self, 'normalized_output') or not hasattr(self, 'diarization_output'):
|
||||
# raise AttributeError("You need to run the diarization first")
|
||||
|
||||
# speaker = set(self.diarization_output["speakers"])
|
||||
# num_speak_iter = [0 for _ in range(len(speaker))]
|
||||
|
||||
# for count, outp in enumerate(self.normalized_output):
|
||||
# print(outp)
|
||||
# print(self.diarization_output["segments"][outp[0]])
|
||||
# print(self.diarization_output["segments"][outp[1]])
|
||||
|
||||
# start = self.diarization_output["segments"][outp[0]].start
|
||||
# end = self.diarization_output["segments"][outp[1]].end
|
||||
|
||||
# print("start: ", start)
|
||||
# print("end: ", end)
|
||||
|
||||
# start_milliseconds = start * 1000
|
||||
# end_milliseconds = end * 1000
|
||||
|
||||
# print("start_milliseconds: ", start_milliseconds)
|
||||
# print("end_milliseconds: ", end_milliseconds)
|
||||
|
||||
# print("cut audio")
|
||||
|
||||
# cut_audio = self.audio_file[start_milliseconds:end_milliseconds]
|
||||
|
||||
# print("save audio")
|
||||
# print(f".temp/{count}_speaker_" + str(outp[2]) + ".wav")
|
||||
# cut_audio.export(f".temp/{count}_speaker_" + str(outp[2]) + ".wav", format="wav")
|
||||
|
||||
# return os.path.realpath(folder)
|
||||
Reference in New Issue
Block a user