diff --git a/scraibe/app/config.yml b/scraibe/app/config.yml new file mode 100644 index 0000000..16d296c --- /dev/null +++ b/scraibe/app/config.yml @@ -0,0 +1,48 @@ +launch: + # The following are the default values for the launch configuration + # for more informations look at https://www.gradio.app/docs/interface + server_port: 8080 + server_name: 0.0.0.0 + inbrowser: true + inline: false + max-threads: 40 + quiet: false + auth: + enabled: false + username: admin + password: admin + auth_message: "Please enter your credentials" + show_error : false + favicon_path : null + ssl_keyfile : null + ssl_certfile : null + ssl_keyfile_password : null + ssl_verify : false + quiet : false + show_api : false + allowed_paths : null + blocked_paths : null + root_path : null + app_kwargs : null + state_session_capacity : 1000 + share_server_address : null + share_server_protocol : null + share : false + debug : false +queue: + # The following are the default values for the queue configuration + # for more informations look at hhttps://www.gradio.app/docs/interface + status_update_rate : 'auto' + api_open : null + max_size : null + concurrency_count : null + default_concurrency_limit : 'not_set' +layout: + header: scraibe/app/header.html + footer: null + logo: scraibe/app/logo.svg +model: + whisper_model : null + dia_model: null +advanced: + timeout: 300 #seconds e.g. 5 minutes diff --git a/scraibe/app/global_var.py b/scraibe/app/global_var.py index 6d8f3cf..99f6eea 100644 --- a/scraibe/app/global_var.py +++ b/scraibe/app/global_var.py @@ -3,16 +3,22 @@ Stores global variables for the app. """ # Global variable to store the model -from threading import Event - +import multiprocessing +import os import time +import yaml +REQUEST_QUEUE = multiprocessing.Queue() # audio file path as string +RESPONSE_QUEUE = multiprocessing.Queue() # transcription as string +LAST_ACTIVE_TIME = multiprocessing.Value('d', time.time()) # time of last activity +LOADED_EVENT = multiprocessing.Event() # model loaded event +RUNNING_EVENT = multiprocessing.Event() # model running event -MODEL = None -MODEL_THREAD_PARAMS = None -MODEL_THREAD = None +MODEL_PARAMS = None # model parameters +MODEL_PROCESS = None # model process to handle globally # Global variable to track user activity LAST_USED = time.time() -TIMEOUT = 30 #seconds -TRANSCRIBE_ACTIVE = Event() \ No newline at end of file +TIMEOUT = None #seconds + +DEFAULT_APP_CONIFG_PATH = os.path.join(os.path.dirname(os.path.realpath(__file__)), "config.yml") diff --git a/scraibe/app/interactions.py b/scraibe/app/interactions.py index 6151d64..1719388 100644 --- a/scraibe/app/interactions.py +++ b/scraibe/app/interactions.py @@ -3,12 +3,12 @@ This file contains ervery function that will be called when the user interacts w UI like pressing a button or uploading a file. """ +from re import M import time import gradio as gr import scraibe.app.global_var as gv from scraibe import Transcript -from scraibe.app.stg import GradioTranscriptionInterface -import threading +from .multi import start_model_worker def select_task(choice): # tell the app that it is still in use @@ -84,33 +84,37 @@ def run_scraibe(task, video1, video2, file_in, - progress = gr.Progress(track_tqdm= True)): + progress = gr.Progress(track_tqdm=False)): # get *args which are not None + if gv.MODEL_PROCESS is None or not gv.MODEL_PROCESS.is_alive(): + #progress(0.0, desc='Loading model...') + gv.MODEL_PROCESS = start_model_worker(gv.MODEL_PARAMS, + gv.REQUEST_QUEUE, + gv.LAST_ACTIVE_TIME, + gv.RESPONSE_QUEUE, + gv.LOADED_EVENT, + gv.RUNNING_EVENT) - if gv.MODEL is None and gv.MODEL_THREAD_PARAMS is not None: - progress(0, desc='Model was not loaded to conserve resources. Loading model...') - time.sleep(1) - gv.MODEL_THREAD = threading.Thread(**gv.MODEL_THREAD_PARAMS) - gv.MODEL_THREAD.start() - gv.MODEL_THREAD.join() - - pipe = GradioTranscriptionInterface() - - progress(0.1, desc='Starting task...') + # progress(0.1, desc='Starting task...') source = audio1 or audio2 or video1 or video2 or file_in if isinstance(source, list): source = [s.name for s in source] if len(source) == 1: source = source[0] - + + config = dict(source = source, + task = task, + num_speakers = num_speakers, + translate = translate, + language = language) + + gv.REQUEST_QUEUE.put(config) + if task == 'Auto Transcribe': - - out_str , out_json = pipe.auto_transcribe(source = source, - num_speakers = num_speakers, - translation = translate, - language = language) + + out_str , out_json = gv.RESPONSE_QUEUE.get() if isinstance(source, str): return (gr.update(value = out_str, visible = True), @@ -125,9 +129,7 @@ def run_scraibe(task, elif task == 'Transcribe': - out = pipe.transcribe(source = source, - translation = translate, - language = language) + out = gv.RESPONSE_QUEUE.get() return (gr.update(value = out, visible = True), gr.update(value = None, visible = False), @@ -136,8 +138,7 @@ def run_scraibe(task, elif task == 'Diarisation': - out = pipe.perform_diarisation(source = source, - num_speakers = num_speakers) + out = gv.RESPONSE_QUEUE.get() return (gr.update(value = None, visible = False), gr.update(value = out, visible = True), diff --git a/scraibe/app/multi.py b/scraibe/app/multi.py index 4aa0c09..17fd1bb 100644 --- a/scraibe/app/multi.py +++ b/scraibe/app/multi.py @@ -4,41 +4,86 @@ Therefore, the app is to be more efficient in the usage of the resources. By for example, unloading or reloading the model. """ + + import time import gc from typing import Union +import multiprocessing import torch +import signal -import scraibe.app.global_var as gv -from scraibe.autotranscript import Scraibe +from gradio import Warning +from scraibe.autotranscript import Scraibe +from .stg import GradioTranscriptionInterface + +def init_worker(): + signal.signal(signal.SIGINT, signal.SIG_IGN) -def load_model_thread(model : Union[Scraibe, dict] = None): - if model is None: - gv.MODEL = Scraibe() - elif type(model) is Scraibe: - gv.MODEL = model - elif type(model) is dict: - gv.MODEL = Scraibe(**model) +def clear_queue(queue): + while not queue.empty(): + try: + queue.get_nowait() + except queue.Empty: + continue + +def model_worker(model_params : Union[Scraibe, dict], + request_queue, + last_active_time, + response_queue, + loaded_event, + running_event, + *args, **kwargs): + + loaded_event.set() + + if model_params is None: + _model = Scraibe() + elif type(model_params) is Scraibe: + _model = model_params + elif type(model_params) is dict: + _model = Scraibe(**model_params) else: raise TypeError("model must be of type Scraibe, or dict") - gv.LAST_USED = time.time() - -# Create a thread to monitor user activity -def delete_unused_model(): + model = GradioTranscriptionInterface(_model) + while True: - _unload_porperty = (not gv.TRANSCRIBE_ACTIVE.is_set() and (time.time() - gv.LAST_USED > gv.TIMEOUT) and gv.MODEL is not None) + req = request_queue.get() - if _unload_porperty: + if req == "STOP": - del gv.MODEL - gv.MODEL = None - - gc.collect() - torch.cuda.empty_cache() + break + elif type(req) is dict: + runner = model.get_task_from_str(req.pop("task")) + running_event.set() + transcription = runner(**req) + running_event.clear() + response_queue.put(transcription) + last_active_time.value = time.time() + else: + raise TypeError("request must be of type dict") - gv.MODEL_THREAD.join() - - time.sleep(int(gv.TIMEOUT/5)) + del model + torch.cuda.empty_cache() + gc.collect() + clear_queue(request_queue) + clear_queue(response_queue) + loaded_event.clear() + +def start_model_worker(model_params, request_queue, last_active_time, response_queue,loaded_event, running_event, *args, **kwargs): + context = multiprocessing.get_context('spawn') + model_process = context.Process(target=model_worker, args=(model_params, request_queue, last_active_time, response_queue,loaded_event, running_event, *args), kwargs=kwargs) + model_process.start() + return model_process + +def timer_thread(request_queue, last_active_time,loaded_event, running_event, timeout=30): + while True: + time.sleep(timeout) + + if time.time() - last_active_time.value > timeout and loaded_event.is_set() and not running_event.is_set(): + print(f"No activity for the last {timeout} seconds. Stopping the model worker.", flush=True) + request_queue.put("STOP") + Warning("Model worker stopped due to inactivity.") \ No newline at end of file diff --git a/scraibe/app/stg.py b/scraibe/app/stg.py index 0215903..1b9caf7 100644 --- a/scraibe/app/stg.py +++ b/scraibe/app/stg.py @@ -18,19 +18,19 @@ class GradioTranscriptionInterface: Interface handling the interaction between Gradio UI and the Audio Transcription system. """ - def __init__(self): + def __init__(self, model): """ Initializes the GradioTranscriptionInterface with a transcription model. Args: model (Scraibe): Model responsible for audio transcription tasks. """ - self.model = gv.MODEL + self.model = model - def auto_transcribe(self, source, + def autotranscribe(self, source, num_speakers : int, - translation : bool, - language : str): + translate : bool, + language : str,*args ,**kwargs): """ Shortcut method for the Scraibe task. @@ -38,22 +38,18 @@ class GradioTranscriptionInterface: tuple: Transcribed text (str), JSON output (dict) """ - gv.TRANSCRIBE_ACTIVE.set() - - kwargs = { + _kwargs = { "num_speakers": num_speakers if num_speakers != 0 else None, "language": language if language != "None" else None, - "task": 'translate' if translation else None + "task": 'translate' if translate else None } if isinstance(source, str): try: - result = self.model.autotranscribe(source, **kwargs) + result = self.model.autotranscribe(source, **_kwargs) except ValueError: - gv.TRANSCRIBE_ACTIVE.clear() raise gr.Error("Couldn't detect any speech in the provided audio. \ Please try again!") - - gv.TRANSCRIBE_ACTIVE.clear() + return str(result), result.get_json() elif isinstance(source, list): @@ -61,7 +57,7 @@ class GradioTranscriptionInterface: result = [] for s in tqdm(source, total=len(source),desc = "Transcribing audio files"): try: - res = self.model.autotranscribe(s, **kwargs) + res = self.model.autotranscribe(s, **_kwargs) except ValueError: _name = s.split("/")[-1] res = f"NO TRANSCRIPT FOUND FOR {_name}" @@ -79,42 +75,36 @@ class GradioTranscriptionInterface: out_dict[source_names[i]] = r else: out_dict[source_names[i]] = r.get_dict() - - - gv.TRANSCRIBE_ACTIVE.clear() return out, json.dumps(out_dict, indent=4) else: - gv.TRANSCRIBE_ACTIVE.clear() raise gr.Error("Please provide a valid audio file.") - def transcribe(self, source, translation, language): + def transcribe(self, source, translate, language,*args ,**kwargs): """ Shortcut method for the Transcribe task. Returns: str: Transcribed text. """ - - gv.TRANSCRIBE_ACTIVE.set() - - kwargs = { + + _kwargs = { "language": language if language != "None" else None, - "task": 'translate' if translation == "Yes" else None + "task": 'translate' if translate == "Yes" else None } if isinstance(source, str): - result = self.model.transcribe(source, **kwargs) - gv.TRANSCRIBE_ACTIVE.clear() + result = self.model.transcribe(source, **_kwargs) + return str(result) elif isinstance(source, list): source_names = [s.split("/")[-1] for s in source] result = [] for s in tqdm(source, total=len(source),desc = "Transcribing audio files"): - res = self.model.transcribe(s, **kwargs) + res = self.model.transcribe(s, **_kwargs) result.append(res) out = '' @@ -123,15 +113,12 @@ class GradioTranscriptionInterface: out += str(res) out += "\n\n" - gv.TRANSCRIBE_ACTIVE.clear() - return out else: - gv.TRANSCRIBE_ACTIVE.clear() raise gr.Error("Please provide a valid audio file.") - def perform_diarisation(self, source, num_speakers): + def diarisation(self, source, num_speakers, *args ,**kwargs): """ Shortcut method for the Diarisation task. @@ -139,27 +126,24 @@ class GradioTranscriptionInterface: str: JSON output of diarisation result. """ - gv.TRANSCRIBE_ACTIVE.set() - - kwargs = { + _kwargs = { "num_speakers": num_speakers if num_speakers != 0 else None, } if isinstance(source, str): try: - result = self.model.diarization(source, **kwargs) + result = self.model.diarization(source, **_kwargs) except ValueError: - gv.TRANSCRIBE_ACTIVE.clear() raise gr.Error("Couldn't detect any speech in the provided audio. \ Please try again!") - gv.TRANSCRIBE_ACTIVE.clear() + return json.dumps(result, indent=2) elif isinstance(source, list): source_names = [s.split("/")[-1] for s in source] result = [] for s in tqdm(source, total=len(source),desc = "Performing diarisation"): try: - res = self.model.diarization(s, **kwargs) + res = self.model.diarization(s, **_kwargs) except ValueError: res = f"NO DIARISATION FOUND FOR {s}" @@ -171,10 +155,29 @@ class GradioTranscriptionInterface: for i, res in enumerate(result): out[source_names[i]] = res - gv.TRANSCRIBE_ACTIVE.clear() - return json.dumps(out, indent=4) else: - gv.TRANSCRIBE_ACTIVE.clear() - gr.Error("Please provide a valid audio file.") + gr.Error("Please provide a valid audio file.") + + def get_task_from_str(self, task): + """ + Returns the coresponing task function based on the task string. + + params: + task (str): Task string. Can be one of the following: + - 'Auto Transcribe' + - 'Transcribe' + - 'Diarisation' + """ + + if task == 'Auto Transcribe': + return self.autotranscribe + elif task == 'Transcribe': + return self.transcribe + elif task == 'Diarisation': + return self.diarisation + else: + raise ValueError("Invalid task string.") + + diff --git a/scraibe/app/utils.py b/scraibe/app/utils.py new file mode 100644 index 0000000..b41a88f --- /dev/null +++ b/scraibe/app/utils.py @@ -0,0 +1,42 @@ +import scraibe.app.global_var as gv +import yaml + +def load_config(original_config_path = gv.DEFAULT_APP_CONIFG_PATH, override_yaml_path=None, **kwargs): + + + # Load the original configuration + with open(original_config_path, 'r') as file: + config = yaml.safe_load(file) + + # Override with another YAML file if provided + if override_yaml_path: + with open(override_yaml_path, 'r') as file: + override_config = yaml.safe_load(file) + apply_overrides(config, override_config) + + # Apply overrides from kwargs + apply_overrides(config, kwargs) + + return config + +def apply_overrides(orig_dict, override_dict): + """ Recursively apply overrides to the configuration. """ + for key, value in override_dict.items(): + if isinstance(value, dict): + # If the value is a dict, apply recursively + apply_overrides(orig_dict.get(key, {}), value) + else: + # If the value is not a dict, search for the key and update + if update_nested_key(orig_dict, key, value): + continue # Key was found and updated + orig_dict[key] = value # Key not found, update at this level + +def update_nested_key(d, key, value): + """ Recursively search and update the key in nested dictionary. """ + if key in d: + d[key] = value + return True + for k, v in d.items(): + if isinstance(v, dict) and update_nested_key(v, key, value): + return True + return False \ No newline at end of file