Merge pull request #37 from JSchmie/develop_gradio_app

Develop gradio app
This commit is contained in:
Jacob Schmieder
2024-02-12 12:48:41 +01:00
committed by GitHub
13 changed files with 149 additions and 919 deletions
+6
View File
@@ -0,0 +1,6 @@
scraibe/*__pycache__
scraibe/app/*__pycache__
scraibe/.pyannotetoken
.git
.gitignore
.github
+6
View File
@@ -0,0 +1,6 @@
transcibe.py
scraibe/*__pycache__
scraibe/app/*__pycache__
scraibe/.pyannotetoken
-3
View File
@@ -7,9 +7,6 @@ from .diarisation import *
from .version import get_version as _get_version
from .misc import *
from .app.gradio_app import *
from .app.qtfaststart import *
from .cli import *
__version__ = _get_version()
-2
View File
@@ -1,2 +0,0 @@
from .qtfaststart import *
from .gradio_app import *
-441
View File
@@ -1,441 +0,0 @@
"""
Gradio Audio Transcription App.
--------------------------------
This module provides an interface to transcribe audio files using the
Scraibe model. Users can either upload an audio file or record their speech
live for transcription. The application supports multiple languages and provides
options to specify the number of speakers and the language of the audio.
Attributes:
LANGUAGES (list): A list of supported languages for transcription.
Usage:
Run this script to start the Gradio web interface for audio transcription.
"""
"""
Gradio Audio Transcription App.
--------------------------------
This module provides an interface to transcribe audio files using the
Scraibe model. Users can either upload an audio file or record their speech
live for transcription. The application supports multiple languages and provides
options to specify the number of speakers and the language of the audio.
Attributes:
LANGUAGES (list): A list of supported languages for transcription.
Usage:
Run this script to start the Gradio web interface for audio transcription.
"""
import json
import os
import gradio as gr
from tqdm import tqdm
from scraibe import Scraibe, Transcript
theme = gr.themes.Soft(
primary_hue="green",
secondary_hue='orange',
neutral_hue="gray",
)
LANGUAGES = [
"Afrikaans", "Arabic", "Armenian", "Azerbaijani", "Belarusian",
"Bosnian", "Bulgarian", "Catalan", "Chinese", "Croatian",
"Czech", "Danish", "Dutch", "English", "Estonian",
"Finnish", "French", "Galician", "German", "Greek",
"Hebrew", "Hindi", "Hungarian", "Icelandic", "Indonesian",
"Italian", "Japanese", "Kannada", "Kazakh", "Korean",
"Latvian", "Lithuanian", "Macedonian", "Malay", "Marathi",
"Maori", "Nepali", "Norwegian", "Persian", "Polish",
"Portuguese", "Romanian", "Russian", "Serbian", "Slovak",
"Slovenian", "Spanish", "Swahili", "Swedish", "Tagalog",
"Tamil", "Thai", "Turkish", "Ukrainian", "Urdu",
"Vietnamese", "Welsh"
]
CURRENT_PATH = os.path.dirname(os.path.realpath(__file__))
class GradioTranscriptionInterface:
"""
Interface handling the interaction between Gradio UI and the Audio Transcription system.
"""
def __init__(self, model: Scraibe):
"""
Initializes the GradioTranscriptionInterface with a transcription model.
Args:
model (Scraibe): Model responsible for audio transcription tasks.
"""
self.model = model
def auto_transcribe(self, source,
num_speakers : int,
translation : bool,
language : str):
"""
Shortcut method for the Scraibe task.
Returns:
tuple: Transcribed text (str), JSON output (dict)
"""
kwargs = {
"num_speakers": num_speakers if num_speakers != 0 else None,
"language": language if language != "None" else None,
"task": 'translate' if translation else None
}
if isinstance(source, str):
try:
result = self.model.autotranscribe(source, **kwargs)
except ValueError:
raise gr.Error("Couldn't detect any speech in the provided audio. \
Please try again!")
return str(result), result.get_json()
elif isinstance(source, list):
source_names = [s.split("/")[-1] for s in source]
result = []
for s in tqdm(source, total=len(source),desc = "Transcribing audio files"):
try:
res = self.model.autotranscribe(s, **kwargs)
except ValueError:
_name = s.split("/")[-1]
res = f"NO TRANSCRIPT FOUND FOR {_name}"
gr.Warning(f"Couldn't detect any speech in {_name} will skip this file.")
result.append(res)
out = ''
out_dict = {}
for i, r in enumerate(result):
out += f"TRANSCRIPT FOR {source_names[i]}:\n\n"
out += str(r)
out += "\n\n"
if isinstance(r, str):
out_dict[source_names[i]] = r
else:
out_dict[source_names[i]] = r.get_dict()
return out, json.dumps(out_dict, indent=4)
else:
raise gr.Error("Please provide a valid audio file.")
def transcribe(self, source, translation, language):
"""
Shortcut method for the Transcribe task.
Returns:
str: Transcribed text.
"""
kwargs = {
"language": language if language != "None" else None,
"task": 'translate' if translation == "Yes" else None
}
if isinstance(source, str):
result = self.model.transcribe(source, **kwargs)
return str(result)
elif isinstance(source, list):
source_names = [s.split("/")[-1] for s in source]
result = []
for s in tqdm(source, total=len(source),desc = "Transcribing audio files"):
res = self.model.transcribe(s, **kwargs)
result.append(res)
out = ''
for i, res in enumerate(result):
out += f"TRANSCRIPT FOR {source_names[i]}:\n\n"
out += str(res)
out += "\n\n"
return out
else:
raise gr.Error("Please provide a valid audio file.")
def perform_diarisation(self, source, num_speakers):
"""
Shortcut method for the Diarisation task.
Returns:
str: JSON output of diarisation result.
"""
kwargs = {
"num_speakers": num_speakers if num_speakers != 0 else None,
}
if isinstance(source, str):
try:
result = self.model.diarization(source, **kwargs)
except ValueError:
raise gr.Error("Couldn't detect any speech in the provided audio. \
Please try again!")
return json.dumps(result, indent=2)
elif isinstance(source, list):
source_names = [s.split("/")[-1] for s in source]
result = []
for s in tqdm(source, total=len(source),desc = "Performing diarisation"):
try:
res = self.model.diarization(s, **kwargs)
except ValueError:
res = f"NO DIARISATION FOUND FOR {s}"
gr.Warning(f"Couldn't detect any speech in {s} will skip this file.")
result.append(res)
out = {}
for i, res in enumerate(result):
out[source_names[i]] = res
return json.dumps(out, indent=4)
else:
gr.Error("Please provide a valid audio file.")
####
# Gradio Interface
####
def gradio_Interface(model : Scraibe = None):
if model is None:
model = Scraibe()
pipe = GradioTranscriptionInterface(model)
def select_task(choice):
if choice == 'Auto Transcribe':
return (gr.update(visible = True),
gr.update(visible = True),
gr.update(visible = True))
elif choice == 'Transcribe':
return (gr.update(visible = False),
gr.update(visible = True),
gr.update(visible = True))
elif choice == 'Diarisation':
return (gr.update(visible = True),
gr.update(visible = False),
gr.update(visible = False))
def select_origin(choice):
if choice == "Upload Audio":
return (gr.update(visible = True),
gr.update(visible = False, value = None),
gr.update(visible = False, value = None),
gr.update(visible = False, value = None),
gr.update(visible = False, value = None))
elif choice == "Record Audio":
return (gr.update(visible = False, value = None),
gr.update(visible = True),
gr.update(visible = False, value = None),
gr.update(visible = False, value = None),
gr.update(visible = False, value = None))
elif choice == "Upload Video":
return (gr.update(visible = False, value = None),
gr.update(visible = False, value = None),
gr.update(visible = True),
gr.update(visible = False, value = None),
gr.update(visible = False, value = None))
elif choice == "Record Video":
return (gr.update(visible = False, value = None),
gr.update(visible = False, value = None),
gr.update(visible = False, value = None),
gr.update(visible = True),
gr.update(visible = False, value = None))
elif choice == "File or Files":
return (gr.update(visible = False, value = None),
gr.update(visible = False, value = None),
gr.update(visible = False, value = None),
gr.update(visible = False, value = None),
gr.update(visible = True))
def run_scribe(task,
num_speakers,
translate,
language,
audio1,
audio2,
video1,
video2,
file_in,
progress = gr.Progress(track_tqdm= True)):
# get *args which are not None
progress(0, desc='Starting task...')
source = audio1 or audio2 or video1 or video2 or file_in
if isinstance(source, list):
source = [s.name for s in source]
if len(source) == 1:
source = source[0]
if task == 'Auto Transcribe':
out_str , out_json = pipe.auto_transcribe(source = source,
num_speakers = num_speakers,
translation = translate,
language = language)
if isinstance(source, str):
return (gr.update(value = out_str, visible = True),
gr.update(value = out_json, visible = True),
gr.update(visible = True),
gr.update(visible = True))
else:
return (gr.update(value = out_str, visible = True),
gr.update(value = out_json, visible = True),
gr.update(visible = False),
gr.update(visible = False))
elif task == 'Transcribe':
out = pipe.transcribe(source = source,
translation = translate,
language = language)
return (gr.update(value = out, visible = True),
gr.update(value = None, visible = False),
gr.update(visible = False),
gr.update(visible = False))
elif task == 'Diarisation':
out = pipe.perform_diarisation(source = source,
num_speakers = num_speakers)
return (gr.update(value = None, visible = False),
gr.update(value = out, visible = True),
gr.update(visible = False),
gr.update(visible = False))
def annotate_output(annoation : str, out_json : dict):
# get *args which are not None
trans = Transcript.from_json(out_json)
trans = trans.annotate(*annoation.split(","))
return gr.update(value = str(trans)),gr.update(value = trans.get_json())
with gr.Blocks(theme=theme,title='ScrAIbe: Automatic Audio Transcription') as demo:
# Define components
hname = os.path.join(CURRENT_PATH, "header.html")
header = open(hname, "r").read()
# ugly hack to get the logo to work
header = header.replace("/file=logo.svg", f"/file={CURRENT_PATH}/logo.svg" )
gr.HTML(header, visible= True, show_label=False)
with gr.Row():
with gr.Column():
task = gr.Radio(["Auto Transcribe", "Transcribe", "Diarisation"], label="Task",
value= 'Auto Transcribe')
num_speakers = gr.Number(value=0, label= "Number of speakers (optional)",
info = "Number of speakers in the audio file. If you don't know,\
leave it at 0.", visible= True)
translate = gr.Checkbox(label="Translation", choices=[True, False], value = False,
info="Select 'Yes' to have the output translated into English.",
visible= True)
language = gr.Dropdown(LANGUAGES,
label="Language (optional)", value = "None",
info="Language of the audio file. If you don't know,\
leave it at None.", visible= True)
input = gr.Radio(["Upload Audio", "Record Audio", "Upload Video","Record Video"
,"File or Files"], label="Input Type", value="Upload Audio")
audio1 = gr.Audio(source="upload", type="filepath", label="Upload Audio",
interactive= True, visible= True)
audio2 = gr.Audio(source="microphone", label="Record Audio", type="filepath",
interactive= True, visible= False)
video1 = gr.Video(source="upload", type="filepath", label="Upload Video",
interactive= True, visible= False)
video2 = gr.Video(source="webcam", label="Record Video", type="filepath",include_audio= True,
interactive= True, visible= False)
file_in = gr.Files(label="Upload File or Files", interactive= True, visible= False)
submit = gr.Button()
with gr.Column():
out_txt = gr.Textbox(label="Output",
visible= True, show_copy_button=True)
out_json = gr.JSON(label="JSON Output",
visible= False, show_copy_button=True)
annoation = gr.Textbox(label="Name your speaker's",
info= "Please provide a list of the speakers arranged \
in the order in which they appear in the input. Use comma ',' \
as a seperator. Be aware that the first name is given \
to SPEAKER_00 the second to SPEAKER_01 and so on.",
visible= False, interactive= True)
annotate = gr.Button(value="Annotate", visible= False, interactive= True)
# Define usage of components
input.change(fn=select_origin, inputs=[input],
outputs=[audio1, audio2, video1, video2, file_in])
task.change(fn=select_task, inputs=[task],
outputs=[num_speakers, translate, language])
translate.change(fn= lambda x : gr.update(value = x),
inputs=[translate], outputs=[translate])
num_speakers.change(fn= lambda x : gr.update(value = x),
inputs=[num_speakers], outputs=[num_speakers])
language.change(fn= lambda x : gr.update(value = x),
inputs=[language], outputs=[language])
submit.click(fn = run_scribe,
inputs=[task, num_speakers, translate, language, audio1,
audio2, video1, video2, file_in],
outputs=[out_txt, out_json, annoation, annotate])
annotate.click(fn = annotate_output, inputs=[annoation, out_json],
outputs=[out_txt, out_json])
return demo
if __name__ == "__main__":
gradio_Interface().queue().launch()
-66
View File
@@ -1,66 +0,0 @@
<!-- Importing Cormorant Garamond font from Google Fonts -->
<link href="https://fonts.googleapis.com/css2?family=Cormorant+Garamond:wght@400;700&display=swap" rel="stylesheet">
<style>
.header-container {
display: flex;
align-items: center;
justify-content: center;
position: relative;
padding-top: 30px;
}
.logo-container {
position: absolute;
top: 50%;
right: 20px;
transform: translateY(-50%);
width: 300px;
}
.logo {
width: 100%;
height: auto;
}
h1 {
font-family: 'Cormorant Garamond', serif;
font-size: 50px !important; /* Increased font size */
font-weight: bold;
color: #50AF31;
margin: 0;
position: relative;
padding: 0.5em 0;
}
h1::before, h1::after {
content: "";
position: absolute;
height: 2px;
width: 80%;
background-color: #50AF31;
left: 10%;
}
h1::before {
top: 0.5em;
}
h1::after {
bottom: 0.5em;
}
p, h2 {
font-size: 16px;
margin: 10px 0;
line-height: 1.4;
}
</style>
<div class="header-container">
<h1>ScrAIbe</h1>
<div class="logo-container">
<a href="https://www.kida-bmel.de/"> <!-- Replace with your actual URL -->
<img src="/file=logo.svg" alt="KIDA Logo" class="logo">
</a>
</div>
</div>
<div style="text-align: center; padding: 20px 10%;">
<p>
Upload, record, or provide a video with audio for transcription. Our toolkit is designed to transcribe content from multiple languages accurately. The integrated speaker diarisation feature identifies different speakers, ensuring a smooth transcription experience. For optimal results, indicate the number of speakers and the original language of the content.
</p>
<h2 style="font-weight: bold; color: #50AF31;">What would you like to do next?</h2>
</div>
File diff suppressed because one or more lines are too long

Before

Width:  |  Height:  |  Size: 29 KiB

-319
View File
@@ -1,319 +0,0 @@
"""
This file contains a modified version of qtfaststart by qtfaststart
https://github.com/danielgtaylor/qtfaststart/tree/master
All credit goes to the original author.
Copyright (C) 2008 - 2013 Daniel G. Taylor <dan@programmer-art.org>
Permission is hereby granted, free of charge, to any person obtaining a copy of this
software and associated documentation files (the "Software"),
to deal in the Software without restriction, including without limitation the rights to
use, copy, modify, merge, publish, distribute, sublicense, and/or sell copies of the
Software, and to permit persons to whom the Software is furnished to do so,
subject to the following conditions:
The above copyright notice and this permission notice shall be included in all copies
or substantial portions of the Software.
THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR IMPLIED,
INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT.
IN NO EVENT SHALL THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM,
DAMAGES OR OTHER LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE,
ARISING FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS
IN THE SOFTWARE.
"""
import logging
import os
import struct
import collections
import io
# define error classes
class FastStartException(Exception):
"""
Raised when something bad happens during processing.
"""
pass
class FastStartSetupError(FastStartException):
"""
Rasised when asked to process a file that does not need processing
"""
pass
class MalformedFileError(FastStartException):
"""
Raised when the input file is setup in an unexpected way
"""
pass
class UnsupportedFormatError(FastStartException):
"""
Raised when a movie file is recognized as a format not supported.
"""
pass
# define constants
CHUNK_SIZE = 8192
log = logging.getLogger("qtfaststart")
# Older versions of Python require this to be defined
if not hasattr(os, 'SEEK_CUR'):
os.SEEK_CUR = 1
Atom = collections.namedtuple('Atom', 'name position size')
def read_atom(datastream):
"""
Read an atom and return a tuple of (size, type) where size is the size
in bytes (including the 8 bytes already read) and type is a "fourcc"
like "ftyp" or "moov".
"""
size, type = struct.unpack(">L4s", datastream.read(8))
type = type.decode('ascii')
return size, type
def _read_atom_ex(datastream):
"""
Read an Atom from datastream
"""
pos = datastream.tell()
atom_size, atom_type = read_atom(datastream)
if atom_size == 1:
atom_size, = struct.unpack(">Q", datastream.read(8))
return Atom(atom_type, pos, atom_size)
def get_index(datastream):
"""
Return an index of top level atoms, their absolute byte-position in the
file and their size in a list:
index = [
("ftyp", 0, 24),
("moov", 25, 2658),
("free", 2683, 8),
...
]
The tuple elements will be in the order that they appear in the file.
"""
log.debug("Getting index of top level atoms...")
index = list(_read_atoms(datastream))
_ensure_valid_index(index)
return index
def _read_atoms(datastream):
"""
Read atoms until an error occurs
"""
while datastream:
try:
atom = _read_atom_ex(datastream)
log.debug("%s: %s" % (atom.name, atom.size))
except:
break
yield atom
if atom.size == 0:
if atom.name == "mdat":
# Some files may end in mdat with no size set, which generally
# means to seek to the end of the file. We can just stop indexing
# as no more entries will be found!
break
else:
# Weird, but just continue to try to find more atoms
continue
datastream.seek(atom.position + atom.size)
def _ensure_valid_index(index):
"""
Ensure the minimum viable atoms are present in the index.
Raise FastStartException if not.
"""
top_level_atoms = set([item.name for item in index])
for key in ["moov", "mdat"]:
if key not in top_level_atoms:
log.error("%s atom not found, is this a valid MOV/MP4 file?" % key)
raise FastStartException()
def find_atoms(size, datastream):
"""
Compatibilty interface for _find_atoms_ex
"""
fake_parent = Atom('fake', datastream.tell()-8, size+8)
for atom in _find_atoms_ex(fake_parent, datastream):
yield atom.name
def _find_atoms_ex(parent_atom, datastream):
"""
Yield either "stco" or "co64" Atoms from datastream.
datastream will be 8 bytes into the stco or co64 atom when the value
is yielded.
It is assumed that datastream will be at the end of the atom after
the value has been yielded and processed.
parent_atom is the parent atom, a 'moov' or other ancestor of CO
atoms in the datastream.
"""
stop = parent_atom.position + parent_atom.size
while datastream.tell() < stop:
try:
atom = _read_atom_ex(datastream)
except:
log.exception("Error reading next atom!")
raise FastStartException()
if atom.name in ["trak", "mdia", "minf", "stbl"]:
# Known ancestor atom of stco or co64, search within it!
for res in _find_atoms_ex(atom, datastream):
yield res
elif atom.name in ["stco", "co64"]:
yield atom
else:
# Ignore this atom, seek to the end of it.
datastream.seek(atom.position + atom.size)
def process(infilename, limit=float('inf')):
"""
Convert a Quicktime/MP4 file for streaming by moving the metadata to
the front of the file. This method writes a new file.
If limit is set to something other than zero it will be used as the
number of bytes to write of the atoms following the moov atom. This
is very useful to create a small sample of a file with full headers,
which can then be used in bug reports and such.
"""
if isinstance(infilename, str):
datastream = open(infilename, "rb")
elif isinstance(infilename, bytes):
datastream = io.BytesIO(infilename)
else:
raise TypeError("infilename must be a filename, bytes or file-like object")
# Get the top level atom index
index = get_index(datastream)
mdat_pos = 999999
free_size = 0
# Make sure moov occurs AFTER mdat, otherwise no need to run!
for atom in index:
# The atoms are guaranteed to exist from get_index above!
if atom.name == "moov":
moov_atom = atom
moov_pos = atom.position
elif atom.name == "mdat":
mdat_pos = atom.position
elif atom.name == "free" and atom.position < mdat_pos:
# This free atom is before the mdat!
free_size += atom.size
log.info("Removing free atom at %d (%d bytes)" % (atom.position, atom.size))
elif atom.name == "\x00\x00\x00\x00" and atom.position < mdat_pos:
# This is some strange zero atom with incorrect size
free_size += 8
log.info("Removing strange zero atom at %s (8 bytes)" % atom.position)
# Offset to shift positions
offset = moov_atom.size - free_size
if moov_pos < mdat_pos:
# moov appears to be in the proper place, don't shift by moov size
offset -= moov_atom.size
if not free_size:
# No free atoms and moov is correct, we are done!
log.error("This file appears to already be setup for streaming!")
# Stupid hack to retrun the non-processed file:
if isinstance(infilename, str):
return open(infilename, "rb").read()
elif isinstance(infilename, bytes):
return io.BytesIO(infilename).read()
# Read and fix moov
moov = _patch_moov(datastream, moov_atom, offset)
log.info("Writing output...")
outfile = b''
# Write ftype
for atom in index:
if atom.name == "ftyp":
log.debug("Writing ftyp... (%d bytes)" % atom.size)
datastream.seek(atom.position)
outfile += datastream.read(atom.size)
# Write moov
_bytes = moov.getvalue()
log.debug("Writing moov... (%d bytes)" % len(_bytes))
outfile += _bytes
# Write the rest
atoms = [item for item in index if item.name not in ["ftyp", "moov", "free"]]
for atom in atoms:
log.debug("Writing %s... (%d bytes)" % (atom.name, atom.size))
datastream.seek(atom.position)
# for compatability, allow '0' to mean no limit
cur_limit = limit or float('inf')
cur_limit = min(cur_limit, atom.size)
for chunk in get_chunks(datastream, CHUNK_SIZE, cur_limit):
outfile += chunk
return outfile
def _patch_moov(datastream, atom, offset):
datastream.seek(atom.position)
moov = io.BytesIO(datastream.read(atom.size))
# reload the atom from the fixed stream
atom = _read_atom_ex(moov)
for atom in _find_atoms_ex(atom, moov):
# Read either 32-bit or 64-bit offsets
ctype, csize = dict(
stco=('L', 4),
co64=('Q', 8),
)[atom.name]
# Get number of entries
version, entry_count = struct.unpack(">2L", moov.read(8))
log.info("Patching %s with %d entries" % (atom.name, entry_count))
entries_pos = moov.tell()
struct_fmt = ">%(entry_count)s%(ctype)s" % vars()
# Read entries
entries = struct.unpack(struct_fmt, moov.read(csize * entry_count))
# Patch and write entries
offset_entries = [entry + offset for entry in entries]
moov.seek(entries_pos)
moov.write(struct.pack(struct_fmt, *offset_entries))
return moov
def get_chunks(stream, chunk_size, limit):
remaining = limit
while remaining:
chunk = stream.read(min(remaining, chunk_size))
if not chunk:
return
remaining -= len(chunk)
yield chunk
+14
View File
@@ -75,6 +75,11 @@ class Scraibe:
Path to pyannote diarization model or model itself.
**kwargs: Additional keyword arguments for whisper
and pyannote diarization models.
e.g.:
- verbose: If True, the class will print additional information.
- save_kwargs: If True, the keyword arguments will be saved
for autotranscribe. So you can unload the class and reload it again.
"""
@@ -98,6 +103,15 @@ class Scraibe:
else:
self.verbose = False
# Save kwargs for autotranscribe if you want to unload the class and load it again.
if kwargs.get('save_setup'):
self.params = dict(whisper_model = whisper_model,
dia_model = dia_model,
**kwargs)
else:
self.params = {}
def autotranscribe(self, audio_file : Union[str, torch.Tensor, ndarray],
remove_original : bool = False,
**kwargs) -> Transcript:
+70 -51
View File
@@ -9,7 +9,8 @@ from argparse import ArgumentParser, ArgumentDefaultsHelpFormatter
import json
from .autotranscript import Scraibe
from .app.gradio_app import gradio_Interface
from .misc import ParseKwargs
from whisper.tokenizer import LANGUAGES , TO_LANGUAGE_CODE
from torch.cuda import is_available
@@ -41,13 +42,15 @@ def cli():
help="List of audio files to transcribe.")
group.add_argument('--start-server', action='store_true',
help='Start the Gradio app.')
help='Start the Gradio app.' \
'If set, all other arguments are ignored' \
'besides --server-config or --server-kwargs.')
parser.add_argument("--port", type=int, default= None,
help="Port to run the Gradio app on. Defaults to 7860.")
parser.add_argument("--server-config", type=str, default= None,
help="Path to the configy.yml file.")
parser.add_argument("--server-name", type=str, default= None,
help="Name of the Gradio app. If empty 127.0.0.1 or 0.0.0.0 will be used.")
parser.add_argument('--server-kwargs', nargs='*', action=ParseKwargs, default={},
help='Keyword arguments for the Gradio app.')
parser.add_argument("--whisper-model-name", default="medium",
help="Name of the Whisper model to use.")
@@ -66,7 +69,8 @@ def cli():
help="Device to use for PyTorch inference.")
parser.add_argument("--num-threads", type=int, default=0,
help="Number of threads used by torch for CPU inference; overrides MKL_NUM_THREADS/OMP_NUM_THREADS.")
help="Number of threads used by torch for CPU inference; '\
'overrides MKL_NUM_THREADS/OMP_NUM_THREADS.")
parser.add_argument("--output-directory", "-o", type=str, default=".",
help="Directory to save the transcription outputs.")
@@ -113,55 +117,70 @@ def cli():
if arg_dict["whisper_model_directory"]:
class_kwargs["download_root"] = arg_dict.pop("whisper_model_directory")
model = Scraibe(**class_kwargs)
if not start_server:
model = Scraibe(**class_kwargs)
if arg_dict["audio_files"]:
audio_files = arg_dict.pop("audio_files")
if task == "autotranscribe" or task == "autotranscribe+translate":
for audio in audio_files:
if task == "autotranscribe+translate":
task = "translate"
else:
task = "transcribe"
out = model.autotranscribe(audio,task = task, language=arg_dict.pop("language"), verbose = arg_dict.pop("verbose_output"))
basename = audio.split("/")[-1].split(".")[0]
print(f'Saving {basename}.{out_format} to {out_folder}')
out.save(os.path.join(out_folder, f"{basename}.{out_format}"))
elif task == "diarization":
for audio in audio_files:
if arg_dict.pop("verbose_output"):
print(f"Verbose not implemented for diarization.")
out = model.diarization(audio)
basename = audio.split("/")[-1].split(".")[0]
path = os.path.join(out_folder, f"{basename}.{out_format}")
print(f'Saving {basename}.{out_format} to {out_folder}')
with open(path, "w") as f:
json.dump(json.dumps(out, indent= 1), f)
elif task == "transcribe" or task == "translate":
for audio in audio_files:
out = model.transcribe(audio, task = task,
language= arg_dict.pop("language"),
verbose = arg_dict.pop("verbose_output"))
basename = audio.split("/")[-1].split(".")[0]
path = os.path.join(out_folder, f"{basename}.{out_format}")
with open(path, "w") as f:
f.write(out)
if arg_dict["audio_files"]:
audio_files = arg_dict.pop("audio_files")
else: # unfinished code
raise NotImplementedError("Currently not Working")
import subprocess
import sys
if task == "autotranscribe" or task == "autotranscribe+translate":
for audio in audio_files:
if task == "autotranscribe+translate":
task = "translate"
else:
task = "transcribe"
execute_path = os.path.join(os.path.dirname(__file__), "app/app_starter.py")
out = model.autotranscribe(audio,task = task, language=arg_dict.pop("language"), verbose = arg_dict.pop("verbose_output"))
basename = audio.split("/")[-1].split(".")[0]
print(f'Saving {basename}.{out_format} to {out_folder}')
out.save(os.path.join(out_folder, f"{basename}.{out_format}"))
elif task == "diarization":
for audio in audio_files:
if arg_dict.pop("verbose_output"):
print(f"Verbose not implemented for diarization.")
out = model.diarization(audio)
basename = audio.split("/")[-1].split(".")[0]
path = os.path.join(out_folder, f"{basename}.{out_format}")
print(f'Saving {basename}.{out_format} to {out_folder}')
with open(path, "w") as f:
json.dump(json.dumps(out, indent= 1), f)
elif task == "transcribe" or task == "translate":
for audio in audio_files:
out = model.transcribe(audio, task = task,
language= arg_dict.pop("language"),
verbose = arg_dict.pop("verbose_output"))
basename = audio.split("/")[-1].split(".")[0]
path = os.path.join(out_folder, f"{basename}.{out_format}")
with open(path, "w") as f:
f.write(out)
if start_server: # unfinished code
gradio_Interface(model).queue().launch(server_port=args.port, server_name=args.server_name)
config = arg_dict.pop("server_config")
server_kwargs = arg_dict.pop("server_kwargs")
if not config:
subprocess.run([sys.executable, execute_path, f"--server-kwargs={server_kwargs}"])
elif not server_kwargs:
subprocess.run([sys.executable, execute_path, f"--server-config={config}"])
elif not config and not server_kwargs:
subprocess.run([sys.executable, execute_path])
else:
subprocess.run([sys.executable, execute_path, f"--server-config={config}", f"--server-kwargs={server_kwargs}"])
if __name__ == "__main__":
cli()
+37
View File
@@ -27,7 +27,9 @@ Usage:
diarisation_output = model.diarization("path/to/audiofile.wav")
"""
import warnings
import os
import yaml
from pathlib import Path
from typing import TypeVar, Union
@@ -216,6 +218,41 @@ class Diariser:
if not os.path.exists(model) and use_auth_token is None:
use_auth_token = cls._get_token()
elif os.path.exists(model) and not use_auth_token:
# check if model can be found locally nearby the config file
with open(model, 'r') as file:
config = yaml.safe_load(file)
path_to_model = config['pipeline']['params']['segmentation']
if not os.path.exists(path_to_model):
warnings.warn(f"Model not found at {path_to_model}. "\
"Trying to find it nearby the config file.")
pwd = model.split("/")[:-1]
pwd = "/".join(pwd)
path_to_model = os.path.join(pwd, "pytorch_model.bin")
if not os.path.exists(path_to_model):
warnings.warn(f"Model not found at {path_to_model}. \
'Trying to find it nearby .bin files instead.")
# list elementes with the ending .bin
bin_files = [f for f in os.listdir(pwd) if f.endswith(".bin")]
if len(bin_files) == 1:
path_to_model = os.path.join(pwd, bin_files[0])
else:
warnings.warn("Found more than one .bin file. "\
"or none. Please specify the path to the model " \
"or setup a huggingface token.")
warnings.warn(f"Found model at {path_to_model} overwriting config file.")
config['pipeline']['params']['segmentation'] = path_to_model
with open(model, 'w') as file:
yaml.dump(config, file)
_model = Pipeline.from_pretrained(model,
use_auth_token = use_auth_token,
cache_dir = cache_dir,
+15
View File
@@ -1,6 +1,7 @@
import os
import yaml
from pyannote.audio.core.model import CACHE_DIR as PYANNOTE_CACHE_DIR
from argparse import Action
CACHE_DIR = os.getenv(
"AUTOT_CACHE",
@@ -40,3 +41,17 @@ def config_diarization_yaml(file_path: str, path_to_segmentation: str = None) ->
with open(file_path, "w") as stream:
yaml.dump(yml, stream)
class ParseKwargs(Action):
"""
Custom argparse action to parse keyword arguments.
"""
def __call__(self, parser, namespace, values, option_string=None):
setattr(namespace, self.dest, dict())
for value in values:
key, value = value.split('=')
try:
value = eval(value)
except:
pass
getattr(namespace, self.dest)[key] = value
+3 -2
View File
@@ -1,4 +1,3 @@
from calendar import c
import pkg_resources
import os
from setuptools import setup, find_packages
@@ -21,6 +20,8 @@ with open(verfile, "r") as fp:
build_version = "SCRAIBE_BUILD" in os.environ
version["ISRELEASED"] = True if "ISRELEASED" in os.environ else False
if __name__ == "__main__":
setup(
@@ -53,7 +54,7 @@ if __name__ == "__main__":
keywords = ['transcription', 'speech recognition', 'whisper', 'pyannote', 'audio', 'ScrAIbe', 'scraibe',
'speech-to-text', 'speech-to-text transcription', 'speech-to-text recognition',
'voice-to-speech'],
package_data={'scraibe.app' : ["*.html", "*.svg"]},
package_data={'scraibe.app' : ["*.html", "*.svg","*.yml"]},
entry_points={'console_scripts':
['scraibe = scraibe.cli:cli']}