renamed module
This commit is contained in:
@@ -0,0 +1 @@
|
||||
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
|
||||
@@ -0,0 +1,15 @@
|
||||
from .autotranscript import *
|
||||
from .transcriber import *
|
||||
from .audio import *
|
||||
from .transcript_exporter import *
|
||||
from .diarisation import *
|
||||
|
||||
from .version import get_version as _get_version
|
||||
from .misc import *
|
||||
|
||||
from .app.gradio_app import *
|
||||
from .app.qtfaststart import *
|
||||
|
||||
from .cli import *
|
||||
|
||||
__version__ = _get_version()
|
||||
File diff suppressed because one or more lines are too long
|
After Width: | Height: | Size: 38 KiB |
@@ -0,0 +1,2 @@
|
||||
from .qtfaststart import *
|
||||
from .gradio_app import *
|
||||
@@ -0,0 +1,340 @@
|
||||
"""
|
||||
Gradio Audio Transcription App.
|
||||
--------------------------------
|
||||
|
||||
This module provides an interface to transcribe audio files using the
|
||||
AutoTranscribe model. Users can either upload an audio file or record their speech
|
||||
live for transcription. The application supports multiple languages and provides
|
||||
options to specify the number of speakers and the language of the audio.
|
||||
|
||||
Attributes:
|
||||
LANGUAGES (list): A list of supported languages for transcription.
|
||||
|
||||
Usage:
|
||||
Run this script to start the Gradio web interface for audio transcription.
|
||||
|
||||
"""
|
||||
|
||||
"""
|
||||
Gradio Audio Transcription App.
|
||||
--------------------------------
|
||||
|
||||
This module provides an interface to transcribe audio files using the
|
||||
AutoTranscribe model. Users can either upload an audio file or record their speech
|
||||
live for transcription. The application supports multiple languages and provides
|
||||
options to specify the number of speakers and the language of the audio.
|
||||
|
||||
Attributes:
|
||||
LANGUAGES (list): A list of supported languages for transcription.
|
||||
|
||||
Usage:
|
||||
Run this script to start the Gradio web interface for audio transcription.
|
||||
|
||||
"""
|
||||
|
||||
import json
|
||||
|
||||
import gradio as gr
|
||||
from scraibe import AutoTranscribe, Transcript
|
||||
|
||||
|
||||
theme = gr.themes.Soft(
|
||||
primary_hue="green",
|
||||
secondary_hue='orange',
|
||||
neutral_hue="gray",
|
||||
)
|
||||
|
||||
LANGUAGES = [
|
||||
"Afrikaans", "Arabic", "Armenian", "Azerbaijani", "Belarusian",
|
||||
"Bosnian", "Bulgarian", "Catalan", "Chinese", "Croatian",
|
||||
"Czech", "Danish", "Dutch", "English", "Estonian",
|
||||
"Finnish", "French", "Galician", "German", "Greek",
|
||||
"Hebrew", "Hindi", "Hungarian", "Icelandic", "Indonesian",
|
||||
"Italian", "Japanese", "Kannada", "Kazakh", "Korean",
|
||||
"Latvian", "Lithuanian", "Macedonian", "Malay", "Marathi",
|
||||
"Maori", "Nepali", "Norwegian", "Persian", "Polish",
|
||||
"Portuguese", "Romanian", "Russian", "Serbian", "Slovak",
|
||||
"Slovenian", "Spanish", "Swahili", "Swedish", "Tagalog",
|
||||
"Tamil", "Thai", "Turkish", "Ukrainian", "Urdu",
|
||||
"Vietnamese", "Welsh"
|
||||
]
|
||||
|
||||
class GradioTranscriptionInterface:
|
||||
"""
|
||||
Interface handling the interaction between Gradio UI and the Audio Transcription system.
|
||||
"""
|
||||
|
||||
def __init__(self, model: AutoTranscribe):
|
||||
"""
|
||||
Initializes the GradioTranscriptionInterface with a transcription model.
|
||||
|
||||
Args:
|
||||
model (AutoTranscribe): Model responsible for audio transcription tasks.
|
||||
"""
|
||||
self.model = model
|
||||
|
||||
def auto_transcribe(self, source,
|
||||
num_speakers : int,
|
||||
translation : bool,
|
||||
language : str):
|
||||
"""
|
||||
Shortcut method for the AutoTranscribe task.
|
||||
|
||||
Returns:
|
||||
tuple: Transcribed text (str), JSON output (dict)
|
||||
"""
|
||||
|
||||
kwargs = {
|
||||
"num_speakers": num_speakers if num_speakers != 0 else None,
|
||||
"language": language if language != "None" else None,
|
||||
"task": 'translate' if translation else None
|
||||
}
|
||||
|
||||
try:
|
||||
result = self.model.autotranscribe(source, **kwargs)
|
||||
except ValueError:
|
||||
raise gr.Error("Couldn't detect any speech in the provided audio. \
|
||||
Please try again!")
|
||||
return str(result), result.get_json()
|
||||
|
||||
|
||||
def transcribe(self, source, translation, language):
|
||||
"""
|
||||
Shortcut method for the Transcribe task.
|
||||
|
||||
Returns:
|
||||
str: Transcribed text.
|
||||
"""
|
||||
kwargs = {
|
||||
"language": language if language != "None" else None,
|
||||
"task": 'translate' if translation == "Yes" else None
|
||||
}
|
||||
|
||||
result = self.model.transcribe(source, **kwargs)
|
||||
return str(result)
|
||||
|
||||
def perform_diarisation(self, source, num_speakers):
|
||||
"""
|
||||
Shortcut method for the Diarisation task.
|
||||
|
||||
Returns:
|
||||
str: JSON output of diarisation result.
|
||||
"""
|
||||
kwargs = {
|
||||
"num_speakers": num_speakers if num_speakers != 0 else None,
|
||||
}
|
||||
|
||||
|
||||
try:
|
||||
result = self.model.diarization(source, **kwargs)
|
||||
except ValueError:
|
||||
raise gr.Error("Couldn't detect any speech in the provided audio. \
|
||||
Please try again!")
|
||||
return json.dumps(result, indent=2)
|
||||
|
||||
####
|
||||
# Gradio Interface
|
||||
####
|
||||
|
||||
def gradio_Interface(model : AutoTranscribe = None):
|
||||
|
||||
if model is None:
|
||||
model = AutoTranscribe()
|
||||
|
||||
pipe = GradioTranscriptionInterface(model)
|
||||
|
||||
def select_task(choice):
|
||||
if choice == 'Auto Transcribe':
|
||||
|
||||
return (gr.update(visible = True),
|
||||
gr.update(visible = True),
|
||||
gr.update(visible = True))
|
||||
|
||||
|
||||
elif choice == 'Transcribe':
|
||||
|
||||
return (gr.update(visible = False),
|
||||
gr.update(visible = True),
|
||||
gr.update(visible = True))
|
||||
|
||||
|
||||
elif choice == 'Diarisation':
|
||||
|
||||
return (gr.update(visible = True),
|
||||
gr.update(visible = False),
|
||||
gr.update(visible = False))
|
||||
|
||||
def select_origin(choice):
|
||||
if choice == "Upload Audio":
|
||||
|
||||
return (gr.update(visible = True),
|
||||
gr.update(visible = False, value = None),
|
||||
gr.update(visible = False, value = None),
|
||||
gr.update(visible = False, value = None),
|
||||
gr.update(visible = False, value = None))
|
||||
|
||||
elif choice == "Record Audio":
|
||||
|
||||
return (gr.update(visible = False, value = None),
|
||||
gr.update(visible = True),
|
||||
gr.update(visible = False, value = None),
|
||||
gr.update(visible = False, value = None),
|
||||
gr.update(visible = False, value = None))
|
||||
|
||||
elif choice == "Upload Video":
|
||||
|
||||
return (gr.update(visible = False, value = None),
|
||||
gr.update(visible = False, value = None),
|
||||
gr.update(visible = True),
|
||||
gr.update(visible = False, value = None),
|
||||
gr.update(visible = False, value = None))
|
||||
|
||||
elif choice == "Record Video":
|
||||
|
||||
return (gr.update(visible = False, value = None),
|
||||
gr.update(visible = False, value = None),
|
||||
gr.update(visible = False, value = None),
|
||||
gr.update(visible = True),
|
||||
gr.update(visible = False, value = None))
|
||||
|
||||
elif choice == "File":
|
||||
|
||||
return (gr.update(visible = False, value = None),
|
||||
gr.update(visible = False, value = None),
|
||||
gr.update(visible = False, value = None),
|
||||
gr.update(visible = False, value = None),
|
||||
gr.update(visible = True))
|
||||
|
||||
def run_scribe(task, num_speakers, translate, language, audio1, audio2, video1, video2, file_in, progress = gr.Progress(track_tqdm= True)):
|
||||
# get *args which are not None
|
||||
progress(0, desc='Starting task...')
|
||||
source = audio1 or audio2 or video1 or video2 or file_in
|
||||
|
||||
if task == 'Auto Transcribe':
|
||||
|
||||
out_str , out_json = pipe.auto_transcribe(source = source,
|
||||
num_speakers = num_speakers,
|
||||
translation = translate,
|
||||
language = language)
|
||||
|
||||
return (gr.update(value = out_str, visible = True),
|
||||
gr.update(value = out_json, visible = True),
|
||||
gr.update(visible = True),
|
||||
gr.update(visible = True))
|
||||
|
||||
elif task == 'Transcribe':
|
||||
|
||||
out = pipe.transcribe(source = source,
|
||||
translation = translate,
|
||||
language = language)
|
||||
|
||||
return (gr.update(value = out, visible = True),
|
||||
gr.update(value = None, visible = False),
|
||||
gr.update(visible = False),
|
||||
gr.update(visible = False))
|
||||
|
||||
elif task == 'Diarisation':
|
||||
|
||||
out = pipe.perform_diarisation(source = source,
|
||||
num_speakers = num_speakers)
|
||||
|
||||
return (gr.update(value = None, visible = False),
|
||||
gr.update(value = out, visible = True),
|
||||
gr.update(visible = False),
|
||||
gr.update(visible = False))
|
||||
|
||||
def annotate_output(annoation : str, out_json : dict):
|
||||
# get *args which are not None
|
||||
|
||||
trans = Transcript.from_json(out_json)
|
||||
trans = trans.annotate(*annoation.split(","))
|
||||
|
||||
return gr.update(value = str(trans)),gr.update(value = trans.get_json())
|
||||
|
||||
|
||||
with gr.Blocks(theme=theme,title='ScrAIbe: Automatic Audio Transcription') as demo:
|
||||
|
||||
# Define components
|
||||
header = open("header.html", "r").read()
|
||||
gr.HTML(header, visible= True, show_label=False)
|
||||
|
||||
with gr.Row():
|
||||
|
||||
with gr.Column():
|
||||
|
||||
task = gr.Radio(["Auto Transcribe", "Transcribe", "Diarisation"], label="Task",
|
||||
value= 'Auto Transcribe')
|
||||
|
||||
num_speakers = gr.Number(value=0, label= "Number of speakers (optional)",
|
||||
info = "Number of speakers in the audio file. If you don't know,\
|
||||
leave it at 0.", visible= True)
|
||||
|
||||
translate = gr.Checkbox(label="Translation", choices=[True, False], value = False,
|
||||
info="Select 'Yes' to have the output translated into English.",
|
||||
visible= True)
|
||||
|
||||
language = gr.Dropdown(LANGUAGES,
|
||||
label="Language (optional)", value = "None",
|
||||
info="Language of the audio file. If you don't know,\
|
||||
leave it at None.", visible= True)
|
||||
|
||||
input = gr.Radio(["Upload Audio", "Record Audio", "Upload Video","Record Video"
|
||||
,"File"], label="Input Type", value="Upload Audio")
|
||||
|
||||
audio1 = gr.Audio(source="upload", type="filepath", label="Upload Audio",
|
||||
interactive= True, visible= True)
|
||||
audio2 = gr.Audio(source="microphone", label="Record Audio", type="filepath",
|
||||
interactive= True, visible= False)
|
||||
video1 = gr.Video(source="upload", type="filepath", label="Upload Video",
|
||||
interactive= True, visible= False)
|
||||
video2 = gr.Video(source="webcam", label="Record Video", type="filepath",
|
||||
interactive= True, visible= False)
|
||||
file_in = gr.File(label="Upload File", interactive= True, visible= False)
|
||||
|
||||
submit = gr.Button()
|
||||
|
||||
with gr.Column():
|
||||
|
||||
out_txt = gr.Textbox(label="Output",
|
||||
visible= True, show_copy_button=True)
|
||||
|
||||
out_json = gr.JSON(label="JSON Output",
|
||||
visible= False, show_copy_button=True)
|
||||
|
||||
annoation = gr.Textbox(label="Name your speaker's",
|
||||
info= "Please provide a list of the speakers arranged \
|
||||
in the order in which they appear in the input. Use comma ',' \
|
||||
as a seperator. Be aware that the first name is given \
|
||||
to SPEAKER_00 the second to SPEAKER_01 and so on.",
|
||||
visible= False, interactive= True)
|
||||
|
||||
annotate = gr.Button(value="Annotate", visible= False, interactive= True)
|
||||
|
||||
# Define usage of components
|
||||
input.change(fn=select_origin, inputs=[input],
|
||||
outputs=[audio1, audio2, video1, video2, file_in])
|
||||
|
||||
task.change(fn=select_task, inputs=[task],
|
||||
outputs=[num_speakers, translate, language])
|
||||
|
||||
translate.change(fn= lambda x : gr.update(value = x),
|
||||
inputs=[translate], outputs=[translate])
|
||||
num_speakers.change(fn= lambda x : gr.update(value = x),
|
||||
inputs=[num_speakers], outputs=[num_speakers])
|
||||
language.change(fn= lambda x : gr.update(value = x),
|
||||
inputs=[language], outputs=[language])
|
||||
|
||||
submit.click(fn = run_scribe,
|
||||
inputs=[task, num_speakers, translate, language, audio1,
|
||||
audio2, video1, video2, file_in],
|
||||
outputs=[out_txt, out_json, annoation, annotate])
|
||||
|
||||
annotate.click(fn = annotate_output, inputs=[annoation, out_json],
|
||||
outputs=[out_txt, out_json])
|
||||
|
||||
return demo
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
|
||||
gradio_Interface().queue().launch()
|
||||
@@ -0,0 +1,66 @@
|
||||
<!-- Importing Cormorant Garamond font from Google Fonts -->
|
||||
<link href="https://fonts.googleapis.com/css2?family=Cormorant+Garamond:wght@400;700&display=swap" rel="stylesheet">
|
||||
|
||||
<style>
|
||||
.header-container {
|
||||
display: flex;
|
||||
align-items: center;
|
||||
justify-content: center;
|
||||
position: relative;
|
||||
padding-top: 30px;
|
||||
}
|
||||
.logo-container {
|
||||
position: absolute;
|
||||
top: 50%;
|
||||
right: 20px;
|
||||
transform: translateY(-50%);
|
||||
width: 300px;
|
||||
}
|
||||
.logo {
|
||||
width: 100%;
|
||||
height: auto;
|
||||
}
|
||||
h1 {
|
||||
font-family: 'Cormorant Garamond', serif;
|
||||
font-size: 50px !important; /* Increased font size */
|
||||
font-weight: bold;
|
||||
color: #50AF31;
|
||||
margin: 0;
|
||||
position: relative;
|
||||
padding: 0.5em 0;
|
||||
}
|
||||
h1::before, h1::after {
|
||||
content: "";
|
||||
position: absolute;
|
||||
height: 2px;
|
||||
width: 80%;
|
||||
background-color: #50AF31;
|
||||
left: 10%;
|
||||
}
|
||||
h1::before {
|
||||
top: 0.5em;
|
||||
}
|
||||
h1::after {
|
||||
bottom: 0.5em;
|
||||
}
|
||||
p, h2 {
|
||||
font-size: 16px;
|
||||
margin: 10px 0;
|
||||
line-height: 1.4;
|
||||
}
|
||||
</style>
|
||||
|
||||
<div class="header-container">
|
||||
<h1>ScrAIbe</h1>
|
||||
<div class="logo-container">
|
||||
<a href="https://www.kida-bmel.de/"> <!-- Replace with your actual URL -->
|
||||
<img src="file/Logo_KIDA_bmel_green.svg" alt="KIDA Logo" class="logo">
|
||||
</a>
|
||||
</div>
|
||||
</div>
|
||||
<div style="text-align: center; padding: 20px 10%;">
|
||||
<p>
|
||||
Upload, record, or provide a video with audio for transcription. Our toolkit is designed to transcribe content from multiple languages accurately. The integrated speaker diarisation feature identifies different speakers, ensuring a smooth transcription experience. For optimal results, indicate the number of speakers and the original language of the content.
|
||||
</p>
|
||||
<h2 style="font-weight: bold; color: #50AF31;">What would you like to do next?</h2>
|
||||
</div>
|
||||
@@ -0,0 +1,319 @@
|
||||
"""
|
||||
This file contains a modified version of qtfaststart by qtfaststart
|
||||
https://github.com/danielgtaylor/qtfaststart/tree/master
|
||||
|
||||
All credit goes to the original author.
|
||||
Copyright (C) 2008 - 2013 Daniel G. Taylor <dan@programmer-art.org>
|
||||
Permission is hereby granted, free of charge, to any person obtaining a copy of this
|
||||
software and associated documentation files (the "Software"),
|
||||
to deal in the Software without restriction, including without limitation the rights to
|
||||
use, copy, modify, merge, publish, distribute, sublicense, and/or sell copies of the
|
||||
Software, and to permit persons to whom the Software is furnished to do so,
|
||||
subject to the following conditions:
|
||||
|
||||
The above copyright notice and this permission notice shall be included in all copies
|
||||
or substantial portions of the Software.
|
||||
|
||||
THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR IMPLIED,
|
||||
INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
|
||||
FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT.
|
||||
IN NO EVENT SHALL THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM,
|
||||
DAMAGES OR OTHER LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE,
|
||||
ARISING FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS
|
||||
IN THE SOFTWARE.
|
||||
"""
|
||||
|
||||
import logging
|
||||
import os
|
||||
import struct
|
||||
import collections
|
||||
import io
|
||||
|
||||
# define error classes
|
||||
class FastStartException(Exception):
|
||||
"""
|
||||
Raised when something bad happens during processing.
|
||||
"""
|
||||
pass
|
||||
|
||||
class FastStartSetupError(FastStartException):
|
||||
"""
|
||||
Rasised when asked to process a file that does not need processing
|
||||
"""
|
||||
pass
|
||||
|
||||
class MalformedFileError(FastStartException):
|
||||
"""
|
||||
Raised when the input file is setup in an unexpected way
|
||||
"""
|
||||
pass
|
||||
|
||||
class UnsupportedFormatError(FastStartException):
|
||||
"""
|
||||
Raised when a movie file is recognized as a format not supported.
|
||||
"""
|
||||
pass
|
||||
|
||||
# define constants
|
||||
CHUNK_SIZE = 8192
|
||||
|
||||
log = logging.getLogger("qtfaststart")
|
||||
|
||||
# Older versions of Python require this to be defined
|
||||
if not hasattr(os, 'SEEK_CUR'):
|
||||
os.SEEK_CUR = 1
|
||||
|
||||
Atom = collections.namedtuple('Atom', 'name position size')
|
||||
|
||||
def read_atom(datastream):
|
||||
"""
|
||||
Read an atom and return a tuple of (size, type) where size is the size
|
||||
in bytes (including the 8 bytes already read) and type is a "fourcc"
|
||||
like "ftyp" or "moov".
|
||||
"""
|
||||
size, type = struct.unpack(">L4s", datastream.read(8))
|
||||
type = type.decode('ascii')
|
||||
return size, type
|
||||
|
||||
|
||||
def _read_atom_ex(datastream):
|
||||
"""
|
||||
Read an Atom from datastream
|
||||
"""
|
||||
pos = datastream.tell()
|
||||
atom_size, atom_type = read_atom(datastream)
|
||||
if atom_size == 1:
|
||||
atom_size, = struct.unpack(">Q", datastream.read(8))
|
||||
return Atom(atom_type, pos, atom_size)
|
||||
|
||||
|
||||
def get_index(datastream):
|
||||
"""
|
||||
Return an index of top level atoms, their absolute byte-position in the
|
||||
file and their size in a list:
|
||||
|
||||
index = [
|
||||
("ftyp", 0, 24),
|
||||
("moov", 25, 2658),
|
||||
("free", 2683, 8),
|
||||
...
|
||||
]
|
||||
|
||||
The tuple elements will be in the order that they appear in the file.
|
||||
"""
|
||||
log.debug("Getting index of top level atoms...")
|
||||
|
||||
index = list(_read_atoms(datastream))
|
||||
_ensure_valid_index(index)
|
||||
|
||||
return index
|
||||
|
||||
|
||||
def _read_atoms(datastream):
|
||||
"""
|
||||
Read atoms until an error occurs
|
||||
"""
|
||||
while datastream:
|
||||
try:
|
||||
atom = _read_atom_ex(datastream)
|
||||
log.debug("%s: %s" % (atom.name, atom.size))
|
||||
except:
|
||||
break
|
||||
|
||||
yield atom
|
||||
|
||||
if atom.size == 0:
|
||||
if atom.name == "mdat":
|
||||
# Some files may end in mdat with no size set, which generally
|
||||
# means to seek to the end of the file. We can just stop indexing
|
||||
# as no more entries will be found!
|
||||
break
|
||||
else:
|
||||
# Weird, but just continue to try to find more atoms
|
||||
continue
|
||||
|
||||
datastream.seek(atom.position + atom.size)
|
||||
|
||||
|
||||
def _ensure_valid_index(index):
|
||||
"""
|
||||
Ensure the minimum viable atoms are present in the index.
|
||||
|
||||
Raise FastStartException if not.
|
||||
"""
|
||||
top_level_atoms = set([item.name for item in index])
|
||||
for key in ["moov", "mdat"]:
|
||||
if key not in top_level_atoms:
|
||||
log.error("%s atom not found, is this a valid MOV/MP4 file?" % key)
|
||||
raise FastStartException()
|
||||
|
||||
|
||||
def find_atoms(size, datastream):
|
||||
"""
|
||||
Compatibilty interface for _find_atoms_ex
|
||||
"""
|
||||
fake_parent = Atom('fake', datastream.tell()-8, size+8)
|
||||
for atom in _find_atoms_ex(fake_parent, datastream):
|
||||
yield atom.name
|
||||
|
||||
|
||||
def _find_atoms_ex(parent_atom, datastream):
|
||||
"""
|
||||
Yield either "stco" or "co64" Atoms from datastream.
|
||||
datastream will be 8 bytes into the stco or co64 atom when the value
|
||||
is yielded.
|
||||
|
||||
It is assumed that datastream will be at the end of the atom after
|
||||
the value has been yielded and processed.
|
||||
|
||||
parent_atom is the parent atom, a 'moov' or other ancestor of CO
|
||||
atoms in the datastream.
|
||||
"""
|
||||
stop = parent_atom.position + parent_atom.size
|
||||
|
||||
while datastream.tell() < stop:
|
||||
try:
|
||||
atom = _read_atom_ex(datastream)
|
||||
except:
|
||||
log.exception("Error reading next atom!")
|
||||
raise FastStartException()
|
||||
|
||||
if atom.name in ["trak", "mdia", "minf", "stbl"]:
|
||||
# Known ancestor atom of stco or co64, search within it!
|
||||
for res in _find_atoms_ex(atom, datastream):
|
||||
yield res
|
||||
elif atom.name in ["stco", "co64"]:
|
||||
yield atom
|
||||
else:
|
||||
# Ignore this atom, seek to the end of it.
|
||||
datastream.seek(atom.position + atom.size)
|
||||
|
||||
|
||||
def process(infilename, limit=float('inf')):
|
||||
"""
|
||||
Convert a Quicktime/MP4 file for streaming by moving the metadata to
|
||||
the front of the file. This method writes a new file.
|
||||
|
||||
If limit is set to something other than zero it will be used as the
|
||||
number of bytes to write of the atoms following the moov atom. This
|
||||
is very useful to create a small sample of a file with full headers,
|
||||
which can then be used in bug reports and such.
|
||||
"""
|
||||
if isinstance(infilename, str):
|
||||
datastream = open(infilename, "rb")
|
||||
elif isinstance(infilename, bytes):
|
||||
datastream = io.BytesIO(infilename)
|
||||
else:
|
||||
raise TypeError("infilename must be a filename, bytes or file-like object")
|
||||
# Get the top level atom index
|
||||
index = get_index(datastream)
|
||||
|
||||
mdat_pos = 999999
|
||||
free_size = 0
|
||||
|
||||
# Make sure moov occurs AFTER mdat, otherwise no need to run!
|
||||
for atom in index:
|
||||
# The atoms are guaranteed to exist from get_index above!
|
||||
if atom.name == "moov":
|
||||
moov_atom = atom
|
||||
moov_pos = atom.position
|
||||
elif atom.name == "mdat":
|
||||
mdat_pos = atom.position
|
||||
elif atom.name == "free" and atom.position < mdat_pos:
|
||||
# This free atom is before the mdat!
|
||||
free_size += atom.size
|
||||
log.info("Removing free atom at %d (%d bytes)" % (atom.position, atom.size))
|
||||
elif atom.name == "\x00\x00\x00\x00" and atom.position < mdat_pos:
|
||||
# This is some strange zero atom with incorrect size
|
||||
free_size += 8
|
||||
log.info("Removing strange zero atom at %s (8 bytes)" % atom.position)
|
||||
|
||||
# Offset to shift positions
|
||||
offset = moov_atom.size - free_size
|
||||
|
||||
if moov_pos < mdat_pos:
|
||||
# moov appears to be in the proper place, don't shift by moov size
|
||||
offset -= moov_atom.size
|
||||
if not free_size:
|
||||
# No free atoms and moov is correct, we are done!
|
||||
log.error("This file appears to already be setup for streaming!")
|
||||
# Stupid hack to retrun the non-processed file:
|
||||
if isinstance(infilename, str):
|
||||
return open(infilename, "rb").read()
|
||||
elif isinstance(infilename, bytes):
|
||||
return io.BytesIO(infilename).read()
|
||||
|
||||
# Read and fix moov
|
||||
moov = _patch_moov(datastream, moov_atom, offset)
|
||||
|
||||
log.info("Writing output...")
|
||||
outfile = b''
|
||||
|
||||
# Write ftype
|
||||
for atom in index:
|
||||
if atom.name == "ftyp":
|
||||
log.debug("Writing ftyp... (%d bytes)" % atom.size)
|
||||
datastream.seek(atom.position)
|
||||
outfile += datastream.read(atom.size)
|
||||
|
||||
# Write moov
|
||||
_bytes = moov.getvalue()
|
||||
log.debug("Writing moov... (%d bytes)" % len(_bytes))
|
||||
outfile += _bytes
|
||||
|
||||
# Write the rest
|
||||
atoms = [item for item in index if item.name not in ["ftyp", "moov", "free"]]
|
||||
for atom in atoms:
|
||||
log.debug("Writing %s... (%d bytes)" % (atom.name, atom.size))
|
||||
datastream.seek(atom.position)
|
||||
|
||||
# for compatability, allow '0' to mean no limit
|
||||
cur_limit = limit or float('inf')
|
||||
cur_limit = min(cur_limit, atom.size)
|
||||
|
||||
for chunk in get_chunks(datastream, CHUNK_SIZE, cur_limit):
|
||||
outfile += chunk
|
||||
|
||||
return outfile
|
||||
|
||||
|
||||
def _patch_moov(datastream, atom, offset):
|
||||
datastream.seek(atom.position)
|
||||
moov = io.BytesIO(datastream.read(atom.size))
|
||||
|
||||
# reload the atom from the fixed stream
|
||||
atom = _read_atom_ex(moov)
|
||||
|
||||
for atom in _find_atoms_ex(atom, moov):
|
||||
# Read either 32-bit or 64-bit offsets
|
||||
ctype, csize = dict(
|
||||
stco=('L', 4),
|
||||
co64=('Q', 8),
|
||||
)[atom.name]
|
||||
|
||||
# Get number of entries
|
||||
version, entry_count = struct.unpack(">2L", moov.read(8))
|
||||
|
||||
log.info("Patching %s with %d entries" % (atom.name, entry_count))
|
||||
|
||||
entries_pos = moov.tell()
|
||||
|
||||
struct_fmt = ">%(entry_count)s%(ctype)s" % vars()
|
||||
|
||||
# Read entries
|
||||
entries = struct.unpack(struct_fmt, moov.read(csize * entry_count))
|
||||
|
||||
# Patch and write entries
|
||||
offset_entries = [entry + offset for entry in entries]
|
||||
moov.seek(entries_pos)
|
||||
moov.write(struct.pack(struct_fmt, *offset_entries))
|
||||
return moov
|
||||
|
||||
def get_chunks(stream, chunk_size, limit):
|
||||
remaining = limit
|
||||
while remaining:
|
||||
chunk = stream.read(min(remaining, chunk_size))
|
||||
if not chunk:
|
||||
return
|
||||
remaining -= len(chunk)
|
||||
yield chunk
|
||||
@@ -0,0 +1,150 @@
|
||||
"""
|
||||
Audio Processor Module
|
||||
=======================
|
||||
|
||||
This module provides the AudioProcessor class, utilizing PyTorchaudio for handling audio files.
|
||||
It includes functionalities to load, cut, and manage audio waveforms, offering efficient and
|
||||
flexible audio processing.
|
||||
|
||||
Available Classes:
|
||||
- AudioProcessor: Processes audio waveforms and provides methods for loading,
|
||||
cutting, and handling audio.
|
||||
|
||||
Usage:
|
||||
from .audio_import AudioProcessor
|
||||
|
||||
processor = AudioProcessor.from_file("path/to/audiofile.wav")
|
||||
cut_waveform = processor.cut(start=1.0, end=5.0)
|
||||
|
||||
Constants:
|
||||
- SAMPLE_RATE (int): Default sample rate for processing.
|
||||
- NORMALIZATION_FACTOR (float): Normalization factor for audio waveform.
|
||||
"""
|
||||
|
||||
from subprocess import CalledProcessError, run
|
||||
import numpy as np
|
||||
import torch
|
||||
|
||||
SAMPLE_RATE = 16000
|
||||
NORMALIZATION_FACTOR = 32768.0
|
||||
|
||||
class AudioProcessor:
|
||||
"""
|
||||
Audio Processor class that leverages PyTorchaudio to provide functionalities
|
||||
for loading, cutting, and handling audio waveforms.
|
||||
|
||||
Attributes:
|
||||
waveform: torch.Tensor
|
||||
The audio waveform tensor.
|
||||
sr: int
|
||||
The sample rate of the audio.
|
||||
"""
|
||||
|
||||
def __init__(self, waveform: torch.Tensor, sr : int = SAMPLE_RATE,
|
||||
*args, **kwargs) -> None:
|
||||
|
||||
"""
|
||||
Initialize the AudioProcessor object.
|
||||
|
||||
Args:
|
||||
waveform (torch.Tensor): The audio waveform tensor.
|
||||
sr (int, optional): The sample rate of the audio. Defaults to SAMPLE_RATE.
|
||||
args: Additional arguments.
|
||||
kwargs: Additional keyword arguments, e.g., device to use for processing.
|
||||
If CUDA is available, it defaults to CUDA.
|
||||
|
||||
Raises:
|
||||
ValueError: If the provided sample rate is not of type int.
|
||||
"""
|
||||
|
||||
device = kwargs.get("device", "cuda" if torch.cuda.is_available() else "cpu")
|
||||
|
||||
self.waveform = waveform.to(device)
|
||||
self.sr = sr
|
||||
|
||||
if not isinstance(self.sr, int):
|
||||
raise ValueError("Sample rate should be a single value of type int," \
|
||||
f"not {len(self.sr)} and type {type(self.sr)}")
|
||||
|
||||
@classmethod
|
||||
def from_file(cls, file: str, *args, **kwargs) -> 'AudioProcessor':
|
||||
"""
|
||||
Create an AudioProcessor instance from an audio file.
|
||||
|
||||
Args:
|
||||
file (str): The audio file path.
|
||||
|
||||
Returns:
|
||||
AudioProcessor: An instance of the AudioProcessor class containing the loaded audio.
|
||||
"""
|
||||
|
||||
audio, sr = cls.load_audio(file , *args, **kwargs)
|
||||
|
||||
audio = torch.from_numpy(audio)
|
||||
|
||||
return cls(audio, sr)
|
||||
|
||||
|
||||
def cut(self, start: float, end: float) -> torch.Tensor:
|
||||
"""
|
||||
Cut a segment from the audio waveform between the specified start and end times.
|
||||
|
||||
Args:
|
||||
start (float): Start time in seconds.
|
||||
end (float): End time in seconds.
|
||||
|
||||
Returns:
|
||||
torch.Tensor: The cut waveform segment.
|
||||
"""
|
||||
|
||||
start = int(start * self.sr)
|
||||
if (isinstance(end, float) or isinstance(end, int)) and isinstance(self.sr, int):
|
||||
end = int(np.ceil(end * self.sr))
|
||||
else:
|
||||
end = int(torch.ceil(end * self.sr))
|
||||
return self.waveform[start:end]
|
||||
|
||||
@staticmethod
|
||||
def load_audio(file: str, sr: int = SAMPLE_RATE):
|
||||
"""
|
||||
Open an audio file and read it as a mono waveform, resampling if necessary.
|
||||
This method ensures compatibility with pyannote.audio
|
||||
and requires the ffmpeg CLI in PATH.
|
||||
|
||||
Args:
|
||||
file (str): The audio file to open.
|
||||
sr (int, optional): The desired sample rate. Defaults to SAMPLE_RATE.
|
||||
|
||||
Returns:
|
||||
tuple: A NumPy array containing the audio waveform in float32 dtype
|
||||
and the sample rate.
|
||||
|
||||
Raises:
|
||||
RuntimeError: If failed to load audio.
|
||||
"""
|
||||
# This launches a subprocess to decode audio while down-mixing
|
||||
# and resampling as necessary. Requires the ffmpeg CLI in PATH.
|
||||
# fmt: off
|
||||
cmd = [
|
||||
"ffmpeg",
|
||||
"-nostdin",
|
||||
"-threads", "0",
|
||||
"-i", file,
|
||||
"-f", "s16le",
|
||||
"-ac", "1",
|
||||
"-acodec", "pcm_s16le",
|
||||
"-ar", str(sr),
|
||||
"-"
|
||||
]
|
||||
# fmt: on
|
||||
try:
|
||||
out = run(cmd, capture_output=True, check=True).stdout
|
||||
except CalledProcessError as e:
|
||||
raise RuntimeError(f"Failed to load audio: {e.stderr.decode()}") from e
|
||||
|
||||
out = np.frombuffer(out, np.int16).flatten().astype(np.float32) / NORMALIZATION_FACTOR
|
||||
|
||||
return out , sr
|
||||
|
||||
def __repr__(self) -> str:
|
||||
return f'TorchAudioProcessor(waveform={len(self.waveform)}, sr={int(self.sr)})'
|
||||
@@ -0,0 +1,283 @@
|
||||
"""
|
||||
AutoTranscribe Class
|
||||
--------------------
|
||||
|
||||
This class serves as the core of the transcription system, responsible for handling
|
||||
transcription and diarization of audio files. It leverages pretrained models for
|
||||
speech-to-text (such as Whisper) and speaker diarization (such as pyannote.audio),
|
||||
providing an accessible interface for audio processing tasks such as transcription,
|
||||
speaker separation, and timestamping.
|
||||
|
||||
By encapsulating the complexities of underlying models, it allows for straightforward
|
||||
integration into various applications, ranging from transcription services to voice assistants.
|
||||
|
||||
Available Classes:
|
||||
- AutoTranscribe: Main class for performing transcription and diarization.
|
||||
Includes methods for loading models, processing audio files,
|
||||
and formatting the transcription output.
|
||||
|
||||
Usage:
|
||||
from .autotranscribe import AutoTranscribe
|
||||
|
||||
model = AutoTranscribe(whisper_model="path/to/whisper/model", dia_model="path/to/diarisation/model")
|
||||
transcript = model.transcribe("path/to/audiofile.wav")
|
||||
"""
|
||||
|
||||
# Standard Library Imports
|
||||
import os
|
||||
from glob import iglob
|
||||
from subprocess import run
|
||||
from typing import TypeVar, Union
|
||||
from warnings import warn
|
||||
|
||||
# Third-Party Imports
|
||||
import torch
|
||||
from numpy import ndarray
|
||||
from tqdm import trange
|
||||
|
||||
# Application-Specific Imports
|
||||
from .audio import AudioProcessor
|
||||
from .diarisation import Diariser
|
||||
from .transcriber import Transcriber, whisper
|
||||
from .transcript_exporter import Transcript
|
||||
|
||||
|
||||
DiarisationType = TypeVar('DiarisationType')
|
||||
|
||||
|
||||
class AutoTranscribe:
|
||||
"""
|
||||
AutoTranscribe is a class responsible for managing the transcription and diarization of audio files.
|
||||
It serves as the core of the transcription system, incorporating pretrained models
|
||||
for speech-to-text (such as Whisper) and speaker diarization (such as pyannote.audio),
|
||||
allowing for comprehensive audio processing.
|
||||
|
||||
Attributes:
|
||||
transcriber (Transcriber): The transcriber object to handle transcription.
|
||||
diariser (Diariser): The diariser object to handle diarization.
|
||||
|
||||
Methods:
|
||||
__init__: Initializes the AutoTranscribe class with appropriate models.
|
||||
transcribe: Transcribes an audio file using the whisper model and pyannote diarization model.
|
||||
remove_audio_file: Removes the original audio file to avoid disk space issues or ensure data privacy.
|
||||
get_audio_file: Gets an audio file as an AudioProcessor object.
|
||||
"""
|
||||
def __init__(self,
|
||||
whisper_model: Union[bool, str, whisper] = None,
|
||||
dia_model : Union[bool, str, DiarisationType] = None,
|
||||
**kwargs) -> None:
|
||||
"""Initializes the AutoTranscribe class.
|
||||
|
||||
Args:
|
||||
whisper_model (Union[bool, str, whisper], optional):
|
||||
Path to whisper model or whisper model itself.
|
||||
diarisation_model (Union[bool, str, DiarisationType], optional):
|
||||
Path to pyannote diarization model or model itself.
|
||||
**kwargs: Additional keyword arguments for whisper
|
||||
and pyannote diarization models.
|
||||
"""
|
||||
|
||||
|
||||
if whisper_model is None:
|
||||
self.transcriber = Transcriber.load_model("medium", **kwargs)
|
||||
elif isinstance(whisper_model, str):
|
||||
self.transcriber = Transcriber.load_model(whisper_model, **kwargs)
|
||||
else:
|
||||
self.transcriber = whisper_model
|
||||
|
||||
if dia_model is None:
|
||||
self.diariser = Diariser.load_model(**kwargs)
|
||||
elif isinstance(dia_model, str):
|
||||
self.diariser = Diariser.load_model(dia_model, **kwargs)
|
||||
else:
|
||||
self.diariser = dia_model
|
||||
|
||||
print("AutoTranscribe initialized all models successfully loaded.")
|
||||
|
||||
def autotranscribe(self, audio_file : Union[str, torch.Tensor, ndarray],
|
||||
remove_original : bool = False,
|
||||
**kwargs) -> Transcript:
|
||||
"""
|
||||
Transcribes an audio file using the whisper model and pyannote diarization model.
|
||||
|
||||
Args:
|
||||
audio_file (Union[str, torch.Tensor, ndarray]):
|
||||
Path to audio file or a tensor representing the audio.
|
||||
remove_original (bool, optional): If True, the original audio file will
|
||||
be removed after transcription.
|
||||
*args: Additional positional arguments for diarization and transcription.
|
||||
**kwargs: Additional keyword arguments for diarization and transcription.
|
||||
|
||||
Returns:
|
||||
Transcript: A Transcript object containing the transcription,
|
||||
which can be exported to different formats.
|
||||
"""
|
||||
|
||||
# Get audio file as an AudioProcessor object
|
||||
audio_file = self.get_audio_file(audio_file)
|
||||
|
||||
# Prepare waveform and sample rate for diarization
|
||||
dia_audio = {
|
||||
"waveform" : audio_file.waveform.reshape(1,len(audio_file.waveform)),
|
||||
"sample_rate": audio_file.sr
|
||||
}
|
||||
|
||||
print("Starting diarisation.")
|
||||
|
||||
diarisation = self.diariser.diarization(dia_audio, **kwargs)
|
||||
|
||||
if not diarisation["segments"]:
|
||||
print("No segments found. Try to run transcription without diarisation.")
|
||||
|
||||
transcript = self.transcriber.transcribe(audio_file.waveform, **kwargs)
|
||||
|
||||
final_transcript= {0 : {"speakers" : 'SPEAKER_01',
|
||||
"segments" : [0, len(audio_file.waveform)],
|
||||
"text" : transcript}}
|
||||
|
||||
return Transcript(final_transcript)
|
||||
|
||||
print("Diarisation finished. Starting transcription.")
|
||||
|
||||
audio_file.sr = torch.Tensor([audio_file.sr]).to(audio_file.waveform.device)
|
||||
|
||||
# Transcribe each segment and store the results
|
||||
final_transcript = dict()
|
||||
|
||||
for i in trange(len(diarisation["segments"]), desc= "Transcribing"):
|
||||
|
||||
seg = diarisation["segments"][i]
|
||||
|
||||
audio = audio_file.cut(seg[0], seg[1])
|
||||
|
||||
transcript = self.transcriber.transcribe(audio, **kwargs)
|
||||
|
||||
final_transcript[i] = {"speakers" : diarisation["speakers"][i],
|
||||
"segments" : seg,
|
||||
"text" : transcript}
|
||||
|
||||
# Remove original file if needed
|
||||
if remove_original:
|
||||
if kwargs.get("shred") is True:
|
||||
self.remove_audio_file(audio_file, shred=True)
|
||||
else:
|
||||
self.remove_audio_file(audio_file, shred=False)
|
||||
|
||||
return Transcript(final_transcript)
|
||||
|
||||
def diarization(self, audio_file : Union[str, torch.Tensor, ndarray],
|
||||
**kwargs) -> dict:
|
||||
"""
|
||||
Perform diarization on an audio file using the pyannote diarization model.
|
||||
|
||||
Args:
|
||||
audio_file (Union[str, torch.Tensor, ndarray]):
|
||||
The audio source which can either be a path to the audio file or a tensor representation.
|
||||
**kwargs:
|
||||
Additional keyword arguments for diarization.
|
||||
|
||||
Returns:
|
||||
dict:
|
||||
A dictionary containing the results of the diarization process.
|
||||
"""
|
||||
|
||||
# Get audio file as an AudioProcessor object
|
||||
audio_file = self.get_audio_file(audio_file)
|
||||
|
||||
# Prepare waveform and sample rate for diarization
|
||||
dia_audio = {
|
||||
"waveform" : audio_file.waveform.reshape(1,len(audio_file.waveform)),
|
||||
"sample_rate": audio_file.sr
|
||||
}
|
||||
|
||||
print("Starting diarisation.")
|
||||
|
||||
diarisation = self.diariser.diarization(dia_audio, **kwargs)
|
||||
|
||||
return diarisation
|
||||
|
||||
def transcribe(self, audio_file : Union[str, torch.Tensor, ndarray],
|
||||
**kwargs):
|
||||
"""
|
||||
Transcribe the provided audio file.
|
||||
|
||||
Args:
|
||||
audio_file (Union[str, torch.Tensor, ndarray]):
|
||||
The audio source, which can either be a path or a tensor representation.
|
||||
**kwargs:
|
||||
Additional keyword arguments for transcription.
|
||||
|
||||
Returns:
|
||||
str:
|
||||
The transcribed text from the audio source.
|
||||
"""
|
||||
audio_file = self.get_audio_file(audio_file)
|
||||
|
||||
return self.transcriber.transcribe(audio_file.waveform, **kwargs)
|
||||
@staticmethod
|
||||
def remove_audio_file(audio_file : str,
|
||||
shred : bool = False) -> None:
|
||||
"""
|
||||
Removes the original audio file to avoid disk space issues or ensure data privacy.
|
||||
|
||||
Args:
|
||||
audio_file_path (str): Path to the audio file.
|
||||
shred (bool, optional): If True, the audio file will be shredded,
|
||||
not just removed.
|
||||
"""
|
||||
if not os.path.exists(audio_file):
|
||||
raise ValueError(f"Audiofile {audio_file} does not exist.")
|
||||
|
||||
if shred:
|
||||
|
||||
warn("Shredding audiofile can take a long time.", RuntimeWarning)
|
||||
|
||||
gen = iglob(f'{audio_file}', recursive=True)
|
||||
cmd = ['shred', '-zvu', '-n', '10', f'{audio_file}']
|
||||
|
||||
if os.path.isdir(audio_file):
|
||||
raise ValueError(f"Audiofile {audio_file} is a directory.")
|
||||
|
||||
for file in gen:
|
||||
print(f'shredding {file} now\n')
|
||||
|
||||
run(cmd , check=True)
|
||||
|
||||
else:
|
||||
os.remove(audio_file)
|
||||
print(f"Audiofile {audio_file} removed.")
|
||||
|
||||
|
||||
|
||||
@staticmethod
|
||||
def get_audio_file(audio_file : Union[str, torch.Tensor, ndarray],
|
||||
*args, **kwargs) -> AudioProcessor:
|
||||
"""Gets an audio file as TorchAudioProcessor.
|
||||
|
||||
Args:
|
||||
audio_file (Union[str, torch.Tensor, ndarray]): Path to the audio file or
|
||||
a tensor representing the audio.
|
||||
*args: Additional positional arguments.
|
||||
**kwargs: Additional keyword arguments.
|
||||
|
||||
Returns:
|
||||
AudioProcessor: An object containing the waveform and sample rate in
|
||||
torch.Tensor format.
|
||||
"""
|
||||
|
||||
if isinstance(audio_file, str):
|
||||
audio_file = AudioProcessor.from_file(audio_file)
|
||||
|
||||
elif isinstance(audio_file, torch.Tensor):
|
||||
audio_file = AudioProcessor(audio_file[0], audio_file[1])
|
||||
elif isinstance(audio_file, ndarray):
|
||||
audio_file = AudioProcessor(torch.Tensor(audio_file[0]),
|
||||
audio_file[1])
|
||||
|
||||
if not isinstance(audio_file, AudioProcessor):
|
||||
raise ValueError(f'Audiofile must be of type AudioProcessor,' \
|
||||
f'not {type(audio_file)}')
|
||||
return audio_file
|
||||
|
||||
def __repr__(self):
|
||||
return f"AutoTranscribe(transcriber={self.transcriber}, diariser={self.diariser})"
|
||||
+169
@@ -0,0 +1,169 @@
|
||||
"""
|
||||
Command-Line Interface (CLI) for the AutoTranscribe class,
|
||||
allowing for user interaction to transcribe and diarize audio files.
|
||||
The function includes arguments for specifying the audio files, model paths,
|
||||
output formats, and other options necessary for transcription.
|
||||
"""
|
||||
import os
|
||||
from argparse import ArgumentParser, ArgumentDefaultsHelpFormatter
|
||||
import json
|
||||
|
||||
from sympy import use
|
||||
|
||||
from .autotranscript import AutoTranscribe
|
||||
from .app.gradio_app import gradio_Interface
|
||||
|
||||
from whisper.tokenizer import LANGUAGES , TO_LANGUAGE_CODE
|
||||
from torch.cuda import is_available
|
||||
from torch import set_num_threads
|
||||
|
||||
|
||||
def cli():
|
||||
"""
|
||||
Command-Line Interface (CLI) for the AutoTranscribe class, allowing for user interaction to transcribe
|
||||
and diarize audio files. The function includes arguments for specifying the audio files, model paths,
|
||||
output formats, and other options necessary for transcription.
|
||||
|
||||
This function can be executed from the command line to perform transcription tasks, providing a
|
||||
user-friendly way to access the AutoTranscribe class functionalities.
|
||||
"""
|
||||
|
||||
def str2bool(string):
|
||||
str2val = {"True": True, "False": False}
|
||||
if string in str2val:
|
||||
return str2val[string]
|
||||
else:
|
||||
raise ValueError(f"Expected one of {set(str2val.keys())}, got {string}")
|
||||
|
||||
parser = ArgumentParser(formatter_class = ArgumentDefaultsHelpFormatter)
|
||||
|
||||
group = parser.add_mutually_exclusive_group()
|
||||
|
||||
parser.add_argument("-f","--audio-files", nargs="+", type=str, default=None,
|
||||
help="List of audio files to transcribe.")
|
||||
|
||||
group.add_argument('--start-server', action='store_true',
|
||||
help='Start the Gradio app.')
|
||||
|
||||
parser.add_argument("--port", type=int, default= None,
|
||||
help="Port to run the Gradio app on. Defaults to 7860.")
|
||||
|
||||
parser.add_argument("--server-name", type=str, default= None,
|
||||
help="Name of the Gradio app. If empty 127.0.0.1 or 0.0.0.0 will be used.")
|
||||
|
||||
parser.add_argument("--whisper-model-name", default="medium",
|
||||
help="Name of the Whisper model to use.")
|
||||
|
||||
parser.add_argument("--whisper-model-directory", type=str, default= None,
|
||||
help="Path to save Whisper model files; defaults to ./models/whisper.")
|
||||
|
||||
parser.add_argument("--diarization-directory", type=str, default= None,
|
||||
help="Path to the diarization model directory.")
|
||||
|
||||
parser.add_argument("--hf-token", default= None, type=str,
|
||||
help="HuggingFace token for private model download.")
|
||||
|
||||
parser.add_argument("--inference-device",
|
||||
default="cuda" if is_available() else "cpu",
|
||||
help="Device to use for PyTorch inference.")
|
||||
|
||||
parser.add_argument("--num-threads", type=int, default=0,
|
||||
help="Number of threads used by torch for CPU inference; overrides MKL_NUM_THREADS/OMP_NUM_THREADS.")
|
||||
|
||||
parser.add_argument("--output-directory", "-o", type=str, default=".",
|
||||
help="Directory to save the transcription outputs.")
|
||||
|
||||
parser.add_argument("--output-format", "-of", type=str, default="txt",
|
||||
choices=["txt", "json", "md", "html"],
|
||||
help="Format of the output file; defaults to txt.")
|
||||
|
||||
parser.add_argument("--verbose-output", type=str2bool, default=True,
|
||||
help="Enable or disable progress and debug messages.")
|
||||
|
||||
parser.add_argument("--task", type=str, default= 'autotranscribe', # unifinished code
|
||||
choices=["autotranscribe", "diarization",
|
||||
"autotranscribe+translate", "translate", 'transcribe'],
|
||||
help="Choose to perform transcription, diarization, or translation. \
|
||||
If set to translate, the output will be translated to English.")
|
||||
|
||||
parser.add_argument("--language", type=str, default=None,
|
||||
choices=sorted(LANGUAGES.keys()) + sorted([k.title() for k in TO_LANGUAGE_CODE.keys()]),
|
||||
help="Language spoken in the audio. Specify None to perform language detection.")
|
||||
|
||||
args = parser.parse_args()
|
||||
|
||||
arg_dict = vars(args)
|
||||
|
||||
# configure output
|
||||
out_folder = arg_dict.pop("output_directory")
|
||||
os.makedirs(out_folder, exist_ok=True)
|
||||
|
||||
out_format = arg_dict.pop("output_format")
|
||||
|
||||
# seup server arg:
|
||||
start_server = arg_dict.pop("start_server")
|
||||
|
||||
task = arg_dict.pop("task")
|
||||
|
||||
if args.num_threads > 0:
|
||||
set_num_threads(arg_dict.pop("num_threads"))
|
||||
|
||||
class_kwargs = {'whisper_model' : arg_dict.pop("whisper_model_name"),
|
||||
'dia_model': arg_dict.pop("diarization_directory"),
|
||||
'use_auth_token' : arg_dict.pop("hf_token")}
|
||||
|
||||
if arg_dict["whisper_model_directory"]:
|
||||
class_kwargs["download_root"] = arg_dict.pop("whisper_model_directory")
|
||||
|
||||
model = AutoTranscribe(**class_kwargs)
|
||||
|
||||
|
||||
if arg_dict["audio_files"]:
|
||||
audio_files = arg_dict.pop("audio_files")
|
||||
|
||||
if task == "autotranscribe" or task == "autotranscribe+translate":
|
||||
for audio in audio_files:
|
||||
if task == "autotranscribe+translate":
|
||||
task = "translate"
|
||||
else:
|
||||
task = "transcribe"
|
||||
|
||||
out = model.autotranscribe(audio,task = task, language=arg_dict.pop("language"), verbose = arg_dict.pop("verbose_output"))
|
||||
basename = audio.split("/")[-1].split(".")[0]
|
||||
print(f'Saving {basename}.{out_format} to {out_folder}')
|
||||
out.save(os.path.join(out_folder, f"{basename}.{out_format}"))
|
||||
|
||||
elif task == "diarization":
|
||||
for audio in audio_files:
|
||||
if arg_dict.pop("verbose_output"):
|
||||
print(f"Verbose not implemented for diarization.")
|
||||
|
||||
out = model.diarization(audio)
|
||||
basename = audio.split("/")[-1].split(".")[0]
|
||||
path = os.path.join(out_folder, f"{basename}.{out_format}")
|
||||
|
||||
print(f'Saving {basename}.{out_format} to {out_folder}')
|
||||
|
||||
with open(path, "w") as f:
|
||||
json.dump(json.dumps(out, indent= 1), f)
|
||||
|
||||
elif task == "transcribe" or task == "translate":
|
||||
|
||||
for audio in audio_files:
|
||||
|
||||
out = model.transcribe(audio, task = task,
|
||||
language= arg_dict.pop("language"),
|
||||
verbose = arg_dict.pop("verbose_output"))
|
||||
basename = audio.split("/")[-1].split(".")[0]
|
||||
path = os.path.join(out_folder, f"{basename}.{out_format}")
|
||||
with open(path, "w") as f:
|
||||
f.write(out)
|
||||
|
||||
|
||||
if start_server: # unfinished code
|
||||
|
||||
gradio_Interface(model).queue().launch(server_port=args.port, server_name=args.server_name)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
cli()
|
||||
@@ -0,0 +1,247 @@
|
||||
"""
|
||||
Diarisation Class
|
||||
------------------
|
||||
|
||||
This class serves as the heart of the speaker diarization system, responsible for identifying
|
||||
and segmenting individual speakers from a given audio file. It leverages a pretrained model
|
||||
from pyannote.audio, providing an accessible interface for audio processing tasks such as
|
||||
speaker separation, and timestamping.
|
||||
|
||||
By encapsulating the complexities of the underlying model, it allows for straightforward
|
||||
integration into various applications, ranging from transcription services to voice assistants.
|
||||
|
||||
Available Classes:
|
||||
- Diariser: Main class for performing speaker diarization.
|
||||
Includes methods for loading models, processing audio files,
|
||||
and formatting the diarization output.
|
||||
|
||||
Constants:
|
||||
- TOKEN_PATH (str): Path to the Pyannote token.
|
||||
- PYANNOTE_DEFAULT_PATH (str): Default path to Pyannote models.
|
||||
- PYANNOTE_DEFAULT_CONFIG (str): Default configuration for Pyannote models.
|
||||
|
||||
Usage:
|
||||
from .diarisation import Diariser
|
||||
|
||||
model = Diariser.load_model(model="path/to/model/config.yaml")
|
||||
diarisation_output = model.diarization("path/to/audiofile.wav")
|
||||
"""
|
||||
|
||||
import os
|
||||
from pathlib import Path
|
||||
from typing import TypeVar, Union
|
||||
|
||||
from pyannote.audio import Pipeline
|
||||
from pyannote.audio.pipelines.speaker_diarization import SpeakerDiarization
|
||||
from torch import Tensor
|
||||
|
||||
from .misc import PYANNOTE_DEFAULT_PATH, PYANNOTE_DEFAULT_CONFIG
|
||||
Annotation = TypeVar('Annotation')
|
||||
|
||||
TOKEN_PATH = os.path.join(os.path.dirname(
|
||||
os.path.realpath(__file__)), '.pyannotetoken')
|
||||
|
||||
class Diariser:
|
||||
"""
|
||||
Handles the diarization process of an audio file using a pretrained model
|
||||
from pyannote.audio. Diarization is the task of determining "who spoke when."
|
||||
|
||||
Args:
|
||||
model: The pretrained model to use for diarization.
|
||||
"""
|
||||
|
||||
def __init__(self, model) -> None:
|
||||
|
||||
self.model = model
|
||||
|
||||
def diarization(self, audiofile : Union[str, Tensor, dict] ,
|
||||
*args, **kwargs) -> Annotation:
|
||||
"""
|
||||
Perform speaker diarization on the provided audio file,
|
||||
effectively separating different speakers
|
||||
and providing a timestamp for each segment.
|
||||
|
||||
Args:
|
||||
audiofile: The path to the audio file or a torch.Tensor
|
||||
containing the audio data.
|
||||
args: Additional arguments for the diarization model.
|
||||
kwargs: Additional keyword arguments for the diarization model.
|
||||
|
||||
Returns:
|
||||
dict: A dictionary containing speaker names,
|
||||
segments, and other information related
|
||||
to the diarization process.
|
||||
"""
|
||||
kwargs = self._get_diarisation_kwargs(**kwargs)
|
||||
|
||||
diarization = self.model(audiofile,*args, **kwargs)
|
||||
|
||||
out = self.format_diarization_output(diarization)
|
||||
|
||||
return out
|
||||
|
||||
@staticmethod
|
||||
def format_diarization_output(dia : Annotation) -> dict:
|
||||
"""
|
||||
Formats the raw diarization output into a more usable structure for this project.
|
||||
|
||||
Args:
|
||||
dia: Raw diarization output.
|
||||
|
||||
Returns:
|
||||
dict: A structured representation of the diarization, with speaker names
|
||||
as keys and a list of tuples representing segments as values.
|
||||
"""
|
||||
|
||||
dia_list = list(dia.itertracks(yield_label=True))
|
||||
diarization_output = {"speakers": [], "segments": []}
|
||||
|
||||
normalized_output = []
|
||||
index_start_speaker = 0
|
||||
index_end_speaker = 0
|
||||
current_speaker = str()
|
||||
|
||||
###
|
||||
# Sometimes two consecutive speakers are the same
|
||||
# This loop removes these duplicates
|
||||
###
|
||||
|
||||
if len(dia_list) == 1:
|
||||
normalized_output.append([0, 0, dia_list[0][2]])
|
||||
else:
|
||||
|
||||
for i, (_, _, speaker) in enumerate(dia_list):
|
||||
|
||||
if i == 0:
|
||||
current_speaker = speaker
|
||||
|
||||
if speaker != current_speaker:
|
||||
|
||||
index_end_speaker = i - 1
|
||||
|
||||
normalized_output.append([index_start_speaker,
|
||||
index_end_speaker,
|
||||
current_speaker])
|
||||
|
||||
index_start_speaker = i
|
||||
current_speaker = speaker
|
||||
|
||||
|
||||
if i == len(dia_list) - 1:
|
||||
|
||||
index_end_speaker = i
|
||||
|
||||
normalized_output.append([index_start_speaker,
|
||||
index_end_speaker,
|
||||
current_speaker])
|
||||
|
||||
for outp in normalized_output:
|
||||
start = dia_list[outp[0]][0].start
|
||||
end = dia_list[outp[1]][0].end
|
||||
|
||||
diarization_output["segments"].append([start, end])
|
||||
diarization_output["speakers"].append(outp[2])
|
||||
return diarization_output
|
||||
|
||||
@staticmethod
|
||||
def _get_token():
|
||||
"""
|
||||
Retrieves the Huggingface token from a local file. This token is required
|
||||
for accessing certain online resources.
|
||||
|
||||
Raises:
|
||||
ValueError: If the token is not found.
|
||||
|
||||
Returns:
|
||||
str: The Huggingface token.
|
||||
"""
|
||||
|
||||
if os.path.exists(TOKEN_PATH):
|
||||
with open(TOKEN_PATH, 'r', encoding="utf-8") as file:
|
||||
token = file.read()
|
||||
else:
|
||||
raise ValueError('No token found.' \
|
||||
'Please create a token at https://huggingface.co/settings/token' \
|
||||
f'and save it in a file called {TOKEN_PATH}')
|
||||
return token
|
||||
|
||||
@staticmethod
|
||||
def _save_token(token):
|
||||
"""
|
||||
Saves the provided Huggingface token to a local file. This facilitates future
|
||||
access to online resources without needing to repeatedly authenticate.
|
||||
|
||||
Args:
|
||||
token: The Huggingface token to save.
|
||||
"""
|
||||
with open(TOKEN_PATH, 'w', encoding="utf-8") as file:
|
||||
file.write(token)
|
||||
|
||||
@classmethod
|
||||
def load_model(cls,
|
||||
model: str = PYANNOTE_DEFAULT_CONFIG,
|
||||
use_auth_token: str = None,
|
||||
cache_token: bool = True,
|
||||
cache_dir: Union[Path, str] = PYANNOTE_DEFAULT_PATH,
|
||||
hparams_file: Union[str, Path] = None,
|
||||
*args, **kwargs
|
||||
) -> Pipeline:
|
||||
|
||||
"""
|
||||
Loads a pretrained model from pyannote.audio,
|
||||
either from a local cache or online repository.
|
||||
|
||||
Args:
|
||||
model: Path or identifier for the pyannote model.
|
||||
default: /models/pyannote/speaker_diarization/config.yaml
|
||||
token: Optional HUGGINGFACE_TOKEN for authenticated access.
|
||||
cache_token: Whether to cache the token locally for future use.
|
||||
cache_dir: Directory for caching models.
|
||||
hparams_file: Path to a YAML file containing hyperparameters.
|
||||
args: Additional arguments only to avoid errors.
|
||||
kwargs: Additional keyword arguments only to avoid errors.
|
||||
|
||||
Returns:
|
||||
Pipeline: A pyannote.audio Pipeline object, encapsulating the loaded model.
|
||||
"""
|
||||
|
||||
if cache_token and use_auth_token is not None:
|
||||
cls._save_token(use_auth_token)
|
||||
|
||||
if not os.path.exists(model) and use_auth_token is None:
|
||||
use_auth_token = cls._get_token()
|
||||
model = 'pyannote/speaker-diarization'
|
||||
elif not os.path.exists(model) and use_auth_token is not None:
|
||||
model = 'pyannote/speaker-diarization'
|
||||
|
||||
_model = Pipeline.from_pretrained(model,
|
||||
use_auth_token = use_auth_token,
|
||||
cache_dir = cache_dir,
|
||||
hparams_file = hparams_file,)
|
||||
|
||||
if _model is None:
|
||||
raise ValueError('Unable to load model either from local cache' \
|
||||
'or from huggingface.co models. Please check your token' \
|
||||
'or your local model path')
|
||||
|
||||
return cls(_model)
|
||||
|
||||
@staticmethod
|
||||
def _get_diarisation_kwargs(**kwargs) -> dict:
|
||||
"""
|
||||
Validates and extracts the keyword arguments for the pyannote diarization model.
|
||||
|
||||
Ensures that the provided keyword arguments match the expected parameters,
|
||||
filtering out any invalid or unnecessary arguments.
|
||||
|
||||
Returns:
|
||||
dict: A dictionary containing the validated keyword arguments.
|
||||
"""
|
||||
_possible_kwargs = SpeakerDiarization.apply.__code__.co_varnames
|
||||
|
||||
diarisation_kwargs = {k: v for k, v in kwargs.items() if k in _possible_kwargs}
|
||||
|
||||
return diarisation_kwargs
|
||||
|
||||
def __repr__(self):
|
||||
return f"Diarisation(model={self.model})"
|
||||
@@ -0,0 +1,41 @@
|
||||
import os
|
||||
import yaml
|
||||
from pyannote.audio.core.model import CACHE_DIR as PYANNOTE_CACHE_DIR
|
||||
|
||||
CACHE_DIR = os.getenv(
|
||||
"AUTOT_CACHE",
|
||||
os.path.expanduser("~/.cache/torch/models"),
|
||||
)
|
||||
|
||||
if CACHE_DIR != PYANNOTE_CACHE_DIR:
|
||||
os.environ["PYANNOTE_CACHE"] = os.path.join(CACHE_DIR, "pyannote")
|
||||
|
||||
WHISPER_DEFAULT_PATH = os.path.join(CACHE_DIR, "whisper")
|
||||
PYANNOTE_DEFAULT_PATH = os.path.join(CACHE_DIR, "pyannote")
|
||||
PYANNOTE_DEFAULT_CONFIG = os.path.join(PYANNOTE_DEFAULT_PATH, "config.yaml")
|
||||
|
||||
|
||||
def config_diarization_yaml(file_path: str, path_to_segmentation: str = None) -> None:
|
||||
"""Configure diarization pipeline from a YAML file.
|
||||
|
||||
This function updates the YAML file to use the given segmentation model
|
||||
offline, and avoids manual file manipulation.
|
||||
|
||||
Args:
|
||||
file_path (str): Path to the YAML file.
|
||||
path_to_segmentation (str, optional): Optional path to the segmentation model.
|
||||
|
||||
Raises:
|
||||
FileNotFoundError: If the segmentation model file is not found.
|
||||
"""
|
||||
with open(file_path, "r") as stream:
|
||||
yml = yaml.safe_load(stream)
|
||||
|
||||
segmentation_path = path_to_segmentation or os.path.join(PYANNOTE_DEFAULT_PATH, "pytorch_model.bin")
|
||||
yml["pipeline"]["params"]["segmentation"] = segmentation_path
|
||||
|
||||
if not os.path.exists(segmentation_path):
|
||||
raise FileNotFoundError(f"Segmentation model not found at {segmentation_path}")
|
||||
|
||||
with open(file_path, "w") as stream:
|
||||
yaml.dump(yml, stream)
|
||||
@@ -0,0 +1,179 @@
|
||||
"""
|
||||
Transcriber Module
|
||||
------------------
|
||||
|
||||
This module provides the Transcriber class, a comprehensive tool for working with Whisper models.
|
||||
The Transcriber class offers functionalities such as loading different Whisper models, transcribing audio files,
|
||||
and saving transcriptions to text files. It acts as an interface between various Whisper models and the user,
|
||||
simplifying the process of audio transcription.
|
||||
|
||||
Main Features:
|
||||
- Loading different sizes and versions of Whisper models.
|
||||
- Transcribing audio in various formats including str, Tensor, and nparray.
|
||||
- Saving the transcriptions to the specified paths.
|
||||
- Adaptable to various language specifications.
|
||||
- Options to control the verbosity of the transcription process.
|
||||
|
||||
Constants:
|
||||
WHISPER_DEFAULT_PATH: Default path for downloading and loading Whisper models.
|
||||
|
||||
Usage:
|
||||
>>> from your_package import Transcriber
|
||||
>>> transcriber = Transcriber.load_model(model="medium")
|
||||
>>> transcript = transcriber.transcribe(audio="path/to/audio.wav")
|
||||
>>> transcriber.save_transcript(transcript, "path/to/save.txt")
|
||||
"""
|
||||
|
||||
from whisper import Whisper, load_model
|
||||
from typing import TypeVar , Union , Optional
|
||||
from torch import Tensor, device
|
||||
from numpy import ndarray
|
||||
|
||||
|
||||
from .misc import WHISPER_DEFAULT_PATH
|
||||
whisper = TypeVar('whisper')
|
||||
|
||||
|
||||
|
||||
|
||||
class Transcriber:
|
||||
"""
|
||||
Transcriber Class
|
||||
-----------------
|
||||
|
||||
The Transcriber class serves as a wrapper around Whisper models for efficient audio
|
||||
transcription. By encapsulating the intricacies of loading models, processing audio,
|
||||
and saving transcripts, it offers an easy-to-use interface
|
||||
for users to transcribe audio files.
|
||||
|
||||
Attributes:
|
||||
model (whisper): The Whisper model used for transcription.
|
||||
|
||||
Methods:
|
||||
transcribe: Transcribes the given audio file.
|
||||
save_transcript: Saves the transcript to a file.
|
||||
load_model: Loads a specific Whisper model.
|
||||
_get_whisper_kwargs: Private method to get valid keyword arguments for the whisper model.
|
||||
|
||||
Examples:
|
||||
>>> transcriber = Transcriber.load_model(model="medium")
|
||||
>>> transcript = transcriber.transcribe(audio="path/to/audio.wav")
|
||||
>>> transcriber.save_transcript(transcript, "path/to/save.txt")
|
||||
|
||||
Note:
|
||||
The class supports various sizes and versions of Whisper models. Please refer to
|
||||
the load_model method for available options.
|
||||
"""
|
||||
def __init__(self, model: whisper ) -> None:
|
||||
"""
|
||||
Initialize the Transcriber class with a Whisper model.
|
||||
|
||||
Args:
|
||||
model (whisper): The Whisper model to use for transcription.
|
||||
"""
|
||||
self.model = model
|
||||
|
||||
def transcribe(self, audio : Union[str, Tensor, ndarray] ,
|
||||
*args, **kwargs) -> str:
|
||||
"""
|
||||
Transcribe an audio file.
|
||||
|
||||
Args:
|
||||
audio (Union[str, Tensor, nparray]): The audio file to transcribe.
|
||||
*args: Additional arguments.
|
||||
**kwargs: Additional keyword arguments,
|
||||
such as the language of the audio file.
|
||||
|
||||
Returns:
|
||||
str: The transcript as a string.
|
||||
"""
|
||||
|
||||
kwargs = self._get_whisper_kwargs(**kwargs)
|
||||
|
||||
if "verbose" not in kwargs:
|
||||
kwargs["verbose"] = False
|
||||
|
||||
result = self.model.transcribe(audio, *args, **kwargs)
|
||||
return result["text"]
|
||||
|
||||
@staticmethod
|
||||
def save_transcript(transcript : str , save_path : str) -> None:
|
||||
"""
|
||||
Save a transcript to a file.
|
||||
|
||||
Args:
|
||||
transcript (str): The transcript as a string.
|
||||
save_path (str): The path to save the transcript.
|
||||
|
||||
Returns:
|
||||
None
|
||||
"""
|
||||
|
||||
with open(save_path, 'w') as f:
|
||||
f.write(transcript)
|
||||
|
||||
print(f'Transcript saved to {save_path}')
|
||||
|
||||
@classmethod
|
||||
def load_model(cls,
|
||||
model: str = "medium",
|
||||
download_root: str = WHISPER_DEFAULT_PATH,
|
||||
device: Optional[Union[str, device]] = None,
|
||||
in_memory: bool = False,
|
||||
*args, **kwargs
|
||||
) -> 'Transcriber':
|
||||
"""
|
||||
Load whisper model.
|
||||
|
||||
Args:
|
||||
model (str): Whisper model. Available models include:
|
||||
- 'tiny.en'
|
||||
- 'tiny'
|
||||
- 'base.en'
|
||||
- 'base'
|
||||
- 'small.en'
|
||||
- 'small'
|
||||
- 'medium.en'
|
||||
- 'medium'
|
||||
- 'large-v1'
|
||||
- 'large-v2'
|
||||
- 'large'
|
||||
|
||||
download_root (str, optional): Path to download the model.
|
||||
Defaults to WHISPER_DEFAULT_PATH.
|
||||
|
||||
device (Optional[Union[str, torch.device]], optional):
|
||||
Device to load model on. Defaults to None.
|
||||
in_memory (bool, optional): Whether to load model in memory.
|
||||
Defaults to False.
|
||||
args: Additional arguments only to avoid errors.
|
||||
kwargs: Additional keyword arguments only to avoid errors.
|
||||
|
||||
Returns:
|
||||
Transcriber: A Transcriber object initialized with the specified model.
|
||||
"""
|
||||
|
||||
_model = load_model(model, download_root=download_root,
|
||||
device=device, in_memory=in_memory)
|
||||
|
||||
return cls(_model)
|
||||
|
||||
@staticmethod
|
||||
def _get_whisper_kwargs(**kwargs) -> dict:
|
||||
"""
|
||||
Get kwargs for whisper model. Ensure that kwargs are valid.
|
||||
|
||||
Returns:
|
||||
dict: Keyword arguments for whisper model.
|
||||
"""
|
||||
_possible_kwargs = Whisper.transcribe.__code__.co_varnames
|
||||
|
||||
whisper_kwargs = {k: v for k, v in kwargs.items() if k in _possible_kwargs}
|
||||
|
||||
if (task := kwargs.get("task")):
|
||||
whisper_kwargs["task"] = task
|
||||
|
||||
return whisper_kwargs
|
||||
|
||||
def __repr__(self) -> str:
|
||||
return f"Transcriber(model={self.model})"
|
||||
@@ -0,0 +1,303 @@
|
||||
import json
|
||||
import time
|
||||
from traceback import print_stack
|
||||
|
||||
|
||||
from typing import Union
|
||||
|
||||
ALPHABET = [*"abcdefghijklmnopqrstuvwxyz"]
|
||||
|
||||
|
||||
class Transcript:
|
||||
"""
|
||||
Class for storing transcript data, including speaker information and text segments,
|
||||
and exporting it to various file formats such as JSON, HTML, and LaTeX.
|
||||
"""
|
||||
|
||||
def __init__(self, transcript: dict) -> None:
|
||||
"""
|
||||
Initializes the Transcript object with the given transcript data.
|
||||
|
||||
Args:
|
||||
transcript (dict): A dictionary containing the formatted transcript string.
|
||||
Keys should correspond to segment IDs, and values should
|
||||
contain speaker and segment information.
|
||||
"""
|
||||
|
||||
self.transcript = transcript
|
||||
self.speakers = self._extract_speakers()
|
||||
self.segments = self._extract_segments()
|
||||
self.annotation = {}
|
||||
|
||||
def annotate(self, *args, **kwargs) -> dict:
|
||||
"""
|
||||
Annotates the transcript to associate specific names with speakers.
|
||||
|
||||
Args:
|
||||
args (list): List of speaker names. These will be mapped sequentially to the speakers.
|
||||
kwargs (dict): Dictionary with speaker names as keys and list of segments as values.
|
||||
|
||||
Returns:
|
||||
dict: Dictionary with speaker names as keys and list of segments as values.
|
||||
|
||||
Raises:
|
||||
ValueError: If the number of speaker names does not match the number
|
||||
of speakers, or if an unknown speaker is found.
|
||||
"""
|
||||
|
||||
annotations = {}
|
||||
if args and len(args) != len(self.speakers):
|
||||
raise ValueError("Number of speaker names does not match number of speakers")
|
||||
|
||||
if args:
|
||||
for arg, speaker in zip(args, sorted(self.speakers)):
|
||||
|
||||
annotations[speaker] = arg
|
||||
|
||||
invalid_speakers = set(kwargs.keys()) - set(self.speakers)
|
||||
if invalid_speakers:
|
||||
raise ValueError(f"These keys are not speakers: {', '.join(invalid_speakers)}")
|
||||
|
||||
annotations.update({key: kwargs[key] for key in self.speakers if key in kwargs})
|
||||
|
||||
self.annotation = annotations
|
||||
|
||||
return self
|
||||
|
||||
def _extract_speakers(self) -> list:
|
||||
"""
|
||||
Extracts the unique speaker names from the transcript.
|
||||
|
||||
Returns:
|
||||
list: List of unique speaker names in the transcript.
|
||||
"""
|
||||
|
||||
return list(set([self.transcript[id]["speakers"] for id in self.transcript]))
|
||||
|
||||
def _extract_segments(self) -> list:
|
||||
"""
|
||||
Extracts all the text segments from the transcript.
|
||||
|
||||
Returns:
|
||||
list: List of segments, where each segment is represented
|
||||
by the starting and ending times.
|
||||
"""
|
||||
return [self.transcript[id]["segments"] for id in self.transcript]
|
||||
|
||||
def __str__(self) -> str:
|
||||
"""
|
||||
Converts the transcript to a string representation.
|
||||
|
||||
Returns:
|
||||
str: String representation of the transcript, including speaker names and
|
||||
time stamps for each segment.
|
||||
"""
|
||||
fstring = ""
|
||||
|
||||
for _id in self.transcript:
|
||||
seq = self.transcript[_id]
|
||||
|
||||
if self.annotation:
|
||||
speaker = self.annotation[seq["speakers"]]
|
||||
else:
|
||||
speaker = seq["speakers"]
|
||||
|
||||
segm = seq["segments"]
|
||||
sseg = time.strftime("%H:%M:%S",time.gmtime(segm[0]))
|
||||
eseg = time.strftime("%H:%M:%S",time.gmtime(segm[1]))
|
||||
|
||||
fstring += f"{speaker} ({sseg} ; {eseg}):\t{seq['text']}\n"
|
||||
|
||||
return fstring
|
||||
|
||||
def __repr__(self) -> str:
|
||||
"""Return a string representation of the Transcript object.
|
||||
|
||||
Returns:
|
||||
str: A string that provides an informative description of the object.
|
||||
"""
|
||||
return f"Transcript(speakers = {self.speakers},"\
|
||||
f"segments = {self.segments}, annotation = {self.annotation})"
|
||||
|
||||
def get_dict(self) -> dict:
|
||||
"""
|
||||
Get transcript as dict
|
||||
|
||||
:return: transcript as dict
|
||||
:rtype: dict
|
||||
"""
|
||||
|
||||
return self.transcript
|
||||
|
||||
def get_json(self, *args, use_annotation : bool = True, **kwargs) -> str:
|
||||
"""
|
||||
Get transcript as json string
|
||||
:return: transcript as json string
|
||||
:rtype: str
|
||||
"""
|
||||
if "indent" not in kwargs:
|
||||
kwargs["indent"] = 3
|
||||
|
||||
if use_annotation and self.annotation:
|
||||
for _id in self.transcript:
|
||||
seq = self.transcript[_id]
|
||||
seq["speakers"] = self.annotation[seq["speakers"]]
|
||||
|
||||
return json.dumps(self.transcript, *args, **kwargs)
|
||||
|
||||
def get_html(self) -> str:
|
||||
"""
|
||||
Get transcript as html string
|
||||
|
||||
:return: transcript as html string
|
||||
:rtype: str
|
||||
"""
|
||||
html = "<p>" + self.__str__().replace("\n", "<br>") + "</p>"
|
||||
html = "<html><body>" + html + "</body></html>"
|
||||
html = html.replace("\t", " ")
|
||||
|
||||
return html
|
||||
|
||||
def get_md(self) -> str:
|
||||
"""Get transcript as Markdown string, using HTML formatting.
|
||||
|
||||
Returns:
|
||||
str: Transcript as a Markdown string.
|
||||
"""
|
||||
return self.get_html()
|
||||
|
||||
def get_tex(self) -> str:
|
||||
"""Get transcript as LaTeX string. If no annotations are present, the speakers will
|
||||
be annotated with the first letters of the alphabet.
|
||||
|
||||
Returns:
|
||||
str: Transcript as LaTeX string.
|
||||
"""
|
||||
if not self.annotation:
|
||||
|
||||
self.annotate(*ALPHABET[:len(self.speakers)])
|
||||
|
||||
fstring ="\\begin{drama}"
|
||||
|
||||
for speaker in self.speakers:
|
||||
|
||||
fstring += "\n\t\\Character{"+ str(self.annotation[speaker]) + "}" \
|
||||
"{"+ str(self.annotation[speaker]) + "}"
|
||||
|
||||
for id in self.transcript:
|
||||
seq = self.transcript[id]
|
||||
speaker = self.annotation[seq["speakers"]]
|
||||
fstring += f"\n\\{speaker}speaks:\n{seq['text']}"
|
||||
|
||||
fstring += "\n\\end{drama}"
|
||||
|
||||
return fstring
|
||||
|
||||
|
||||
def to_json(self,path, *args, **kwargs) -> None:
|
||||
"""Save transcript as json file
|
||||
|
||||
Args:
|
||||
path (str): path to save file
|
||||
"""
|
||||
with open(path, "w") as f:
|
||||
json.dump(self.transcript, f, *args, **kwargs)
|
||||
|
||||
def to_txt(self, path: str) -> None:
|
||||
"""Save transcript as a LaTeX file (placeholder function, implementation needed).
|
||||
|
||||
Args:
|
||||
path (str): Path to save the LaTeX file.
|
||||
"""
|
||||
|
||||
with open(path, "w") as f:
|
||||
f.write(self.__str__())
|
||||
|
||||
def to_md(self, path: str) -> None:
|
||||
"""Get transcript as Markdown string, using HTML formatting.
|
||||
|
||||
Returns:
|
||||
str: Transcript as a Markdown string.
|
||||
"""
|
||||
return self.to_html(path)
|
||||
|
||||
def to_html(self, path: str) -> None:
|
||||
"""
|
||||
Save transcript as html file
|
||||
|
||||
:param path: path to save file
|
||||
:type path: str
|
||||
"""
|
||||
|
||||
with open(path, "w") as file:
|
||||
file.write(self.get_html())
|
||||
|
||||
def to_tex(self, path: str) -> None:
|
||||
"""Save transcript as a LaTeX file (placeholder function, implementation needed).
|
||||
|
||||
Args:
|
||||
path (str): Path to save the LaTeX file.
|
||||
"""
|
||||
pass
|
||||
|
||||
def to_pdf(self, path: str) -> None:
|
||||
"""Save transcript as a PDF file (placeholder function, implementation needed).
|
||||
|
||||
Args:
|
||||
path (str): Path to save the PDF file.
|
||||
"""
|
||||
pass
|
||||
|
||||
def save(self, path: str, *args, **kwargs) -> None:
|
||||
"""Save transcript to file with the given path and file format.
|
||||
|
||||
This method can save the transcript in various formats including JSON, TXT,
|
||||
MD, HTML, TEX, and PDF. The file format is determined by the extension of
|
||||
the path.
|
||||
|
||||
Args:
|
||||
path (str): Path to save the file, including the desired file extension.
|
||||
*args: Additional positional arguments to be passed to the specific save methods.
|
||||
**kwargs: Additional keyword arguments to be passed to the specific save methods.
|
||||
|
||||
Raises:
|
||||
ValueError: If the file format specified in the path is unknown.
|
||||
"""
|
||||
|
||||
if path.endswith(".json"):
|
||||
self.to_json(path, *args, **kwargs)
|
||||
elif path.endswith(".txt"):
|
||||
self.to_txt(path, *args, **kwargs)
|
||||
elif path.endswith(".md"):
|
||||
self.to_md(path, *args, **kwargs)
|
||||
elif path.endswith(".html"):
|
||||
self.to_html(path, *args, **kwargs)
|
||||
elif path.endswith(".tex"):
|
||||
self.to_tex(path, *args, **kwargs)
|
||||
elif path.endswith(".pdf"):
|
||||
self.to_pdf(path, *args, **kwargs)
|
||||
else:
|
||||
raise ValueError("Unknown file format")
|
||||
|
||||
@classmethod
|
||||
def from_json(cls, json: Union[dict, str]) -> "Transcript":
|
||||
"""Load transcript from json file
|
||||
|
||||
Args:
|
||||
path (str): path to json file
|
||||
|
||||
Returns:
|
||||
Transcript: Transcript object
|
||||
"""
|
||||
if isinstance(json, dict):
|
||||
return cls(json)
|
||||
else:
|
||||
try:
|
||||
transcript = json.loads(json)
|
||||
except:
|
||||
with open(json, "r") as f:
|
||||
transcript = json.load(f)
|
||||
|
||||
return cls(transcript)
|
||||
|
||||
|
||||
@@ -0,0 +1,69 @@
|
||||
import os
|
||||
import subprocess as sp
|
||||
|
||||
MAJOR = 0
|
||||
MINOR = 1
|
||||
MICRO = 0
|
||||
MICRO_POST = 0
|
||||
ISRELEASED = False
|
||||
VERSION = '%d.%d.%d.%d' % (MAJOR, MINOR, MICRO, MICRO_POST)
|
||||
|
||||
# Return the git revision as a string
|
||||
# taken from numpy/numpy
|
||||
def git_version():
|
||||
def _minimal_ext_cmd(cmd):
|
||||
# construct minimal environment
|
||||
env = {}
|
||||
for k in ['SYSTEMROOT', 'PATH', 'HOME']:
|
||||
v = os.environ.get(k)
|
||||
if v is not None:
|
||||
env[k] = v
|
||||
|
||||
# LANGUAGE is used on win32
|
||||
env['LANGUAGE'] = 'C'
|
||||
env['LANG'] = 'C'
|
||||
env['LC_ALL'] = 'C'
|
||||
|
||||
out = sp.Popen(cmd, stdout=sp.PIPE, stderr=sp.PIPE, env=env).communicate()[0]
|
||||
return out
|
||||
|
||||
try:
|
||||
out = _minimal_ext_cmd(['git', 'rev-parse', 'HEAD'])
|
||||
GIT_REVISION = out.strip().decode('ascii')
|
||||
except OSError:
|
||||
GIT_REVISION = "Unknown"
|
||||
|
||||
return GIT_REVISION
|
||||
|
||||
def _get_git_version():
|
||||
cwd = os.getcwd()
|
||||
|
||||
# go to the main directory
|
||||
fdir = os.path.dirname(os.path.abspath(__file__))
|
||||
maindir = os.path.abspath(os.path.join(fdir, ".."))
|
||||
# maindir = fdir # os.path.join(fdir, "..")
|
||||
os.chdir(maindir)
|
||||
|
||||
# get git version
|
||||
res = git_version()
|
||||
|
||||
# restore the cwd
|
||||
os.chdir(cwd)
|
||||
return res
|
||||
|
||||
def get_version(build_version=False):
|
||||
if ISRELEASED:
|
||||
return VERSION
|
||||
|
||||
# unreleased version
|
||||
GIT_REVISION = _get_git_version()
|
||||
|
||||
if build_version:
|
||||
import datetime as dt
|
||||
date = dt.date.strftime(dt.datetime.now(), "%Y%m%d%H%M%S")
|
||||
return VERSION + ".dev" + date
|
||||
else:
|
||||
return VERSION + ".dev0+" + GIT_REVISION[:7]
|
||||
|
||||
|
||||
|
||||
Reference in New Issue
Block a user