renamed module

This commit is contained in:
Jaikinator
2023-09-18 15:29:09 +02:00
parent e76b7b51a5
commit 5385e266cc
21 changed files with 399 additions and 86 deletions
+1
View File
@@ -0,0 +1 @@
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
+15
View File
@@ -0,0 +1,15 @@
from .autotranscript import *
from .transcriber import *
from .audio import *
from .transcript_exporter import *
from .diarisation import *
from .version import get_version as _get_version
from .misc import *
from .app.gradio_app import *
from .app.qtfaststart import *
from .cli import *
__version__ = _get_version()
File diff suppressed because one or more lines are too long

After

Width:  |  Height:  |  Size: 38 KiB

+2
View File
@@ -0,0 +1,2 @@
from .qtfaststart import *
from .gradio_app import *
+340
View File
@@ -0,0 +1,340 @@
"""
Gradio Audio Transcription App.
--------------------------------
This module provides an interface to transcribe audio files using the
AutoTranscribe model. Users can either upload an audio file or record their speech
live for transcription. The application supports multiple languages and provides
options to specify the number of speakers and the language of the audio.
Attributes:
LANGUAGES (list): A list of supported languages for transcription.
Usage:
Run this script to start the Gradio web interface for audio transcription.
"""
"""
Gradio Audio Transcription App.
--------------------------------
This module provides an interface to transcribe audio files using the
AutoTranscribe model. Users can either upload an audio file or record their speech
live for transcription. The application supports multiple languages and provides
options to specify the number of speakers and the language of the audio.
Attributes:
LANGUAGES (list): A list of supported languages for transcription.
Usage:
Run this script to start the Gradio web interface for audio transcription.
"""
import json
import gradio as gr
from scraibe import AutoTranscribe, Transcript
theme = gr.themes.Soft(
primary_hue="green",
secondary_hue='orange',
neutral_hue="gray",
)
LANGUAGES = [
"Afrikaans", "Arabic", "Armenian", "Azerbaijani", "Belarusian",
"Bosnian", "Bulgarian", "Catalan", "Chinese", "Croatian",
"Czech", "Danish", "Dutch", "English", "Estonian",
"Finnish", "French", "Galician", "German", "Greek",
"Hebrew", "Hindi", "Hungarian", "Icelandic", "Indonesian",
"Italian", "Japanese", "Kannada", "Kazakh", "Korean",
"Latvian", "Lithuanian", "Macedonian", "Malay", "Marathi",
"Maori", "Nepali", "Norwegian", "Persian", "Polish",
"Portuguese", "Romanian", "Russian", "Serbian", "Slovak",
"Slovenian", "Spanish", "Swahili", "Swedish", "Tagalog",
"Tamil", "Thai", "Turkish", "Ukrainian", "Urdu",
"Vietnamese", "Welsh"
]
class GradioTranscriptionInterface:
"""
Interface handling the interaction between Gradio UI and the Audio Transcription system.
"""
def __init__(self, model: AutoTranscribe):
"""
Initializes the GradioTranscriptionInterface with a transcription model.
Args:
model (AutoTranscribe): Model responsible for audio transcription tasks.
"""
self.model = model
def auto_transcribe(self, source,
num_speakers : int,
translation : bool,
language : str):
"""
Shortcut method for the AutoTranscribe task.
Returns:
tuple: Transcribed text (str), JSON output (dict)
"""
kwargs = {
"num_speakers": num_speakers if num_speakers != 0 else None,
"language": language if language != "None" else None,
"task": 'translate' if translation else None
}
try:
result = self.model.autotranscribe(source, **kwargs)
except ValueError:
raise gr.Error("Couldn't detect any speech in the provided audio. \
Please try again!")
return str(result), result.get_json()
def transcribe(self, source, translation, language):
"""
Shortcut method for the Transcribe task.
Returns:
str: Transcribed text.
"""
kwargs = {
"language": language if language != "None" else None,
"task": 'translate' if translation == "Yes" else None
}
result = self.model.transcribe(source, **kwargs)
return str(result)
def perform_diarisation(self, source, num_speakers):
"""
Shortcut method for the Diarisation task.
Returns:
str: JSON output of diarisation result.
"""
kwargs = {
"num_speakers": num_speakers if num_speakers != 0 else None,
}
try:
result = self.model.diarization(source, **kwargs)
except ValueError:
raise gr.Error("Couldn't detect any speech in the provided audio. \
Please try again!")
return json.dumps(result, indent=2)
####
# Gradio Interface
####
def gradio_Interface(model : AutoTranscribe = None):
if model is None:
model = AutoTranscribe()
pipe = GradioTranscriptionInterface(model)
def select_task(choice):
if choice == 'Auto Transcribe':
return (gr.update(visible = True),
gr.update(visible = True),
gr.update(visible = True))
elif choice == 'Transcribe':
return (gr.update(visible = False),
gr.update(visible = True),
gr.update(visible = True))
elif choice == 'Diarisation':
return (gr.update(visible = True),
gr.update(visible = False),
gr.update(visible = False))
def select_origin(choice):
if choice == "Upload Audio":
return (gr.update(visible = True),
gr.update(visible = False, value = None),
gr.update(visible = False, value = None),
gr.update(visible = False, value = None),
gr.update(visible = False, value = None))
elif choice == "Record Audio":
return (gr.update(visible = False, value = None),
gr.update(visible = True),
gr.update(visible = False, value = None),
gr.update(visible = False, value = None),
gr.update(visible = False, value = None))
elif choice == "Upload Video":
return (gr.update(visible = False, value = None),
gr.update(visible = False, value = None),
gr.update(visible = True),
gr.update(visible = False, value = None),
gr.update(visible = False, value = None))
elif choice == "Record Video":
return (gr.update(visible = False, value = None),
gr.update(visible = False, value = None),
gr.update(visible = False, value = None),
gr.update(visible = True),
gr.update(visible = False, value = None))
elif choice == "File":
return (gr.update(visible = False, value = None),
gr.update(visible = False, value = None),
gr.update(visible = False, value = None),
gr.update(visible = False, value = None),
gr.update(visible = True))
def run_scribe(task, num_speakers, translate, language, audio1, audio2, video1, video2, file_in, progress = gr.Progress(track_tqdm= True)):
# get *args which are not None
progress(0, desc='Starting task...')
source = audio1 or audio2 or video1 or video2 or file_in
if task == 'Auto Transcribe':
out_str , out_json = pipe.auto_transcribe(source = source,
num_speakers = num_speakers,
translation = translate,
language = language)
return (gr.update(value = out_str, visible = True),
gr.update(value = out_json, visible = True),
gr.update(visible = True),
gr.update(visible = True))
elif task == 'Transcribe':
out = pipe.transcribe(source = source,
translation = translate,
language = language)
return (gr.update(value = out, visible = True),
gr.update(value = None, visible = False),
gr.update(visible = False),
gr.update(visible = False))
elif task == 'Diarisation':
out = pipe.perform_diarisation(source = source,
num_speakers = num_speakers)
return (gr.update(value = None, visible = False),
gr.update(value = out, visible = True),
gr.update(visible = False),
gr.update(visible = False))
def annotate_output(annoation : str, out_json : dict):
# get *args which are not None
trans = Transcript.from_json(out_json)
trans = trans.annotate(*annoation.split(","))
return gr.update(value = str(trans)),gr.update(value = trans.get_json())
with gr.Blocks(theme=theme,title='ScrAIbe: Automatic Audio Transcription') as demo:
# Define components
header = open("header.html", "r").read()
gr.HTML(header, visible= True, show_label=False)
with gr.Row():
with gr.Column():
task = gr.Radio(["Auto Transcribe", "Transcribe", "Diarisation"], label="Task",
value= 'Auto Transcribe')
num_speakers = gr.Number(value=0, label= "Number of speakers (optional)",
info = "Number of speakers in the audio file. If you don't know,\
leave it at 0.", visible= True)
translate = gr.Checkbox(label="Translation", choices=[True, False], value = False,
info="Select 'Yes' to have the output translated into English.",
visible= True)
language = gr.Dropdown(LANGUAGES,
label="Language (optional)", value = "None",
info="Language of the audio file. If you don't know,\
leave it at None.", visible= True)
input = gr.Radio(["Upload Audio", "Record Audio", "Upload Video","Record Video"
,"File"], label="Input Type", value="Upload Audio")
audio1 = gr.Audio(source="upload", type="filepath", label="Upload Audio",
interactive= True, visible= True)
audio2 = gr.Audio(source="microphone", label="Record Audio", type="filepath",
interactive= True, visible= False)
video1 = gr.Video(source="upload", type="filepath", label="Upload Video",
interactive= True, visible= False)
video2 = gr.Video(source="webcam", label="Record Video", type="filepath",
interactive= True, visible= False)
file_in = gr.File(label="Upload File", interactive= True, visible= False)
submit = gr.Button()
with gr.Column():
out_txt = gr.Textbox(label="Output",
visible= True, show_copy_button=True)
out_json = gr.JSON(label="JSON Output",
visible= False, show_copy_button=True)
annoation = gr.Textbox(label="Name your speaker's",
info= "Please provide a list of the speakers arranged \
in the order in which they appear in the input. Use comma ',' \
as a seperator. Be aware that the first name is given \
to SPEAKER_00 the second to SPEAKER_01 and so on.",
visible= False, interactive= True)
annotate = gr.Button(value="Annotate", visible= False, interactive= True)
# Define usage of components
input.change(fn=select_origin, inputs=[input],
outputs=[audio1, audio2, video1, video2, file_in])
task.change(fn=select_task, inputs=[task],
outputs=[num_speakers, translate, language])
translate.change(fn= lambda x : gr.update(value = x),
inputs=[translate], outputs=[translate])
num_speakers.change(fn= lambda x : gr.update(value = x),
inputs=[num_speakers], outputs=[num_speakers])
language.change(fn= lambda x : gr.update(value = x),
inputs=[language], outputs=[language])
submit.click(fn = run_scribe,
inputs=[task, num_speakers, translate, language, audio1,
audio2, video1, video2, file_in],
outputs=[out_txt, out_json, annoation, annotate])
annotate.click(fn = annotate_output, inputs=[annoation, out_json],
outputs=[out_txt, out_json])
return demo
if __name__ == "__main__":
gradio_Interface().queue().launch()
+66
View File
@@ -0,0 +1,66 @@
<!-- Importing Cormorant Garamond font from Google Fonts -->
<link href="https://fonts.googleapis.com/css2?family=Cormorant+Garamond:wght@400;700&display=swap" rel="stylesheet">
<style>
.header-container {
display: flex;
align-items: center;
justify-content: center;
position: relative;
padding-top: 30px;
}
.logo-container {
position: absolute;
top: 50%;
right: 20px;
transform: translateY(-50%);
width: 300px;
}
.logo {
width: 100%;
height: auto;
}
h1 {
font-family: 'Cormorant Garamond', serif;
font-size: 50px !important; /* Increased font size */
font-weight: bold;
color: #50AF31;
margin: 0;
position: relative;
padding: 0.5em 0;
}
h1::before, h1::after {
content: "";
position: absolute;
height: 2px;
width: 80%;
background-color: #50AF31;
left: 10%;
}
h1::before {
top: 0.5em;
}
h1::after {
bottom: 0.5em;
}
p, h2 {
font-size: 16px;
margin: 10px 0;
line-height: 1.4;
}
</style>
<div class="header-container">
<h1>ScrAIbe</h1>
<div class="logo-container">
<a href="https://www.kida-bmel.de/"> <!-- Replace with your actual URL -->
<img src="file/Logo_KIDA_bmel_green.svg" alt="KIDA Logo" class="logo">
</a>
</div>
</div>
<div style="text-align: center; padding: 20px 10%;">
<p>
Upload, record, or provide a video with audio for transcription. Our toolkit is designed to transcribe content from multiple languages accurately. The integrated speaker diarisation feature identifies different speakers, ensuring a smooth transcription experience. For optimal results, indicate the number of speakers and the original language of the content.
</p>
<h2 style="font-weight: bold; color: #50AF31;">What would you like to do next?</h2>
</div>
+319
View File
@@ -0,0 +1,319 @@
"""
This file contains a modified version of qtfaststart by qtfaststart
https://github.com/danielgtaylor/qtfaststart/tree/master
All credit goes to the original author.
Copyright (C) 2008 - 2013 Daniel G. Taylor <dan@programmer-art.org>
Permission is hereby granted, free of charge, to any person obtaining a copy of this
software and associated documentation files (the "Software"),
to deal in the Software without restriction, including without limitation the rights to
use, copy, modify, merge, publish, distribute, sublicense, and/or sell copies of the
Software, and to permit persons to whom the Software is furnished to do so,
subject to the following conditions:
The above copyright notice and this permission notice shall be included in all copies
or substantial portions of the Software.
THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR IMPLIED,
INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT.
IN NO EVENT SHALL THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM,
DAMAGES OR OTHER LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE,
ARISING FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS
IN THE SOFTWARE.
"""
import logging
import os
import struct
import collections
import io
# define error classes
class FastStartException(Exception):
"""
Raised when something bad happens during processing.
"""
pass
class FastStartSetupError(FastStartException):
"""
Rasised when asked to process a file that does not need processing
"""
pass
class MalformedFileError(FastStartException):
"""
Raised when the input file is setup in an unexpected way
"""
pass
class UnsupportedFormatError(FastStartException):
"""
Raised when a movie file is recognized as a format not supported.
"""
pass
# define constants
CHUNK_SIZE = 8192
log = logging.getLogger("qtfaststart")
# Older versions of Python require this to be defined
if not hasattr(os, 'SEEK_CUR'):
os.SEEK_CUR = 1
Atom = collections.namedtuple('Atom', 'name position size')
def read_atom(datastream):
"""
Read an atom and return a tuple of (size, type) where size is the size
in bytes (including the 8 bytes already read) and type is a "fourcc"
like "ftyp" or "moov".
"""
size, type = struct.unpack(">L4s", datastream.read(8))
type = type.decode('ascii')
return size, type
def _read_atom_ex(datastream):
"""
Read an Atom from datastream
"""
pos = datastream.tell()
atom_size, atom_type = read_atom(datastream)
if atom_size == 1:
atom_size, = struct.unpack(">Q", datastream.read(8))
return Atom(atom_type, pos, atom_size)
def get_index(datastream):
"""
Return an index of top level atoms, their absolute byte-position in the
file and their size in a list:
index = [
("ftyp", 0, 24),
("moov", 25, 2658),
("free", 2683, 8),
...
]
The tuple elements will be in the order that they appear in the file.
"""
log.debug("Getting index of top level atoms...")
index = list(_read_atoms(datastream))
_ensure_valid_index(index)
return index
def _read_atoms(datastream):
"""
Read atoms until an error occurs
"""
while datastream:
try:
atom = _read_atom_ex(datastream)
log.debug("%s: %s" % (atom.name, atom.size))
except:
break
yield atom
if atom.size == 0:
if atom.name == "mdat":
# Some files may end in mdat with no size set, which generally
# means to seek to the end of the file. We can just stop indexing
# as no more entries will be found!
break
else:
# Weird, but just continue to try to find more atoms
continue
datastream.seek(atom.position + atom.size)
def _ensure_valid_index(index):
"""
Ensure the minimum viable atoms are present in the index.
Raise FastStartException if not.
"""
top_level_atoms = set([item.name for item in index])
for key in ["moov", "mdat"]:
if key not in top_level_atoms:
log.error("%s atom not found, is this a valid MOV/MP4 file?" % key)
raise FastStartException()
def find_atoms(size, datastream):
"""
Compatibilty interface for _find_atoms_ex
"""
fake_parent = Atom('fake', datastream.tell()-8, size+8)
for atom in _find_atoms_ex(fake_parent, datastream):
yield atom.name
def _find_atoms_ex(parent_atom, datastream):
"""
Yield either "stco" or "co64" Atoms from datastream.
datastream will be 8 bytes into the stco or co64 atom when the value
is yielded.
It is assumed that datastream will be at the end of the atom after
the value has been yielded and processed.
parent_atom is the parent atom, a 'moov' or other ancestor of CO
atoms in the datastream.
"""
stop = parent_atom.position + parent_atom.size
while datastream.tell() < stop:
try:
atom = _read_atom_ex(datastream)
except:
log.exception("Error reading next atom!")
raise FastStartException()
if atom.name in ["trak", "mdia", "minf", "stbl"]:
# Known ancestor atom of stco or co64, search within it!
for res in _find_atoms_ex(atom, datastream):
yield res
elif atom.name in ["stco", "co64"]:
yield atom
else:
# Ignore this atom, seek to the end of it.
datastream.seek(atom.position + atom.size)
def process(infilename, limit=float('inf')):
"""
Convert a Quicktime/MP4 file for streaming by moving the metadata to
the front of the file. This method writes a new file.
If limit is set to something other than zero it will be used as the
number of bytes to write of the atoms following the moov atom. This
is very useful to create a small sample of a file with full headers,
which can then be used in bug reports and such.
"""
if isinstance(infilename, str):
datastream = open(infilename, "rb")
elif isinstance(infilename, bytes):
datastream = io.BytesIO(infilename)
else:
raise TypeError("infilename must be a filename, bytes or file-like object")
# Get the top level atom index
index = get_index(datastream)
mdat_pos = 999999
free_size = 0
# Make sure moov occurs AFTER mdat, otherwise no need to run!
for atom in index:
# The atoms are guaranteed to exist from get_index above!
if atom.name == "moov":
moov_atom = atom
moov_pos = atom.position
elif atom.name == "mdat":
mdat_pos = atom.position
elif atom.name == "free" and atom.position < mdat_pos:
# This free atom is before the mdat!
free_size += atom.size
log.info("Removing free atom at %d (%d bytes)" % (atom.position, atom.size))
elif atom.name == "\x00\x00\x00\x00" and atom.position < mdat_pos:
# This is some strange zero atom with incorrect size
free_size += 8
log.info("Removing strange zero atom at %s (8 bytes)" % atom.position)
# Offset to shift positions
offset = moov_atom.size - free_size
if moov_pos < mdat_pos:
# moov appears to be in the proper place, don't shift by moov size
offset -= moov_atom.size
if not free_size:
# No free atoms and moov is correct, we are done!
log.error("This file appears to already be setup for streaming!")
# Stupid hack to retrun the non-processed file:
if isinstance(infilename, str):
return open(infilename, "rb").read()
elif isinstance(infilename, bytes):
return io.BytesIO(infilename).read()
# Read and fix moov
moov = _patch_moov(datastream, moov_atom, offset)
log.info("Writing output...")
outfile = b''
# Write ftype
for atom in index:
if atom.name == "ftyp":
log.debug("Writing ftyp... (%d bytes)" % atom.size)
datastream.seek(atom.position)
outfile += datastream.read(atom.size)
# Write moov
_bytes = moov.getvalue()
log.debug("Writing moov... (%d bytes)" % len(_bytes))
outfile += _bytes
# Write the rest
atoms = [item for item in index if item.name not in ["ftyp", "moov", "free"]]
for atom in atoms:
log.debug("Writing %s... (%d bytes)" % (atom.name, atom.size))
datastream.seek(atom.position)
# for compatability, allow '0' to mean no limit
cur_limit = limit or float('inf')
cur_limit = min(cur_limit, atom.size)
for chunk in get_chunks(datastream, CHUNK_SIZE, cur_limit):
outfile += chunk
return outfile
def _patch_moov(datastream, atom, offset):
datastream.seek(atom.position)
moov = io.BytesIO(datastream.read(atom.size))
# reload the atom from the fixed stream
atom = _read_atom_ex(moov)
for atom in _find_atoms_ex(atom, moov):
# Read either 32-bit or 64-bit offsets
ctype, csize = dict(
stco=('L', 4),
co64=('Q', 8),
)[atom.name]
# Get number of entries
version, entry_count = struct.unpack(">2L", moov.read(8))
log.info("Patching %s with %d entries" % (atom.name, entry_count))
entries_pos = moov.tell()
struct_fmt = ">%(entry_count)s%(ctype)s" % vars()
# Read entries
entries = struct.unpack(struct_fmt, moov.read(csize * entry_count))
# Patch and write entries
offset_entries = [entry + offset for entry in entries]
moov.seek(entries_pos)
moov.write(struct.pack(struct_fmt, *offset_entries))
return moov
def get_chunks(stream, chunk_size, limit):
remaining = limit
while remaining:
chunk = stream.read(min(remaining, chunk_size))
if not chunk:
return
remaining -= len(chunk)
yield chunk
+150
View File
@@ -0,0 +1,150 @@
"""
Audio Processor Module
=======================
This module provides the AudioProcessor class, utilizing PyTorchaudio for handling audio files.
It includes functionalities to load, cut, and manage audio waveforms, offering efficient and
flexible audio processing.
Available Classes:
- AudioProcessor: Processes audio waveforms and provides methods for loading,
cutting, and handling audio.
Usage:
from .audio_import AudioProcessor
processor = AudioProcessor.from_file("path/to/audiofile.wav")
cut_waveform = processor.cut(start=1.0, end=5.0)
Constants:
- SAMPLE_RATE (int): Default sample rate for processing.
- NORMALIZATION_FACTOR (float): Normalization factor for audio waveform.
"""
from subprocess import CalledProcessError, run
import numpy as np
import torch
SAMPLE_RATE = 16000
NORMALIZATION_FACTOR = 32768.0
class AudioProcessor:
"""
Audio Processor class that leverages PyTorchaudio to provide functionalities
for loading, cutting, and handling audio waveforms.
Attributes:
waveform: torch.Tensor
The audio waveform tensor.
sr: int
The sample rate of the audio.
"""
def __init__(self, waveform: torch.Tensor, sr : int = SAMPLE_RATE,
*args, **kwargs) -> None:
"""
Initialize the AudioProcessor object.
Args:
waveform (torch.Tensor): The audio waveform tensor.
sr (int, optional): The sample rate of the audio. Defaults to SAMPLE_RATE.
args: Additional arguments.
kwargs: Additional keyword arguments, e.g., device to use for processing.
If CUDA is available, it defaults to CUDA.
Raises:
ValueError: If the provided sample rate is not of type int.
"""
device = kwargs.get("device", "cuda" if torch.cuda.is_available() else "cpu")
self.waveform = waveform.to(device)
self.sr = sr
if not isinstance(self.sr, int):
raise ValueError("Sample rate should be a single value of type int," \
f"not {len(self.sr)} and type {type(self.sr)}")
@classmethod
def from_file(cls, file: str, *args, **kwargs) -> 'AudioProcessor':
"""
Create an AudioProcessor instance from an audio file.
Args:
file (str): The audio file path.
Returns:
AudioProcessor: An instance of the AudioProcessor class containing the loaded audio.
"""
audio, sr = cls.load_audio(file , *args, **kwargs)
audio = torch.from_numpy(audio)
return cls(audio, sr)
def cut(self, start: float, end: float) -> torch.Tensor:
"""
Cut a segment from the audio waveform between the specified start and end times.
Args:
start (float): Start time in seconds.
end (float): End time in seconds.
Returns:
torch.Tensor: The cut waveform segment.
"""
start = int(start * self.sr)
if (isinstance(end, float) or isinstance(end, int)) and isinstance(self.sr, int):
end = int(np.ceil(end * self.sr))
else:
end = int(torch.ceil(end * self.sr))
return self.waveform[start:end]
@staticmethod
def load_audio(file: str, sr: int = SAMPLE_RATE):
"""
Open an audio file and read it as a mono waveform, resampling if necessary.
This method ensures compatibility with pyannote.audio
and requires the ffmpeg CLI in PATH.
Args:
file (str): The audio file to open.
sr (int, optional): The desired sample rate. Defaults to SAMPLE_RATE.
Returns:
tuple: A NumPy array containing the audio waveform in float32 dtype
and the sample rate.
Raises:
RuntimeError: If failed to load audio.
"""
# This launches a subprocess to decode audio while down-mixing
# and resampling as necessary. Requires the ffmpeg CLI in PATH.
# fmt: off
cmd = [
"ffmpeg",
"-nostdin",
"-threads", "0",
"-i", file,
"-f", "s16le",
"-ac", "1",
"-acodec", "pcm_s16le",
"-ar", str(sr),
"-"
]
# fmt: on
try:
out = run(cmd, capture_output=True, check=True).stdout
except CalledProcessError as e:
raise RuntimeError(f"Failed to load audio: {e.stderr.decode()}") from e
out = np.frombuffer(out, np.int16).flatten().astype(np.float32) / NORMALIZATION_FACTOR
return out , sr
def __repr__(self) -> str:
return f'TorchAudioProcessor(waveform={len(self.waveform)}, sr={int(self.sr)})'
+283
View File
@@ -0,0 +1,283 @@
"""
AutoTranscribe Class
--------------------
This class serves as the core of the transcription system, responsible for handling
transcription and diarization of audio files. It leverages pretrained models for
speech-to-text (such as Whisper) and speaker diarization (such as pyannote.audio),
providing an accessible interface for audio processing tasks such as transcription,
speaker separation, and timestamping.
By encapsulating the complexities of underlying models, it allows for straightforward
integration into various applications, ranging from transcription services to voice assistants.
Available Classes:
- AutoTranscribe: Main class for performing transcription and diarization.
Includes methods for loading models, processing audio files,
and formatting the transcription output.
Usage:
from .autotranscribe import AutoTranscribe
model = AutoTranscribe(whisper_model="path/to/whisper/model", dia_model="path/to/diarisation/model")
transcript = model.transcribe("path/to/audiofile.wav")
"""
# Standard Library Imports
import os
from glob import iglob
from subprocess import run
from typing import TypeVar, Union
from warnings import warn
# Third-Party Imports
import torch
from numpy import ndarray
from tqdm import trange
# Application-Specific Imports
from .audio import AudioProcessor
from .diarisation import Diariser
from .transcriber import Transcriber, whisper
from .transcript_exporter import Transcript
DiarisationType = TypeVar('DiarisationType')
class AutoTranscribe:
"""
AutoTranscribe is a class responsible for managing the transcription and diarization of audio files.
It serves as the core of the transcription system, incorporating pretrained models
for speech-to-text (such as Whisper) and speaker diarization (such as pyannote.audio),
allowing for comprehensive audio processing.
Attributes:
transcriber (Transcriber): The transcriber object to handle transcription.
diariser (Diariser): The diariser object to handle diarization.
Methods:
__init__: Initializes the AutoTranscribe class with appropriate models.
transcribe: Transcribes an audio file using the whisper model and pyannote diarization model.
remove_audio_file: Removes the original audio file to avoid disk space issues or ensure data privacy.
get_audio_file: Gets an audio file as an AudioProcessor object.
"""
def __init__(self,
whisper_model: Union[bool, str, whisper] = None,
dia_model : Union[bool, str, DiarisationType] = None,
**kwargs) -> None:
"""Initializes the AutoTranscribe class.
Args:
whisper_model (Union[bool, str, whisper], optional):
Path to whisper model or whisper model itself.
diarisation_model (Union[bool, str, DiarisationType], optional):
Path to pyannote diarization model or model itself.
**kwargs: Additional keyword arguments for whisper
and pyannote diarization models.
"""
if whisper_model is None:
self.transcriber = Transcriber.load_model("medium", **kwargs)
elif isinstance(whisper_model, str):
self.transcriber = Transcriber.load_model(whisper_model, **kwargs)
else:
self.transcriber = whisper_model
if dia_model is None:
self.diariser = Diariser.load_model(**kwargs)
elif isinstance(dia_model, str):
self.diariser = Diariser.load_model(dia_model, **kwargs)
else:
self.diariser = dia_model
print("AutoTranscribe initialized all models successfully loaded.")
def autotranscribe(self, audio_file : Union[str, torch.Tensor, ndarray],
remove_original : bool = False,
**kwargs) -> Transcript:
"""
Transcribes an audio file using the whisper model and pyannote diarization model.
Args:
audio_file (Union[str, torch.Tensor, ndarray]):
Path to audio file or a tensor representing the audio.
remove_original (bool, optional): If True, the original audio file will
be removed after transcription.
*args: Additional positional arguments for diarization and transcription.
**kwargs: Additional keyword arguments for diarization and transcription.
Returns:
Transcript: A Transcript object containing the transcription,
which can be exported to different formats.
"""
# Get audio file as an AudioProcessor object
audio_file = self.get_audio_file(audio_file)
# Prepare waveform and sample rate for diarization
dia_audio = {
"waveform" : audio_file.waveform.reshape(1,len(audio_file.waveform)),
"sample_rate": audio_file.sr
}
print("Starting diarisation.")
diarisation = self.diariser.diarization(dia_audio, **kwargs)
if not diarisation["segments"]:
print("No segments found. Try to run transcription without diarisation.")
transcript = self.transcriber.transcribe(audio_file.waveform, **kwargs)
final_transcript= {0 : {"speakers" : 'SPEAKER_01',
"segments" : [0, len(audio_file.waveform)],
"text" : transcript}}
return Transcript(final_transcript)
print("Diarisation finished. Starting transcription.")
audio_file.sr = torch.Tensor([audio_file.sr]).to(audio_file.waveform.device)
# Transcribe each segment and store the results
final_transcript = dict()
for i in trange(len(diarisation["segments"]), desc= "Transcribing"):
seg = diarisation["segments"][i]
audio = audio_file.cut(seg[0], seg[1])
transcript = self.transcriber.transcribe(audio, **kwargs)
final_transcript[i] = {"speakers" : diarisation["speakers"][i],
"segments" : seg,
"text" : transcript}
# Remove original file if needed
if remove_original:
if kwargs.get("shred") is True:
self.remove_audio_file(audio_file, shred=True)
else:
self.remove_audio_file(audio_file, shred=False)
return Transcript(final_transcript)
def diarization(self, audio_file : Union[str, torch.Tensor, ndarray],
**kwargs) -> dict:
"""
Perform diarization on an audio file using the pyannote diarization model.
Args:
audio_file (Union[str, torch.Tensor, ndarray]):
The audio source which can either be a path to the audio file or a tensor representation.
**kwargs:
Additional keyword arguments for diarization.
Returns:
dict:
A dictionary containing the results of the diarization process.
"""
# Get audio file as an AudioProcessor object
audio_file = self.get_audio_file(audio_file)
# Prepare waveform and sample rate for diarization
dia_audio = {
"waveform" : audio_file.waveform.reshape(1,len(audio_file.waveform)),
"sample_rate": audio_file.sr
}
print("Starting diarisation.")
diarisation = self.diariser.diarization(dia_audio, **kwargs)
return diarisation
def transcribe(self, audio_file : Union[str, torch.Tensor, ndarray],
**kwargs):
"""
Transcribe the provided audio file.
Args:
audio_file (Union[str, torch.Tensor, ndarray]):
The audio source, which can either be a path or a tensor representation.
**kwargs:
Additional keyword arguments for transcription.
Returns:
str:
The transcribed text from the audio source.
"""
audio_file = self.get_audio_file(audio_file)
return self.transcriber.transcribe(audio_file.waveform, **kwargs)
@staticmethod
def remove_audio_file(audio_file : str,
shred : bool = False) -> None:
"""
Removes the original audio file to avoid disk space issues or ensure data privacy.
Args:
audio_file_path (str): Path to the audio file.
shred (bool, optional): If True, the audio file will be shredded,
not just removed.
"""
if not os.path.exists(audio_file):
raise ValueError(f"Audiofile {audio_file} does not exist.")
if shred:
warn("Shredding audiofile can take a long time.", RuntimeWarning)
gen = iglob(f'{audio_file}', recursive=True)
cmd = ['shred', '-zvu', '-n', '10', f'{audio_file}']
if os.path.isdir(audio_file):
raise ValueError(f"Audiofile {audio_file} is a directory.")
for file in gen:
print(f'shredding {file} now\n')
run(cmd , check=True)
else:
os.remove(audio_file)
print(f"Audiofile {audio_file} removed.")
@staticmethod
def get_audio_file(audio_file : Union[str, torch.Tensor, ndarray],
*args, **kwargs) -> AudioProcessor:
"""Gets an audio file as TorchAudioProcessor.
Args:
audio_file (Union[str, torch.Tensor, ndarray]): Path to the audio file or
a tensor representing the audio.
*args: Additional positional arguments.
**kwargs: Additional keyword arguments.
Returns:
AudioProcessor: An object containing the waveform and sample rate in
torch.Tensor format.
"""
if isinstance(audio_file, str):
audio_file = AudioProcessor.from_file(audio_file)
elif isinstance(audio_file, torch.Tensor):
audio_file = AudioProcessor(audio_file[0], audio_file[1])
elif isinstance(audio_file, ndarray):
audio_file = AudioProcessor(torch.Tensor(audio_file[0]),
audio_file[1])
if not isinstance(audio_file, AudioProcessor):
raise ValueError(f'Audiofile must be of type AudioProcessor,' \
f'not {type(audio_file)}')
return audio_file
def __repr__(self):
return f"AutoTranscribe(transcriber={self.transcriber}, diariser={self.diariser})"
+169
View File
@@ -0,0 +1,169 @@
"""
Command-Line Interface (CLI) for the AutoTranscribe class,
allowing for user interaction to transcribe and diarize audio files.
The function includes arguments for specifying the audio files, model paths,
output formats, and other options necessary for transcription.
"""
import os
from argparse import ArgumentParser, ArgumentDefaultsHelpFormatter
import json
from sympy import use
from .autotranscript import AutoTranscribe
from .app.gradio_app import gradio_Interface
from whisper.tokenizer import LANGUAGES , TO_LANGUAGE_CODE
from torch.cuda import is_available
from torch import set_num_threads
def cli():
"""
Command-Line Interface (CLI) for the AutoTranscribe class, allowing for user interaction to transcribe
and diarize audio files. The function includes arguments for specifying the audio files, model paths,
output formats, and other options necessary for transcription.
This function can be executed from the command line to perform transcription tasks, providing a
user-friendly way to access the AutoTranscribe class functionalities.
"""
def str2bool(string):
str2val = {"True": True, "False": False}
if string in str2val:
return str2val[string]
else:
raise ValueError(f"Expected one of {set(str2val.keys())}, got {string}")
parser = ArgumentParser(formatter_class = ArgumentDefaultsHelpFormatter)
group = parser.add_mutually_exclusive_group()
parser.add_argument("-f","--audio-files", nargs="+", type=str, default=None,
help="List of audio files to transcribe.")
group.add_argument('--start-server', action='store_true',
help='Start the Gradio app.')
parser.add_argument("--port", type=int, default= None,
help="Port to run the Gradio app on. Defaults to 7860.")
parser.add_argument("--server-name", type=str, default= None,
help="Name of the Gradio app. If empty 127.0.0.1 or 0.0.0.0 will be used.")
parser.add_argument("--whisper-model-name", default="medium",
help="Name of the Whisper model to use.")
parser.add_argument("--whisper-model-directory", type=str, default= None,
help="Path to save Whisper model files; defaults to ./models/whisper.")
parser.add_argument("--diarization-directory", type=str, default= None,
help="Path to the diarization model directory.")
parser.add_argument("--hf-token", default= None, type=str,
help="HuggingFace token for private model download.")
parser.add_argument("--inference-device",
default="cuda" if is_available() else "cpu",
help="Device to use for PyTorch inference.")
parser.add_argument("--num-threads", type=int, default=0,
help="Number of threads used by torch for CPU inference; overrides MKL_NUM_THREADS/OMP_NUM_THREADS.")
parser.add_argument("--output-directory", "-o", type=str, default=".",
help="Directory to save the transcription outputs.")
parser.add_argument("--output-format", "-of", type=str, default="txt",
choices=["txt", "json", "md", "html"],
help="Format of the output file; defaults to txt.")
parser.add_argument("--verbose-output", type=str2bool, default=True,
help="Enable or disable progress and debug messages.")
parser.add_argument("--task", type=str, default= 'autotranscribe', # unifinished code
choices=["autotranscribe", "diarization",
"autotranscribe+translate", "translate", 'transcribe'],
help="Choose to perform transcription, diarization, or translation. \
If set to translate, the output will be translated to English.")
parser.add_argument("--language", type=str, default=None,
choices=sorted(LANGUAGES.keys()) + sorted([k.title() for k in TO_LANGUAGE_CODE.keys()]),
help="Language spoken in the audio. Specify None to perform language detection.")
args = parser.parse_args()
arg_dict = vars(args)
# configure output
out_folder = arg_dict.pop("output_directory")
os.makedirs(out_folder, exist_ok=True)
out_format = arg_dict.pop("output_format")
# seup server arg:
start_server = arg_dict.pop("start_server")
task = arg_dict.pop("task")
if args.num_threads > 0:
set_num_threads(arg_dict.pop("num_threads"))
class_kwargs = {'whisper_model' : arg_dict.pop("whisper_model_name"),
'dia_model': arg_dict.pop("diarization_directory"),
'use_auth_token' : arg_dict.pop("hf_token")}
if arg_dict["whisper_model_directory"]:
class_kwargs["download_root"] = arg_dict.pop("whisper_model_directory")
model = AutoTranscribe(**class_kwargs)
if arg_dict["audio_files"]:
audio_files = arg_dict.pop("audio_files")
if task == "autotranscribe" or task == "autotranscribe+translate":
for audio in audio_files:
if task == "autotranscribe+translate":
task = "translate"
else:
task = "transcribe"
out = model.autotranscribe(audio,task = task, language=arg_dict.pop("language"), verbose = arg_dict.pop("verbose_output"))
basename = audio.split("/")[-1].split(".")[0]
print(f'Saving {basename}.{out_format} to {out_folder}')
out.save(os.path.join(out_folder, f"{basename}.{out_format}"))
elif task == "diarization":
for audio in audio_files:
if arg_dict.pop("verbose_output"):
print(f"Verbose not implemented for diarization.")
out = model.diarization(audio)
basename = audio.split("/")[-1].split(".")[0]
path = os.path.join(out_folder, f"{basename}.{out_format}")
print(f'Saving {basename}.{out_format} to {out_folder}')
with open(path, "w") as f:
json.dump(json.dumps(out, indent= 1), f)
elif task == "transcribe" or task == "translate":
for audio in audio_files:
out = model.transcribe(audio, task = task,
language= arg_dict.pop("language"),
verbose = arg_dict.pop("verbose_output"))
basename = audio.split("/")[-1].split(".")[0]
path = os.path.join(out_folder, f"{basename}.{out_format}")
with open(path, "w") as f:
f.write(out)
if start_server: # unfinished code
gradio_Interface(model).queue().launch(server_port=args.port, server_name=args.server_name)
if __name__ == "__main__":
cli()
+247
View File
@@ -0,0 +1,247 @@
"""
Diarisation Class
------------------
This class serves as the heart of the speaker diarization system, responsible for identifying
and segmenting individual speakers from a given audio file. It leverages a pretrained model
from pyannote.audio, providing an accessible interface for audio processing tasks such as
speaker separation, and timestamping.
By encapsulating the complexities of the underlying model, it allows for straightforward
integration into various applications, ranging from transcription services to voice assistants.
Available Classes:
- Diariser: Main class for performing speaker diarization.
Includes methods for loading models, processing audio files,
and formatting the diarization output.
Constants:
- TOKEN_PATH (str): Path to the Pyannote token.
- PYANNOTE_DEFAULT_PATH (str): Default path to Pyannote models.
- PYANNOTE_DEFAULT_CONFIG (str): Default configuration for Pyannote models.
Usage:
from .diarisation import Diariser
model = Diariser.load_model(model="path/to/model/config.yaml")
diarisation_output = model.diarization("path/to/audiofile.wav")
"""
import os
from pathlib import Path
from typing import TypeVar, Union
from pyannote.audio import Pipeline
from pyannote.audio.pipelines.speaker_diarization import SpeakerDiarization
from torch import Tensor
from .misc import PYANNOTE_DEFAULT_PATH, PYANNOTE_DEFAULT_CONFIG
Annotation = TypeVar('Annotation')
TOKEN_PATH = os.path.join(os.path.dirname(
os.path.realpath(__file__)), '.pyannotetoken')
class Diariser:
"""
Handles the diarization process of an audio file using a pretrained model
from pyannote.audio. Diarization is the task of determining "who spoke when."
Args:
model: The pretrained model to use for diarization.
"""
def __init__(self, model) -> None:
self.model = model
def diarization(self, audiofile : Union[str, Tensor, dict] ,
*args, **kwargs) -> Annotation:
"""
Perform speaker diarization on the provided audio file,
effectively separating different speakers
and providing a timestamp for each segment.
Args:
audiofile: The path to the audio file or a torch.Tensor
containing the audio data.
args: Additional arguments for the diarization model.
kwargs: Additional keyword arguments for the diarization model.
Returns:
dict: A dictionary containing speaker names,
segments, and other information related
to the diarization process.
"""
kwargs = self._get_diarisation_kwargs(**kwargs)
diarization = self.model(audiofile,*args, **kwargs)
out = self.format_diarization_output(diarization)
return out
@staticmethod
def format_diarization_output(dia : Annotation) -> dict:
"""
Formats the raw diarization output into a more usable structure for this project.
Args:
dia: Raw diarization output.
Returns:
dict: A structured representation of the diarization, with speaker names
as keys and a list of tuples representing segments as values.
"""
dia_list = list(dia.itertracks(yield_label=True))
diarization_output = {"speakers": [], "segments": []}
normalized_output = []
index_start_speaker = 0
index_end_speaker = 0
current_speaker = str()
###
# Sometimes two consecutive speakers are the same
# This loop removes these duplicates
###
if len(dia_list) == 1:
normalized_output.append([0, 0, dia_list[0][2]])
else:
for i, (_, _, speaker) in enumerate(dia_list):
if i == 0:
current_speaker = speaker
if speaker != current_speaker:
index_end_speaker = i - 1
normalized_output.append([index_start_speaker,
index_end_speaker,
current_speaker])
index_start_speaker = i
current_speaker = speaker
if i == len(dia_list) - 1:
index_end_speaker = i
normalized_output.append([index_start_speaker,
index_end_speaker,
current_speaker])
for outp in normalized_output:
start = dia_list[outp[0]][0].start
end = dia_list[outp[1]][0].end
diarization_output["segments"].append([start, end])
diarization_output["speakers"].append(outp[2])
return diarization_output
@staticmethod
def _get_token():
"""
Retrieves the Huggingface token from a local file. This token is required
for accessing certain online resources.
Raises:
ValueError: If the token is not found.
Returns:
str: The Huggingface token.
"""
if os.path.exists(TOKEN_PATH):
with open(TOKEN_PATH, 'r', encoding="utf-8") as file:
token = file.read()
else:
raise ValueError('No token found.' \
'Please create a token at https://huggingface.co/settings/token' \
f'and save it in a file called {TOKEN_PATH}')
return token
@staticmethod
def _save_token(token):
"""
Saves the provided Huggingface token to a local file. This facilitates future
access to online resources without needing to repeatedly authenticate.
Args:
token: The Huggingface token to save.
"""
with open(TOKEN_PATH, 'w', encoding="utf-8") as file:
file.write(token)
@classmethod
def load_model(cls,
model: str = PYANNOTE_DEFAULT_CONFIG,
use_auth_token: str = None,
cache_token: bool = True,
cache_dir: Union[Path, str] = PYANNOTE_DEFAULT_PATH,
hparams_file: Union[str, Path] = None,
*args, **kwargs
) -> Pipeline:
"""
Loads a pretrained model from pyannote.audio,
either from a local cache or online repository.
Args:
model: Path or identifier for the pyannote model.
default: /models/pyannote/speaker_diarization/config.yaml
token: Optional HUGGINGFACE_TOKEN for authenticated access.
cache_token: Whether to cache the token locally for future use.
cache_dir: Directory for caching models.
hparams_file: Path to a YAML file containing hyperparameters.
args: Additional arguments only to avoid errors.
kwargs: Additional keyword arguments only to avoid errors.
Returns:
Pipeline: A pyannote.audio Pipeline object, encapsulating the loaded model.
"""
if cache_token and use_auth_token is not None:
cls._save_token(use_auth_token)
if not os.path.exists(model) and use_auth_token is None:
use_auth_token = cls._get_token()
model = 'pyannote/speaker-diarization'
elif not os.path.exists(model) and use_auth_token is not None:
model = 'pyannote/speaker-diarization'
_model = Pipeline.from_pretrained(model,
use_auth_token = use_auth_token,
cache_dir = cache_dir,
hparams_file = hparams_file,)
if _model is None:
raise ValueError('Unable to load model either from local cache' \
'or from huggingface.co models. Please check your token' \
'or your local model path')
return cls(_model)
@staticmethod
def _get_diarisation_kwargs(**kwargs) -> dict:
"""
Validates and extracts the keyword arguments for the pyannote diarization model.
Ensures that the provided keyword arguments match the expected parameters,
filtering out any invalid or unnecessary arguments.
Returns:
dict: A dictionary containing the validated keyword arguments.
"""
_possible_kwargs = SpeakerDiarization.apply.__code__.co_varnames
diarisation_kwargs = {k: v for k, v in kwargs.items() if k in _possible_kwargs}
return diarisation_kwargs
def __repr__(self):
return f"Diarisation(model={self.model})"
+41
View File
@@ -0,0 +1,41 @@
import os
import yaml
from pyannote.audio.core.model import CACHE_DIR as PYANNOTE_CACHE_DIR
CACHE_DIR = os.getenv(
"AUTOT_CACHE",
os.path.expanduser("~/.cache/torch/models"),
)
if CACHE_DIR != PYANNOTE_CACHE_DIR:
os.environ["PYANNOTE_CACHE"] = os.path.join(CACHE_DIR, "pyannote")
WHISPER_DEFAULT_PATH = os.path.join(CACHE_DIR, "whisper")
PYANNOTE_DEFAULT_PATH = os.path.join(CACHE_DIR, "pyannote")
PYANNOTE_DEFAULT_CONFIG = os.path.join(PYANNOTE_DEFAULT_PATH, "config.yaml")
def config_diarization_yaml(file_path: str, path_to_segmentation: str = None) -> None:
"""Configure diarization pipeline from a YAML file.
This function updates the YAML file to use the given segmentation model
offline, and avoids manual file manipulation.
Args:
file_path (str): Path to the YAML file.
path_to_segmentation (str, optional): Optional path to the segmentation model.
Raises:
FileNotFoundError: If the segmentation model file is not found.
"""
with open(file_path, "r") as stream:
yml = yaml.safe_load(stream)
segmentation_path = path_to_segmentation or os.path.join(PYANNOTE_DEFAULT_PATH, "pytorch_model.bin")
yml["pipeline"]["params"]["segmentation"] = segmentation_path
if not os.path.exists(segmentation_path):
raise FileNotFoundError(f"Segmentation model not found at {segmentation_path}")
with open(file_path, "w") as stream:
yaml.dump(yml, stream)
+179
View File
@@ -0,0 +1,179 @@
"""
Transcriber Module
------------------
This module provides the Transcriber class, a comprehensive tool for working with Whisper models.
The Transcriber class offers functionalities such as loading different Whisper models, transcribing audio files,
and saving transcriptions to text files. It acts as an interface between various Whisper models and the user,
simplifying the process of audio transcription.
Main Features:
- Loading different sizes and versions of Whisper models.
- Transcribing audio in various formats including str, Tensor, and nparray.
- Saving the transcriptions to the specified paths.
- Adaptable to various language specifications.
- Options to control the verbosity of the transcription process.
Constants:
WHISPER_DEFAULT_PATH: Default path for downloading and loading Whisper models.
Usage:
>>> from your_package import Transcriber
>>> transcriber = Transcriber.load_model(model="medium")
>>> transcript = transcriber.transcribe(audio="path/to/audio.wav")
>>> transcriber.save_transcript(transcript, "path/to/save.txt")
"""
from whisper import Whisper, load_model
from typing import TypeVar , Union , Optional
from torch import Tensor, device
from numpy import ndarray
from .misc import WHISPER_DEFAULT_PATH
whisper = TypeVar('whisper')
class Transcriber:
"""
Transcriber Class
-----------------
The Transcriber class serves as a wrapper around Whisper models for efficient audio
transcription. By encapsulating the intricacies of loading models, processing audio,
and saving transcripts, it offers an easy-to-use interface
for users to transcribe audio files.
Attributes:
model (whisper): The Whisper model used for transcription.
Methods:
transcribe: Transcribes the given audio file.
save_transcript: Saves the transcript to a file.
load_model: Loads a specific Whisper model.
_get_whisper_kwargs: Private method to get valid keyword arguments for the whisper model.
Examples:
>>> transcriber = Transcriber.load_model(model="medium")
>>> transcript = transcriber.transcribe(audio="path/to/audio.wav")
>>> transcriber.save_transcript(transcript, "path/to/save.txt")
Note:
The class supports various sizes and versions of Whisper models. Please refer to
the load_model method for available options.
"""
def __init__(self, model: whisper ) -> None:
"""
Initialize the Transcriber class with a Whisper model.
Args:
model (whisper): The Whisper model to use for transcription.
"""
self.model = model
def transcribe(self, audio : Union[str, Tensor, ndarray] ,
*args, **kwargs) -> str:
"""
Transcribe an audio file.
Args:
audio (Union[str, Tensor, nparray]): The audio file to transcribe.
*args: Additional arguments.
**kwargs: Additional keyword arguments,
such as the language of the audio file.
Returns:
str: The transcript as a string.
"""
kwargs = self._get_whisper_kwargs(**kwargs)
if "verbose" not in kwargs:
kwargs["verbose"] = False
result = self.model.transcribe(audio, *args, **kwargs)
return result["text"]
@staticmethod
def save_transcript(transcript : str , save_path : str) -> None:
"""
Save a transcript to a file.
Args:
transcript (str): The transcript as a string.
save_path (str): The path to save the transcript.
Returns:
None
"""
with open(save_path, 'w') as f:
f.write(transcript)
print(f'Transcript saved to {save_path}')
@classmethod
def load_model(cls,
model: str = "medium",
download_root: str = WHISPER_DEFAULT_PATH,
device: Optional[Union[str, device]] = None,
in_memory: bool = False,
*args, **kwargs
) -> 'Transcriber':
"""
Load whisper model.
Args:
model (str): Whisper model. Available models include:
- 'tiny.en'
- 'tiny'
- 'base.en'
- 'base'
- 'small.en'
- 'small'
- 'medium.en'
- 'medium'
- 'large-v1'
- 'large-v2'
- 'large'
download_root (str, optional): Path to download the model.
Defaults to WHISPER_DEFAULT_PATH.
device (Optional[Union[str, torch.device]], optional):
Device to load model on. Defaults to None.
in_memory (bool, optional): Whether to load model in memory.
Defaults to False.
args: Additional arguments only to avoid errors.
kwargs: Additional keyword arguments only to avoid errors.
Returns:
Transcriber: A Transcriber object initialized with the specified model.
"""
_model = load_model(model, download_root=download_root,
device=device, in_memory=in_memory)
return cls(_model)
@staticmethod
def _get_whisper_kwargs(**kwargs) -> dict:
"""
Get kwargs for whisper model. Ensure that kwargs are valid.
Returns:
dict: Keyword arguments for whisper model.
"""
_possible_kwargs = Whisper.transcribe.__code__.co_varnames
whisper_kwargs = {k: v for k, v in kwargs.items() if k in _possible_kwargs}
if (task := kwargs.get("task")):
whisper_kwargs["task"] = task
return whisper_kwargs
def __repr__(self) -> str:
return f"Transcriber(model={self.model})"
+303
View File
@@ -0,0 +1,303 @@
import json
import time
from traceback import print_stack
from typing import Union
ALPHABET = [*"abcdefghijklmnopqrstuvwxyz"]
class Transcript:
"""
Class for storing transcript data, including speaker information and text segments,
and exporting it to various file formats such as JSON, HTML, and LaTeX.
"""
def __init__(self, transcript: dict) -> None:
"""
Initializes the Transcript object with the given transcript data.
Args:
transcript (dict): A dictionary containing the formatted transcript string.
Keys should correspond to segment IDs, and values should
contain speaker and segment information.
"""
self.transcript = transcript
self.speakers = self._extract_speakers()
self.segments = self._extract_segments()
self.annotation = {}
def annotate(self, *args, **kwargs) -> dict:
"""
Annotates the transcript to associate specific names with speakers.
Args:
args (list): List of speaker names. These will be mapped sequentially to the speakers.
kwargs (dict): Dictionary with speaker names as keys and list of segments as values.
Returns:
dict: Dictionary with speaker names as keys and list of segments as values.
Raises:
ValueError: If the number of speaker names does not match the number
of speakers, or if an unknown speaker is found.
"""
annotations = {}
if args and len(args) != len(self.speakers):
raise ValueError("Number of speaker names does not match number of speakers")
if args:
for arg, speaker in zip(args, sorted(self.speakers)):
annotations[speaker] = arg
invalid_speakers = set(kwargs.keys()) - set(self.speakers)
if invalid_speakers:
raise ValueError(f"These keys are not speakers: {', '.join(invalid_speakers)}")
annotations.update({key: kwargs[key] for key in self.speakers if key in kwargs})
self.annotation = annotations
return self
def _extract_speakers(self) -> list:
"""
Extracts the unique speaker names from the transcript.
Returns:
list: List of unique speaker names in the transcript.
"""
return list(set([self.transcript[id]["speakers"] for id in self.transcript]))
def _extract_segments(self) -> list:
"""
Extracts all the text segments from the transcript.
Returns:
list: List of segments, where each segment is represented
by the starting and ending times.
"""
return [self.transcript[id]["segments"] for id in self.transcript]
def __str__(self) -> str:
"""
Converts the transcript to a string representation.
Returns:
str: String representation of the transcript, including speaker names and
time stamps for each segment.
"""
fstring = ""
for _id in self.transcript:
seq = self.transcript[_id]
if self.annotation:
speaker = self.annotation[seq["speakers"]]
else:
speaker = seq["speakers"]
segm = seq["segments"]
sseg = time.strftime("%H:%M:%S",time.gmtime(segm[0]))
eseg = time.strftime("%H:%M:%S",time.gmtime(segm[1]))
fstring += f"{speaker} ({sseg} ; {eseg}):\t{seq['text']}\n"
return fstring
def __repr__(self) -> str:
"""Return a string representation of the Transcript object.
Returns:
str: A string that provides an informative description of the object.
"""
return f"Transcript(speakers = {self.speakers},"\
f"segments = {self.segments}, annotation = {self.annotation})"
def get_dict(self) -> dict:
"""
Get transcript as dict
:return: transcript as dict
:rtype: dict
"""
return self.transcript
def get_json(self, *args, use_annotation : bool = True, **kwargs) -> str:
"""
Get transcript as json string
:return: transcript as json string
:rtype: str
"""
if "indent" not in kwargs:
kwargs["indent"] = 3
if use_annotation and self.annotation:
for _id in self.transcript:
seq = self.transcript[_id]
seq["speakers"] = self.annotation[seq["speakers"]]
return json.dumps(self.transcript, *args, **kwargs)
def get_html(self) -> str:
"""
Get transcript as html string
:return: transcript as html string
:rtype: str
"""
html = "<p>" + self.__str__().replace("\n", "<br>") + "</p>"
html = "<html><body>" + html + "</body></html>"
html = html.replace("\t", "&nbsp;&nbsp;&nbsp;&nbsp;")
return html
def get_md(self) -> str:
"""Get transcript as Markdown string, using HTML formatting.
Returns:
str: Transcript as a Markdown string.
"""
return self.get_html()
def get_tex(self) -> str:
"""Get transcript as LaTeX string. If no annotations are present, the speakers will
be annotated with the first letters of the alphabet.
Returns:
str: Transcript as LaTeX string.
"""
if not self.annotation:
self.annotate(*ALPHABET[:len(self.speakers)])
fstring ="\\begin{drama}"
for speaker in self.speakers:
fstring += "\n\t\\Character{"+ str(self.annotation[speaker]) + "}" \
"{"+ str(self.annotation[speaker]) + "}"
for id in self.transcript:
seq = self.transcript[id]
speaker = self.annotation[seq["speakers"]]
fstring += f"\n\\{speaker}speaks:\n{seq['text']}"
fstring += "\n\\end{drama}"
return fstring
def to_json(self,path, *args, **kwargs) -> None:
"""Save transcript as json file
Args:
path (str): path to save file
"""
with open(path, "w") as f:
json.dump(self.transcript, f, *args, **kwargs)
def to_txt(self, path: str) -> None:
"""Save transcript as a LaTeX file (placeholder function, implementation needed).
Args:
path (str): Path to save the LaTeX file.
"""
with open(path, "w") as f:
f.write(self.__str__())
def to_md(self, path: str) -> None:
"""Get transcript as Markdown string, using HTML formatting.
Returns:
str: Transcript as a Markdown string.
"""
return self.to_html(path)
def to_html(self, path: str) -> None:
"""
Save transcript as html file
:param path: path to save file
:type path: str
"""
with open(path, "w") as file:
file.write(self.get_html())
def to_tex(self, path: str) -> None:
"""Save transcript as a LaTeX file (placeholder function, implementation needed).
Args:
path (str): Path to save the LaTeX file.
"""
pass
def to_pdf(self, path: str) -> None:
"""Save transcript as a PDF file (placeholder function, implementation needed).
Args:
path (str): Path to save the PDF file.
"""
pass
def save(self, path: str, *args, **kwargs) -> None:
"""Save transcript to file with the given path and file format.
This method can save the transcript in various formats including JSON, TXT,
MD, HTML, TEX, and PDF. The file format is determined by the extension of
the path.
Args:
path (str): Path to save the file, including the desired file extension.
*args: Additional positional arguments to be passed to the specific save methods.
**kwargs: Additional keyword arguments to be passed to the specific save methods.
Raises:
ValueError: If the file format specified in the path is unknown.
"""
if path.endswith(".json"):
self.to_json(path, *args, **kwargs)
elif path.endswith(".txt"):
self.to_txt(path, *args, **kwargs)
elif path.endswith(".md"):
self.to_md(path, *args, **kwargs)
elif path.endswith(".html"):
self.to_html(path, *args, **kwargs)
elif path.endswith(".tex"):
self.to_tex(path, *args, **kwargs)
elif path.endswith(".pdf"):
self.to_pdf(path, *args, **kwargs)
else:
raise ValueError("Unknown file format")
@classmethod
def from_json(cls, json: Union[dict, str]) -> "Transcript":
"""Load transcript from json file
Args:
path (str): path to json file
Returns:
Transcript: Transcript object
"""
if isinstance(json, dict):
return cls(json)
else:
try:
transcript = json.loads(json)
except:
with open(json, "r") as f:
transcript = json.load(f)
return cls(transcript)
+69
View File
@@ -0,0 +1,69 @@
import os
import subprocess as sp
MAJOR = 0
MINOR = 1
MICRO = 0
MICRO_POST = 0
ISRELEASED = False
VERSION = '%d.%d.%d.%d' % (MAJOR, MINOR, MICRO, MICRO_POST)
# Return the git revision as a string
# taken from numpy/numpy
def git_version():
def _minimal_ext_cmd(cmd):
# construct minimal environment
env = {}
for k in ['SYSTEMROOT', 'PATH', 'HOME']:
v = os.environ.get(k)
if v is not None:
env[k] = v
# LANGUAGE is used on win32
env['LANGUAGE'] = 'C'
env['LANG'] = 'C'
env['LC_ALL'] = 'C'
out = sp.Popen(cmd, stdout=sp.PIPE, stderr=sp.PIPE, env=env).communicate()[0]
return out
try:
out = _minimal_ext_cmd(['git', 'rev-parse', 'HEAD'])
GIT_REVISION = out.strip().decode('ascii')
except OSError:
GIT_REVISION = "Unknown"
return GIT_REVISION
def _get_git_version():
cwd = os.getcwd()
# go to the main directory
fdir = os.path.dirname(os.path.abspath(__file__))
maindir = os.path.abspath(os.path.join(fdir, ".."))
# maindir = fdir # os.path.join(fdir, "..")
os.chdir(maindir)
# get git version
res = git_version()
# restore the cwd
os.chdir(cwd)
return res
def get_version(build_version=False):
if ISRELEASED:
return VERSION
# unreleased version
GIT_REVISION = _get_git_version()
if build_version:
import datetime as dt
date = dt.date.strftime(dt.datetime.now(), "%Y%m%d%H%M%S")
return VERSION + ".dev" + date
else:
return VERSION + ".dev0+" + GIT_REVISION[:7]