removed App here
This commit is contained in:
@@ -1,7 +0,0 @@
|
|||||||
from .multi import *
|
|
||||||
from .interface import *
|
|
||||||
from .stg import *
|
|
||||||
from .interactions import *
|
|
||||||
from .global_var import *
|
|
||||||
from .utils import *
|
|
||||||
from .app import *
|
|
||||||
@@ -1,100 +0,0 @@
|
|||||||
"""
|
|
||||||
Gradio App
|
|
||||||
----------
|
|
||||||
|
|
||||||
This module provides an interface to transcribe audio files using the
|
|
||||||
Scraibe model. Users can either upload an audio file or record their speech
|
|
||||||
live for transcription. The application supports multiple languages and provides
|
|
||||||
options to specify the number of speakers and the language of the audio. It also
|
|
||||||
enables efficient management of resources by loading and unloading AI models
|
|
||||||
based on usage.
|
|
||||||
|
|
||||||
The configuration is managed via a 'config.yml' file, which allows customization
|
|
||||||
of various aspects of the application, including the Gradio interface, queue
|
|
||||||
management, and model parameters.
|
|
||||||
|
|
||||||
Configuration Sections in 'config.yml':
|
|
||||||
- launch: Settings for launching the interface, such as server port, authentication, SSL configuration.
|
|
||||||
- queue: Configuration for managing request handling and concurrency.
|
|
||||||
- layout: Customization options for the interface layout, like headers, footers, and logos.
|
|
||||||
- model: Specifications for different AI models used in transcription.
|
|
||||||
- advanced: Advanced settings, including session timeout duration.
|
|
||||||
|
|
||||||
Note:
|
|
||||||
The .queue function of the Gradio interface is currently experiencing issues
|
|
||||||
and might not work as expected.
|
|
||||||
|
|
||||||
Usage:
|
|
||||||
Run this script to start the Gradio web interface for audio transcription.
|
|
||||||
"""
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
####
|
|
||||||
# Gradio Interface
|
|
||||||
####
|
|
||||||
|
|
||||||
from threading import Thread
|
|
||||||
|
|
||||||
import scraibe.app.global_var as gv
|
|
||||||
from .interface import gradio_Interface
|
|
||||||
from .multi import *
|
|
||||||
from .utils import *
|
|
||||||
|
|
||||||
|
|
||||||
def app(config : str = None, **kwargs):
|
|
||||||
"""
|
|
||||||
Launches the Gradio interface for audio transcription.
|
|
||||||
|
|
||||||
Initializes the Gradio web interface with settings from a YAML configuration file
|
|
||||||
and/or keyword arguments. The function manages AI models, handling their loading
|
|
||||||
into RAM and unloading after a session or specified timeout.
|
|
||||||
|
|
||||||
The `kwargs` are used to override or supplement values from the `config.yml` file.
|
|
||||||
They should follow the structure of `config.yml`, which includes sections like
|
|
||||||
'launch', 'queue', 'layout', 'model', and 'advanced'.
|
|
||||||
|
|
||||||
Args:
|
|
||||||
config (str): Path to the YAML configuration file. Default settings are used
|
|
||||||
if not provided.
|
|
||||||
**kwargs: Keyword arguments corresponding to the configuration sections. Each
|
|
||||||
argument should be a dictionary reflecting the structure of its
|
|
||||||
respective section in `config.yml`.
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
None
|
|
||||||
"""
|
|
||||||
|
|
||||||
# Load and override configuration from the YAML file with kwargs
|
|
||||||
|
|
||||||
config = AppConfig.load_config(config, **kwargs)
|
|
||||||
|
|
||||||
|
|
||||||
gv.MODEL_PROCESS = start_model_worker(gv.MODEL_PARAMS,
|
|
||||||
gv.REQUEST_QUEUE,
|
|
||||||
gv.LAST_ACTIVE_TIME,
|
|
||||||
gv.RESPONSE_QUEUE,
|
|
||||||
gv.LOADED_EVENT,
|
|
||||||
gv.RUNNING_EVENT)
|
|
||||||
|
|
||||||
# Set the timer thread to manage model loading and unloading
|
|
||||||
timer = Thread(target=timer_thread, args=(gv.REQUEST_QUEUE,
|
|
||||||
gv.LAST_ACTIVE_TIME,
|
|
||||||
gv.LOADED_EVENT,
|
|
||||||
gv.RUNNING_EVENT,
|
|
||||||
gv.TIMEOUT), daemon=True)
|
|
||||||
|
|
||||||
# Set the layout for the Gradio interface
|
|
||||||
layout = config.get_layout()
|
|
||||||
|
|
||||||
# start the timer thread
|
|
||||||
timer.start()
|
|
||||||
|
|
||||||
print("Starting Gradio Web Interface")
|
|
||||||
|
|
||||||
# Launch the Gradio interface
|
|
||||||
gradio_Interface(layout).queue(**config.queue).launch(**config.launch)
|
|
||||||
|
|
||||||
# Wait for the timer thread to finish
|
|
||||||
timer.join()
|
|
||||||
gv.MODEL_PROCESS.join()
|
|
||||||
@@ -1,70 +0,0 @@
|
|||||||
"""Starts the Gradio interface for audio transcription with optional configuration.
|
|
||||||
|
|
||||||
This script, app_starter.py, initializes and runs a Gradio interface for audio
|
|
||||||
transcription tasks. It allows users to provide a configuration file for custom
|
|
||||||
settings. If no configuration file is specified, default settings are applied.
|
|
||||||
The script is designed to support multiprocessing for improved performance.
|
|
||||||
|
|
||||||
Attributes:
|
|
||||||
args (argparse.Namespace): Parsed command line arguments.
|
|
||||||
|
|
||||||
Example:
|
|
||||||
To run the script with custom server configuration and keyword arguments:
|
|
||||||
$ python app_starter.py --server-config path/to/config.yml --server-kwargs key1=val1 key2=val2
|
|
||||||
"""
|
|
||||||
|
|
||||||
import multiprocessing
|
|
||||||
from argparse import ArgumentParser, Action
|
|
||||||
|
|
||||||
class ParseKwargs(Action):
|
|
||||||
"""Custom action for argparse to parse keyword arguments for Gradio app configuration.
|
|
||||||
|
|
||||||
This action parses a series of keyword arguments and converts them into a
|
|
||||||
dictionary, which is then used to configure the Gradio application. It
|
|
||||||
supports dynamic types by attempting to evaluate the argument values.
|
|
||||||
|
|
||||||
Attributes:
|
|
||||||
dest (str): The name of the attribute to be added to the object returned by parse_args().
|
|
||||||
"""
|
|
||||||
def __call__(self, parser, namespace, values, option_string=None):
|
|
||||||
"""Parses keyword arguments and updates the namespace with these arguments as a dictionary.
|
|
||||||
|
|
||||||
For each value provided, this method splits the string on the '=' character
|
|
||||||
to separate keys and values, attempting to evaluate the values for Python
|
|
||||||
literals. If evaluation fails, the raw string is used as the value.
|
|
||||||
|
|
||||||
Args:
|
|
||||||
parser (ArgumentParser): The ArgumentParser object that called this method.
|
|
||||||
namespace (Namespace): An argparse.Namespace object that will be returned by parse_args().
|
|
||||||
values (list of str): List of strings, each representing a key-value pair in 'key=value' format.
|
|
||||||
option_string (Optional[str]): The option string that was used to invoke this action.
|
|
||||||
|
|
||||||
Raises:
|
|
||||||
ValueError: If any string in values does not contain the '=' character, indicating an invalid format.
|
|
||||||
"""
|
|
||||||
setattr(namespace, self.dest, dict())
|
|
||||||
for value in values:
|
|
||||||
key, value = value.split('=')
|
|
||||||
try:
|
|
||||||
value = eval(value)
|
|
||||||
except:
|
|
||||||
pass
|
|
||||||
getattr(namespace, self.dest)[key] = value
|
|
||||||
|
|
||||||
parser = ArgumentParser()
|
|
||||||
|
|
||||||
parser.add_argument("--server-config", type=str, default= None,
|
|
||||||
help="Path to the configy.yml file.")
|
|
||||||
|
|
||||||
parser.add_argument('--server-kwargs', nargs='*', action=ParseKwargs, default={},
|
|
||||||
help='Keyword arguments for the Gradio app.')
|
|
||||||
|
|
||||||
args = parser.parse_args()
|
|
||||||
|
|
||||||
if __name__ == '__main__':
|
|
||||||
|
|
||||||
multiprocessing.set_start_method('spawn')
|
|
||||||
|
|
||||||
from scraibe.app.app import app
|
|
||||||
|
|
||||||
app(config = args.server_config, **args.server_kwargs)
|
|
||||||
@@ -1,47 +0,0 @@
|
|||||||
launch:
|
|
||||||
# The following are the default values for the launch configuration
|
|
||||||
# for more informations look at https://www.gradio.app/docs/interface
|
|
||||||
server_port: 7860
|
|
||||||
server_name: 0.0.0.0
|
|
||||||
inline: false
|
|
||||||
inbrowser: false
|
|
||||||
share : false
|
|
||||||
debug : false
|
|
||||||
max_threads: 40
|
|
||||||
quiet: false
|
|
||||||
auth:
|
|
||||||
auth_enabled: false
|
|
||||||
auth_username: admin
|
|
||||||
auth_password: admin
|
|
||||||
auth_message: null
|
|
||||||
prevent_thread_lock : false
|
|
||||||
show_error : false
|
|
||||||
show_tips : false
|
|
||||||
favicon_path : null
|
|
||||||
ssl_keyfile : null
|
|
||||||
ssl_certfile : null
|
|
||||||
ssl_keyfile_password : null
|
|
||||||
ssl_verify : false
|
|
||||||
show_api : false
|
|
||||||
allowed_paths : null
|
|
||||||
blocked_paths : null
|
|
||||||
root_path : ''
|
|
||||||
app_kwargs : null
|
|
||||||
|
|
||||||
queue:
|
|
||||||
# The following are the default values for the queue configuration
|
|
||||||
# for more informations look at hhttps://www.gradio.app/docs/interface
|
|
||||||
concurrency_count : 1
|
|
||||||
status_update_rate : 'auto'
|
|
||||||
api_open : null
|
|
||||||
max_size : null
|
|
||||||
|
|
||||||
layout:
|
|
||||||
header: scraibe/app/header.html
|
|
||||||
footer: null
|
|
||||||
logo: scraibe/app/logo.svg
|
|
||||||
model:
|
|
||||||
whisper_model : null
|
|
||||||
dia_model: null
|
|
||||||
advanced:
|
|
||||||
timeout: 300 #seconds e.g. 5 minutes
|
|
||||||
@@ -1,37 +0,0 @@
|
|||||||
"""
|
|
||||||
global_var.py
|
|
||||||
|
|
||||||
This module stores global variables for the app.
|
|
||||||
|
|
||||||
Global variables:
|
|
||||||
REQUEST_QUEUE (multiprocessing.Queue): A queue to store audio file paths as strings.
|
|
||||||
RESPONSE_QUEUE (multiprocessing.Queue): A queue to store transcriptions as strings.
|
|
||||||
LAST_ACTIVE_TIME (multiprocessing.Value): A value to store the time of the last activity.
|
|
||||||
LOADED_EVENT (multiprocessing.Event): An event to indicate when the model is loaded.
|
|
||||||
RUNNING_EVENT (multiprocessing.Event): An event to indicate when the model is running.
|
|
||||||
MODEL_PARAMS (Optional[dict]): A dictionary to store the model parameters.
|
|
||||||
MODEL_PROCESS (Optional[multiprocessing.Process]): A process to handle the model globally.
|
|
||||||
LAST_USED (float): A float to track the time of the last user activity.
|
|
||||||
TIMEOUT (Optional[int]): An integer to store the timeout in seconds.
|
|
||||||
DEFAULT_APP_CONIFG_PATH (str): A string to store the default path to the app configuration file.
|
|
||||||
"""
|
|
||||||
|
|
||||||
import multiprocessing
|
|
||||||
import os
|
|
||||||
import time
|
|
||||||
from typing import Optional
|
|
||||||
|
|
||||||
REQUEST_QUEUE: multiprocessing.Queue = multiprocessing.Queue() # audio file path as string
|
|
||||||
RESPONSE_QUEUE: multiprocessing.Queue = multiprocessing.Queue() # transcription as string
|
|
||||||
LAST_ACTIVE_TIME: multiprocessing.Value = multiprocessing.Value('d', time.time()) # time of last activity
|
|
||||||
LOADED_EVENT: multiprocessing.Event = multiprocessing.Event() # model loaded event
|
|
||||||
RUNNING_EVENT: multiprocessing.Event = multiprocessing.Event() # model running event
|
|
||||||
|
|
||||||
MODEL_PARAMS: Optional[dict] = None # model parameters
|
|
||||||
MODEL_PROCESS: Optional[multiprocessing.Process] = None # model process to handle globally
|
|
||||||
|
|
||||||
# Global variable to track user activity
|
|
||||||
LAST_USED: float = time.time()
|
|
||||||
TIMEOUT: Optional[int] = None # seconds
|
|
||||||
|
|
||||||
DEFAULT_APP_CONIFG_PATH: str = os.path.join(os.path.dirname(os.path.realpath(__file__)), "config.yml")
|
|
||||||
@@ -1,66 +0,0 @@
|
|||||||
<!-- Importing Cormorant Garamond font from Google Fonts -->
|
|
||||||
<link href="https://fonts.googleapis.com/css2?family=Cormorant+Garamond:wght@400;700&display=swap" rel="stylesheet">
|
|
||||||
|
|
||||||
<style>
|
|
||||||
.header-container {
|
|
||||||
display: flex;
|
|
||||||
align-items: center;
|
|
||||||
justify-content: center;
|
|
||||||
position: relative;
|
|
||||||
padding-top: 30px;
|
|
||||||
}
|
|
||||||
.logo-container {
|
|
||||||
position: absolute;
|
|
||||||
top: 50%;
|
|
||||||
right: 20px;
|
|
||||||
transform: translateY(-50%);
|
|
||||||
width: 300px;
|
|
||||||
}
|
|
||||||
.logo {
|
|
||||||
width: 100%;
|
|
||||||
height: auto;
|
|
||||||
}
|
|
||||||
h1 {
|
|
||||||
font-family: 'Cormorant Garamond', serif;
|
|
||||||
font-size: 50px !important; /* Increased font size */
|
|
||||||
font-weight: bold;
|
|
||||||
color: #50AF31;
|
|
||||||
margin: 0;
|
|
||||||
position: relative;
|
|
||||||
padding: 0.5em 0;
|
|
||||||
}
|
|
||||||
h1::before, h1::after {
|
|
||||||
content: "";
|
|
||||||
position: absolute;
|
|
||||||
height: 2px;
|
|
||||||
width: 80%;
|
|
||||||
background-color: #50AF31;
|
|
||||||
left: 10%;
|
|
||||||
}
|
|
||||||
h1::before {
|
|
||||||
top: 0.5em;
|
|
||||||
}
|
|
||||||
h1::after {
|
|
||||||
bottom: 0.5em;
|
|
||||||
}
|
|
||||||
p, h2 {
|
|
||||||
font-size: 16px;
|
|
||||||
margin: 10px 0;
|
|
||||||
line-height: 1.4;
|
|
||||||
}
|
|
||||||
</style>
|
|
||||||
|
|
||||||
<div class="header-container">
|
|
||||||
<h1>ScrAIbe</h1>
|
|
||||||
<div class="logo-container">
|
|
||||||
<a href="https://www.kida-bmel.de/"> <!-- Replace with your actual URL -->
|
|
||||||
<img src="/file=logo.svg" alt="KIDA Logo" class="logo">
|
|
||||||
</a>
|
|
||||||
</div>
|
|
||||||
</div>
|
|
||||||
<div style="text-align: center; padding: 20px 10%;">
|
|
||||||
<p>
|
|
||||||
Upload, record, or provide a video with audio for transcription. Our toolkit is designed to transcribe content from multiple languages accurately. The integrated speaker diarisation feature identifies different speakers, ensuring a smooth transcription experience. For optimal results, indicate the number of speakers and the original language of the content.
|
|
||||||
</p>
|
|
||||||
<h2 style="font-weight: bold; color: #50AF31;">What would you like to do next?</h2>
|
|
||||||
</div>
|
|
||||||
@@ -1,153 +0,0 @@
|
|||||||
"""
|
|
||||||
This file contains ervery function that will be called when the user interacts with the
|
|
||||||
UI like pressing a button or uploading a file.
|
|
||||||
"""
|
|
||||||
|
|
||||||
import gradio as gr
|
|
||||||
import scraibe.app.global_var as gv
|
|
||||||
from scraibe import Transcript
|
|
||||||
from .multi import start_model_worker
|
|
||||||
|
|
||||||
def select_task(choice):
|
|
||||||
# tell the app that it is still in use
|
|
||||||
if choice == 'Auto Transcribe':
|
|
||||||
|
|
||||||
return (gr.update(visible = True),
|
|
||||||
gr.update(visible = True),
|
|
||||||
gr.update(visible = True))
|
|
||||||
|
|
||||||
|
|
||||||
elif choice == 'Transcribe':
|
|
||||||
|
|
||||||
return (gr.update(visible = False),
|
|
||||||
gr.update(visible = True),
|
|
||||||
gr.update(visible = True))
|
|
||||||
|
|
||||||
|
|
||||||
elif choice == 'Diarisation':
|
|
||||||
|
|
||||||
return (gr.update(visible = True),
|
|
||||||
gr.update(visible = False),
|
|
||||||
gr.update(visible = False))
|
|
||||||
|
|
||||||
def select_origin(choice):
|
|
||||||
|
|
||||||
# tell the app that it is still in use
|
|
||||||
if choice == "Upload Audio":
|
|
||||||
|
|
||||||
return (gr.update(visible = True),
|
|
||||||
gr.update(visible = False, value = None),
|
|
||||||
gr.update(visible = False, value = None),
|
|
||||||
gr.update(visible = False, value = None),
|
|
||||||
gr.update(visible = False, value = None))
|
|
||||||
|
|
||||||
elif choice == "Record Audio":
|
|
||||||
|
|
||||||
return (gr.update(visible = False, value = None),
|
|
||||||
gr.update(visible = True),
|
|
||||||
gr.update(visible = False, value = None),
|
|
||||||
gr.update(visible = False, value = None),
|
|
||||||
gr.update(visible = False, value = None))
|
|
||||||
|
|
||||||
elif choice == "Upload Video":
|
|
||||||
|
|
||||||
return (gr.update(visible = False, value = None),
|
|
||||||
gr.update(visible = False, value = None),
|
|
||||||
gr.update(visible = True),
|
|
||||||
gr.update(visible = False, value = None),
|
|
||||||
gr.update(visible = False, value = None))
|
|
||||||
|
|
||||||
elif choice == "Record Video":
|
|
||||||
|
|
||||||
return (gr.update(visible = False, value = None),
|
|
||||||
gr.update(visible = False, value = None),
|
|
||||||
gr.update(visible = False, value = None),
|
|
||||||
gr.update(visible = True),
|
|
||||||
gr.update(visible = False, value = None))
|
|
||||||
|
|
||||||
elif choice == "File or Files":
|
|
||||||
|
|
||||||
return (gr.update(visible = False, value = None),
|
|
||||||
gr.update(visible = False, value = None),
|
|
||||||
gr.update(visible = False, value = None),
|
|
||||||
gr.update(visible = False, value = None),
|
|
||||||
gr.update(visible = True))
|
|
||||||
|
|
||||||
def run_scraibe(task,
|
|
||||||
num_speakers,
|
|
||||||
translate,
|
|
||||||
language,
|
|
||||||
audio1,
|
|
||||||
audio2,
|
|
||||||
video1,
|
|
||||||
video2,
|
|
||||||
file_in,
|
|
||||||
progress = gr.Progress(track_tqdm=False)):
|
|
||||||
|
|
||||||
# get *args which are not None
|
|
||||||
if gv.MODEL_PROCESS is None or not gv.MODEL_PROCESS.is_alive():
|
|
||||||
#progress(0.0, desc='Loading model...')
|
|
||||||
gv.MODEL_PROCESS = start_model_worker(gv.MODEL_PARAMS,
|
|
||||||
gv.REQUEST_QUEUE,
|
|
||||||
gv.LAST_ACTIVE_TIME,
|
|
||||||
gv.RESPONSE_QUEUE,
|
|
||||||
gv.LOADED_EVENT,
|
|
||||||
gv.RUNNING_EVENT)
|
|
||||||
|
|
||||||
# progress(0.1, desc='Starting task...')
|
|
||||||
source = audio1 or audio2 or video1 or video2 or file_in
|
|
||||||
|
|
||||||
if isinstance(source, list):
|
|
||||||
source = [s.name for s in source]
|
|
||||||
if len(source) == 1:
|
|
||||||
source = source[0]
|
|
||||||
|
|
||||||
config = dict(source = source,
|
|
||||||
task = task,
|
|
||||||
num_speakers = num_speakers,
|
|
||||||
translate = translate,
|
|
||||||
language = language)
|
|
||||||
|
|
||||||
gv.REQUEST_QUEUE.put(config)
|
|
||||||
|
|
||||||
if task == 'Auto Transcribe':
|
|
||||||
|
|
||||||
out_str , out_json = gv.RESPONSE_QUEUE.get()
|
|
||||||
|
|
||||||
if isinstance(source, str):
|
|
||||||
return (gr.update(value = out_str, visible = True),
|
|
||||||
gr.update(value = out_json, visible = True),
|
|
||||||
gr.update(visible = True),
|
|
||||||
gr.update(visible = True))
|
|
||||||
else:
|
|
||||||
return (gr.update(value = out_str, visible = True),
|
|
||||||
gr.update(value = out_json, visible = True),
|
|
||||||
gr.update(visible = False),
|
|
||||||
gr.update(visible = False))
|
|
||||||
|
|
||||||
elif task == 'Transcribe':
|
|
||||||
|
|
||||||
out = gv.RESPONSE_QUEUE.get()
|
|
||||||
|
|
||||||
return (gr.update(value = out, visible = True),
|
|
||||||
gr.update(value = None, visible = False),
|
|
||||||
gr.update(visible = False),
|
|
||||||
gr.update(visible = False))
|
|
||||||
|
|
||||||
elif task == 'Diarisation':
|
|
||||||
|
|
||||||
out = gv.RESPONSE_QUEUE.get()
|
|
||||||
|
|
||||||
return (gr.update(value = None, visible = False),
|
|
||||||
gr.update(value = out, visible = True),
|
|
||||||
gr.update(visible = False),
|
|
||||||
gr.update(visible = False))
|
|
||||||
|
|
||||||
def annotate_output(annoation : str, out_json : dict):
|
|
||||||
# get *args which are not None
|
|
||||||
|
|
||||||
trans = Transcript.from_json(out_json)
|
|
||||||
trans = trans.annotate(*annoation.split(","))
|
|
||||||
|
|
||||||
return gr.update(value = str(trans)),gr.update(value = trans.get_json())
|
|
||||||
|
|
||||||
@@ -1,146 +0,0 @@
|
|||||||
"""
|
|
||||||
This module contains the gradio Interface which is used to interact with the user.
|
|
||||||
|
|
||||||
The interface is themed with a soft color scheme, with primary colors of green and orange, and a neutral color of gray.
|
|
||||||
|
|
||||||
A list of languages is also defined in this module, which may be used elsewhere in the application.
|
|
||||||
|
|
||||||
Classes:
|
|
||||||
Soft: A class from the gradio library used to theme the interface.
|
|
||||||
|
|
||||||
Variables:
|
|
||||||
theme (gr.themes.Soft): The theme for the gradio interface.
|
|
||||||
LANGUAGES (list of str): A list of languages supported by the application.
|
|
||||||
"""
|
|
||||||
|
|
||||||
import gradio as gr
|
|
||||||
|
|
||||||
from .interactions import *
|
|
||||||
from .stg import *
|
|
||||||
|
|
||||||
theme = gr.themes.Soft(
|
|
||||||
primary_hue="green",
|
|
||||||
secondary_hue='orange',
|
|
||||||
neutral_hue="gray",
|
|
||||||
)
|
|
||||||
|
|
||||||
|
|
||||||
LANGUAGES = [
|
|
||||||
"Afrikaans", "Arabic", "Armenian", "Azerbaijani", "Belarusian",
|
|
||||||
"Bosnian", "Bulgarian", "Catalan", "Chinese", "Croatian",
|
|
||||||
"Czech", "Danish", "Dutch", "English", "Estonian",
|
|
||||||
"Finnish", "French", "Galician", "German", "Greek",
|
|
||||||
"Hebrew", "Hindi", "Hungarian", "Icelandic", "Indonesian",
|
|
||||||
"Italian", "Japanese", "Kannada", "Kazakh", "Korean",
|
|
||||||
"Latvian", "Lithuanian", "Macedonian", "Malay", "Marathi",
|
|
||||||
"Maori", "Nepali", "Norwegian", "Persian", "Polish",
|
|
||||||
"Portuguese", "Romanian", "Russian", "Serbian", "Slovak",
|
|
||||||
"Slovenian", "Spanish", "Swahili", "Swedish", "Tagalog",
|
|
||||||
"Tamil", "Thai", "Turkish", "Ukrainian", "Urdu",
|
|
||||||
"Vietnamese", "Welsh"
|
|
||||||
]
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
def gradio_Interface(layout = None,):
|
|
||||||
"""
|
|
||||||
Creates a gradio interface for audio transcription.
|
|
||||||
|
|
||||||
The interface includes options for the user to select the task, number of speakers, translation, language, and input type.
|
|
||||||
It also provides options for the user to upload or record audio/video, or upload files.
|
|
||||||
The output of the transcription is displayed in a textbox, and the JSON output in a JSON viewer.
|
|
||||||
The user can also annotate the output by naming the speakers.
|
|
||||||
|
|
||||||
Args:
|
|
||||||
layout (dict, optional): A dictionary containing layout information. Defaults to None.
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
gr.Blocks: A gradio Blocks object representing the interface.
|
|
||||||
"""
|
|
||||||
with gr.Blocks(theme=theme,title='ScrAIbe: Automatic Audio Transcription') as demo:
|
|
||||||
|
|
||||||
# Define components
|
|
||||||
|
|
||||||
|
|
||||||
if layout.get('header') is not None:
|
|
||||||
gr.HTML(layout.get('header'), visible= True, show_label=False)
|
|
||||||
|
|
||||||
with gr.Row():
|
|
||||||
|
|
||||||
with gr.Column():
|
|
||||||
|
|
||||||
task = gr.Radio(["Auto Transcribe", "Transcribe", "Diarisation"], label="Task",
|
|
||||||
value= 'Auto Transcribe')
|
|
||||||
|
|
||||||
num_speakers = gr.Number(value=0, label= "Number of speakers (optional)",
|
|
||||||
info = "Number of speakers in the audio file. If you don't know,\
|
|
||||||
leave it at 0.", visible= True)
|
|
||||||
|
|
||||||
translate = gr.Checkbox(label="Translation", choices=[True, False], value = False,
|
|
||||||
info="Select 'Yes' to have the output translated into English.",
|
|
||||||
visible= True)
|
|
||||||
|
|
||||||
language = gr.Dropdown(LANGUAGES,
|
|
||||||
label="Language (optional)", value = "None",
|
|
||||||
info="Language of the audio file. If you don't know,\
|
|
||||||
leave it at None.", visible= True)
|
|
||||||
|
|
||||||
input = gr.Radio(["Upload Audio", "Record Audio", "Upload Video","Record Video"
|
|
||||||
,"File or Files"], label="Input Type", value="Upload Audio")
|
|
||||||
|
|
||||||
audio1 = gr.Audio(source="upload", type="filepath", label="Upload Audio",
|
|
||||||
interactive= True, visible= True)
|
|
||||||
audio2 = gr.Audio(source="microphone", label="Record Audio", type="filepath",
|
|
||||||
interactive= True, visible= False)
|
|
||||||
video1 = gr.Video(source="upload", type="filepath", label="Upload Video",
|
|
||||||
interactive= True, visible= False)
|
|
||||||
video2 = gr.Video(source="webcam", label="Record Video", type="filepath",include_audio= True,
|
|
||||||
interactive= True, visible= False)
|
|
||||||
file_in = gr.Files(label="Upload File or Files", interactive= True, visible= False)
|
|
||||||
|
|
||||||
submit = gr.Button()
|
|
||||||
|
|
||||||
with gr.Column():
|
|
||||||
|
|
||||||
out_txt = gr.Textbox(label="Output",
|
|
||||||
visible= True, show_copy_button=True)
|
|
||||||
|
|
||||||
out_json = gr.JSON(label="JSON Output",
|
|
||||||
visible= False, show_copy_button=True)
|
|
||||||
|
|
||||||
annoation = gr.Textbox(label="Name your speaker's",
|
|
||||||
info= "Please provide a list of the speakers arranged \
|
|
||||||
in the order in which they appear in the input. Use comma ',' \
|
|
||||||
as a seperator. Be aware that the first name is given \
|
|
||||||
to SPEAKER_00 the second to SPEAKER_01 and so on.",
|
|
||||||
visible= False, interactive= True)
|
|
||||||
|
|
||||||
annotate = gr.Button(value="Annotate", visible= False, interactive= True)
|
|
||||||
|
|
||||||
if layout.get('footer') is not None:
|
|
||||||
gr.HTML(layout.get('footer'), visible= True, show_label=False)
|
|
||||||
|
|
||||||
# Define usage of components
|
|
||||||
input.change(fn=select_origin, inputs=[input],
|
|
||||||
outputs=[audio1, audio2, video1, video2, file_in])
|
|
||||||
|
|
||||||
task.change(fn=select_task, inputs=[task],
|
|
||||||
outputs=[num_speakers, translate, language])
|
|
||||||
|
|
||||||
translate.change(fn= lambda x : gr.update(value = x),
|
|
||||||
inputs=[translate], outputs=[translate])
|
|
||||||
num_speakers.change(fn= lambda x : gr.update(value = x),
|
|
||||||
inputs=[num_speakers], outputs=[num_speakers])
|
|
||||||
language.change(fn= lambda x : gr.update(value = x),
|
|
||||||
inputs=[language], outputs=[language])
|
|
||||||
|
|
||||||
submit.click(fn = run_scraibe,
|
|
||||||
inputs=[task, num_speakers, translate, language, audio1,
|
|
||||||
audio2, video1, video2, file_in],
|
|
||||||
outputs=[out_txt, out_json, annoation, annotate])
|
|
||||||
|
|
||||||
annotate.click(fn = annotate_output, inputs=[annoation, out_json],
|
|
||||||
outputs=[out_txt, out_json])
|
|
||||||
|
|
||||||
return demo
|
|
||||||
File diff suppressed because one or more lines are too long
|
Before Width: | Height: | Size: 29 KiB |
@@ -1,151 +0,0 @@
|
|||||||
"""
|
|
||||||
This module contains functions for managing and optimizing the resource usage of the application.
|
|
||||||
|
|
||||||
The functions in this module monitor the application's usage and make adjustments to improve efficiency.
|
|
||||||
This includes managing the loading and unloading of the model based on the application's activity.
|
|
||||||
This dynamic management of resources helps to ensure that the application uses only the resources it needs,
|
|
||||||
improving overall performance and reducing unnecessary resource consumption.
|
|
||||||
|
|
||||||
Functions:
|
|
||||||
clear_queue(queue): Clears all items from the queue.
|
|
||||||
model_worker(model_params, request_queue, last_active_time,
|
|
||||||
response_queue, loaded_event, running_event, *args, **kwargs): Manages the model worker process.
|
|
||||||
|
|
||||||
Modules:
|
|
||||||
time: Provides various time-related functions.
|
|
||||||
gc: Provides an interface to the garbage collector.
|
|
||||||
multiprocessing: Provides support for parallel execution of code.
|
|
||||||
torch: Provides tensor computation and deep learning functionality.
|
|
||||||
gradio: Provides a simple way to create interactive UIs for Python functions.
|
|
||||||
scraibe.autotranscript: Provides automatic transcription functionality.
|
|
||||||
.stg: Contains the GradioTranscriptionInterface class.
|
|
||||||
"""
|
|
||||||
|
|
||||||
|
|
||||||
import time
|
|
||||||
import gc
|
|
||||||
from typing import Union, Any
|
|
||||||
import multiprocessing
|
|
||||||
import torch
|
|
||||||
|
|
||||||
from gradio import Warning
|
|
||||||
from scraibe.autotranscript import Scraibe
|
|
||||||
from .stg import GradioTranscriptionInterface
|
|
||||||
|
|
||||||
def clear_queue(queue):
|
|
||||||
while not queue.empty():
|
|
||||||
try:
|
|
||||||
queue.get_nowait()
|
|
||||||
except queue.Empty:
|
|
||||||
continue
|
|
||||||
|
|
||||||
def model_worker(model_params : Union[Scraibe, dict],
|
|
||||||
request_queue: multiprocessing.Queue,
|
|
||||||
last_active_time: multiprocessing.Value,
|
|
||||||
response_queue: multiprocessing.Queue,
|
|
||||||
loaded_event: multiprocessing.Event,
|
|
||||||
running_event: multiprocessing.Event,
|
|
||||||
*args: Any, **kwargs: Any) -> None:
|
|
||||||
"""
|
|
||||||
Manages the model worker process.
|
|
||||||
|
|
||||||
The model worker process is responsible for running the model and returning the results.
|
|
||||||
|
|
||||||
Args:
|
|
||||||
model_params (Union[Scraibe, dict]): The parameters for the Scraibe model.
|
|
||||||
request_queue (multiprocessing.Queue): The queue for incoming requests.
|
|
||||||
last_active_time (multiprocessing.Value): The last time the model was active.
|
|
||||||
response_queue (multiprocessing.Queue): The queue for outgoing responses.
|
|
||||||
loaded_event (multiprocessing.Event): An event that signals when the model is loaded.
|
|
||||||
running_event (multiprocessing.Event): An event that signals when the model is running.
|
|
||||||
*args: Additional arguments.
|
|
||||||
**kwargs: Additional keyword arguments.
|
|
||||||
"""
|
|
||||||
|
|
||||||
loaded_event.set()
|
|
||||||
|
|
||||||
if model_params is None:
|
|
||||||
_model = Scraibe()
|
|
||||||
elif type(model_params) is Scraibe:
|
|
||||||
_model = model_params
|
|
||||||
elif type(model_params) is dict:
|
|
||||||
_model = Scraibe(**model_params)
|
|
||||||
else:
|
|
||||||
raise TypeError("model must be of type Scraibe, or dict")
|
|
||||||
|
|
||||||
model = GradioTranscriptionInterface(_model)
|
|
||||||
|
|
||||||
while True:
|
|
||||||
|
|
||||||
req = request_queue.get()
|
|
||||||
|
|
||||||
if req == "STOP":
|
|
||||||
|
|
||||||
break
|
|
||||||
elif type(req) is dict:
|
|
||||||
runner = model.get_task_from_str(req.pop("task"))
|
|
||||||
running_event.set()
|
|
||||||
transcription = runner(**req)
|
|
||||||
running_event.clear()
|
|
||||||
response_queue.put(transcription)
|
|
||||||
last_active_time.value = time.time()
|
|
||||||
else:
|
|
||||||
raise TypeError("request must be of type dict")
|
|
||||||
|
|
||||||
del model
|
|
||||||
torch.cuda.empty_cache()
|
|
||||||
gc.collect()
|
|
||||||
clear_queue(request_queue)
|
|
||||||
clear_queue(response_queue)
|
|
||||||
loaded_event.clear()
|
|
||||||
|
|
||||||
def start_model_worker(model_params: Union[Scraibe, dict],
|
|
||||||
request_queue: multiprocessing.Queue,
|
|
||||||
last_active_time: multiprocessing.Value,
|
|
||||||
response_queue: multiprocessing.Queue,
|
|
||||||
loaded_event: multiprocessing.Event,
|
|
||||||
running_event: multiprocessing.Event,
|
|
||||||
*args: Any, **kwargs: Any) -> multiprocessing.Process:
|
|
||||||
"""
|
|
||||||
Starts the model worker process.
|
|
||||||
|
|
||||||
Args:
|
|
||||||
model_params (Union[Scraibe, dict]): The parameters for the Scraibe model.
|
|
||||||
request_queue (multiprocessing.Queue): The queue for incoming requests.
|
|
||||||
last_active_time (multiprocessing.Value): The last time the model was active.
|
|
||||||
response_queue (multiprocessing.Queue): The queue for outgoing responses.
|
|
||||||
loaded_event (multiprocessing.Event): An event that signals when the model is loaded.
|
|
||||||
running_event (multiprocessing.Event): An event that signals when the model is running.
|
|
||||||
*args: Additional arguments.
|
|
||||||
**kwargs: Additional keyword arguments.
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
multiprocessing.Process: The model worker process.
|
|
||||||
"""
|
|
||||||
context = multiprocessing.get_context('spawn')
|
|
||||||
model_process = context.Process(target=model_worker, args=(model_params, request_queue, last_active_time, response_queue,loaded_event, running_event, *args), kwargs=kwargs)
|
|
||||||
model_process.start()
|
|
||||||
return model_process
|
|
||||||
|
|
||||||
def timer_thread(request_queue: multiprocessing.Queue,
|
|
||||||
last_active_time: multiprocessing.Value,
|
|
||||||
loaded_event: multiprocessing.Event,
|
|
||||||
running_event: multiprocessing.Event,
|
|
||||||
timeout: int) -> None:
|
|
||||||
"""
|
|
||||||
Monitors the model worker process and stops it after a period of inactivity.
|
|
||||||
|
|
||||||
Args:
|
|
||||||
request_queue (multiprocessing.Queue): The queue for incoming requests.
|
|
||||||
last_active_time (multiprocessing.Value): The last time the model was active.
|
|
||||||
loaded_event (multiprocessing.Event): An event that signals when the model is loaded.
|
|
||||||
running_event (multiprocessing.Event): An event that signals when the model is running.
|
|
||||||
timeout (int): The period of inactivity after which the model worker process is stopped.
|
|
||||||
"""
|
|
||||||
while True:
|
|
||||||
time.sleep(timeout)
|
|
||||||
|
|
||||||
if time.time() - last_active_time.value > timeout and loaded_event.is_set() and not running_event.is_set():
|
|
||||||
print(f"No activity for the last {timeout} seconds. Stopping the model worker.", flush=True)
|
|
||||||
request_queue.put("STOP")
|
|
||||||
Warning("Model worker stopped due to inactivity.")
|
|
||||||
@@ -1,229 +0,0 @@
|
|||||||
"""
|
|
||||||
stg - Scraibe to Gradio Interface
|
|
||||||
|
|
||||||
This module provides an interface between the Scraibe transcription system and the Gradio user interface.
|
|
||||||
It defines a class, GradioTranscriptionInterface, that wraps the Scraibe model and provides methods for performing transcription tasks through the Gradio UI.
|
|
||||||
|
|
||||||
Modules:
|
|
||||||
json: Used for encoding and decoding JSON data.
|
|
||||||
gradio as gr: Used for creating the Gradio UI.
|
|
||||||
tqdm: Used for displaying progress bars.
|
|
||||||
scraibe.app.global_var as gv: Contains global variables for the Scraibe app.
|
|
||||||
"""
|
|
||||||
import json
|
|
||||||
import gradio as gr
|
|
||||||
from tqdm import tqdm
|
|
||||||
from typing import Any, Dict, Union, Tuple, List
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
class GradioTranscriptionInterface:
|
|
||||||
"""
|
|
||||||
A class that provides an interface between the Gradio UI and the Scraibe transcription system.
|
|
||||||
|
|
||||||
This class wraps the Scraibe model and provides methods for performing transcription tasks through the Gradio UI.
|
|
||||||
These tasks include auto transcription, transcription, and diarisation.
|
|
||||||
|
|
||||||
Attributes:
|
|
||||||
model (Scraibe): The Scraibe model for performing transcription tasks.
|
|
||||||
"""
|
|
||||||
|
|
||||||
def __init__(self, model) -> None:
|
|
||||||
"""
|
|
||||||
Initializes the GradioTranscriptionInterface with a Scraibe model.
|
|
||||||
|
|
||||||
Args:
|
|
||||||
model (Scraibe): The Scraibe model for performing transcription tasks.
|
|
||||||
*args (Any): Additional positional arguments.
|
|
||||||
**kwargs (Dict[str, Any]): Additional keyword arguments.
|
|
||||||
"""
|
|
||||||
|
|
||||||
self.model = model
|
|
||||||
|
|
||||||
def autotranscribe(self, source: Union[str, List[str]],
|
|
||||||
num_speakers: int,
|
|
||||||
translate: bool,
|
|
||||||
language: str,
|
|
||||||
*args: Any, **kwargs: Dict[str, Any]) -> Tuple[str, Union[str, dict]]:
|
|
||||||
"""
|
|
||||||
Performs auto transcription on the given source.
|
|
||||||
|
|
||||||
Args:
|
|
||||||
source (Union[str, List[str]]): The source to transcribe. This can be a string representing a single source,
|
|
||||||
or a list of strings representing multiple sources.
|
|
||||||
num_speakers (int): The number of speakers in the source.
|
|
||||||
translate (bool): Whether to translate the transcription.
|
|
||||||
language (str): The language of the source.
|
|
||||||
*args (Any): Additional positional arguments.
|
|
||||||
**kwargs (Dict[str, Any]): Additional keyword arguments.
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
Tuple[str, Union[str, dict]]: A tuple containing the transcribed text (str) and the JSON output (str or dict).
|
|
||||||
"""
|
|
||||||
|
|
||||||
_kwargs = {
|
|
||||||
"num_speakers": num_speakers if num_speakers != 0 else None,
|
|
||||||
"language": language if language != "None" else None,
|
|
||||||
"task": 'translate' if translate else None
|
|
||||||
}
|
|
||||||
if isinstance(source, str):
|
|
||||||
try:
|
|
||||||
result = self.model.autotranscribe(source, **_kwargs)
|
|
||||||
except ValueError:
|
|
||||||
raise gr.Error("Couldn't detect any speech in the provided audio. \
|
|
||||||
Please try again!")
|
|
||||||
|
|
||||||
return str(result), result.get_json()
|
|
||||||
|
|
||||||
elif isinstance(source, list):
|
|
||||||
source_names = [s.split("/")[-1] for s in source]
|
|
||||||
result = []
|
|
||||||
for s in tqdm(source, total=len(source),desc = "Transcribing audio files"):
|
|
||||||
try:
|
|
||||||
res = self.model.autotranscribe(s, **_kwargs)
|
|
||||||
except ValueError:
|
|
||||||
_name = s.split("/")[-1]
|
|
||||||
res = f"NO TRANSCRIPT FOUND FOR {_name}"
|
|
||||||
gr.Warning(f"Couldn't detect any speech in {_name} will skip this file.")
|
|
||||||
result.append(res)
|
|
||||||
|
|
||||||
out = ''
|
|
||||||
out_dict = {}
|
|
||||||
for i, r in enumerate(result):
|
|
||||||
out += f"TRANSCRIPT FOR {source_names[i]}:\n\n"
|
|
||||||
out += str(r)
|
|
||||||
out += "\n\n"
|
|
||||||
|
|
||||||
if isinstance(r, str):
|
|
||||||
out_dict[source_names[i]] = r
|
|
||||||
else:
|
|
||||||
out_dict[source_names[i]] = r.get_dict()
|
|
||||||
|
|
||||||
return out, json.dumps(out_dict, indent=4)
|
|
||||||
|
|
||||||
else:
|
|
||||||
raise gr.Error("Please provide a valid audio file.")
|
|
||||||
|
|
||||||
|
|
||||||
def transcribe(self, source: Union[str, List[str]],
|
|
||||||
translate: bool,
|
|
||||||
language: str,
|
|
||||||
*args: Any, **kwargs: Dict[str, Any]) -> str:
|
|
||||||
"""
|
|
||||||
Performs transcription on the given source.
|
|
||||||
|
|
||||||
Args:
|
|
||||||
source (Union[str, List[str]]): The source to transcribe.
|
|
||||||
This can be a string representing a single source, or a list of strings representing multiple sources.
|
|
||||||
translate (bool): Whether to translate the transcription.
|
|
||||||
language (str): The language of the source.
|
|
||||||
*args (Any): Additional positional arguments.
|
|
||||||
**kwargs (Dict[str, Any]): Additional keyword arguments.
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
str: The transcribed text.
|
|
||||||
"""
|
|
||||||
|
|
||||||
_kwargs = {
|
|
||||||
"language": language if language != "None" else None,
|
|
||||||
"task": 'translate' if translate == "Yes" else None
|
|
||||||
}
|
|
||||||
|
|
||||||
if isinstance(source, str):
|
|
||||||
result = self.model.transcribe(source, **_kwargs)
|
|
||||||
|
|
||||||
return str(result)
|
|
||||||
|
|
||||||
elif isinstance(source, list):
|
|
||||||
source_names = [s.split("/")[-1] for s in source]
|
|
||||||
result = []
|
|
||||||
for s in tqdm(source, total=len(source),desc = "Transcribing audio files"):
|
|
||||||
res = self.model.transcribe(s, **_kwargs)
|
|
||||||
result.append(res)
|
|
||||||
|
|
||||||
out = ''
|
|
||||||
for i, res in enumerate(result):
|
|
||||||
out += f"TRANSCRIPT FOR {source_names[i]}:\n\n"
|
|
||||||
out += str(res)
|
|
||||||
out += "\n\n"
|
|
||||||
|
|
||||||
return out
|
|
||||||
|
|
||||||
else:
|
|
||||||
raise gr.Error("Please provide a valid audio file.")
|
|
||||||
|
|
||||||
def diarisation(self, source: Union[str, List[str]],
|
|
||||||
num_speakers: int,
|
|
||||||
*args: Any, **kwargs: Dict[str, Any]) -> str:
|
|
||||||
"""
|
|
||||||
Performs diarisation on the given source.
|
|
||||||
|
|
||||||
Args:
|
|
||||||
source (Union[str, List[str]]): The source to perform diarisation on.
|
|
||||||
This can be a string representing a single source,
|
|
||||||
or a list of strings representing multiple sources.
|
|
||||||
num_speakers (int): The number of speakers in the source.
|
|
||||||
*args (Any): Additional positional arguments.
|
|
||||||
**kwargs (Dict[str, Any]): Additional keyword arguments.
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
str: The JSON output of the diarisation result.
|
|
||||||
"""
|
|
||||||
|
|
||||||
|
|
||||||
_kwargs = {
|
|
||||||
"num_speakers": num_speakers if num_speakers != 0 else None,
|
|
||||||
}
|
|
||||||
|
|
||||||
if isinstance(source, str):
|
|
||||||
try:
|
|
||||||
result = self.model.diarization(source, **_kwargs)
|
|
||||||
except ValueError:
|
|
||||||
raise gr.Error("Couldn't detect any speech in the provided audio. \
|
|
||||||
Please try again!")
|
|
||||||
|
|
||||||
return json.dumps(result, indent=2)
|
|
||||||
elif isinstance(source, list):
|
|
||||||
source_names = [s.split("/")[-1] for s in source]
|
|
||||||
result = []
|
|
||||||
for s in tqdm(source, total=len(source),desc = "Performing diarisation"):
|
|
||||||
try:
|
|
||||||
res = self.model.diarization(s, **_kwargs)
|
|
||||||
except ValueError:
|
|
||||||
|
|
||||||
res = f"NO DIARISATION FOUND FOR {s}"
|
|
||||||
gr.Warning(f"Couldn't detect any speech in {s} will skip this file.")
|
|
||||||
result.append(res)
|
|
||||||
|
|
||||||
out = {}
|
|
||||||
|
|
||||||
for i, res in enumerate(result):
|
|
||||||
out[source_names[i]] = res
|
|
||||||
|
|
||||||
return json.dumps(out, indent=4)
|
|
||||||
|
|
||||||
else:
|
|
||||||
gr.Error("Please provide a valid audio file.")
|
|
||||||
|
|
||||||
def get_task_from_str(self, task: str) -> callable:
|
|
||||||
"""
|
|
||||||
Returns the corresponding task function based on the given task string.
|
|
||||||
|
|
||||||
Args:
|
|
||||||
task (str): The task string. This can be one of the following: 'Auto Transcribe', 'Transcribe', 'Diarisation'.
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
callable: The corresponding task function.
|
|
||||||
"""
|
|
||||||
|
|
||||||
if task == 'Auto Transcribe':
|
|
||||||
return self.autotranscribe
|
|
||||||
elif task == 'Transcribe':
|
|
||||||
return self.transcribe
|
|
||||||
elif task == 'Diarisation':
|
|
||||||
return self.diarisation
|
|
||||||
else:
|
|
||||||
raise ValueError("Invalid task string.")
|
|
||||||
|
|
||||||
|
|
||||||
@@ -1,298 +0,0 @@
|
|||||||
"""
|
|
||||||
utils.py
|
|
||||||
|
|
||||||
This module contains two classes, ConfigLoader and AppConfig, which are used to manage application-specific configuration settings.
|
|
||||||
|
|
||||||
The ConfigLoader class provides methods for loading a configuration file, applying overrides, and restoring default values for specified keys. It also includes methods for recursively updating nested keys and getting the default configuration.
|
|
||||||
|
|
||||||
The AppConfig class extends ConfigLoader and provides additional methods for setting global variables, launch options, and layout options from the configuration. It also includes methods for checking and setting file paths, and getting layout options.
|
|
||||||
|
|
||||||
Classes:
|
|
||||||
ConfigLoader: Manages application-specific configuration settings.
|
|
||||||
AppConfig: Extends ConfigLoader to provide additional methods for managing application-specific configuration settings.
|
|
||||||
"""
|
|
||||||
import os
|
|
||||||
import warnings
|
|
||||||
import yaml
|
|
||||||
from typing import Any, Dict, Optional
|
|
||||||
|
|
||||||
import scraibe.app.global_var as gv
|
|
||||||
|
|
||||||
CURRENT_PATH = os.path.dirname(os.path.realpath(__file__))
|
|
||||||
|
|
||||||
class ConfigLoader:
|
|
||||||
"""A class that extends ConfigLoader to manage application-specific configuration settings.
|
|
||||||
|
|
||||||
This class provides methods for setting global variables, launch options, and layout options from the configuration.
|
|
||||||
|
|
||||||
Attributes:
|
|
||||||
config (Dict[str, Any]): The current configuration settings.
|
|
||||||
launch (Dict[str, Any]): The launch configuration settings.
|
|
||||||
model (Dict[str, Any]): The model configuration settings.
|
|
||||||
advanced (Dict[str, Any]): The advanced configuration settings.
|
|
||||||
queue (Dict[str, Any]): The queue configuration settings.
|
|
||||||
layout (Dict[str, Any]): The layout configuration settings.
|
|
||||||
"""
|
|
||||||
def __init__(self, config: Dict[str, Any]):
|
|
||||||
"""Initializes a new instance of the ConfigLoader class.
|
|
||||||
|
|
||||||
Args:
|
|
||||||
config (dict): The configuration dictionary.
|
|
||||||
"""
|
|
||||||
self.config = config
|
|
||||||
|
|
||||||
def restore_defaults_for_keys(self, *args: str):
|
|
||||||
"""Restores specified keys to their default values, including nested keys.
|
|
||||||
|
|
||||||
Args:
|
|
||||||
*args (str): A list of keys or paths to keys (for nested dictionaries) to restore to default values.
|
|
||||||
Each key or path should be a list of keys leading to the desired key.
|
|
||||||
"""
|
|
||||||
default_config = self.get_default_config()
|
|
||||||
|
|
||||||
for key in args:
|
|
||||||
self.apply_overrides(self.config, default_config, key)
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
@classmethod
|
|
||||||
def load_config(cls, yaml_path: Optional[str] = None, **kwargs: Any) -> 'ConfigLoader':
|
|
||||||
"""Load the configuration file and apply overrides.
|
|
||||||
|
|
||||||
Args:
|
|
||||||
yaml_path (str, optional): Path to the YAML file containing overrides.
|
|
||||||
**kwargs: Additional overrides as keyword arguments.
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
ConfigLoader: A ConfigLoader object with the loaded configuration.
|
|
||||||
"""
|
|
||||||
|
|
||||||
# Load the original configuration
|
|
||||||
config = cls.get_default_config()
|
|
||||||
|
|
||||||
# Override with another YAML file if provided
|
|
||||||
|
|
||||||
if yaml_path:
|
|
||||||
with open(yaml_path, 'r') as file:
|
|
||||||
override_config = yaml.safe_load(file)
|
|
||||||
cls.apply_overrides(config, override_config)
|
|
||||||
|
|
||||||
# Apply overrides from kwargs
|
|
||||||
cls.apply_overrides(config, kwargs)
|
|
||||||
return cls(config)
|
|
||||||
|
|
||||||
@staticmethod
|
|
||||||
def apply_overrides(orig_dict: Dict[str, Any], override_dict: Dict[str, Any], specific: Optional[str] = None):
|
|
||||||
"""Recursively apply overrides to the configuration, only for specific keys.
|
|
||||||
|
|
||||||
Args:
|
|
||||||
orig_dict (Dict[str, Any]): The original dictionary.
|
|
||||||
override_dict (Dict[str, Any]): The override dictionary.
|
|
||||||
specific (str, optional): The specific key to override.
|
|
||||||
"""
|
|
||||||
for key, value in override_dict.items():
|
|
||||||
|
|
||||||
if isinstance(value, dict):
|
|
||||||
# If the value is a dict, apply recursively
|
|
||||||
sub_dict = orig_dict.get(key, {})
|
|
||||||
ConfigLoader.apply_overrides(sub_dict, value, specific)
|
|
||||||
orig_dict[key] = sub_dict
|
|
||||||
else:
|
|
||||||
# Apply override for this key
|
|
||||||
if specific is None:
|
|
||||||
# If no specific keys are provided, update the key
|
|
||||||
# If the value is not a dict, search for the key and update
|
|
||||||
if ConfigLoader.update_nested_key(orig_dict, key, value):
|
|
||||||
continue # Key was found and updated
|
|
||||||
orig_dict[key] = value # Key not found, update at this level
|
|
||||||
|
|
||||||
elif key in specific:
|
|
||||||
# If specific keys are provided, only update if the key is in the list
|
|
||||||
if ConfigLoader.update_nested_key(orig_dict, specific, value):
|
|
||||||
continue # Key was found and updated
|
|
||||||
orig_dict[specific] = value
|
|
||||||
|
|
||||||
@staticmethod
|
|
||||||
def update_nested_key(d, key, value):
|
|
||||||
"""Recursively search and update the key in nested dictionary.
|
|
||||||
|
|
||||||
Args:
|
|
||||||
d (Dict[str, Any]): The dictionary.
|
|
||||||
key (str): The key to update.
|
|
||||||
value (Any): The new value.
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
bool: True if the key was found and updated, False otherwise.
|
|
||||||
"""
|
|
||||||
|
|
||||||
if key in d:
|
|
||||||
d[key] = value
|
|
||||||
return True
|
|
||||||
for k, v in d.items():
|
|
||||||
if isinstance(v, dict) and ConfigLoader.update_nested_key(v, key, value):
|
|
||||||
return True
|
|
||||||
return False
|
|
||||||
|
|
||||||
@staticmethod
|
|
||||||
def get_default_config():
|
|
||||||
"""Return the default configuration.
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
Dict[str, Any]: The default configuration.
|
|
||||||
"""
|
|
||||||
with open(gv.DEFAULT_APP_CONIFG_PATH , 'r') as file:
|
|
||||||
config = yaml.safe_load(file)
|
|
||||||
return config
|
|
||||||
|
|
||||||
|
|
||||||
class AppConfig(ConfigLoader):
|
|
||||||
"""A class that extends ConfigLoader to manage application-specific configuration settings.
|
|
||||||
|
|
||||||
This class provides methods for setting global variables, launch options, and layout options from the configuration.
|
|
||||||
|
|
||||||
Attributes:
|
|
||||||
config (dict): The current configuration settings.
|
|
||||||
launch (dict): The launch configuration settings.
|
|
||||||
model (dict): The model configuration settings.
|
|
||||||
advanced (dict): The advanced configuration settings.
|
|
||||||
queue (dict): The queue configuration settings.
|
|
||||||
layout (dict): The layout configuration settings.
|
|
||||||
"""
|
|
||||||
def __init__(self, config : Dict[str, Any]):
|
|
||||||
"""Initializes a new instance of the AppConfig class.
|
|
||||||
|
|
||||||
Args:
|
|
||||||
config (dict): The configuration dictionary.
|
|
||||||
"""
|
|
||||||
self.config = config
|
|
||||||
|
|
||||||
self.set_global_vars_from_config()
|
|
||||||
self.set_launch_options()
|
|
||||||
self.set_layout_options()
|
|
||||||
|
|
||||||
self.launch = self.config.get("launch")
|
|
||||||
self.model = self.config.get("model")
|
|
||||||
self.advanced = self.config.get("advanced")
|
|
||||||
self.queue = self.config.get("queue")
|
|
||||||
self.layout = self.config.get("layout")
|
|
||||||
|
|
||||||
def set_global_vars_from_config(self) -> None:
|
|
||||||
"""Sets the global variables from a configuration dictionary.
|
|
||||||
|
|
||||||
Args:
|
|
||||||
config (dict): A dictionary containing the parameters for the model. Modify the default parameters in the config.yml file.
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
None
|
|
||||||
"""
|
|
||||||
|
|
||||||
gv.MODEL_PARAMS = self.config.get('model')
|
|
||||||
gv.TIMEOUT = self.config.get("advanced").get('timeout')
|
|
||||||
|
|
||||||
def set_launch_options(self) -> None:
|
|
||||||
"""Sets the launch options from a configuration dictionary.
|
|
||||||
|
|
||||||
Args:
|
|
||||||
None
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
None
|
|
||||||
"""
|
|
||||||
launch_options = self.config.get("launch")
|
|
||||||
|
|
||||||
if launch_options.get('auth').pop('auth_enabled'):
|
|
||||||
self.config['launch']['auth'] = (launch_options.get('auth').pop('auth_username'),
|
|
||||||
launch_options.get('auth').pop('auth_password'))
|
|
||||||
else:
|
|
||||||
self.config['launch']['auth'] = None
|
|
||||||
|
|
||||||
def set_layout_options(self) -> None:
|
|
||||||
"""Sets the layout options from a configuration dictionary.
|
|
||||||
|
|
||||||
Args:
|
|
||||||
None
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
None
|
|
||||||
"""
|
|
||||||
self.config['layout']['header'] = self.check_and_set_path(self.config['layout'], 'header')
|
|
||||||
self.config['layout']['footer'] = self.check_and_set_path(self.config['layout'], 'footer')
|
|
||||||
self.config['layout']['logo'] = self.check_and_set_path(self.config['layout'], 'logo')
|
|
||||||
|
|
||||||
def get_layout(self) -> Dict[str, str]:
|
|
||||||
"""Gets the layout options from a configuration dictionary.
|
|
||||||
|
|
||||||
Args:
|
|
||||||
None
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
dict: A dictionary containing the header and footer layout options.
|
|
||||||
"""
|
|
||||||
if not os.path.exists(self.config['layout']['header']) and \
|
|
||||||
self.config['layout']['header'] == "scraibe/app/header.html":
|
|
||||||
|
|
||||||
hname = os.path.join(CURRENT_PATH, "header.html")
|
|
||||||
|
|
||||||
header = open(hname).read()
|
|
||||||
|
|
||||||
elif not os.path.exists(self.config['layout']['header']) and self.config['layout']['header'] != "scraibe/app/header.html":
|
|
||||||
warnings.warn(f"Header file not found: {self.config['layout']['header']} \n" \
|
|
||||||
"fall back to default.")
|
|
||||||
|
|
||||||
hname = os.path.join(CURRENT_PATH, "header.html")
|
|
||||||
|
|
||||||
header = open(hname).read()
|
|
||||||
elif os.path.exists(self.config['layout']['header']):
|
|
||||||
header = open(self.config['layout']['header']).read()
|
|
||||||
else:
|
|
||||||
warnings.warn(f"Header file not found: {self.config['layout']['header']}")
|
|
||||||
header = None
|
|
||||||
|
|
||||||
|
|
||||||
if header != None:
|
|
||||||
if self.config['layout']['logo'] == "scraibe/app/logo.svg":
|
|
||||||
header = header.replace("/file=logo.svg", f"/file={os.path.join(CURRENT_PATH, 'logo.svg')}")
|
|
||||||
elif self.config['layout']['logo'] != "scraibe/app/logo.svg":
|
|
||||||
header = header.replace("/file=logo.svg", f"/file={self.config['layout']['logo']}")
|
|
||||||
else:
|
|
||||||
warnings.warn(f"Logo file not found: {self.config['layout']['logo']}")
|
|
||||||
|
|
||||||
|
|
||||||
if self.config['layout']['footer'] != None:
|
|
||||||
if os.path.exists(self.config['layout']['footer']):
|
|
||||||
footer = open(self.config['layout']['footer']).read()
|
|
||||||
elif self.config['layout']['footer'] == None:
|
|
||||||
footer = None
|
|
||||||
else:
|
|
||||||
warnings.warn(f"Footer file not found: {self.config['layout']['footer']}")
|
|
||||||
else:
|
|
||||||
footer = None
|
|
||||||
return {'header' : header ,
|
|
||||||
'footer' : footer}
|
|
||||||
|
|
||||||
@staticmethod
|
|
||||||
def check_and_set_path(config_item: dict, key: str) -> Optional[str]:
|
|
||||||
"""Check if the file exists at the given path. If not, try with CURRENT_PATH.
|
|
||||||
Raise FileNotFoundError if the file still doesn't exist.
|
|
||||||
|
|
||||||
Args:
|
|
||||||
config_item (dict): The configuration item.
|
|
||||||
key (str): The key to check in the configuration item.
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
str: The path to the file if it exists, None otherwise.
|
|
||||||
"""
|
|
||||||
_current_path = os.path.dirname(os.path.realpath(__file__)) # Define your CURRENT_PATH
|
|
||||||
|
|
||||||
file_path = config_item.get(key)
|
|
||||||
if file_path is None:
|
|
||||||
return None
|
|
||||||
if not os.path.exists(file_path):
|
|
||||||
new_path = os.path.join(_current_path, file_path)
|
|
||||||
if not os.path.exists(new_path):
|
|
||||||
warnings.warn(f"{key.capitalize()} file not found: {config_item[key]} \n" \
|
|
||||||
"fall back to default.")
|
|
||||||
else:
|
|
||||||
config_item[key] = new_path
|
|
||||||
|
|
||||||
return config_item[key]
|
|
||||||
Reference in New Issue
Block a user