final codebase rework

This commit is contained in:
Jaikinator
2023-08-24 16:12:28 +02:00
parent dc79fed6af
commit e331fe98f3
6 changed files with 128 additions and 17 deletions
+24 -4
View File
@@ -125,6 +125,17 @@ class AutoTranscribe:
diarisation = self.diariser.diarization(dia_audio, **kwargs)
if not diarisation["segments"]:
warn("No segments found. Try to run transcription without diarisation.")
transcript = self.transcriber.transcribe(audio_file.waveform, **kwargs)
final_transcript= {"speakers" : ["speaker01"],
"segments" : [0, len(audio_file.waveform)],
"text" : transcript}
return Transcript(final_transcript)
print("Diarisation finished. Starting transcription.")
audio_file.sr = torch.Tensor([audio_file.sr]).to(audio_file.waveform.device)
@@ -140,8 +151,8 @@ class AutoTranscribe:
transcript = self.transcriber.transcribe(audio, **kwargs)
final_transcript[i] = {"speaker" : diarisation["speakers"][i],
"segment" : seg,
final_transcript[i] = {"speakers" : diarisation["speakers"][i],
"segments" : seg,
"text" : transcript}
# Remove original file if needed
@@ -233,6 +244,7 @@ def cli():
from whisper.tokenizer import LANGUAGES , TO_LANGUAGE_CODE
from .transcriber import WHISPER_DEFAULT_PATH
from .diarisation import PYANNOTE_DEFAULT_PATH
def str2bool(string):
str2val = {"True": True, "False": False}
if string in str2val:
@@ -242,9 +254,12 @@ def cli():
parser = argparse.ArgumentParser(formatter_class=argparse.ArgumentDefaultsHelpFormatter)
parser.add_argument("audio_files", nargs="+", type=str,
parser.add_argument("-f","--audio_files", nargs="+", type=str,
help="List of audio files to transcribe.")
parser.add_argument('--start_server', action='store_true',
help='Start the Gradio app.')
parser.add_argument("--whisper_model_name", default="medium",
help="Name of the Whisper model to use.")
@@ -299,6 +314,7 @@ def cli():
audio_files = args.audio_files
spoken_language = args.spoken_language
output_format = args.output_format
start_server = args.start_server
os.makedirs(output_directory, exist_ok=True)
@@ -335,6 +351,10 @@ def cli():
elif transcription_task == "wtranscribe":
# wtranscribe code here
pass
if start_server:
from .gradio_app import gradio_app
gradio_app(model)
if __name__ == "__main__":
cli()