Make everything work in processes and adding config to customize instance
This commit is contained in:
@@ -0,0 +1,48 @@
|
||||
launch:
|
||||
# The following are the default values for the launch configuration
|
||||
# for more informations look at https://www.gradio.app/docs/interface
|
||||
server_port: 8080
|
||||
server_name: 0.0.0.0
|
||||
inbrowser: true
|
||||
inline: false
|
||||
max-threads: 40
|
||||
quiet: false
|
||||
auth:
|
||||
enabled: false
|
||||
username: admin
|
||||
password: admin
|
||||
auth_message: "Please enter your credentials"
|
||||
show_error : false
|
||||
favicon_path : null
|
||||
ssl_keyfile : null
|
||||
ssl_certfile : null
|
||||
ssl_keyfile_password : null
|
||||
ssl_verify : false
|
||||
quiet : false
|
||||
show_api : false
|
||||
allowed_paths : null
|
||||
blocked_paths : null
|
||||
root_path : null
|
||||
app_kwargs : null
|
||||
state_session_capacity : 1000
|
||||
share_server_address : null
|
||||
share_server_protocol : null
|
||||
share : false
|
||||
debug : false
|
||||
queue:
|
||||
# The following are the default values for the queue configuration
|
||||
# for more informations look at hhttps://www.gradio.app/docs/interface
|
||||
status_update_rate : 'auto'
|
||||
api_open : null
|
||||
max_size : null
|
||||
concurrency_count : null
|
||||
default_concurrency_limit : 'not_set'
|
||||
layout:
|
||||
header: scraibe/app/header.html
|
||||
footer: null
|
||||
logo: scraibe/app/logo.svg
|
||||
model:
|
||||
whisper_model : null
|
||||
dia_model: null
|
||||
advanced:
|
||||
timeout: 300 #seconds e.g. 5 minutes
|
||||
@@ -3,16 +3,22 @@ Stores global variables for the app.
|
||||
"""
|
||||
|
||||
# Global variable to store the model
|
||||
from threading import Event
|
||||
|
||||
import multiprocessing
|
||||
import os
|
||||
import time
|
||||
import yaml
|
||||
|
||||
REQUEST_QUEUE = multiprocessing.Queue() # audio file path as string
|
||||
RESPONSE_QUEUE = multiprocessing.Queue() # transcription as string
|
||||
LAST_ACTIVE_TIME = multiprocessing.Value('d', time.time()) # time of last activity
|
||||
LOADED_EVENT = multiprocessing.Event() # model loaded event
|
||||
RUNNING_EVENT = multiprocessing.Event() # model running event
|
||||
|
||||
MODEL = None
|
||||
MODEL_THREAD_PARAMS = None
|
||||
MODEL_THREAD = None
|
||||
MODEL_PARAMS = None # model parameters
|
||||
MODEL_PROCESS = None # model process to handle globally
|
||||
|
||||
# Global variable to track user activity
|
||||
LAST_USED = time.time()
|
||||
TIMEOUT = 30 #seconds
|
||||
TRANSCRIBE_ACTIVE = Event()
|
||||
TIMEOUT = None #seconds
|
||||
|
||||
DEFAULT_APP_CONIFG_PATH = os.path.join(os.path.dirname(os.path.realpath(__file__)), "config.yml")
|
||||
|
||||
+23
-22
@@ -3,12 +3,12 @@ This file contains ervery function that will be called when the user interacts w
|
||||
UI like pressing a button or uploading a file.
|
||||
"""
|
||||
|
||||
from re import M
|
||||
import time
|
||||
import gradio as gr
|
||||
import scraibe.app.global_var as gv
|
||||
from scraibe import Transcript
|
||||
from scraibe.app.stg import GradioTranscriptionInterface
|
||||
import threading
|
||||
from .multi import start_model_worker
|
||||
|
||||
def select_task(choice):
|
||||
# tell the app that it is still in use
|
||||
@@ -84,20 +84,19 @@ def run_scraibe(task,
|
||||
video1,
|
||||
video2,
|
||||
file_in,
|
||||
progress = gr.Progress(track_tqdm= True)):
|
||||
progress = gr.Progress(track_tqdm=False)):
|
||||
|
||||
# get *args which are not None
|
||||
if gv.MODEL_PROCESS is None or not gv.MODEL_PROCESS.is_alive():
|
||||
#progress(0.0, desc='Loading model...')
|
||||
gv.MODEL_PROCESS = start_model_worker(gv.MODEL_PARAMS,
|
||||
gv.REQUEST_QUEUE,
|
||||
gv.LAST_ACTIVE_TIME,
|
||||
gv.RESPONSE_QUEUE,
|
||||
gv.LOADED_EVENT,
|
||||
gv.RUNNING_EVENT)
|
||||
|
||||
if gv.MODEL is None and gv.MODEL_THREAD_PARAMS is not None:
|
||||
progress(0, desc='Model was not loaded to conserve resources. Loading model...')
|
||||
time.sleep(1)
|
||||
gv.MODEL_THREAD = threading.Thread(**gv.MODEL_THREAD_PARAMS)
|
||||
gv.MODEL_THREAD.start()
|
||||
gv.MODEL_THREAD.join()
|
||||
|
||||
pipe = GradioTranscriptionInterface()
|
||||
|
||||
progress(0.1, desc='Starting task...')
|
||||
# progress(0.1, desc='Starting task...')
|
||||
source = audio1 or audio2 or video1 or video2 or file_in
|
||||
|
||||
if isinstance(source, list):
|
||||
@@ -105,12 +104,17 @@ def run_scraibe(task,
|
||||
if len(source) == 1:
|
||||
source = source[0]
|
||||
|
||||
config = dict(source = source,
|
||||
task = task,
|
||||
num_speakers = num_speakers,
|
||||
translate = translate,
|
||||
language = language)
|
||||
|
||||
gv.REQUEST_QUEUE.put(config)
|
||||
|
||||
if task == 'Auto Transcribe':
|
||||
|
||||
out_str , out_json = pipe.auto_transcribe(source = source,
|
||||
num_speakers = num_speakers,
|
||||
translation = translate,
|
||||
language = language)
|
||||
out_str , out_json = gv.RESPONSE_QUEUE.get()
|
||||
|
||||
if isinstance(source, str):
|
||||
return (gr.update(value = out_str, visible = True),
|
||||
@@ -125,9 +129,7 @@ def run_scraibe(task,
|
||||
|
||||
elif task == 'Transcribe':
|
||||
|
||||
out = pipe.transcribe(source = source,
|
||||
translation = translate,
|
||||
language = language)
|
||||
out = gv.RESPONSE_QUEUE.get()
|
||||
|
||||
return (gr.update(value = out, visible = True),
|
||||
gr.update(value = None, visible = False),
|
||||
@@ -136,8 +138,7 @@ def run_scraibe(task,
|
||||
|
||||
elif task == 'Diarisation':
|
||||
|
||||
out = pipe.perform_diarisation(source = source,
|
||||
num_speakers = num_speakers)
|
||||
out = gv.RESPONSE_QUEUE.get()
|
||||
|
||||
return (gr.update(value = None, visible = False),
|
||||
gr.update(value = out, visible = True),
|
||||
|
||||
+64
-19
@@ -4,41 +4,86 @@ Therefore, the app is to be more efficient in the usage of the resources.
|
||||
By for example, unloading or reloading the model.
|
||||
"""
|
||||
|
||||
|
||||
|
||||
import time
|
||||
import gc
|
||||
from typing import Union
|
||||
import multiprocessing
|
||||
import torch
|
||||
import signal
|
||||
|
||||
import scraibe.app.global_var as gv
|
||||
from gradio import Warning
|
||||
from scraibe.autotranscript import Scraibe
|
||||
from .stg import GradioTranscriptionInterface
|
||||
|
||||
def init_worker():
|
||||
signal.signal(signal.SIGINT, signal.SIG_IGN)
|
||||
|
||||
|
||||
def load_model_thread(model : Union[Scraibe, dict] = None):
|
||||
if model is None:
|
||||
gv.MODEL = Scraibe()
|
||||
elif type(model) is Scraibe:
|
||||
gv.MODEL = model
|
||||
elif type(model) is dict:
|
||||
gv.MODEL = Scraibe(**model)
|
||||
def clear_queue(queue):
|
||||
while not queue.empty():
|
||||
try:
|
||||
queue.get_nowait()
|
||||
except queue.Empty:
|
||||
continue
|
||||
|
||||
def model_worker(model_params : Union[Scraibe, dict],
|
||||
request_queue,
|
||||
last_active_time,
|
||||
response_queue,
|
||||
loaded_event,
|
||||
running_event,
|
||||
*args, **kwargs):
|
||||
|
||||
loaded_event.set()
|
||||
|
||||
if model_params is None:
|
||||
_model = Scraibe()
|
||||
elif type(model_params) is Scraibe:
|
||||
_model = model_params
|
||||
elif type(model_params) is dict:
|
||||
_model = Scraibe(**model_params)
|
||||
else:
|
||||
raise TypeError("model must be of type Scraibe, or dict")
|
||||
|
||||
gv.LAST_USED = time.time()
|
||||
model = GradioTranscriptionInterface(_model)
|
||||
|
||||
# Create a thread to monitor user activity
|
||||
def delete_unused_model():
|
||||
while True:
|
||||
|
||||
_unload_porperty = (not gv.TRANSCRIBE_ACTIVE.is_set() and (time.time() - gv.LAST_USED > gv.TIMEOUT) and gv.MODEL is not None)
|
||||
req = request_queue.get()
|
||||
|
||||
if _unload_porperty:
|
||||
if req == "STOP":
|
||||
|
||||
del gv.MODEL
|
||||
gv.MODEL = None
|
||||
break
|
||||
elif type(req) is dict:
|
||||
runner = model.get_task_from_str(req.pop("task"))
|
||||
running_event.set()
|
||||
transcription = runner(**req)
|
||||
running_event.clear()
|
||||
response_queue.put(transcription)
|
||||
last_active_time.value = time.time()
|
||||
else:
|
||||
raise TypeError("request must be of type dict")
|
||||
|
||||
gc.collect()
|
||||
torch.cuda.empty_cache()
|
||||
del model
|
||||
torch.cuda.empty_cache()
|
||||
gc.collect()
|
||||
clear_queue(request_queue)
|
||||
clear_queue(response_queue)
|
||||
loaded_event.clear()
|
||||
|
||||
gv.MODEL_THREAD.join()
|
||||
def start_model_worker(model_params, request_queue, last_active_time, response_queue,loaded_event, running_event, *args, **kwargs):
|
||||
context = multiprocessing.get_context('spawn')
|
||||
model_process = context.Process(target=model_worker, args=(model_params, request_queue, last_active_time, response_queue,loaded_event, running_event, *args), kwargs=kwargs)
|
||||
model_process.start()
|
||||
return model_process
|
||||
|
||||
time.sleep(int(gv.TIMEOUT/5))
|
||||
def timer_thread(request_queue, last_active_time,loaded_event, running_event, timeout=30):
|
||||
while True:
|
||||
time.sleep(timeout)
|
||||
|
||||
if time.time() - last_active_time.value > timeout and loaded_event.is_set() and not running_event.is_set():
|
||||
print(f"No activity for the last {timeout} seconds. Stopping the model worker.", flush=True)
|
||||
request_queue.put("STOP")
|
||||
Warning("Model worker stopped due to inactivity.")
|
||||
+42
-39
@@ -18,19 +18,19 @@ class GradioTranscriptionInterface:
|
||||
Interface handling the interaction between Gradio UI and the Audio Transcription system.
|
||||
"""
|
||||
|
||||
def __init__(self):
|
||||
def __init__(self, model):
|
||||
"""
|
||||
Initializes the GradioTranscriptionInterface with a transcription model.
|
||||
|
||||
Args:
|
||||
model (Scraibe): Model responsible for audio transcription tasks.
|
||||
"""
|
||||
self.model = gv.MODEL
|
||||
self.model = model
|
||||
|
||||
def auto_transcribe(self, source,
|
||||
def autotranscribe(self, source,
|
||||
num_speakers : int,
|
||||
translation : bool,
|
||||
language : str):
|
||||
translate : bool,
|
||||
language : str,*args ,**kwargs):
|
||||
"""
|
||||
Shortcut method for the Scraibe task.
|
||||
|
||||
@@ -38,22 +38,18 @@ class GradioTranscriptionInterface:
|
||||
tuple: Transcribed text (str), JSON output (dict)
|
||||
"""
|
||||
|
||||
gv.TRANSCRIBE_ACTIVE.set()
|
||||
|
||||
kwargs = {
|
||||
_kwargs = {
|
||||
"num_speakers": num_speakers if num_speakers != 0 else None,
|
||||
"language": language if language != "None" else None,
|
||||
"task": 'translate' if translation else None
|
||||
"task": 'translate' if translate else None
|
||||
}
|
||||
if isinstance(source, str):
|
||||
try:
|
||||
result = self.model.autotranscribe(source, **kwargs)
|
||||
result = self.model.autotranscribe(source, **_kwargs)
|
||||
except ValueError:
|
||||
gv.TRANSCRIBE_ACTIVE.clear()
|
||||
raise gr.Error("Couldn't detect any speech in the provided audio. \
|
||||
Please try again!")
|
||||
|
||||
gv.TRANSCRIBE_ACTIVE.clear()
|
||||
return str(result), result.get_json()
|
||||
|
||||
elif isinstance(source, list):
|
||||
@@ -61,7 +57,7 @@ class GradioTranscriptionInterface:
|
||||
result = []
|
||||
for s in tqdm(source, total=len(source),desc = "Transcribing audio files"):
|
||||
try:
|
||||
res = self.model.autotranscribe(s, **kwargs)
|
||||
res = self.model.autotranscribe(s, **_kwargs)
|
||||
except ValueError:
|
||||
_name = s.split("/")[-1]
|
||||
res = f"NO TRANSCRIPT FOUND FOR {_name}"
|
||||
@@ -80,17 +76,13 @@ class GradioTranscriptionInterface:
|
||||
else:
|
||||
out_dict[source_names[i]] = r.get_dict()
|
||||
|
||||
|
||||
gv.TRANSCRIBE_ACTIVE.clear()
|
||||
|
||||
return out, json.dumps(out_dict, indent=4)
|
||||
|
||||
else:
|
||||
gv.TRANSCRIBE_ACTIVE.clear()
|
||||
raise gr.Error("Please provide a valid audio file.")
|
||||
|
||||
|
||||
def transcribe(self, source, translation, language):
|
||||
def transcribe(self, source, translate, language,*args ,**kwargs):
|
||||
"""
|
||||
Shortcut method for the Transcribe task.
|
||||
|
||||
@@ -98,23 +90,21 @@ class GradioTranscriptionInterface:
|
||||
str: Transcribed text.
|
||||
"""
|
||||
|
||||
gv.TRANSCRIBE_ACTIVE.set()
|
||||
|
||||
kwargs = {
|
||||
_kwargs = {
|
||||
"language": language if language != "None" else None,
|
||||
"task": 'translate' if translation == "Yes" else None
|
||||
"task": 'translate' if translate == "Yes" else None
|
||||
}
|
||||
|
||||
if isinstance(source, str):
|
||||
result = self.model.transcribe(source, **kwargs)
|
||||
gv.TRANSCRIBE_ACTIVE.clear()
|
||||
result = self.model.transcribe(source, **_kwargs)
|
||||
|
||||
return str(result)
|
||||
|
||||
elif isinstance(source, list):
|
||||
source_names = [s.split("/")[-1] for s in source]
|
||||
result = []
|
||||
for s in tqdm(source, total=len(source),desc = "Transcribing audio files"):
|
||||
res = self.model.transcribe(s, **kwargs)
|
||||
res = self.model.transcribe(s, **_kwargs)
|
||||
result.append(res)
|
||||
|
||||
out = ''
|
||||
@@ -123,15 +113,12 @@ class GradioTranscriptionInterface:
|
||||
out += str(res)
|
||||
out += "\n\n"
|
||||
|
||||
gv.TRANSCRIBE_ACTIVE.clear()
|
||||
|
||||
return out
|
||||
|
||||
else:
|
||||
gv.TRANSCRIBE_ACTIVE.clear()
|
||||
raise gr.Error("Please provide a valid audio file.")
|
||||
|
||||
def perform_diarisation(self, source, num_speakers):
|
||||
def diarisation(self, source, num_speakers, *args ,**kwargs):
|
||||
"""
|
||||
Shortcut method for the Diarisation task.
|
||||
|
||||
@@ -139,27 +126,24 @@ class GradioTranscriptionInterface:
|
||||
str: JSON output of diarisation result.
|
||||
"""
|
||||
|
||||
gv.TRANSCRIBE_ACTIVE.set()
|
||||
|
||||
kwargs = {
|
||||
_kwargs = {
|
||||
"num_speakers": num_speakers if num_speakers != 0 else None,
|
||||
}
|
||||
|
||||
if isinstance(source, str):
|
||||
try:
|
||||
result = self.model.diarization(source, **kwargs)
|
||||
result = self.model.diarization(source, **_kwargs)
|
||||
except ValueError:
|
||||
gv.TRANSCRIBE_ACTIVE.clear()
|
||||
raise gr.Error("Couldn't detect any speech in the provided audio. \
|
||||
Please try again!")
|
||||
gv.TRANSCRIBE_ACTIVE.clear()
|
||||
|
||||
return json.dumps(result, indent=2)
|
||||
elif isinstance(source, list):
|
||||
source_names = [s.split("/")[-1] for s in source]
|
||||
result = []
|
||||
for s in tqdm(source, total=len(source),desc = "Performing diarisation"):
|
||||
try:
|
||||
res = self.model.diarization(s, **kwargs)
|
||||
res = self.model.diarization(s, **_kwargs)
|
||||
except ValueError:
|
||||
|
||||
res = f"NO DIARISATION FOUND FOR {s}"
|
||||
@@ -171,10 +155,29 @@ class GradioTranscriptionInterface:
|
||||
for i, res in enumerate(result):
|
||||
out[source_names[i]] = res
|
||||
|
||||
gv.TRANSCRIBE_ACTIVE.clear()
|
||||
|
||||
return json.dumps(out, indent=4)
|
||||
|
||||
else:
|
||||
gv.TRANSCRIBE_ACTIVE.clear()
|
||||
gr.Error("Please provide a valid audio file.")
|
||||
|
||||
def get_task_from_str(self, task):
|
||||
"""
|
||||
Returns the coresponing task function based on the task string.
|
||||
|
||||
params:
|
||||
task (str): Task string. Can be one of the following:
|
||||
- 'Auto Transcribe'
|
||||
- 'Transcribe'
|
||||
- 'Diarisation'
|
||||
"""
|
||||
|
||||
if task == 'Auto Transcribe':
|
||||
return self.autotranscribe
|
||||
elif task == 'Transcribe':
|
||||
return self.transcribe
|
||||
elif task == 'Diarisation':
|
||||
return self.diarisation
|
||||
else:
|
||||
raise ValueError("Invalid task string.")
|
||||
|
||||
|
||||
|
||||
@@ -0,0 +1,42 @@
|
||||
import scraibe.app.global_var as gv
|
||||
import yaml
|
||||
|
||||
def load_config(original_config_path = gv.DEFAULT_APP_CONIFG_PATH, override_yaml_path=None, **kwargs):
|
||||
|
||||
|
||||
# Load the original configuration
|
||||
with open(original_config_path, 'r') as file:
|
||||
config = yaml.safe_load(file)
|
||||
|
||||
# Override with another YAML file if provided
|
||||
if override_yaml_path:
|
||||
with open(override_yaml_path, 'r') as file:
|
||||
override_config = yaml.safe_load(file)
|
||||
apply_overrides(config, override_config)
|
||||
|
||||
# Apply overrides from kwargs
|
||||
apply_overrides(config, kwargs)
|
||||
|
||||
return config
|
||||
|
||||
def apply_overrides(orig_dict, override_dict):
|
||||
""" Recursively apply overrides to the configuration. """
|
||||
for key, value in override_dict.items():
|
||||
if isinstance(value, dict):
|
||||
# If the value is a dict, apply recursively
|
||||
apply_overrides(orig_dict.get(key, {}), value)
|
||||
else:
|
||||
# If the value is not a dict, search for the key and update
|
||||
if update_nested_key(orig_dict, key, value):
|
||||
continue # Key was found and updated
|
||||
orig_dict[key] = value # Key not found, update at this level
|
||||
|
||||
def update_nested_key(d, key, value):
|
||||
""" Recursively search and update the key in nested dictionary. """
|
||||
if key in d:
|
||||
d[key] = value
|
||||
return True
|
||||
for k, v in d.items():
|
||||
if isinstance(v, dict) and update_nested_key(v, key, value):
|
||||
return True
|
||||
return False
|
||||
Reference in New Issue
Block a user