make gradio working with treads

This commit is contained in:
Jaikinator
2023-11-25 15:17:12 +01:00
parent bbb2c848e3
commit 93e5ce15f9
6 changed files with 101 additions and 21 deletions
+1 -1
View File
@@ -1,5 +1,5 @@
from .qtfaststart import *
from .activity_tracker import *
from .multi import *
from .interface import *
from .stg import *
from .interactions import *
+10 -1
View File
@@ -3,7 +3,16 @@ Stores global variables for the app.
"""
# Global variable to store the model
from threading import Event
import time
MODEL = None
MODEL_THREAD_PARAMS = None
MODEL_THREAD = None
# Global variable to track user activity
USER_ACTIVE = False
LAST_USED = time.time()
TIMEOUT = 30 #seconds
TRANSCRIBE_ACTIVE = Event()
+14 -5
View File
@@ -3,10 +3,12 @@ This file contains ervery function that will be called when the user interacts w
UI like pressing a button or uploading a file.
"""
from math import pi
import time
import gradio as gr
import scraibe.app.global_var as gv
from scraibe import Transcript
from scraibe.app.stg import GradioTranscriptionInterface
import threading
def select_task(choice):
# tell the app that it is still in use
@@ -84,11 +86,18 @@ def run_scraibe(task,
file_in,
progress = gr.Progress(track_tqdm= True)):
# get *args which are not None
# get *args which are not None
pipe = gv.MODEL
progress(0, desc='Starting task...')
if gv.MODEL is None and gv.MODEL_THREAD_PARAMS is not None:
progress(0, desc='Model was not loaded to conserve resources. Loading model...')
time.sleep(1)
gv.MODEL_THREAD = threading.Thread(**gv.MODEL_THREAD_PARAMS)
gv.MODEL_THREAD.start()
gv.MODEL_THREAD.join()
pipe = GradioTranscriptionInterface()
progress(0.1, desc='Starting task...')
source = audio1 or audio2 or video1 or video2 or file_in
if isinstance(source, list):
+1 -6
View File
@@ -9,8 +9,6 @@ import scraibe.app.global_var as gv
from .interactions import *
from .stg import *
from scraibe import Scraibe
theme = gr.themes.Soft(
primary_hue="green",
secondary_hue='orange',
@@ -36,10 +34,7 @@ LANGUAGES = [
CURRENT_PATH = os.path.dirname(os.path.realpath(__file__))
def gradio_Interface(pipe : Scraibe = None):
if pipe is not None:
gv.MODEL = GradioTranscriptionInterface(pipe)
def gradio_Interface():
with gr.Blocks(theme=theme,title='ScrAIbe: Automatic Audio Transcription') as demo:
+44
View File
@@ -0,0 +1,44 @@
"""
This file contains the functions which are related to monitoring the actual app usage.
Therefore, the app is to be more efficient in the usage of the resources.
By for example, unloading or reloading the model.
"""
import time
import gc
from typing import Union
import torch
import scraibe.app.global_var as gv
from scraibe.autotranscript import Scraibe
def load_model_thread(model : Union[Scraibe, dict] = None):
if model is None:
gv.MODEL = Scraibe()
elif type(model) is Scraibe:
gv.MODEL = model
elif type(model) is dict:
gv.MODEL = Scraibe(**model)
else:
raise TypeError("model must be of type Scraibe, or dict")
gv.LAST_USED = time.time()
# Create a thread to monitor user activity
def delete_unused_model():
while True:
_unload_porperty = (not gv.TRANSCRIBE_ACTIVE.is_set() and (time.time() - gv.LAST_USED > gv.TIMEOUT) and gv.MODEL is not None)
if _unload_porperty:
del gv.MODEL
gv.MODEL = None
gc.collect()
torch.cuda.empty_cache()
gv.MODEL_THREAD.join()
time.sleep(int(gv.TIMEOUT/5))
+31 -8
View File
@@ -9,7 +9,8 @@ It makes adds gradio interactions to the scraibe class in the back.
import json
import gradio as gr
from tqdm import tqdm
from scraibe import Scraibe
import scraibe.app.global_var as gv
class GradioTranscriptionInterface:
@@ -17,14 +18,14 @@ class GradioTranscriptionInterface:
Interface handling the interaction between Gradio UI and the Audio Transcription system.
"""
def __init__(self, model: Scraibe):
def __init__(self):
"""
Initializes the GradioTranscriptionInterface with a transcription model.
Args:
model (Scraibe): Model responsible for audio transcription tasks.
"""
self.model = model
self.model = gv.MODEL
def auto_transcribe(self, source,
num_speakers : int,
@@ -37,6 +38,8 @@ class GradioTranscriptionInterface:
tuple: Transcribed text (str), JSON output (dict)
"""
gv.TRANSCRIBE_ACTIVE.set()
kwargs = {
"num_speakers": num_speakers if num_speakers != 0 else None,
"language": language if language != "None" else None,
@@ -46,9 +49,11 @@ class GradioTranscriptionInterface:
try:
result = self.model.autotranscribe(source, **kwargs)
except ValueError:
gv.TRANSCRIBE_ACTIVE.clear()
raise gr.Error("Couldn't detect any speech in the provided audio. \
Please try again!")
gv.TRANSCRIBE_ACTIVE.clear()
return str(result), result.get_json()
elif isinstance(source, list):
@@ -74,10 +79,14 @@ class GradioTranscriptionInterface:
out_dict[source_names[i]] = r
else:
out_dict[source_names[i]] = r.get_dict()
gv.TRANSCRIBE_ACTIVE.clear()
return out, json.dumps(out_dict, indent=4)
else:
gv.TRANSCRIBE_ACTIVE.clear()
raise gr.Error("Please provide a valid audio file.")
@@ -88,14 +97,17 @@ class GradioTranscriptionInterface:
Returns:
str: Transcribed text.
"""
gv.TRANSCRIBE_ACTIVE.set()
kwargs = {
"language": language if language != "None" else None,
"task": 'translate' if translation == "Yes" else None
}
if isinstance(source, str):
result = self.model.transcribe(source, **kwargs)
gv.TRANSCRIBE_ACTIVE.clear()
return str(result)
elif isinstance(source, list):
@@ -111,9 +123,12 @@ class GradioTranscriptionInterface:
out += str(res)
out += "\n\n"
gv.TRANSCRIBE_ACTIVE.clear()
return out
else:
gv.TRANSCRIBE_ACTIVE.clear()
raise gr.Error("Please provide a valid audio file.")
def perform_diarisation(self, source, num_speakers):
@@ -123,6 +138,9 @@ class GradioTranscriptionInterface:
Returns:
str: JSON output of diarisation result.
"""
gv.TRANSCRIBE_ACTIVE.set()
kwargs = {
"num_speakers": num_speakers if num_speakers != 0 else None,
}
@@ -131,9 +149,10 @@ class GradioTranscriptionInterface:
try:
result = self.model.diarization(source, **kwargs)
except ValueError:
gv.TRANSCRIBE_ACTIVE.clear()
raise gr.Error("Couldn't detect any speech in the provided audio. \
Please try again!")
gv.TRANSCRIBE_ACTIVE.clear()
return json.dumps(result, indent=2)
elif isinstance(source, list):
source_names = [s.split("/")[-1] for s in source]
@@ -142,6 +161,7 @@ class GradioTranscriptionInterface:
try:
res = self.model.diarization(s, **kwargs)
except ValueError:
res = f"NO DIARISATION FOUND FOR {s}"
gr.Warning(f"Couldn't detect any speech in {s} will skip this file.")
result.append(res)
@@ -150,8 +170,11 @@ class GradioTranscriptionInterface:
for i, res in enumerate(result):
out[source_names[i]] = res
gv.TRANSCRIBE_ACTIVE.clear()
return json.dumps(out, indent=4)
else:
gv.TRANSCRIBE_ACTIVE.clear()
gr.Error("Please provide a valid audio file.")