made cli work with new interface

This commit is contained in:
Jaikinator
2024-01-25 16:08:53 +01:00
parent c65dc51541
commit ef7bd6e15c
3 changed files with 100 additions and 49 deletions
+28
View File
@@ -0,0 +1,28 @@
"""
This script is used to start the Gradio interface for audio transcription.
A configuration file can be passed to the script to configure the interface.
If no configuration file is passed, the default configuration is used.
The main Reason for this script is to allow the use of multiprocessing in the app.
"""
import multiprocessing
from scraibe.misc import ParseKwargs
from argparse import ArgumentParser
parser = ArgumentParser()
parser.add_argument("--server-config", type=str, default= None,
help="Path to the configy.yml file.")
parser.add_argument('--server-kwargs', nargs='*', action=ParseKwargs, default={},
help='Keyword arguments for the Gradio app.')
args = parser.parse_args()
if __name__ == '__main__':
multiprocessing.set_start_method('spawn')
from scraibe.app.app import app
app(config = args.server_config, **args.server_kwargs)
+60 -52
View File
@@ -5,10 +5,11 @@ The function includes arguments for specifying the audio files, model paths,
output formats, and other options necessary for transcription. output formats, and other options necessary for transcription.
""" """
import os import os
from argparse import ArgumentParser, ArgumentDefaultsHelpFormatter from argparse import ArgumentParser, ArgumentDefaultsHelpFormatter, Action
import json import json
from .autotranscript import Scraibe from .autotranscript import Scraibe
from .misc import ParseKwargs
from .app.app import gradio_Interface from .app.app import gradio_Interface
from whisper.tokenizer import LANGUAGES , TO_LANGUAGE_CODE from whisper.tokenizer import LANGUAGES , TO_LANGUAGE_CODE
@@ -41,13 +42,15 @@ def cli():
help="List of audio files to transcribe.") help="List of audio files to transcribe.")
group.add_argument('--start-server', action='store_true', group.add_argument('--start-server', action='store_true',
help='Start the Gradio app.') help='Start the Gradio app.' \
'If set, all other arguments are ignored' \
'besides --server-config or --server-kwargs.')
parser.add_argument("--port", type=int, default= None, parser.add_argument("--server-config", type=str, default= None,
help="Port to run the Gradio app on. Defaults to 7860.") help="Path to the configy.yml file.")
parser.add_argument("--server-name", type=str, default= None, parser.add_argument('--server-kwargs', nargs='*', action=ParseKwargs, default={},
help="Name of the Gradio app. If empty 127.0.0.1 or 0.0.0.0 will be used.") help='Keyword arguments for the Gradio app.')
parser.add_argument("--whisper-model-name", default="medium", parser.add_argument("--whisper-model-name", default="medium",
help="Name of the Whisper model to use.") help="Name of the Whisper model to use.")
@@ -66,7 +69,8 @@ def cli():
help="Device to use for PyTorch inference.") help="Device to use for PyTorch inference.")
parser.add_argument("--num-threads", type=int, default=0, parser.add_argument("--num-threads", type=int, default=0,
help="Number of threads used by torch for CPU inference; overrides MKL_NUM_THREADS/OMP_NUM_THREADS.") help="Number of threads used by torch for CPU inference; '\
'overrides MKL_NUM_THREADS/OMP_NUM_THREADS.")
parser.add_argument("--output-directory", "-o", type=str, default=".", parser.add_argument("--output-directory", "-o", type=str, default=".",
help="Directory to save the transcription outputs.") help="Directory to save the transcription outputs.")
@@ -113,55 +117,59 @@ def cli():
if arg_dict["whisper_model_directory"]: if arg_dict["whisper_model_directory"]:
class_kwargs["download_root"] = arg_dict.pop("whisper_model_directory") class_kwargs["download_root"] = arg_dict.pop("whisper_model_directory")
model = Scraibe(**class_kwargs) if not start_server:
model = Scraibe(**class_kwargs)
if arg_dict["audio_files"]:
audio_files = arg_dict.pop("audio_files")
if task == "autotranscribe" or task == "autotranscribe+translate":
for audio in audio_files:
if task == "autotranscribe+translate":
task = "translate"
else:
task = "transcribe"
out = model.autotranscribe(audio,task = task, language=arg_dict.pop("language"), verbose = arg_dict.pop("verbose_output"))
basename = audio.split("/")[-1].split(".")[0]
print(f'Saving {basename}.{out_format} to {out_folder}')
out.save(os.path.join(out_folder, f"{basename}.{out_format}"))
elif task == "diarization":
for audio in audio_files:
if arg_dict.pop("verbose_output"):
print(f"Verbose not implemented for diarization.")
out = model.diarization(audio)
basename = audio.split("/")[-1].split(".")[0]
path = os.path.join(out_folder, f"{basename}.{out_format}")
print(f'Saving {basename}.{out_format} to {out_folder}')
with open(path, "w") as f:
json.dump(json.dumps(out, indent= 1), f)
elif task == "transcribe" or task == "translate":
for audio in audio_files:
out = model.transcribe(audio, task = task,
language= arg_dict.pop("language"),
verbose = arg_dict.pop("verbose_output"))
basename = audio.split("/")[-1].split(".")[0]
path = os.path.join(out_folder, f"{basename}.{out_format}")
with open(path, "w") as f:
f.write(out)
if arg_dict["audio_files"]: else: # unfinished code
audio_files = arg_dict.pop("audio_files") import subprocess
import sys
if task == "autotranscribe" or task == "autotranscribe+translate": execute_path = os.path.join(os.path.dirname(__file__), "app/app_starter.py")
for audio in audio_files:
if task == "autotranscribe+translate":
task = "translate"
else:
task = "transcribe"
out = model.autotranscribe(audio,task = task, language=arg_dict.pop("language"), verbose = arg_dict.pop("verbose_output"))
basename = audio.split("/")[-1].split(".")[0]
print(f'Saving {basename}.{out_format} to {out_folder}')
out.save(os.path.join(out_folder, f"{basename}.{out_format}"))
elif task == "diarization":
for audio in audio_files:
if arg_dict.pop("verbose_output"):
print(f"Verbose not implemented for diarization.")
out = model.diarization(audio)
basename = audio.split("/")[-1].split(".")[0]
path = os.path.join(out_folder, f"{basename}.{out_format}")
print(f'Saving {basename}.{out_format} to {out_folder}')
with open(path, "w") as f:
json.dump(json.dumps(out, indent= 1), f)
elif task == "transcribe" or task == "translate":
for audio in audio_files:
out = model.transcribe(audio, task = task,
language= arg_dict.pop("language"),
verbose = arg_dict.pop("verbose_output"))
basename = audio.split("/")[-1].split(".")[0]
path = os.path.join(out_folder, f"{basename}.{out_format}")
with open(path, "w") as f:
f.write(out)
if start_server: # unfinished code
gradio_Interface(model).queue().launch(server_port=args.port, server_name=args.server_name)
subprocess.run([sys.executable, execute_path])
if __name__ == "__main__": if __name__ == "__main__":
cli() cli()
+15
View File
@@ -1,6 +1,7 @@
import os import os
import yaml import yaml
from pyannote.audio.core.model import CACHE_DIR as PYANNOTE_CACHE_DIR from pyannote.audio.core.model import CACHE_DIR as PYANNOTE_CACHE_DIR
from argparse import Action
CACHE_DIR = os.getenv( CACHE_DIR = os.getenv(
"AUTOT_CACHE", "AUTOT_CACHE",
@@ -38,3 +39,17 @@ def config_diarization_yaml(file_path: str, path_to_segmentation: str = None) ->
with open(file_path, "w") as stream: with open(file_path, "w") as stream:
yaml.dump(yml, stream) yaml.dump(yml, stream)
class ParseKwargs(Action):
"""
Custom argparse action to parse keyword arguments.
"""
def __call__(self, parser, namespace, values, option_string=None):
setattr(namespace, self.dest, dict())
for value in values:
key, value = value.split('=')
try:
value = eval(value)
except:
pass
getattr(namespace, self.dest)[key] = value