diff --git a/autotranscript/autotranscript.py b/autotranscript/autotranscript.py index 612f9e5..e053d6a 100644 --- a/autotranscript/autotranscript.py +++ b/autotranscript/autotranscript.py @@ -125,6 +125,17 @@ class AutoTranscribe: diarisation = self.diariser.diarization(dia_audio, **kwargs) + if not diarisation["segments"]: + warn("No segments found. Try to run transcription without diarisation.") + transcript = self.transcriber.transcribe(audio_file.waveform, **kwargs) + + final_transcript= {"speakers" : ["speaker01"], + "segments" : [0, len(audio_file.waveform)], + "text" : transcript} + + return Transcript(final_transcript) + + print("Diarisation finished. Starting transcription.") audio_file.sr = torch.Tensor([audio_file.sr]).to(audio_file.waveform.device) @@ -140,8 +151,8 @@ class AutoTranscribe: transcript = self.transcriber.transcribe(audio, **kwargs) - final_transcript[i] = {"speaker" : diarisation["speakers"][i], - "segment" : seg, + final_transcript[i] = {"speakers" : diarisation["speakers"][i], + "segments" : seg, "text" : transcript} # Remove original file if needed @@ -233,6 +244,7 @@ def cli(): from whisper.tokenizer import LANGUAGES , TO_LANGUAGE_CODE from .transcriber import WHISPER_DEFAULT_PATH from .diarisation import PYANNOTE_DEFAULT_PATH + def str2bool(string): str2val = {"True": True, "False": False} if string in str2val: @@ -242,9 +254,12 @@ def cli(): parser = argparse.ArgumentParser(formatter_class=argparse.ArgumentDefaultsHelpFormatter) - parser.add_argument("audio_files", nargs="+", type=str, + parser.add_argument("-f","--audio_files", nargs="+", type=str, help="List of audio files to transcribe.") - + + parser.add_argument('--start_server', action='store_true', + help='Start the Gradio app.') + parser.add_argument("--whisper_model_name", default="medium", help="Name of the Whisper model to use.") @@ -299,6 +314,7 @@ def cli(): audio_files = args.audio_files spoken_language = args.spoken_language output_format = args.output_format + start_server = args.start_server os.makedirs(output_directory, exist_ok=True) @@ -335,6 +351,10 @@ def cli(): elif transcription_task == "wtranscribe": # wtranscribe code here pass + + if start_server: + from .gradio_app import gradio_app + gradio_app(model) if __name__ == "__main__": cli() \ No newline at end of file diff --git a/autotranscript/transcript_exporter.py b/autotranscript/transcript_exporter.py index 42f2680..9262be6 100644 --- a/autotranscript/transcript_exporter.py +++ b/autotranscript/transcript_exporter.py @@ -65,7 +65,7 @@ class Transcript: list: List of unique speaker names in the transcript. """ - return list(set([self.transcript[id]["speaker"] for id in self.transcript])) + return list(set([self.transcript[id]["speakers"] for id in self.transcript])) def _extract_segments(self) -> list: """ @@ -75,7 +75,7 @@ class Transcript: list: List of segments, where each segment is represented by the starting and ending times. """ - return [self.transcript[id]["segment"] for id in self.transcript] + return [self.transcript[id]["segments"] for id in self.transcript] def __str__(self) -> str: """ @@ -91,11 +91,11 @@ class Transcript: seq = self.transcript[_id] if self.annotation: - speaker = self.annotation[seq["speaker"]] + speaker = self.annotation[seq["speakers"]] else: - speaker = seq["speaker"] + speaker = seq["speakers"] - segm = seq["segment"] + segm = seq["segments"] sseg = time.strftime("%H:%M:%S",time.gmtime(segm[0])) eseg = time.strftime("%H:%M:%S",time.gmtime(segm[1])) @@ -172,7 +172,7 @@ class Transcript: for id in self.transcript: seq = self.transcript[id] - speaker = self.annotation[seq["speaker"]] + speaker = self.annotation[seq["speakers"]] fstring += f"\n\\{speaker}speaks:\n{seq['text']}" fstring += "\n\\end{drama}" diff --git a/autotranscript/version.py b/autotranscript/version.py index 5bc7ffc..0a3730e 100644 --- a/autotranscript/version.py +++ b/autotranscript/version.py @@ -2,7 +2,7 @@ import os import subprocess as sp MAJOR = 0 -MINOR = 2 +MINOR = 1 MICRO = 0 MICRO_POST = 0 ISRELEASED = False diff --git a/gradio_app.py b/gradio_app.py new file mode 100644 index 0000000..321f8bc --- /dev/null +++ b/gradio_app.py @@ -0,0 +1,65 @@ +from autotranscript import AutoTranscribe +import gradio as gr + +LANGUAGES = [ + "Afrikaans", "Arabic", "Armenian", "Azerbaijani", "Belarusian", + "Bosnian", "Bulgarian", "Catalan", "Chinese", "Croatian", + "Czech", "Danish", "Dutch", "English", "Estonian", + "Finnish", "French", "Galician", "German", "Greek", + "Hebrew", "Hindi", "Hungarian", "Icelandic", "Indonesian", + "Italian", "Japanese", "Kannada", "Kazakh", "Korean", + "Latvian", "Lithuanian", "Macedonian", "Malay", "Marathi", + "Maori", "Nepali", "Norwegian", "Persian", "Polish", + "Portuguese", "Romanian", "Russian", "Serbian", "Slovak", + "Slovenian", "Spanish", "Swahili", "Swedish", "Tagalog", + "Tamil", "Thai", "Turkish", "Ukrainian", "Urdu", + "Vietnamese", "Welsh" +] + + +def gradio_server(model : AutoTranscribe): + + def transcribe(audio, microphone, number_of_speakers, language): + kwargs = {} + if number_of_speakers != 0: + kwargs["num_speakers"] = number_of_speakers + if language != "None": + kwargs["language"] = language + + if audio is not None: + out = model.transcribe(audio, **kwargs) + elif microphone is not None: + out = model.transcribe(microphone , **kwargs) + else: + out = "Please upload an audio file or record one." + + + return str(out) + + gr.Interface( + fn=transcribe, + inputs=[ + gr.Audio(source= "upload", type="filepath", label="Upload Your Audio File", interactive=True), + gr.Audio(source= "microphone", type="filepath", label="Record Your Audio", interactive=True), + gr.Number(value=0, label= "Number of speakers", + info = "Number of speakers in the audio file. If you don't know, leave it at 0."), + # gr.Number(value=0, label= "Minimal number of speakers", + # info = "Minimal number of speakers in the audio file. If you don't know or you have specified Numspeakers, leave it at 0."), + gr.Dropdown(LANGUAGES, + label="Languages", default="None", + info="Language of the audio file. If you don't know, leave it at None.") + ], + outputs=[ + "text" + ], + title="Audio Transcription", + thumbnail = "Logo_KIDA.png", + description="Upload an audio file to transcribe its content. Powered by AutoTranscribe!", + theme="soft", # Example of a more modern theme + ).launch(share=True) + + +if __name__ == "__main__": + + model = AutoTranscribe() + gradio_server(model) \ No newline at end of file diff --git a/requirements.txt b/requirements.txt index 433b3c1..b81b23c 100644 --- a/requirements.txt +++ b/requirements.txt @@ -9,10 +9,6 @@ pyannote.pipeline~=2.3 setuptools~=65.6.3 setuptools-rust~=1.5.2 -torch~=1.11.0 -torchaudio~=0.11.0 -torchmetrics~=0.11.0 -torchvision~=0.12.0 tqdm>=4.65.0 #optional: diff --git a/transcribe.py b/transcribe.py index fca2532..73d8838 100644 --- a/transcribe.py +++ b/transcribe.py @@ -1,8 +1,38 @@ -from autotranscript.autotranscript import AutoTranscribe +# import os +# import sys +# import traceback + +# class TracePrints(object): +# def __init__(self): +# self.stdout = sys.stdout +# def write(self, s): +# self.stdout.write("Writing %r\n" % s) +# traceback.print_stack(file=self.stdout) + +# sys.stdout = TracePrints() + +# os.environ["PYANNOTE_CACHE"] = os.path.expanduser("~/PycharmProjects/autotranscript/autotranscript/models/pyannote") +# import os + +# os.environ['TRANSFORMERS_CACHE'] = os.path.expanduser("~/PycharmProjects/autotranscript/autotranscript/models") +# os.environ['HF_HOME'] = os.path.expanduser("~/PycharmProjects/autotranscript/autotranscript/models") + + +from autotranscript import AutoTranscribe model = AutoTranscribe() -text = model.transcribe("tests/test.wav") +text = model.transcribe("test.mp4") print("Transcription:\n") print(text) + + +# from autotranscript.misc import * +# import os + +# print(os.path.exists(CACHE_DIR)) +# print(os.path.exists(WHISPER_DEFAULT_PATH)) +# print(os.path.exists(PYANNOTE_DEFAULT_PATH)) + +# print(os.path.exists(PYANNOTE_DEFAULT_CONFIG))