make gradio working with treads

This commit is contained in:
Jaikinator
2023-11-25 15:17:12 +01:00
parent bbb2c848e3
commit 93e5ce15f9
6 changed files with 101 additions and 21 deletions
+1 -1
View File
@@ -1,5 +1,5 @@
from .qtfaststart import * from .qtfaststart import *
from .activity_tracker import * from .multi import *
from .interface import * from .interface import *
from .stg import * from .stg import *
from .interactions import * from .interactions import *
+10 -1
View File
@@ -3,7 +3,16 @@ Stores global variables for the app.
""" """
# Global variable to store the model # Global variable to store the model
from threading import Event
import time
MODEL = None MODEL = None
MODEL_THREAD_PARAMS = None
MODEL_THREAD = None
# Global variable to track user activity # Global variable to track user activity
USER_ACTIVE = False LAST_USED = time.time()
TIMEOUT = 30 #seconds
TRANSCRIBE_ACTIVE = Event()
+14 -5
View File
@@ -3,10 +3,12 @@ This file contains ervery function that will be called when the user interacts w
UI like pressing a button or uploading a file. UI like pressing a button or uploading a file.
""" """
from math import pi import time
import gradio as gr import gradio as gr
import scraibe.app.global_var as gv import scraibe.app.global_var as gv
from scraibe import Transcript from scraibe import Transcript
from scraibe.app.stg import GradioTranscriptionInterface
import threading
def select_task(choice): def select_task(choice):
# tell the app that it is still in use # tell the app that it is still in use
@@ -84,11 +86,18 @@ def run_scraibe(task,
file_in, file_in,
progress = gr.Progress(track_tqdm= True)): progress = gr.Progress(track_tqdm= True)):
# get *args which are not None # get *args which are not None
pipe = gv.MODEL if gv.MODEL is None and gv.MODEL_THREAD_PARAMS is not None:
progress(0, desc='Model was not loaded to conserve resources. Loading model...')
progress(0, desc='Starting task...') time.sleep(1)
gv.MODEL_THREAD = threading.Thread(**gv.MODEL_THREAD_PARAMS)
gv.MODEL_THREAD.start()
gv.MODEL_THREAD.join()
pipe = GradioTranscriptionInterface()
progress(0.1, desc='Starting task...')
source = audio1 or audio2 or video1 or video2 or file_in source = audio1 or audio2 or video1 or video2 or file_in
if isinstance(source, list): if isinstance(source, list):
+1 -6
View File
@@ -9,8 +9,6 @@ import scraibe.app.global_var as gv
from .interactions import * from .interactions import *
from .stg import * from .stg import *
from scraibe import Scraibe
theme = gr.themes.Soft( theme = gr.themes.Soft(
primary_hue="green", primary_hue="green",
secondary_hue='orange', secondary_hue='orange',
@@ -36,10 +34,7 @@ LANGUAGES = [
CURRENT_PATH = os.path.dirname(os.path.realpath(__file__)) CURRENT_PATH = os.path.dirname(os.path.realpath(__file__))
def gradio_Interface(pipe : Scraibe = None): def gradio_Interface():
if pipe is not None:
gv.MODEL = GradioTranscriptionInterface(pipe)
with gr.Blocks(theme=theme,title='ScrAIbe: Automatic Audio Transcription') as demo: with gr.Blocks(theme=theme,title='ScrAIbe: Automatic Audio Transcription') as demo:
+44
View File
@@ -0,0 +1,44 @@
"""
This file contains the functions which are related to monitoring the actual app usage.
Therefore, the app is to be more efficient in the usage of the resources.
By for example, unloading or reloading the model.
"""
import time
import gc
from typing import Union
import torch
import scraibe.app.global_var as gv
from scraibe.autotranscript import Scraibe
def load_model_thread(model : Union[Scraibe, dict] = None):
if model is None:
gv.MODEL = Scraibe()
elif type(model) is Scraibe:
gv.MODEL = model
elif type(model) is dict:
gv.MODEL = Scraibe(**model)
else:
raise TypeError("model must be of type Scraibe, or dict")
gv.LAST_USED = time.time()
# Create a thread to monitor user activity
def delete_unused_model():
while True:
_unload_porperty = (not gv.TRANSCRIBE_ACTIVE.is_set() and (time.time() - gv.LAST_USED > gv.TIMEOUT) and gv.MODEL is not None)
if _unload_porperty:
del gv.MODEL
gv.MODEL = None
gc.collect()
torch.cuda.empty_cache()
gv.MODEL_THREAD.join()
time.sleep(int(gv.TIMEOUT/5))
+31 -8
View File
@@ -9,7 +9,8 @@ It makes adds gradio interactions to the scraibe class in the back.
import json import json
import gradio as gr import gradio as gr
from tqdm import tqdm from tqdm import tqdm
from scraibe import Scraibe
import scraibe.app.global_var as gv
class GradioTranscriptionInterface: class GradioTranscriptionInterface:
@@ -17,14 +18,14 @@ class GradioTranscriptionInterface:
Interface handling the interaction between Gradio UI and the Audio Transcription system. Interface handling the interaction between Gradio UI and the Audio Transcription system.
""" """
def __init__(self, model: Scraibe): def __init__(self):
""" """
Initializes the GradioTranscriptionInterface with a transcription model. Initializes the GradioTranscriptionInterface with a transcription model.
Args: Args:
model (Scraibe): Model responsible for audio transcription tasks. model (Scraibe): Model responsible for audio transcription tasks.
""" """
self.model = model self.model = gv.MODEL
def auto_transcribe(self, source, def auto_transcribe(self, source,
num_speakers : int, num_speakers : int,
@@ -37,6 +38,8 @@ class GradioTranscriptionInterface:
tuple: Transcribed text (str), JSON output (dict) tuple: Transcribed text (str), JSON output (dict)
""" """
gv.TRANSCRIBE_ACTIVE.set()
kwargs = { kwargs = {
"num_speakers": num_speakers if num_speakers != 0 else None, "num_speakers": num_speakers if num_speakers != 0 else None,
"language": language if language != "None" else None, "language": language if language != "None" else None,
@@ -46,9 +49,11 @@ class GradioTranscriptionInterface:
try: try:
result = self.model.autotranscribe(source, **kwargs) result = self.model.autotranscribe(source, **kwargs)
except ValueError: except ValueError:
gv.TRANSCRIBE_ACTIVE.clear()
raise gr.Error("Couldn't detect any speech in the provided audio. \ raise gr.Error("Couldn't detect any speech in the provided audio. \
Please try again!") Please try again!")
gv.TRANSCRIBE_ACTIVE.clear()
return str(result), result.get_json() return str(result), result.get_json()
elif isinstance(source, list): elif isinstance(source, list):
@@ -74,10 +79,14 @@ class GradioTranscriptionInterface:
out_dict[source_names[i]] = r out_dict[source_names[i]] = r
else: else:
out_dict[source_names[i]] = r.get_dict() out_dict[source_names[i]] = r.get_dict()
gv.TRANSCRIBE_ACTIVE.clear()
return out, json.dumps(out_dict, indent=4) return out, json.dumps(out_dict, indent=4)
else: else:
gv.TRANSCRIBE_ACTIVE.clear()
raise gr.Error("Please provide a valid audio file.") raise gr.Error("Please provide a valid audio file.")
@@ -88,14 +97,17 @@ class GradioTranscriptionInterface:
Returns: Returns:
str: Transcribed text. str: Transcribed text.
""" """
gv.TRANSCRIBE_ACTIVE.set()
kwargs = { kwargs = {
"language": language if language != "None" else None, "language": language if language != "None" else None,
"task": 'translate' if translation == "Yes" else None "task": 'translate' if translation == "Yes" else None
} }
if isinstance(source, str): if isinstance(source, str):
result = self.model.transcribe(source, **kwargs) result = self.model.transcribe(source, **kwargs)
gv.TRANSCRIBE_ACTIVE.clear()
return str(result) return str(result)
elif isinstance(source, list): elif isinstance(source, list):
@@ -111,9 +123,12 @@ class GradioTranscriptionInterface:
out += str(res) out += str(res)
out += "\n\n" out += "\n\n"
gv.TRANSCRIBE_ACTIVE.clear()
return out return out
else: else:
gv.TRANSCRIBE_ACTIVE.clear()
raise gr.Error("Please provide a valid audio file.") raise gr.Error("Please provide a valid audio file.")
def perform_diarisation(self, source, num_speakers): def perform_diarisation(self, source, num_speakers):
@@ -123,6 +138,9 @@ class GradioTranscriptionInterface:
Returns: Returns:
str: JSON output of diarisation result. str: JSON output of diarisation result.
""" """
gv.TRANSCRIBE_ACTIVE.set()
kwargs = { kwargs = {
"num_speakers": num_speakers if num_speakers != 0 else None, "num_speakers": num_speakers if num_speakers != 0 else None,
} }
@@ -131,9 +149,10 @@ class GradioTranscriptionInterface:
try: try:
result = self.model.diarization(source, **kwargs) result = self.model.diarization(source, **kwargs)
except ValueError: except ValueError:
gv.TRANSCRIBE_ACTIVE.clear()
raise gr.Error("Couldn't detect any speech in the provided audio. \ raise gr.Error("Couldn't detect any speech in the provided audio. \
Please try again!") Please try again!")
gv.TRANSCRIBE_ACTIVE.clear()
return json.dumps(result, indent=2) return json.dumps(result, indent=2)
elif isinstance(source, list): elif isinstance(source, list):
source_names = [s.split("/")[-1] for s in source] source_names = [s.split("/")[-1] for s in source]
@@ -142,6 +161,7 @@ class GradioTranscriptionInterface:
try: try:
res = self.model.diarization(s, **kwargs) res = self.model.diarization(s, **kwargs)
except ValueError: except ValueError:
res = f"NO DIARISATION FOUND FOR {s}" res = f"NO DIARISATION FOUND FOR {s}"
gr.Warning(f"Couldn't detect any speech in {s} will skip this file.") gr.Warning(f"Couldn't detect any speech in {s} will skip this file.")
result.append(res) result.append(res)
@@ -150,8 +170,11 @@ class GradioTranscriptionInterface:
for i, res in enumerate(result): for i, res in enumerate(result):
out[source_names[i]] = res out[source_names[i]] = res
gv.TRANSCRIBE_ACTIVE.clear()
return json.dumps(out, indent=4) return json.dumps(out, indent=4)
else: else:
gv.TRANSCRIBE_ACTIVE.clear()
gr.Error("Please provide a valid audio file.") gr.Error("Please provide a valid audio file.")