final codebase rework
This commit is contained in:
@@ -125,6 +125,17 @@ class AutoTranscribe:
|
||||
|
||||
diarisation = self.diariser.diarization(dia_audio, **kwargs)
|
||||
|
||||
if not diarisation["segments"]:
|
||||
warn("No segments found. Try to run transcription without diarisation.")
|
||||
transcript = self.transcriber.transcribe(audio_file.waveform, **kwargs)
|
||||
|
||||
final_transcript= {"speakers" : ["speaker01"],
|
||||
"segments" : [0, len(audio_file.waveform)],
|
||||
"text" : transcript}
|
||||
|
||||
return Transcript(final_transcript)
|
||||
|
||||
|
||||
print("Diarisation finished. Starting transcription.")
|
||||
|
||||
audio_file.sr = torch.Tensor([audio_file.sr]).to(audio_file.waveform.device)
|
||||
@@ -140,8 +151,8 @@ class AutoTranscribe:
|
||||
|
||||
transcript = self.transcriber.transcribe(audio, **kwargs)
|
||||
|
||||
final_transcript[i] = {"speaker" : diarisation["speakers"][i],
|
||||
"segment" : seg,
|
||||
final_transcript[i] = {"speakers" : diarisation["speakers"][i],
|
||||
"segments" : seg,
|
||||
"text" : transcript}
|
||||
|
||||
# Remove original file if needed
|
||||
@@ -233,6 +244,7 @@ def cli():
|
||||
from whisper.tokenizer import LANGUAGES , TO_LANGUAGE_CODE
|
||||
from .transcriber import WHISPER_DEFAULT_PATH
|
||||
from .diarisation import PYANNOTE_DEFAULT_PATH
|
||||
|
||||
def str2bool(string):
|
||||
str2val = {"True": True, "False": False}
|
||||
if string in str2val:
|
||||
@@ -242,9 +254,12 @@ def cli():
|
||||
|
||||
parser = argparse.ArgumentParser(formatter_class=argparse.ArgumentDefaultsHelpFormatter)
|
||||
|
||||
parser.add_argument("audio_files", nargs="+", type=str,
|
||||
parser.add_argument("-f","--audio_files", nargs="+", type=str,
|
||||
help="List of audio files to transcribe.")
|
||||
|
||||
|
||||
parser.add_argument('--start_server', action='store_true',
|
||||
help='Start the Gradio app.')
|
||||
|
||||
parser.add_argument("--whisper_model_name", default="medium",
|
||||
help="Name of the Whisper model to use.")
|
||||
|
||||
@@ -299,6 +314,7 @@ def cli():
|
||||
audio_files = args.audio_files
|
||||
spoken_language = args.spoken_language
|
||||
output_format = args.output_format
|
||||
start_server = args.start_server
|
||||
|
||||
os.makedirs(output_directory, exist_ok=True)
|
||||
|
||||
@@ -335,6 +351,10 @@ def cli():
|
||||
elif transcription_task == "wtranscribe":
|
||||
# wtranscribe code here
|
||||
pass
|
||||
|
||||
if start_server:
|
||||
from .gradio_app import gradio_app
|
||||
gradio_app(model)
|
||||
|
||||
if __name__ == "__main__":
|
||||
cli()
|
||||
@@ -65,7 +65,7 @@ class Transcript:
|
||||
list: List of unique speaker names in the transcript.
|
||||
"""
|
||||
|
||||
return list(set([self.transcript[id]["speaker"] for id in self.transcript]))
|
||||
return list(set([self.transcript[id]["speakers"] for id in self.transcript]))
|
||||
|
||||
def _extract_segments(self) -> list:
|
||||
"""
|
||||
@@ -75,7 +75,7 @@ class Transcript:
|
||||
list: List of segments, where each segment is represented
|
||||
by the starting and ending times.
|
||||
"""
|
||||
return [self.transcript[id]["segment"] for id in self.transcript]
|
||||
return [self.transcript[id]["segments"] for id in self.transcript]
|
||||
|
||||
def __str__(self) -> str:
|
||||
"""
|
||||
@@ -91,11 +91,11 @@ class Transcript:
|
||||
seq = self.transcript[_id]
|
||||
|
||||
if self.annotation:
|
||||
speaker = self.annotation[seq["speaker"]]
|
||||
speaker = self.annotation[seq["speakers"]]
|
||||
else:
|
||||
speaker = seq["speaker"]
|
||||
speaker = seq["speakers"]
|
||||
|
||||
segm = seq["segment"]
|
||||
segm = seq["segments"]
|
||||
sseg = time.strftime("%H:%M:%S",time.gmtime(segm[0]))
|
||||
eseg = time.strftime("%H:%M:%S",time.gmtime(segm[1]))
|
||||
|
||||
@@ -172,7 +172,7 @@ class Transcript:
|
||||
|
||||
for id in self.transcript:
|
||||
seq = self.transcript[id]
|
||||
speaker = self.annotation[seq["speaker"]]
|
||||
speaker = self.annotation[seq["speakers"]]
|
||||
fstring += f"\n\\{speaker}speaks:\n{seq['text']}"
|
||||
|
||||
fstring += "\n\\end{drama}"
|
||||
|
||||
@@ -2,7 +2,7 @@ import os
|
||||
import subprocess as sp
|
||||
|
||||
MAJOR = 0
|
||||
MINOR = 2
|
||||
MINOR = 1
|
||||
MICRO = 0
|
||||
MICRO_POST = 0
|
||||
ISRELEASED = False
|
||||
|
||||
Reference in New Issue
Block a user