Merge pull request #98 from JSchmie/bug_fix_before_v0.2
Bug fix before v0.2
This commit is contained in:
+43
-83
@@ -5,17 +5,12 @@ The function includes arguments for specifying the audio files, model paths,
|
|||||||
output formats, and other options necessary for transcription.
|
output formats, and other options necessary for transcription.
|
||||||
"""
|
"""
|
||||||
import os
|
import os
|
||||||
from argparse import ArgumentParser, ArgumentDefaultsHelpFormatter
|
|
||||||
import json
|
import json
|
||||||
|
from argparse import ArgumentParser, ArgumentDefaultsHelpFormatter
|
||||||
from .autotranscript import Scraibe
|
|
||||||
from .misc import ParseKwargs
|
|
||||||
|
|
||||||
|
|
||||||
from whisper.tokenizer import LANGUAGES, TO_LANGUAGE_CODE
|
from whisper.tokenizer import LANGUAGES, TO_LANGUAGE_CODE
|
||||||
from torch.cuda import is_available
|
from torch.cuda import is_available
|
||||||
from torch import set_num_threads
|
from torch import set_num_threads
|
||||||
|
from .autotranscript import Scraibe
|
||||||
|
|
||||||
def cli():
|
def cli():
|
||||||
"""
|
"""
|
||||||
@@ -37,21 +32,12 @@ def cli():
|
|||||||
|
|
||||||
parser = ArgumentParser(formatter_class=ArgumentDefaultsHelpFormatter)
|
parser = ArgumentParser(formatter_class=ArgumentDefaultsHelpFormatter)
|
||||||
|
|
||||||
group = parser.add_mutually_exclusive_group()
|
|
||||||
|
|
||||||
parser.add_argument("-f", "--audio-files", nargs="+", type=str, default=None,
|
parser.add_argument("-f", "--audio-files", nargs="+", type=str, default=None,
|
||||||
help="List of audio files to transcribe.")
|
help="List of audio files to transcribe.")
|
||||||
|
|
||||||
group.add_argument('--start-server', action='store_true',
|
parser.add_argument("--whisper-type", type=str, default="whisper",
|
||||||
help='Start the Gradio app.'
|
choices=["whisper", "whisperx"],
|
||||||
'If set, all other arguments are ignored'
|
help="Type of Whisper model to use ('whisper' or 'whisperx').")
|
||||||
'besides --server-config or --server-kwargs.')
|
|
||||||
|
|
||||||
parser.add_argument("--server-config", type=str, default=None,
|
|
||||||
help="Path to the configy.yml file.")
|
|
||||||
|
|
||||||
parser.add_argument('--server-kwargs', nargs='*', action=ParseKwargs, default={},
|
|
||||||
help='Keyword arguments for the Gradio app.')
|
|
||||||
|
|
||||||
parser.add_argument("--whisper-model-name", default="medium",
|
parser.add_argument("--whisper-model-name", default="medium",
|
||||||
help="Name of the Whisper model to use.")
|
help="Name of the Whisper model to use.")
|
||||||
@@ -83,7 +69,7 @@ def cli():
|
|||||||
parser.add_argument("--verbose-output", type=str2bool, default=True,
|
parser.add_argument("--verbose-output", type=str2bool, default=True,
|
||||||
help="Enable or disable progress and debug messages.")
|
help="Enable or disable progress and debug messages.")
|
||||||
|
|
||||||
parser.add_argument("--task", type=str, default='autotranscribe', # unifinished code
|
parser.add_argument("--task", type=str, default='autotranscribe',
|
||||||
choices=["autotranscribe", "diarization",
|
choices=["autotranscribe", "diarization",
|
||||||
"autotranscribe+translate", "translate", 'transcribe'],
|
"autotranscribe+translate", "translate", 'transcribe'],
|
||||||
help="Choose to perform transcription, diarization, or translation. \
|
help="Choose to perform transcription, diarization, or translation. \
|
||||||
@@ -104,91 +90,65 @@ def cli():
|
|||||||
|
|
||||||
out_format = arg_dict.pop("output_format")
|
out_format = arg_dict.pop("output_format")
|
||||||
|
|
||||||
# seup server arg:
|
|
||||||
start_server = arg_dict.pop("start_server")
|
|
||||||
|
|
||||||
task = arg_dict.pop("task")
|
task = arg_dict.pop("task")
|
||||||
|
|
||||||
if args.num_threads > 0:
|
if args.num_threads > 0:
|
||||||
set_num_threads(arg_dict.pop("num_threads"))
|
set_num_threads(arg_dict.pop("num_threads"))
|
||||||
|
|
||||||
class_kwargs = {'whisper_model': arg_dict.pop("whisper_model_name"),
|
class_kwargs = {'whisper_model': arg_dict.pop("whisper_model_name"),
|
||||||
|
'whisper_type':arg_dict.pop("whisper_type"),
|
||||||
'dia_model': arg_dict.pop("diarization_directory"),
|
'dia_model': arg_dict.pop("diarization_directory"),
|
||||||
'use_auth_token': arg_dict.pop("hf_token")}
|
'use_auth_token': arg_dict.pop("hf_token"),
|
||||||
|
}
|
||||||
|
|
||||||
if arg_dict["whisper_model_directory"]:
|
if arg_dict["whisper_model_directory"]:
|
||||||
class_kwargs["download_root"] = arg_dict.pop("whisper_model_directory")
|
class_kwargs["download_root"] = arg_dict.pop("whisper_model_directory")
|
||||||
|
|
||||||
if not start_server:
|
|
||||||
|
|
||||||
model = Scraibe(**class_kwargs)
|
model = Scraibe(**class_kwargs)
|
||||||
|
|
||||||
if arg_dict["audio_files"]:
|
if arg_dict["audio_files"]:
|
||||||
audio_files = arg_dict.pop("audio_files")
|
audio_files = arg_dict.pop("audio_files")
|
||||||
|
|
||||||
if task == "autotranscribe" or task == "autotranscribe+translate":
|
if task == "autotranscribe" or task == "autotranscribe+translate":
|
||||||
for audio in audio_files:
|
for audio in audio_files:
|
||||||
if task == "autotranscribe+translate":
|
if task == "autotranscribe+translate":
|
||||||
task = "translate"
|
task = "translate"
|
||||||
else:
|
else:
|
||||||
task = "transcribe"
|
task = "transcribe"
|
||||||
|
|
||||||
out = model.autotranscribe(audio, task=task, language=arg_dict.pop(
|
out = model.autotranscribe(audio, task=task, language=arg_dict.pop(
|
||||||
"language"), verbose=arg_dict.pop("verbose_output"))
|
"language"), verbose=arg_dict.pop("verbose_output"))
|
||||||
basename = audio.split("/")[-1].split(".")[0]
|
basename = audio.split("/")[-1].split(".")[0]
|
||||||
print(f'Saving {basename}.{out_format} to {out_folder}')
|
print(f'Saving {basename}.{out_format} to {out_folder}')
|
||||||
out.save(os.path.join(
|
out.save(os.path.join(
|
||||||
out_folder, f"{basename}.{out_format}"))
|
out_folder, f"{basename}.{out_format}"))
|
||||||
|
|
||||||
elif task == "diarization":
|
elif task == "diarization":
|
||||||
for audio in audio_files:
|
for audio in audio_files:
|
||||||
if arg_dict.pop("verbose_output"):
|
if arg_dict.pop("verbose_output"):
|
||||||
print("Verbose not implemented for diarization.")
|
print("Verbose not implemented for diarization.")
|
||||||
|
|
||||||
out = model.diarization(audio)
|
out = model.diarization(audio)
|
||||||
basename = audio.split("/")[-1].split(".")[0]
|
basename = audio.split("/")[-1].split(".")[0]
|
||||||
path = os.path.join(out_folder, f"{basename}.{out_format}")
|
path = os.path.join(out_folder, f"{basename}.{out_format}")
|
||||||
|
|
||||||
print(f'Saving {basename}.{out_format} to {out_folder}')
|
print(f'Saving {basename}.{out_format} to {out_folder}')
|
||||||
|
|
||||||
with open(path, "w") as f:
|
with open(path, "w") as f:
|
||||||
json.dump(json.dumps(out, indent=1), f)
|
json.dump(json.dumps(out, indent=1), f)
|
||||||
|
|
||||||
elif task == "transcribe" or task == "translate":
|
elif task == "transcribe" or task == "translate":
|
||||||
|
|
||||||
for audio in audio_files:
|
for audio in audio_files:
|
||||||
|
|
||||||
out = model.transcribe(audio, task=task,
|
|
||||||
language=arg_dict.pop("language"),
|
|
||||||
verbose=arg_dict.pop("verbose_output"))
|
|
||||||
basename = audio.split("/")[-1].split(".")[0]
|
|
||||||
path = os.path.join(out_folder, f"{basename}.{out_format}")
|
|
||||||
with open(path, "w") as f:
|
|
||||||
f.write(out)
|
|
||||||
|
|
||||||
else: # unfinished code
|
|
||||||
raise NotImplementedError("Currently not Working")
|
|
||||||
import subprocess
|
|
||||||
import sys
|
|
||||||
|
|
||||||
execute_path = os.path.join(
|
|
||||||
os.path.dirname(__file__), "app/app_starter.py")
|
|
||||||
|
|
||||||
config = arg_dict.pop("server_config")
|
|
||||||
server_kwargs = arg_dict.pop("server_kwargs")
|
|
||||||
|
|
||||||
if not config:
|
|
||||||
subprocess.run([sys.executable, execute_path,
|
|
||||||
f"--server-kwargs={server_kwargs}"])
|
|
||||||
elif not server_kwargs:
|
|
||||||
subprocess.run([sys.executable, execute_path,
|
|
||||||
f"--server-config={config}"])
|
|
||||||
elif not config and not server_kwargs:
|
|
||||||
subprocess.run([sys.executable, execute_path])
|
|
||||||
else:
|
|
||||||
subprocess.run([sys.executable, execute_path,
|
|
||||||
f"--server-config={config}", f"--server-kwargs={server_kwargs}"])
|
|
||||||
|
|
||||||
|
out = model.transcribe(audio, task=task,
|
||||||
|
language=arg_dict.pop("language"),
|
||||||
|
verbose=arg_dict.pop("verbose_output"))
|
||||||
|
basename = audio.split("/")[-1].split(".")[0]
|
||||||
|
path = os.path.join(out_folder, f"{basename}.{out_format}")
|
||||||
|
with open(path, "w") as f:
|
||||||
|
f.write(out)
|
||||||
|
|
||||||
if __name__ == "__main__":
|
if __name__ == "__main__":
|
||||||
cli()
|
cli()
|
||||||
|
|||||||
Reference in New Issue
Block a user