Make everything work in processes and adding config to customize instance

This commit is contained in:
Jaikinator
2023-12-07 16:22:52 +01:00
parent 32b27442e6
commit 9eb9f5af8d
6 changed files with 241 additions and 96 deletions
+48
View File
@@ -0,0 +1,48 @@
launch:
# The following are the default values for the launch configuration
# for more informations look at https://www.gradio.app/docs/interface
server_port: 8080
server_name: 0.0.0.0
inbrowser: true
inline: false
max-threads: 40
quiet: false
auth:
enabled: false
username: admin
password: admin
auth_message: "Please enter your credentials"
show_error : false
favicon_path : null
ssl_keyfile : null
ssl_certfile : null
ssl_keyfile_password : null
ssl_verify : false
quiet : false
show_api : false
allowed_paths : null
blocked_paths : null
root_path : null
app_kwargs : null
state_session_capacity : 1000
share_server_address : null
share_server_protocol : null
share : false
debug : false
queue:
# The following are the default values for the queue configuration
# for more informations look at hhttps://www.gradio.app/docs/interface
status_update_rate : 'auto'
api_open : null
max_size : null
concurrency_count : null
default_concurrency_limit : 'not_set'
layout:
header: scraibe/app/header.html
footer: null
logo: scraibe/app/logo.svg
model:
whisper_model : null
dia_model: null
advanced:
timeout: 300 #seconds e.g. 5 minutes
+13 -7
View File
@@ -3,16 +3,22 @@ Stores global variables for the app.
"""
# Global variable to store the model
from threading import Event
import multiprocessing
import os
import time
import yaml
REQUEST_QUEUE = multiprocessing.Queue() # audio file path as string
RESPONSE_QUEUE = multiprocessing.Queue() # transcription as string
LAST_ACTIVE_TIME = multiprocessing.Value('d', time.time()) # time of last activity
LOADED_EVENT = multiprocessing.Event() # model loaded event
RUNNING_EVENT = multiprocessing.Event() # model running event
MODEL = None
MODEL_THREAD_PARAMS = None
MODEL_THREAD = None
MODEL_PARAMS = None # model parameters
MODEL_PROCESS = None # model process to handle globally
# Global variable to track user activity
LAST_USED = time.time()
TIMEOUT = 30 #seconds
TRANSCRIBE_ACTIVE = Event()
TIMEOUT = None #seconds
DEFAULT_APP_CONIFG_PATH = os.path.join(os.path.dirname(os.path.realpath(__file__)), "config.yml")
+25 -24
View File
@@ -3,12 +3,12 @@ This file contains ervery function that will be called when the user interacts w
UI like pressing a button or uploading a file.
"""
from re import M
import time
import gradio as gr
import scraibe.app.global_var as gv
from scraibe import Transcript
from scraibe.app.stg import GradioTranscriptionInterface
import threading
from .multi import start_model_worker
def select_task(choice):
# tell the app that it is still in use
@@ -84,33 +84,37 @@ def run_scraibe(task,
video1,
video2,
file_in,
progress = gr.Progress(track_tqdm= True)):
progress = gr.Progress(track_tqdm=False)):
# get *args which are not None
if gv.MODEL_PROCESS is None or not gv.MODEL_PROCESS.is_alive():
#progress(0.0, desc='Loading model...')
gv.MODEL_PROCESS = start_model_worker(gv.MODEL_PARAMS,
gv.REQUEST_QUEUE,
gv.LAST_ACTIVE_TIME,
gv.RESPONSE_QUEUE,
gv.LOADED_EVENT,
gv.RUNNING_EVENT)
if gv.MODEL is None and gv.MODEL_THREAD_PARAMS is not None:
progress(0, desc='Model was not loaded to conserve resources. Loading model...')
time.sleep(1)
gv.MODEL_THREAD = threading.Thread(**gv.MODEL_THREAD_PARAMS)
gv.MODEL_THREAD.start()
gv.MODEL_THREAD.join()
pipe = GradioTranscriptionInterface()
progress(0.1, desc='Starting task...')
# progress(0.1, desc='Starting task...')
source = audio1 or audio2 or video1 or video2 or file_in
if isinstance(source, list):
source = [s.name for s in source]
if len(source) == 1:
source = source[0]
config = dict(source = source,
task = task,
num_speakers = num_speakers,
translate = translate,
language = language)
gv.REQUEST_QUEUE.put(config)
if task == 'Auto Transcribe':
out_str , out_json = pipe.auto_transcribe(source = source,
num_speakers = num_speakers,
translation = translate,
language = language)
out_str , out_json = gv.RESPONSE_QUEUE.get()
if isinstance(source, str):
return (gr.update(value = out_str, visible = True),
@@ -125,9 +129,7 @@ def run_scraibe(task,
elif task == 'Transcribe':
out = pipe.transcribe(source = source,
translation = translate,
language = language)
out = gv.RESPONSE_QUEUE.get()
return (gr.update(value = out, visible = True),
gr.update(value = None, visible = False),
@@ -136,8 +138,7 @@ def run_scraibe(task,
elif task == 'Diarisation':
out = pipe.perform_diarisation(source = source,
num_speakers = num_speakers)
out = gv.RESPONSE_QUEUE.get()
return (gr.update(value = None, visible = False),
gr.update(value = out, visible = True),
+68 -23
View File
@@ -4,41 +4,86 @@ Therefore, the app is to be more efficient in the usage of the resources.
By for example, unloading or reloading the model.
"""
import time
import gc
from typing import Union
import multiprocessing
import torch
import signal
import scraibe.app.global_var as gv
from scraibe.autotranscript import Scraibe
from gradio import Warning
from scraibe.autotranscript import Scraibe
from .stg import GradioTranscriptionInterface
def init_worker():
signal.signal(signal.SIGINT, signal.SIG_IGN)
def load_model_thread(model : Union[Scraibe, dict] = None):
if model is None:
gv.MODEL = Scraibe()
elif type(model) is Scraibe:
gv.MODEL = model
elif type(model) is dict:
gv.MODEL = Scraibe(**model)
def clear_queue(queue):
while not queue.empty():
try:
queue.get_nowait()
except queue.Empty:
continue
def model_worker(model_params : Union[Scraibe, dict],
request_queue,
last_active_time,
response_queue,
loaded_event,
running_event,
*args, **kwargs):
loaded_event.set()
if model_params is None:
_model = Scraibe()
elif type(model_params) is Scraibe:
_model = model_params
elif type(model_params) is dict:
_model = Scraibe(**model_params)
else:
raise TypeError("model must be of type Scraibe, or dict")
gv.LAST_USED = time.time()
# Create a thread to monitor user activity
def delete_unused_model():
model = GradioTranscriptionInterface(_model)
while True:
_unload_porperty = (not gv.TRANSCRIBE_ACTIVE.is_set() and (time.time() - gv.LAST_USED > gv.TIMEOUT) and gv.MODEL is not None)
req = request_queue.get()
if _unload_porperty:
if req == "STOP":
del gv.MODEL
gv.MODEL = None
gc.collect()
torch.cuda.empty_cache()
break
elif type(req) is dict:
runner = model.get_task_from_str(req.pop("task"))
running_event.set()
transcription = runner(**req)
running_event.clear()
response_queue.put(transcription)
last_active_time.value = time.time()
else:
raise TypeError("request must be of type dict")
gv.MODEL_THREAD.join()
time.sleep(int(gv.TIMEOUT/5))
del model
torch.cuda.empty_cache()
gc.collect()
clear_queue(request_queue)
clear_queue(response_queue)
loaded_event.clear()
def start_model_worker(model_params, request_queue, last_active_time, response_queue,loaded_event, running_event, *args, **kwargs):
context = multiprocessing.get_context('spawn')
model_process = context.Process(target=model_worker, args=(model_params, request_queue, last_active_time, response_queue,loaded_event, running_event, *args), kwargs=kwargs)
model_process.start()
return model_process
def timer_thread(request_queue, last_active_time,loaded_event, running_event, timeout=30):
while True:
time.sleep(timeout)
if time.time() - last_active_time.value > timeout and loaded_event.is_set() and not running_event.is_set():
print(f"No activity for the last {timeout} seconds. Stopping the model worker.", flush=True)
request_queue.put("STOP")
Warning("Model worker stopped due to inactivity.")
+45 -42
View File
@@ -18,19 +18,19 @@ class GradioTranscriptionInterface:
Interface handling the interaction between Gradio UI and the Audio Transcription system.
"""
def __init__(self):
def __init__(self, model):
"""
Initializes the GradioTranscriptionInterface with a transcription model.
Args:
model (Scraibe): Model responsible for audio transcription tasks.
"""
self.model = gv.MODEL
self.model = model
def auto_transcribe(self, source,
def autotranscribe(self, source,
num_speakers : int,
translation : bool,
language : str):
translate : bool,
language : str,*args ,**kwargs):
"""
Shortcut method for the Scraibe task.
@@ -38,22 +38,18 @@ class GradioTranscriptionInterface:
tuple: Transcribed text (str), JSON output (dict)
"""
gv.TRANSCRIBE_ACTIVE.set()
kwargs = {
_kwargs = {
"num_speakers": num_speakers if num_speakers != 0 else None,
"language": language if language != "None" else None,
"task": 'translate' if translation else None
"task": 'translate' if translate else None
}
if isinstance(source, str):
try:
result = self.model.autotranscribe(source, **kwargs)
result = self.model.autotranscribe(source, **_kwargs)
except ValueError:
gv.TRANSCRIBE_ACTIVE.clear()
raise gr.Error("Couldn't detect any speech in the provided audio. \
Please try again!")
gv.TRANSCRIBE_ACTIVE.clear()
return str(result), result.get_json()
elif isinstance(source, list):
@@ -61,7 +57,7 @@ class GradioTranscriptionInterface:
result = []
for s in tqdm(source, total=len(source),desc = "Transcribing audio files"):
try:
res = self.model.autotranscribe(s, **kwargs)
res = self.model.autotranscribe(s, **_kwargs)
except ValueError:
_name = s.split("/")[-1]
res = f"NO TRANSCRIPT FOUND FOR {_name}"
@@ -79,42 +75,36 @@ class GradioTranscriptionInterface:
out_dict[source_names[i]] = r
else:
out_dict[source_names[i]] = r.get_dict()
gv.TRANSCRIBE_ACTIVE.clear()
return out, json.dumps(out_dict, indent=4)
else:
gv.TRANSCRIBE_ACTIVE.clear()
raise gr.Error("Please provide a valid audio file.")
def transcribe(self, source, translation, language):
def transcribe(self, source, translate, language,*args ,**kwargs):
"""
Shortcut method for the Transcribe task.
Returns:
str: Transcribed text.
"""
gv.TRANSCRIBE_ACTIVE.set()
kwargs = {
_kwargs = {
"language": language if language != "None" else None,
"task": 'translate' if translation == "Yes" else None
"task": 'translate' if translate == "Yes" else None
}
if isinstance(source, str):
result = self.model.transcribe(source, **kwargs)
gv.TRANSCRIBE_ACTIVE.clear()
result = self.model.transcribe(source, **_kwargs)
return str(result)
elif isinstance(source, list):
source_names = [s.split("/")[-1] for s in source]
result = []
for s in tqdm(source, total=len(source),desc = "Transcribing audio files"):
res = self.model.transcribe(s, **kwargs)
res = self.model.transcribe(s, **_kwargs)
result.append(res)
out = ''
@@ -123,15 +113,12 @@ class GradioTranscriptionInterface:
out += str(res)
out += "\n\n"
gv.TRANSCRIBE_ACTIVE.clear()
return out
else:
gv.TRANSCRIBE_ACTIVE.clear()
raise gr.Error("Please provide a valid audio file.")
def perform_diarisation(self, source, num_speakers):
def diarisation(self, source, num_speakers, *args ,**kwargs):
"""
Shortcut method for the Diarisation task.
@@ -139,27 +126,24 @@ class GradioTranscriptionInterface:
str: JSON output of diarisation result.
"""
gv.TRANSCRIBE_ACTIVE.set()
kwargs = {
_kwargs = {
"num_speakers": num_speakers if num_speakers != 0 else None,
}
if isinstance(source, str):
try:
result = self.model.diarization(source, **kwargs)
result = self.model.diarization(source, **_kwargs)
except ValueError:
gv.TRANSCRIBE_ACTIVE.clear()
raise gr.Error("Couldn't detect any speech in the provided audio. \
Please try again!")
gv.TRANSCRIBE_ACTIVE.clear()
return json.dumps(result, indent=2)
elif isinstance(source, list):
source_names = [s.split("/")[-1] for s in source]
result = []
for s in tqdm(source, total=len(source),desc = "Performing diarisation"):
try:
res = self.model.diarization(s, **kwargs)
res = self.model.diarization(s, **_kwargs)
except ValueError:
res = f"NO DIARISATION FOUND FOR {s}"
@@ -171,10 +155,29 @@ class GradioTranscriptionInterface:
for i, res in enumerate(result):
out[source_names[i]] = res
gv.TRANSCRIBE_ACTIVE.clear()
return json.dumps(out, indent=4)
else:
gv.TRANSCRIBE_ACTIVE.clear()
gr.Error("Please provide a valid audio file.")
gr.Error("Please provide a valid audio file.")
def get_task_from_str(self, task):
"""
Returns the coresponing task function based on the task string.
params:
task (str): Task string. Can be one of the following:
- 'Auto Transcribe'
- 'Transcribe'
- 'Diarisation'
"""
if task == 'Auto Transcribe':
return self.autotranscribe
elif task == 'Transcribe':
return self.transcribe
elif task == 'Diarisation':
return self.diarisation
else:
raise ValueError("Invalid task string.")
+42
View File
@@ -0,0 +1,42 @@
import scraibe.app.global_var as gv
import yaml
def load_config(original_config_path = gv.DEFAULT_APP_CONIFG_PATH, override_yaml_path=None, **kwargs):
# Load the original configuration
with open(original_config_path, 'r') as file:
config = yaml.safe_load(file)
# Override with another YAML file if provided
if override_yaml_path:
with open(override_yaml_path, 'r') as file:
override_config = yaml.safe_load(file)
apply_overrides(config, override_config)
# Apply overrides from kwargs
apply_overrides(config, kwargs)
return config
def apply_overrides(orig_dict, override_dict):
""" Recursively apply overrides to the configuration. """
for key, value in override_dict.items():
if isinstance(value, dict):
# If the value is a dict, apply recursively
apply_overrides(orig_dict.get(key, {}), value)
else:
# If the value is not a dict, search for the key and update
if update_nested_key(orig_dict, key, value):
continue # Key was found and updated
orig_dict[key] = value # Key not found, update at this level
def update_nested_key(d, key, value):
""" Recursively search and update the key in nested dictionary. """
if key in d:
d[key] = value
return True
for k, v in d.items():
if isinstance(v, dict) and update_nested_key(v, key, value):
return True
return False