final codebase rework
This commit is contained in:
@@ -125,6 +125,17 @@ class AutoTranscribe:
|
|||||||
|
|
||||||
diarisation = self.diariser.diarization(dia_audio, **kwargs)
|
diarisation = self.diariser.diarization(dia_audio, **kwargs)
|
||||||
|
|
||||||
|
if not diarisation["segments"]:
|
||||||
|
warn("No segments found. Try to run transcription without diarisation.")
|
||||||
|
transcript = self.transcriber.transcribe(audio_file.waveform, **kwargs)
|
||||||
|
|
||||||
|
final_transcript= {"speakers" : ["speaker01"],
|
||||||
|
"segments" : [0, len(audio_file.waveform)],
|
||||||
|
"text" : transcript}
|
||||||
|
|
||||||
|
return Transcript(final_transcript)
|
||||||
|
|
||||||
|
|
||||||
print("Diarisation finished. Starting transcription.")
|
print("Diarisation finished. Starting transcription.")
|
||||||
|
|
||||||
audio_file.sr = torch.Tensor([audio_file.sr]).to(audio_file.waveform.device)
|
audio_file.sr = torch.Tensor([audio_file.sr]).to(audio_file.waveform.device)
|
||||||
@@ -140,8 +151,8 @@ class AutoTranscribe:
|
|||||||
|
|
||||||
transcript = self.transcriber.transcribe(audio, **kwargs)
|
transcript = self.transcriber.transcribe(audio, **kwargs)
|
||||||
|
|
||||||
final_transcript[i] = {"speaker" : diarisation["speakers"][i],
|
final_transcript[i] = {"speakers" : diarisation["speakers"][i],
|
||||||
"segment" : seg,
|
"segments" : seg,
|
||||||
"text" : transcript}
|
"text" : transcript}
|
||||||
|
|
||||||
# Remove original file if needed
|
# Remove original file if needed
|
||||||
@@ -233,6 +244,7 @@ def cli():
|
|||||||
from whisper.tokenizer import LANGUAGES , TO_LANGUAGE_CODE
|
from whisper.tokenizer import LANGUAGES , TO_LANGUAGE_CODE
|
||||||
from .transcriber import WHISPER_DEFAULT_PATH
|
from .transcriber import WHISPER_DEFAULT_PATH
|
||||||
from .diarisation import PYANNOTE_DEFAULT_PATH
|
from .diarisation import PYANNOTE_DEFAULT_PATH
|
||||||
|
|
||||||
def str2bool(string):
|
def str2bool(string):
|
||||||
str2val = {"True": True, "False": False}
|
str2val = {"True": True, "False": False}
|
||||||
if string in str2val:
|
if string in str2val:
|
||||||
@@ -242,9 +254,12 @@ def cli():
|
|||||||
|
|
||||||
parser = argparse.ArgumentParser(formatter_class=argparse.ArgumentDefaultsHelpFormatter)
|
parser = argparse.ArgumentParser(formatter_class=argparse.ArgumentDefaultsHelpFormatter)
|
||||||
|
|
||||||
parser.add_argument("audio_files", nargs="+", type=str,
|
parser.add_argument("-f","--audio_files", nargs="+", type=str,
|
||||||
help="List of audio files to transcribe.")
|
help="List of audio files to transcribe.")
|
||||||
|
|
||||||
|
parser.add_argument('--start_server', action='store_true',
|
||||||
|
help='Start the Gradio app.')
|
||||||
|
|
||||||
parser.add_argument("--whisper_model_name", default="medium",
|
parser.add_argument("--whisper_model_name", default="medium",
|
||||||
help="Name of the Whisper model to use.")
|
help="Name of the Whisper model to use.")
|
||||||
|
|
||||||
@@ -299,6 +314,7 @@ def cli():
|
|||||||
audio_files = args.audio_files
|
audio_files = args.audio_files
|
||||||
spoken_language = args.spoken_language
|
spoken_language = args.spoken_language
|
||||||
output_format = args.output_format
|
output_format = args.output_format
|
||||||
|
start_server = args.start_server
|
||||||
|
|
||||||
os.makedirs(output_directory, exist_ok=True)
|
os.makedirs(output_directory, exist_ok=True)
|
||||||
|
|
||||||
@@ -336,5 +352,9 @@ def cli():
|
|||||||
# wtranscribe code here
|
# wtranscribe code here
|
||||||
pass
|
pass
|
||||||
|
|
||||||
|
if start_server:
|
||||||
|
from .gradio_app import gradio_app
|
||||||
|
gradio_app(model)
|
||||||
|
|
||||||
if __name__ == "__main__":
|
if __name__ == "__main__":
|
||||||
cli()
|
cli()
|
||||||
@@ -65,7 +65,7 @@ class Transcript:
|
|||||||
list: List of unique speaker names in the transcript.
|
list: List of unique speaker names in the transcript.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
return list(set([self.transcript[id]["speaker"] for id in self.transcript]))
|
return list(set([self.transcript[id]["speakers"] for id in self.transcript]))
|
||||||
|
|
||||||
def _extract_segments(self) -> list:
|
def _extract_segments(self) -> list:
|
||||||
"""
|
"""
|
||||||
@@ -75,7 +75,7 @@ class Transcript:
|
|||||||
list: List of segments, where each segment is represented
|
list: List of segments, where each segment is represented
|
||||||
by the starting and ending times.
|
by the starting and ending times.
|
||||||
"""
|
"""
|
||||||
return [self.transcript[id]["segment"] for id in self.transcript]
|
return [self.transcript[id]["segments"] for id in self.transcript]
|
||||||
|
|
||||||
def __str__(self) -> str:
|
def __str__(self) -> str:
|
||||||
"""
|
"""
|
||||||
@@ -91,11 +91,11 @@ class Transcript:
|
|||||||
seq = self.transcript[_id]
|
seq = self.transcript[_id]
|
||||||
|
|
||||||
if self.annotation:
|
if self.annotation:
|
||||||
speaker = self.annotation[seq["speaker"]]
|
speaker = self.annotation[seq["speakers"]]
|
||||||
else:
|
else:
|
||||||
speaker = seq["speaker"]
|
speaker = seq["speakers"]
|
||||||
|
|
||||||
segm = seq["segment"]
|
segm = seq["segments"]
|
||||||
sseg = time.strftime("%H:%M:%S",time.gmtime(segm[0]))
|
sseg = time.strftime("%H:%M:%S",time.gmtime(segm[0]))
|
||||||
eseg = time.strftime("%H:%M:%S",time.gmtime(segm[1]))
|
eseg = time.strftime("%H:%M:%S",time.gmtime(segm[1]))
|
||||||
|
|
||||||
@@ -172,7 +172,7 @@ class Transcript:
|
|||||||
|
|
||||||
for id in self.transcript:
|
for id in self.transcript:
|
||||||
seq = self.transcript[id]
|
seq = self.transcript[id]
|
||||||
speaker = self.annotation[seq["speaker"]]
|
speaker = self.annotation[seq["speakers"]]
|
||||||
fstring += f"\n\\{speaker}speaks:\n{seq['text']}"
|
fstring += f"\n\\{speaker}speaks:\n{seq['text']}"
|
||||||
|
|
||||||
fstring += "\n\\end{drama}"
|
fstring += "\n\\end{drama}"
|
||||||
|
|||||||
@@ -2,7 +2,7 @@ import os
|
|||||||
import subprocess as sp
|
import subprocess as sp
|
||||||
|
|
||||||
MAJOR = 0
|
MAJOR = 0
|
||||||
MINOR = 2
|
MINOR = 1
|
||||||
MICRO = 0
|
MICRO = 0
|
||||||
MICRO_POST = 0
|
MICRO_POST = 0
|
||||||
ISRELEASED = False
|
ISRELEASED = False
|
||||||
|
|||||||
@@ -0,0 +1,65 @@
|
|||||||
|
from autotranscript import AutoTranscribe
|
||||||
|
import gradio as gr
|
||||||
|
|
||||||
|
LANGUAGES = [
|
||||||
|
"Afrikaans", "Arabic", "Armenian", "Azerbaijani", "Belarusian",
|
||||||
|
"Bosnian", "Bulgarian", "Catalan", "Chinese", "Croatian",
|
||||||
|
"Czech", "Danish", "Dutch", "English", "Estonian",
|
||||||
|
"Finnish", "French", "Galician", "German", "Greek",
|
||||||
|
"Hebrew", "Hindi", "Hungarian", "Icelandic", "Indonesian",
|
||||||
|
"Italian", "Japanese", "Kannada", "Kazakh", "Korean",
|
||||||
|
"Latvian", "Lithuanian", "Macedonian", "Malay", "Marathi",
|
||||||
|
"Maori", "Nepali", "Norwegian", "Persian", "Polish",
|
||||||
|
"Portuguese", "Romanian", "Russian", "Serbian", "Slovak",
|
||||||
|
"Slovenian", "Spanish", "Swahili", "Swedish", "Tagalog",
|
||||||
|
"Tamil", "Thai", "Turkish", "Ukrainian", "Urdu",
|
||||||
|
"Vietnamese", "Welsh"
|
||||||
|
]
|
||||||
|
|
||||||
|
|
||||||
|
def gradio_server(model : AutoTranscribe):
|
||||||
|
|
||||||
|
def transcribe(audio, microphone, number_of_speakers, language):
|
||||||
|
kwargs = {}
|
||||||
|
if number_of_speakers != 0:
|
||||||
|
kwargs["num_speakers"] = number_of_speakers
|
||||||
|
if language != "None":
|
||||||
|
kwargs["language"] = language
|
||||||
|
|
||||||
|
if audio is not None:
|
||||||
|
out = model.transcribe(audio, **kwargs)
|
||||||
|
elif microphone is not None:
|
||||||
|
out = model.transcribe(microphone , **kwargs)
|
||||||
|
else:
|
||||||
|
out = "Please upload an audio file or record one."
|
||||||
|
|
||||||
|
|
||||||
|
return str(out)
|
||||||
|
|
||||||
|
gr.Interface(
|
||||||
|
fn=transcribe,
|
||||||
|
inputs=[
|
||||||
|
gr.Audio(source= "upload", type="filepath", label="Upload Your Audio File", interactive=True),
|
||||||
|
gr.Audio(source= "microphone", type="filepath", label="Record Your Audio", interactive=True),
|
||||||
|
gr.Number(value=0, label= "Number of speakers",
|
||||||
|
info = "Number of speakers in the audio file. If you don't know, leave it at 0."),
|
||||||
|
# gr.Number(value=0, label= "Minimal number of speakers",
|
||||||
|
# info = "Minimal number of speakers in the audio file. If you don't know or you have specified Numspeakers, leave it at 0."),
|
||||||
|
gr.Dropdown(LANGUAGES,
|
||||||
|
label="Languages", default="None",
|
||||||
|
info="Language of the audio file. If you don't know, leave it at None.")
|
||||||
|
],
|
||||||
|
outputs=[
|
||||||
|
"text"
|
||||||
|
],
|
||||||
|
title="Audio Transcription",
|
||||||
|
thumbnail = "Logo_KIDA.png",
|
||||||
|
description="Upload an audio file to transcribe its content. Powered by AutoTranscribe!",
|
||||||
|
theme="soft", # Example of a more modern theme
|
||||||
|
).launch(share=True)
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
|
||||||
|
model = AutoTranscribe()
|
||||||
|
gradio_server(model)
|
||||||
@@ -9,10 +9,6 @@ pyannote.pipeline~=2.3
|
|||||||
setuptools~=65.6.3
|
setuptools~=65.6.3
|
||||||
setuptools-rust~=1.5.2
|
setuptools-rust~=1.5.2
|
||||||
|
|
||||||
torch~=1.11.0
|
|
||||||
torchaudio~=0.11.0
|
|
||||||
torchmetrics~=0.11.0
|
|
||||||
torchvision~=0.12.0
|
|
||||||
tqdm>=4.65.0
|
tqdm>=4.65.0
|
||||||
|
|
||||||
#optional:
|
#optional:
|
||||||
|
|||||||
+32
-2
@@ -1,8 +1,38 @@
|
|||||||
from autotranscript.autotranscript import AutoTranscribe
|
# import os
|
||||||
|
# import sys
|
||||||
|
# import traceback
|
||||||
|
|
||||||
|
# class TracePrints(object):
|
||||||
|
# def __init__(self):
|
||||||
|
# self.stdout = sys.stdout
|
||||||
|
# def write(self, s):
|
||||||
|
# self.stdout.write("Writing %r\n" % s)
|
||||||
|
# traceback.print_stack(file=self.stdout)
|
||||||
|
|
||||||
|
# sys.stdout = TracePrints()
|
||||||
|
|
||||||
|
# os.environ["PYANNOTE_CACHE"] = os.path.expanduser("~/PycharmProjects/autotranscript/autotranscript/models/pyannote")
|
||||||
|
# import os
|
||||||
|
|
||||||
|
# os.environ['TRANSFORMERS_CACHE'] = os.path.expanduser("~/PycharmProjects/autotranscript/autotranscript/models")
|
||||||
|
# os.environ['HF_HOME'] = os.path.expanduser("~/PycharmProjects/autotranscript/autotranscript/models")
|
||||||
|
|
||||||
|
|
||||||
|
from autotranscript import AutoTranscribe
|
||||||
|
|
||||||
model = AutoTranscribe()
|
model = AutoTranscribe()
|
||||||
|
|
||||||
text = model.transcribe("tests/test.wav")
|
text = model.transcribe("test.mp4")
|
||||||
|
|
||||||
print("Transcription:\n")
|
print("Transcription:\n")
|
||||||
print(text)
|
print(text)
|
||||||
|
|
||||||
|
|
||||||
|
# from autotranscript.misc import *
|
||||||
|
# import os
|
||||||
|
|
||||||
|
# print(os.path.exists(CACHE_DIR))
|
||||||
|
# print(os.path.exists(WHISPER_DEFAULT_PATH))
|
||||||
|
# print(os.path.exists(PYANNOTE_DEFAULT_PATH))
|
||||||
|
|
||||||
|
# print(os.path.exists(PYANNOTE_DEFAULT_CONFIG))
|
||||||
|
|||||||
Reference in New Issue
Block a user