changed function name and added addional function for easier use
This commit is contained in:
@@ -24,9 +24,9 @@ Usage:
|
|||||||
"""
|
"""
|
||||||
|
|
||||||
# Standard Library Imports
|
# Standard Library Imports
|
||||||
import argparse
|
|
||||||
import os
|
import os
|
||||||
from glob import iglob
|
from glob import iglob
|
||||||
|
import re
|
||||||
from subprocess import run
|
from subprocess import run
|
||||||
from typing import TypeVar, Union
|
from typing import TypeVar, Union
|
||||||
from warnings import warn
|
from warnings import warn
|
||||||
@@ -93,7 +93,7 @@ class AutoTranscribe:
|
|||||||
|
|
||||||
print("AutoTranscribe initialized all models successfully loaded.")
|
print("AutoTranscribe initialized all models successfully loaded.")
|
||||||
|
|
||||||
def transcribe(self, audio_file : Union[str, torch.Tensor, ndarray],
|
def autotranscribe(self, audio_file : Union[str, torch.Tensor, ndarray],
|
||||||
remove_original : bool = False,
|
remove_original : bool = False,
|
||||||
**kwargs) -> Transcript:
|
**kwargs) -> Transcript:
|
||||||
"""
|
"""
|
||||||
@@ -164,6 +164,55 @@ class AutoTranscribe:
|
|||||||
|
|
||||||
return Transcript(final_transcript)
|
return Transcript(final_transcript)
|
||||||
|
|
||||||
|
def diarization(self, audio_file : Union[str, torch.Tensor, ndarray],
|
||||||
|
**kwargs) -> dict:
|
||||||
|
"""
|
||||||
|
Perform diarization on an audio file using the pyannote diarization model.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
audio_file (Union[str, torch.Tensor, ndarray]):
|
||||||
|
The audio source which can either be a path to the audio file or a tensor representation.
|
||||||
|
**kwargs:
|
||||||
|
Additional keyword arguments for diarization.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
dict:
|
||||||
|
A dictionary containing the results of the diarization process.
|
||||||
|
"""
|
||||||
|
|
||||||
|
# Get audio file as an AudioProcessor object
|
||||||
|
audio_file = self.get_audio_file(audio_file)
|
||||||
|
|
||||||
|
# Prepare waveform and sample rate for diarization
|
||||||
|
dia_audio = {
|
||||||
|
"waveform" : audio_file.waveform.reshape(1,len(audio_file.waveform)),
|
||||||
|
"sample_rate": audio_file.sr
|
||||||
|
}
|
||||||
|
|
||||||
|
print("Starting diarisation.")
|
||||||
|
|
||||||
|
diarisation = self.diariser.diarization(dia_audio, **kwargs)
|
||||||
|
|
||||||
|
return diarisation
|
||||||
|
|
||||||
|
def transcribe(self, audio_file : Union[str, torch.Tensor, ndarray],
|
||||||
|
**kwargs):
|
||||||
|
"""
|
||||||
|
Transcribe the provided audio file.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
audio_file (Union[str, torch.Tensor, ndarray]):
|
||||||
|
The audio source, which can either be a path or a tensor representation.
|
||||||
|
**kwargs:
|
||||||
|
Additional keyword arguments for transcription.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
str:
|
||||||
|
The transcribed text from the audio source.
|
||||||
|
"""
|
||||||
|
audio_file = self.get_audio_file(audio_file)
|
||||||
|
|
||||||
|
return self.transcriber.transcribe(audio_file.waveform, **kwargs)
|
||||||
@staticmethod
|
@staticmethod
|
||||||
def remove_audio_file(audio_file : str,
|
def remove_audio_file(audio_file : str,
|
||||||
shred : bool = False) -> None:
|
shred : bool = False) -> None:
|
||||||
@@ -228,133 +277,4 @@ class AutoTranscribe:
|
|||||||
raise ValueError(f'Audiofile must be of type AudioProcessor,' \
|
raise ValueError(f'Audiofile must be of type AudioProcessor,' \
|
||||||
f'not {type(audio_file)}')
|
f'not {type(audio_file)}')
|
||||||
return audio_file
|
return audio_file
|
||||||
|
|
||||||
|
|
||||||
def cli():
|
|
||||||
"""
|
|
||||||
Command-Line Interface (CLI) for the AutoTranscribe class, allowing for user interaction to transcribe
|
|
||||||
and diarize audio files. The function includes arguments for specifying the audio files, model paths,
|
|
||||||
output formats, and other options necessary for transcription.
|
|
||||||
|
|
||||||
This function can be executed from the command line to perform transcription tasks, providing a
|
|
||||||
user-friendly way to access the AutoTranscribe class functionalities.
|
|
||||||
"""
|
|
||||||
from whisper import available_models
|
|
||||||
from whisper.utils import get_writer
|
|
||||||
from whisper.tokenizer import LANGUAGES , TO_LANGUAGE_CODE
|
|
||||||
from .transcriber import WHISPER_DEFAULT_PATH
|
|
||||||
from .diarisation import PYANNOTE_DEFAULT_PATH
|
|
||||||
|
|
||||||
def str2bool(string):
|
|
||||||
str2val = {"True": True, "False": False}
|
|
||||||
if string in str2val:
|
|
||||||
return str2val[string]
|
|
||||||
else:
|
|
||||||
raise ValueError(f"Expected one of {set(str2val.keys())}, got {string}")
|
|
||||||
|
|
||||||
parser = argparse.ArgumentParser(formatter_class=argparse.ArgumentDefaultsHelpFormatter)
|
|
||||||
|
|
||||||
parser.add_argument("-f","--audio_files", nargs="+", type=str,
|
|
||||||
help="List of audio files to transcribe.")
|
|
||||||
|
|
||||||
parser.add_argument('--start_server', action='store_true',
|
|
||||||
help='Start the Gradio app.')
|
|
||||||
|
|
||||||
parser.add_argument("--whisper_model_name", default="medium",
|
|
||||||
help="Name of the Whisper model to use.")
|
|
||||||
|
|
||||||
parser.add_argument("--whisper_model_directory", type=str, default=WHISPER_DEFAULT_PATH,
|
|
||||||
help="Path to save Whisper model files; defaults to ./models/whisper.")
|
|
||||||
|
|
||||||
parser.add_argument("--diarization_directory", type=str, default=PYANNOTE_DEFAULT_PATH,
|
|
||||||
help="Path to the diarization model directory.")
|
|
||||||
|
|
||||||
parser.add_argument("--huggingface_token", default="", type=str,
|
|
||||||
help="HuggingFace token for private model download.")
|
|
||||||
|
|
||||||
parser.add_argument("--allow_download", type=str2bool, default=False,
|
|
||||||
help="Allow model download if not found locally.")
|
|
||||||
|
|
||||||
parser.add_argument("--inference_device",
|
|
||||||
default="cuda" if torch.cuda.is_available() else "cpu",
|
|
||||||
help="Device to use for PyTorch inference.")
|
|
||||||
|
|
||||||
parser.add_argument("--num_threads", type=int, default=0,
|
|
||||||
help="Number of threads used by torch for CPU inference; overrides MKL_NUM_THREADS/OMP_NUM_THREADS.")
|
|
||||||
|
|
||||||
parser.add_argument("--output_directory", "-o", type=str, default=".",
|
|
||||||
help="Directory to save the transcription outputs.")
|
|
||||||
|
|
||||||
parser.add_argument("--output_format", "-of", type=str, default="txt",
|
|
||||||
choices=["txt", "json", "md", "html"],
|
|
||||||
help="Format of the output file; defaults to txt.")
|
|
||||||
|
|
||||||
parser.add_argument("--verbose_output", type=str2bool, default=True,
|
|
||||||
help="Enable or disable progress and debug messages.")
|
|
||||||
|
|
||||||
parser.add_argument("--transcription_task", type=str, default="transcribe",
|
|
||||||
choices=["transcribe", "diarize", "wtranscribe"],
|
|
||||||
help="Choose to perform transcription, diarization, or Whisper transcription.")
|
|
||||||
|
|
||||||
parser.add_argument("--spoken_language", type=str, default=None,
|
|
||||||
choices=sorted(LANGUAGES.keys()) + sorted([k.title() for k in TO_LANGUAGE_CODE.keys()]),
|
|
||||||
help="Language spoken in the audio. Specify None to perform language detection.")
|
|
||||||
|
|
||||||
args = parser.parse_args()
|
|
||||||
|
|
||||||
output_directory = args.output_directory
|
|
||||||
num_threads = args.num_threads
|
|
||||||
whisper_model_directory = args.whisper_model_directory
|
|
||||||
allow_download = args.allow_download
|
|
||||||
inference_device = args.inference_device
|
|
||||||
whisper_model_name = args.whisper_model_name
|
|
||||||
diarization_directory = args.diarization_directory
|
|
||||||
huggingface_token = args.huggingface_token
|
|
||||||
transcription_task = args.transcription_task
|
|
||||||
audio_files = args.audio_files
|
|
||||||
spoken_language = args.spoken_language
|
|
||||||
output_format = args.output_format
|
|
||||||
start_server = args.start_server
|
|
||||||
|
|
||||||
os.makedirs(output_directory, exist_ok=True)
|
|
||||||
|
|
||||||
if num_threads > 0:
|
|
||||||
torch.set_num_threads(num_threads)
|
|
||||||
|
|
||||||
whisper_kwargs = {
|
|
||||||
"download_root": whisper_model_directory,
|
|
||||||
"local": allow_download,
|
|
||||||
"device": inference_device
|
|
||||||
}
|
|
||||||
|
|
||||||
diarisation_kwargs = {
|
|
||||||
"local": allow_download,
|
|
||||||
"token": huggingface_token
|
|
||||||
}
|
|
||||||
|
|
||||||
model = AutoTranscribe(whisper_model=whisper_model_name,
|
|
||||||
whisper_kwargs=whisper_kwargs,
|
|
||||||
dia_model=diarization_directory,
|
|
||||||
dia_kwargs=diarisation_kwargs)
|
|
||||||
|
|
||||||
if transcription_task == "transcribe":
|
|
||||||
for audio in audio_files:
|
|
||||||
out = model.transcribe(audio, language=spoken_language)
|
|
||||||
basename = audio.split("/")[-1].split(".")[0]
|
|
||||||
spath = f"{output_directory}/{basename}.{output_format}"
|
|
||||||
out.save(spath)
|
|
||||||
|
|
||||||
# ... include other tasks here ...
|
|
||||||
elif transcription_task == "diarize":
|
|
||||||
# diarize code here
|
|
||||||
pass
|
|
||||||
elif transcription_task == "wtranscribe":
|
|
||||||
# wtranscribe code here
|
|
||||||
pass
|
|
||||||
|
|
||||||
if start_server:
|
|
||||||
from .gradio_app import gradio_app
|
|
||||||
gradio_app(model)
|
|
||||||
|
|
||||||
if __name__ == "__main__":
|
|
||||||
cli()
|
|
||||||
Reference in New Issue
Block a user