From 93e5ce15f95dabe126501d982c88c5b928949cfe Mon Sep 17 00:00:00 2001 From: Jaikinator Date: Sat, 25 Nov 2023 15:17:12 +0100 Subject: [PATCH] make gradio working with treads --- scraibe/app/__init__.py | 2 +- scraibe/app/global_var.py | 11 +++++++++- scraibe/app/interactions.py | 19 +++++++++++----- scraibe/app/interface.py | 7 +----- scraibe/app/multi.py | 44 +++++++++++++++++++++++++++++++++++++ scraibe/app/stg.py | 39 +++++++++++++++++++++++++------- 6 files changed, 101 insertions(+), 21 deletions(-) create mode 100644 scraibe/app/multi.py diff --git a/scraibe/app/__init__.py b/scraibe/app/__init__.py index 9e04a48..fa8f8f7 100644 --- a/scraibe/app/__init__.py +++ b/scraibe/app/__init__.py @@ -1,5 +1,5 @@ from .qtfaststart import * -from .activity_tracker import * +from .multi import * from .interface import * from .stg import * from .interactions import * diff --git a/scraibe/app/global_var.py b/scraibe/app/global_var.py index 191e3e6..6d8f3cf 100644 --- a/scraibe/app/global_var.py +++ b/scraibe/app/global_var.py @@ -3,7 +3,16 @@ Stores global variables for the app. """ # Global variable to store the model +from threading import Event + +import time + + MODEL = None +MODEL_THREAD_PARAMS = None +MODEL_THREAD = None # Global variable to track user activity -USER_ACTIVE = False \ No newline at end of file +LAST_USED = time.time() +TIMEOUT = 30 #seconds +TRANSCRIBE_ACTIVE = Event() \ No newline at end of file diff --git a/scraibe/app/interactions.py b/scraibe/app/interactions.py index 10659c0..6151d64 100644 --- a/scraibe/app/interactions.py +++ b/scraibe/app/interactions.py @@ -3,10 +3,12 @@ This file contains ervery function that will be called when the user interacts w UI like pressing a button or uploading a file. """ -from math import pi +import time import gradio as gr import scraibe.app.global_var as gv from scraibe import Transcript +from scraibe.app.stg import GradioTranscriptionInterface +import threading def select_task(choice): # tell the app that it is still in use @@ -84,11 +86,18 @@ def run_scraibe(task, file_in, progress = gr.Progress(track_tqdm= True)): - # get *args which are not None + # get *args which are not None - pipe = gv.MODEL - - progress(0, desc='Starting task...') + if gv.MODEL is None and gv.MODEL_THREAD_PARAMS is not None: + progress(0, desc='Model was not loaded to conserve resources. Loading model...') + time.sleep(1) + gv.MODEL_THREAD = threading.Thread(**gv.MODEL_THREAD_PARAMS) + gv.MODEL_THREAD.start() + gv.MODEL_THREAD.join() + + pipe = GradioTranscriptionInterface() + + progress(0.1, desc='Starting task...') source = audio1 or audio2 or video1 or video2 or file_in if isinstance(source, list): diff --git a/scraibe/app/interface.py b/scraibe/app/interface.py index ef9d818..ddf10ee 100644 --- a/scraibe/app/interface.py +++ b/scraibe/app/interface.py @@ -9,8 +9,6 @@ import scraibe.app.global_var as gv from .interactions import * from .stg import * -from scraibe import Scraibe - theme = gr.themes.Soft( primary_hue="green", secondary_hue='orange', @@ -36,10 +34,7 @@ LANGUAGES = [ CURRENT_PATH = os.path.dirname(os.path.realpath(__file__)) -def gradio_Interface(pipe : Scraibe = None): - - if pipe is not None: - gv.MODEL = GradioTranscriptionInterface(pipe) +def gradio_Interface(): with gr.Blocks(theme=theme,title='ScrAIbe: Automatic Audio Transcription') as demo: diff --git a/scraibe/app/multi.py b/scraibe/app/multi.py new file mode 100644 index 0000000..4aa0c09 --- /dev/null +++ b/scraibe/app/multi.py @@ -0,0 +1,44 @@ +""" +This file contains the functions which are related to monitoring the actual app usage. +Therefore, the app is to be more efficient in the usage of the resources. +By for example, unloading or reloading the model. +""" + +import time +import gc +from typing import Union +import torch + +import scraibe.app.global_var as gv +from scraibe.autotranscript import Scraibe + + +def load_model_thread(model : Union[Scraibe, dict] = None): + if model is None: + gv.MODEL = Scraibe() + elif type(model) is Scraibe: + gv.MODEL = model + elif type(model) is dict: + gv.MODEL = Scraibe(**model) + else: + raise TypeError("model must be of type Scraibe, or dict") + + gv.LAST_USED = time.time() + +# Create a thread to monitor user activity +def delete_unused_model(): + while True: + + _unload_porperty = (not gv.TRANSCRIBE_ACTIVE.is_set() and (time.time() - gv.LAST_USED > gv.TIMEOUT) and gv.MODEL is not None) + + if _unload_porperty: + + del gv.MODEL + gv.MODEL = None + + gc.collect() + torch.cuda.empty_cache() + + gv.MODEL_THREAD.join() + + time.sleep(int(gv.TIMEOUT/5)) diff --git a/scraibe/app/stg.py b/scraibe/app/stg.py index 9b227a1..0215903 100644 --- a/scraibe/app/stg.py +++ b/scraibe/app/stg.py @@ -9,7 +9,8 @@ It makes adds gradio interactions to the scraibe class in the back. import json import gradio as gr from tqdm import tqdm -from scraibe import Scraibe + +import scraibe.app.global_var as gv class GradioTranscriptionInterface: @@ -17,14 +18,14 @@ class GradioTranscriptionInterface: Interface handling the interaction between Gradio UI and the Audio Transcription system. """ - def __init__(self, model: Scraibe): + def __init__(self): """ Initializes the GradioTranscriptionInterface with a transcription model. Args: model (Scraibe): Model responsible for audio transcription tasks. """ - self.model = model + self.model = gv.MODEL def auto_transcribe(self, source, num_speakers : int, @@ -37,6 +38,8 @@ class GradioTranscriptionInterface: tuple: Transcribed text (str), JSON output (dict) """ + gv.TRANSCRIBE_ACTIVE.set() + kwargs = { "num_speakers": num_speakers if num_speakers != 0 else None, "language": language if language != "None" else None, @@ -46,9 +49,11 @@ class GradioTranscriptionInterface: try: result = self.model.autotranscribe(source, **kwargs) except ValueError: + gv.TRANSCRIBE_ACTIVE.clear() raise gr.Error("Couldn't detect any speech in the provided audio. \ Please try again!") - + + gv.TRANSCRIBE_ACTIVE.clear() return str(result), result.get_json() elif isinstance(source, list): @@ -74,10 +79,14 @@ class GradioTranscriptionInterface: out_dict[source_names[i]] = r else: out_dict[source_names[i]] = r.get_dict() + + + gv.TRANSCRIBE_ACTIVE.clear() return out, json.dumps(out_dict, indent=4) else: + gv.TRANSCRIBE_ACTIVE.clear() raise gr.Error("Please provide a valid audio file.") @@ -88,14 +97,17 @@ class GradioTranscriptionInterface: Returns: str: Transcribed text. """ + + gv.TRANSCRIBE_ACTIVE.set() + kwargs = { "language": language if language != "None" else None, "task": 'translate' if translation == "Yes" else None } - + if isinstance(source, str): result = self.model.transcribe(source, **kwargs) - + gv.TRANSCRIBE_ACTIVE.clear() return str(result) elif isinstance(source, list): @@ -111,9 +123,12 @@ class GradioTranscriptionInterface: out += str(res) out += "\n\n" + gv.TRANSCRIBE_ACTIVE.clear() + return out else: + gv.TRANSCRIBE_ACTIVE.clear() raise gr.Error("Please provide a valid audio file.") def perform_diarisation(self, source, num_speakers): @@ -123,6 +138,9 @@ class GradioTranscriptionInterface: Returns: str: JSON output of diarisation result. """ + + gv.TRANSCRIBE_ACTIVE.set() + kwargs = { "num_speakers": num_speakers if num_speakers != 0 else None, } @@ -131,9 +149,10 @@ class GradioTranscriptionInterface: try: result = self.model.diarization(source, **kwargs) except ValueError: + gv.TRANSCRIBE_ACTIVE.clear() raise gr.Error("Couldn't detect any speech in the provided audio. \ Please try again!") - + gv.TRANSCRIBE_ACTIVE.clear() return json.dumps(result, indent=2) elif isinstance(source, list): source_names = [s.split("/")[-1] for s in source] @@ -142,6 +161,7 @@ class GradioTranscriptionInterface: try: res = self.model.diarization(s, **kwargs) except ValueError: + res = f"NO DIARISATION FOUND FOR {s}" gr.Warning(f"Couldn't detect any speech in {s} will skip this file.") result.append(res) @@ -150,8 +170,11 @@ class GradioTranscriptionInterface: for i, res in enumerate(result): out[source_names[i]] = res - + + gv.TRANSCRIBE_ACTIVE.clear() + return json.dumps(out, indent=4) else: + gv.TRANSCRIBE_ACTIVE.clear() gr.Error("Please provide a valid audio file.")