diff --git a/scraibe/app/app_starter.py b/scraibe/app/app_starter.py new file mode 100644 index 0000000..9ed1d0b --- /dev/null +++ b/scraibe/app/app_starter.py @@ -0,0 +1,28 @@ +""" +This script is used to start the Gradio interface for audio transcription. +A configuration file can be passed to the script to configure the interface. +If no configuration file is passed, the default configuration is used. +The main Reason for this script is to allow the use of multiprocessing in the app. +""" + +import multiprocessing +from scraibe.misc import ParseKwargs +from argparse import ArgumentParser + +parser = ArgumentParser() + +parser.add_argument("--server-config", type=str, default= None, + help="Path to the configy.yml file.") + +parser.add_argument('--server-kwargs', nargs='*', action=ParseKwargs, default={}, + help='Keyword arguments for the Gradio app.') + +args = parser.parse_args() + +if __name__ == '__main__': + + multiprocessing.set_start_method('spawn') + + from scraibe.app.app import app + + app(config = args.server_config, **args.server_kwargs) \ No newline at end of file diff --git a/scraibe/cli.py b/scraibe/cli.py index c023f38..f4b49f7 100644 --- a/scraibe/cli.py +++ b/scraibe/cli.py @@ -5,10 +5,11 @@ The function includes arguments for specifying the audio files, model paths, output formats, and other options necessary for transcription. """ import os -from argparse import ArgumentParser, ArgumentDefaultsHelpFormatter +from argparse import ArgumentParser, ArgumentDefaultsHelpFormatter, Action import json from .autotranscript import Scraibe +from .misc import ParseKwargs from .app.app import gradio_Interface from whisper.tokenizer import LANGUAGES , TO_LANGUAGE_CODE @@ -41,13 +42,15 @@ def cli(): help="List of audio files to transcribe.") group.add_argument('--start-server', action='store_true', - help='Start the Gradio app.') + help='Start the Gradio app.' \ + 'If set, all other arguments are ignored' \ + 'besides --server-config or --server-kwargs.') - parser.add_argument("--port", type=int, default= None, - help="Port to run the Gradio app on. Defaults to 7860.") + parser.add_argument("--server-config", type=str, default= None, + help="Path to the configy.yml file.") - parser.add_argument("--server-name", type=str, default= None, - help="Name of the Gradio app. If empty 127.0.0.1 or 0.0.0.0 will be used.") + parser.add_argument('--server-kwargs', nargs='*', action=ParseKwargs, default={}, + help='Keyword arguments for the Gradio app.') parser.add_argument("--whisper-model-name", default="medium", help="Name of the Whisper model to use.") @@ -66,7 +69,8 @@ def cli(): help="Device to use for PyTorch inference.") parser.add_argument("--num-threads", type=int, default=0, - help="Number of threads used by torch for CPU inference; overrides MKL_NUM_THREADS/OMP_NUM_THREADS.") + help="Number of threads used by torch for CPU inference; '\ + 'overrides MKL_NUM_THREADS/OMP_NUM_THREADS.") parser.add_argument("--output-directory", "-o", type=str, default=".", help="Directory to save the transcription outputs.") @@ -113,55 +117,59 @@ def cli(): if arg_dict["whisper_model_directory"]: class_kwargs["download_root"] = arg_dict.pop("whisper_model_directory") - model = Scraibe(**class_kwargs) - - - if arg_dict["audio_files"]: - audio_files = arg_dict.pop("audio_files") + if not start_server: - if task == "autotranscribe" or task == "autotranscribe+translate": - for audio in audio_files: - if task == "autotranscribe+translate": - task = "translate" - else: - task = "transcribe" - - out = model.autotranscribe(audio,task = task, language=arg_dict.pop("language"), verbose = arg_dict.pop("verbose_output")) - basename = audio.split("/")[-1].split(".")[0] - print(f'Saving {basename}.{out_format} to {out_folder}') - out.save(os.path.join(out_folder, f"{basename}.{out_format}")) - - elif task == "diarization": - for audio in audio_files: - if arg_dict.pop("verbose_output"): - print(f"Verbose not implemented for diarization.") - - out = model.diarization(audio) - basename = audio.split("/")[-1].split(".")[0] - path = os.path.join(out_folder, f"{basename}.{out_format}") - - print(f'Saving {basename}.{out_format} to {out_folder}') - - with open(path, "w") as f: - json.dump(json.dumps(out, indent= 1), f) + model = Scraibe(**class_kwargs) - elif task == "transcribe" or task == "translate": + if arg_dict["audio_files"]: + audio_files = arg_dict.pop("audio_files") - for audio in audio_files: - - out = model.transcribe(audio, task = task, - language= arg_dict.pop("language"), - verbose = arg_dict.pop("verbose_output")) - basename = audio.split("/")[-1].split(".")[0] - path = os.path.join(out_folder, f"{basename}.{out_format}") - with open(path, "w") as f: - f.write(out) + if task == "autotranscribe" or task == "autotranscribe+translate": + for audio in audio_files: + if task == "autotranscribe+translate": + task = "translate" + else: + task = "transcribe" + + out = model.autotranscribe(audio,task = task, language=arg_dict.pop("language"), verbose = arg_dict.pop("verbose_output")) + basename = audio.split("/")[-1].split(".")[0] + print(f'Saving {basename}.{out_format} to {out_folder}') + out.save(os.path.join(out_folder, f"{basename}.{out_format}")) + + elif task == "diarization": + for audio in audio_files: + if arg_dict.pop("verbose_output"): + print(f"Verbose not implemented for diarization.") + + out = model.diarization(audio) + basename = audio.split("/")[-1].split(".")[0] + path = os.path.join(out_folder, f"{basename}.{out_format}") + + print(f'Saving {basename}.{out_format} to {out_folder}') + + with open(path, "w") as f: + json.dump(json.dumps(out, indent= 1), f) + + elif task == "transcribe" or task == "translate": + for audio in audio_files: - if start_server: # unfinished code + out = model.transcribe(audio, task = task, + language= arg_dict.pop("language"), + verbose = arg_dict.pop("verbose_output")) + basename = audio.split("/")[-1].split(".")[0] + path = os.path.join(out_folder, f"{basename}.{out_format}") + with open(path, "w") as f: + f.write(out) + + + else: # unfinished code + import subprocess + import sys - gradio_Interface(model).queue().launch(server_port=args.port, server_name=args.server_name) + execute_path = os.path.join(os.path.dirname(__file__), "app/app_starter.py") + subprocess.run([sys.executable, execute_path]) if __name__ == "__main__": cli() \ No newline at end of file diff --git a/scraibe/misc.py b/scraibe/misc.py index b1afeea..ae9136e 100644 --- a/scraibe/misc.py +++ b/scraibe/misc.py @@ -1,6 +1,7 @@ import os import yaml from pyannote.audio.core.model import CACHE_DIR as PYANNOTE_CACHE_DIR +from argparse import Action CACHE_DIR = os.getenv( "AUTOT_CACHE", @@ -38,3 +39,17 @@ def config_diarization_yaml(file_path: str, path_to_segmentation: str = None) -> with open(file_path, "w") as stream: yaml.dump(yml, stream) + +class ParseKwargs(Action): + """ + Custom argparse action to parse keyword arguments. + """ + def __call__(self, parser, namespace, values, option_string=None): + setattr(namespace, self.dest, dict()) + for value in values: + key, value = value.split('=') + try: + value = eval(value) + except: + pass + getattr(namespace, self.dest)[key] = value \ No newline at end of file