diff --git a/autotranscript/diarisation.py b/autotranscript/diarisation.py index 123c692..55fd0cb 100644 --- a/autotranscript/diarisation.py +++ b/autotranscript/diarisation.py @@ -1,7 +1,7 @@ from pyannote.audio import Pipeline -from time import time +from torch import Tensor import os -from typing import TypeVar +from typing import TypeVar, Union Annotation = TypeVar('Annotation') @@ -9,15 +9,16 @@ PYANNOTE_DEFAULT_PATH = os.path.join(os.path.dirname(os.path.dirname(__file__)), "models", "pyannote", "speaker_diarization", "config.yaml") -class Diarisation: +class Diariser: def __init__(self, model,*args,**kwargs) -> None: self.model = model - def diarization(self, audiofile : str , *args, **kwargs) -> Annotation: + def diarization(self, audiofile : Union[str, Tensor] , + *args, **kwargs) -> Annotation: """ Diarization of audio file - :param audiofile: path to audio file + :param audiofile: path to audio file or torch.Tensor :param args: args for diarization model :param kwargs: kwargs for diarization model :return: diarization @@ -83,17 +84,21 @@ class Diarisation: diarization_output["speakers"].append(outp[2]) return diarization_output + @staticmethod def _get_token(): # check ig .pyannotetoken.txt exists - path = os.path.join(os.path.dirname(os.path.realpath(__file__)), '.pyannotetoken') + path = os.path.join(os.path.dirname( + os.path.realpath(__file__)), '.pyannotetoken') if os.path.exists(path): with open(path, 'r') as f: token = f.read() else: - raise ValueError('No token found. Please create a token at https://huggingface.co/settings/token' - ' and save it in a file called .pyannotetoken.txt') + raise ValueError('No token found.' \ + 'Please create a token at https://huggingface.co/settings/token' \ + 'and save it in a file called .pyannotetoken.txt') return token + @classmethod def load_model(cls, model: str = PYANNOTE_DEFAULT_PATH, token: str = "", @@ -129,69 +134,7 @@ class Diarisation: def __repr__(self): return f"Diarisation(model={self.model})" + def __str__(self): return f"Diarisation(model={self.model})" - -if __name__ == '__main__': - - model = Diarisation.load_model() - print(model) - audiofile = "/home/jacob/PycharmProjects/autotranscript/tests/test.wav" - out = model.diarization(audiofile) - - # # deprecated - # def create_temporary_wav(self, location_of_temp_folder : str = '.temp'): - # """ - # Create temporary wav file for diarization - # :param location_of_temp_folder: folder to save the temporary wav file - # default: .temp - # :param savename: name of the temporary wav file prefix - # :param audiofile: audio file - # :return: temporary wav file - # """ - # print("Linne 84 Diarisation.py create_temporary_wav :" / - # "location_of_temp_folder.split('/')[-1]",location_of_temp_folder.split('/')[-1]) - - # if location_of_temp_folder.split('/')[-1] != '.temp': - # folder =os.path.join(location_of_temp_folder, '.temp') - # else: - # folder = location_of_temp_folder - - # if not os.path.exists(folder): - # os.makedirs(folder) - - # folder = os.path.realpath(folder) - - # if not hasattr(self, 'normalized_output') or not hasattr(self, 'diarization_output'): - # raise AttributeError("You need to run the diarization first") - - # speaker = set(self.diarization_output["speakers"]) - # num_speak_iter = [0 for _ in range(len(speaker))] - - # for count, outp in enumerate(self.normalized_output): - # print(outp) - # print(self.diarization_output["segments"][outp[0]]) - # print(self.diarization_output["segments"][outp[1]]) - - # start = self.diarization_output["segments"][outp[0]].start - # end = self.diarization_output["segments"][outp[1]].end - - # print("start: ", start) - # print("end: ", end) - - # start_milliseconds = start * 1000 - # end_milliseconds = end * 1000 - - # print("start_milliseconds: ", start_milliseconds) - # print("end_milliseconds: ", end_milliseconds) - - # print("cut audio") - - # cut_audio = self.audio_file[start_milliseconds:end_milliseconds] - - # print("save audio") - # print(f".temp/{count}_speaker_" + str(outp[2]) + ".wav") - # cut_audio.export(f".temp/{count}_speaker_" + str(outp[2]) + ".wav", format="wav") - - # return os.path.realpath(folder) \ No newline at end of file