Auto fixes from PEP8, fixes from flake8.

This commit is contained in:
Marko Henning
2024-05-15 15:18:17 +02:00
parent 9f526a8f3b
commit 4bcd28d0ea
15 changed files with 391 additions and 417 deletions
+97 -94
View File
@@ -55,18 +55,19 @@ class Scraibe:
Attributes:
transcriber (Transcriber): The transcriber object to handle transcription.
diariser (Diariser): The diariser object to handle diarization.
Methods:
__init__: Initializes the Scraibe class with appropriate models.
transcribe: Transcribes an audio file using the whisper model and pyannote diarization model.
remove_audio_file: Removes the original audio file to avoid disk space issues or ensure data privacy.
get_audio_file: Gets an audio file as an AudioProcessor object.
"""
def __init__(self,
whisper_model: Union[bool, str, whisper] = None,
whisper_type: str = "whisper",
dia_model : Union[bool, str, DiarisationType] = None,
**kwargs) -> None:
whisper_model: Union[bool, str, whisper] = None,
whisper_type: str = "whisper",
dia_model: Union[bool, str, DiarisationType] = None,
**kwargs) -> None:
"""Initializes the Scraibe class.
Args:
@@ -84,12 +85,13 @@ class Scraibe:
- save_kwargs: If True, the keyword arguments will be saved
for autotranscribe. So you can unload the class and reload it again.
"""
if whisper_model is None:
self.transcriber = Transcriber.load_model("medium", whisper_type, **kwargs)
self.transcriber = Transcriber.load_model(
"medium", whisper_type, **kwargs)
elif isinstance(whisper_model, str):
self.transcriber = Transcriber.load_model(whisper_model, whisper_type, **kwargs)
self.transcriber = Transcriber.load_model(
whisper_model, whisper_type, **kwargs)
else:
self.transcriber = whisper_model
@@ -98,26 +100,25 @@ class Scraibe:
elif isinstance(dia_model, str):
self.diariser = Diariser.load_model(dia_model, **kwargs)
else:
self.diariser : Diariser = dia_model
self.diariser: Diariser = dia_model
if kwargs.get("verbose"):
print("Scraibe initialized all models successfully loaded.")
self.verbose = True
else:
self.verbose = False
# Save kwargs for autotranscribe if you want to unload the class and load it again.
if kwargs.get('save_setup'):
self.params = dict(whisper_model = whisper_model,
dia_model = dia_model,
if kwargs.get('save_setup'):
self.params = dict(whisper_model=whisper_model,
dia_model=dia_model,
**kwargs)
else:
self.params = {}
def autotranscribe(self, audio_file : Union[str, torch.Tensor, ndarray],
remove_original : bool = False,
**kwargs) -> Transcript:
def autotranscribe(self, audio_file: Union[str, torch.Tensor, ndarray],
remove_original: bool = False,
**kwargs) -> Transcript:
"""
Transcribes an audio file using the whisper model and pyannote diarization model.
@@ -136,60 +137,62 @@ class Scraibe:
if kwargs.get("verbose"):
self.verbose = kwargs.get("verbose")
# Get audio file as an AudioProcessor object
audio_file : AudioProcessor = self.get_audio_file(audio_file)
audio_file: AudioProcessor = self.get_audio_file(audio_file)
# Prepare waveform and sample rate for diarization
dia_audio = {
"waveform" : audio_file.waveform.reshape(1,len(audio_file.waveform)),
"waveform": audio_file.waveform.reshape(1, len(audio_file.waveform)),
"sample_rate": audio_file.sr
}
}
if self.verbose:
print("Starting diarisation.")
diarisation = self.diariser.diarization(dia_audio, **kwargs)
if not diarisation["segments"]:
print("No segments found. Try to run transcription without diarisation.")
transcript = self.transcriber.transcribe(audio_file.waveform, **kwargs)
final_transcript= {0 : {"speakers" : 'SPEAKER_01',
"segments" : [0, len(audio_file.waveform)],
"text" : transcript}}
transcript = self.transcriber.transcribe(
audio_file.waveform, **kwargs)
final_transcript = {0: {"speakers": 'SPEAKER_01',
"segments": [0, len(audio_file.waveform)],
"text": transcript}}
return Transcript(final_transcript)
if self.verbose:
print("Diarisation finished. Starting transcription.")
audio_file.sr = torch.Tensor([audio_file.sr]).to(audio_file.waveform.device)
audio_file.sr = torch.Tensor([audio_file.sr]).to(
audio_file.waveform.device)
# Transcribe each segment and store the results
final_transcript = dict()
for i in trange(len(diarisation["segments"]), desc= "Transcribing", disable = not self.verbose):
for i in trange(len(diarisation["segments"]), desc="Transcribing", disable=not self.verbose):
seg = diarisation["segments"][i]
audio = audio_file.cut(seg[0], seg[1])
transcript = self.transcriber.transcribe(audio, **kwargs)
final_transcript[i] = {"speakers" : diarisation["speakers"][i],
"segments" : seg,
"text" : transcript}
# Remove original file if needed
final_transcript[i] = {"speakers": diarisation["speakers"][i],
"segments": seg,
"text": transcript}
# Remove original file if needed
if remove_original:
if kwargs.get("shred") is True:
self.remove_audio_file(audio_file, shred=True)
else:
self.remove_audio_file(audio_file, shred=False)
return Transcript(final_transcript)
def diarization(self, audio_file : Union[str, torch.Tensor, ndarray],
def diarization(self, audio_file: Union[str, torch.Tensor, ndarray],
**kwargs) -> dict:
"""
Perform diarization on an audio file using the pyannote diarization model.
@@ -204,24 +207,24 @@ class Scraibe:
dict:
A dictionary containing the results of the diarization process.
"""
# Get audio file as an AudioProcessor object
audio_file : AudioProcessor = self.get_audio_file(audio_file)
audio_file: AudioProcessor = self.get_audio_file(audio_file)
# Prepare waveform and sample rate for diarization
dia_audio = {
"waveform" : audio_file.waveform.reshape(1,len(audio_file.waveform)),
"waveform": audio_file.waveform.reshape(1, len(audio_file.waveform)),
"sample_rate": audio_file.sr
}
}
print("Starting diarisation.")
diarisation = self.diariser.diarization(dia_audio, **kwargs)
return diarisation
def transcribe(self, audio_file : Union[str, torch.Tensor, ndarray],
**kwargs):
def transcribe(self, audio_file: Union[str, torch.Tensor, ndarray],
**kwargs):
"""
Transcribe the provided audio file.
@@ -235,11 +238,11 @@ class Scraibe:
str:
The transcribed text from the audio source.
"""
audio_file : AudioProcessor = self.get_audio_file(audio_file)
return self.transcriber.transcribe(audio_file.waveform, **kwargs)
def update_transcriber(self, whisper_model : Union[str, whisper], **kwargs) -> None:
audio_file: AudioProcessor = self.get_audio_file(audio_file)
return self.transcriber.transcribe(audio_file.waveform, **kwargs)
def update_transcriber(self, whisper_model: Union[str, whisper], **kwargs) -> None:
"""
Update the transcriber model.
@@ -248,22 +251,23 @@ class Scraibe:
The new whisper model to use for transcription.
**kwargs:
Additional keyword arguments for the transcriber model.
Returns:
None
"""
_old_model = self.transcriber.model_name
if isinstance(whisper_model, str):
self.transcriber = Transcriber.load_model(whisper_model, **kwargs)
elif isinstance(whisper_model, Transcriber):
self.transcriber = whisper_model
else:
warn(f"Invalid model type. Please provide a valid model. Fallback to old {_old_model} Model.", RuntimeWarning)
warn(
f"Invalid model type. Please provide a valid model. Fallback to old {_old_model} Model.", RuntimeWarning)
return None
def update_diariser(self, dia_model : Union[str, DiarisationType], **kwargs) -> None:
def update_diariser(self, dia_model: Union[str, DiarisationType], **kwargs) -> None:
"""
Update the diariser model.
@@ -272,7 +276,7 @@ class Scraibe:
The new diariser model to use for diarization.
**kwargs:
Additional keyword arguments for the diariser model.
Returns:
None
"""
@@ -281,13 +285,13 @@ class Scraibe:
elif isinstance(dia_model, Diariser):
self.diariser = dia_model
else:
warn(f"Invalid model type. Please provide a valid model. Fallback to old Model.", RuntimeWarning)
warn("Invalid model type. Please provide a valid model. Fallback to old Model.", RuntimeWarning)
return None
@staticmethod
def remove_audio_file(audio_file : str,
shred : bool = False) -> None:
def remove_audio_file(audio_file: str,
shred: bool = False) -> None:
"""
Removes the original audio file to avoid disk space issues or ensure data privacy.
@@ -298,30 +302,29 @@ class Scraibe:
"""
if not os.path.exists(audio_file):
raise ValueError(f"Audiofile {audio_file} does not exist.")
if shred:
warn("Shredding audiofile can take a long time.", RuntimeWarning)
gen = iglob(f'{audio_file}', recursive=True)
cmd = ['shred', '-zvu', '-n', '10', f'{audio_file}']
if os.path.isdir(audio_file):
raise ValueError(f"Audiofile {audio_file} is a directory.")
for file in gen:
print(f'shredding {file} now\n')
run(cmd , check=True)
run(cmd, check=True)
else:
os.remove(audio_file)
print(f"Audiofile {audio_file} removed.")
@staticmethod
def get_audio_file(audio_file : Union[str, torch.Tensor, ndarray],
*args, **kwargs) -> AudioProcessor:
def get_audio_file(audio_file: Union[str, torch.Tensor, ndarray],
*args, **kwargs) -> AudioProcessor:
"""Gets an audio file as TorchAudioProcessor.
Args:
@@ -334,20 +337,20 @@ class Scraibe:
AudioProcessor: An object containing the waveform and sample rate in
torch.Tensor format.
"""
if isinstance(audio_file, str):
audio_file = AudioProcessor.from_file(audio_file)
audio_file = AudioProcessor.from_file(audio_file)
elif isinstance(audio_file, torch.Tensor):
audio_file = AudioProcessor(audio_file[0], audio_file[1])
elif isinstance(audio_file, ndarray):
audio_file = AudioProcessor(torch.Tensor(audio_file[0]),
audio_file[1])
audio_file[1])
if not isinstance(audio_file, AudioProcessor):
raise ValueError(f'Audiofile must be of type AudioProcessor,' \
f'not {type(audio_file)}')
raise ValueError(f'Audiofile must be of type AudioProcessor,'
f'not {type(audio_file)}')
return audio_file
def __repr__(self):