make gradio working with treads
This commit is contained in:
@@ -1,5 +1,5 @@
|
|||||||
from .qtfaststart import *
|
from .qtfaststart import *
|
||||||
from .activity_tracker import *
|
from .multi import *
|
||||||
from .interface import *
|
from .interface import *
|
||||||
from .stg import *
|
from .stg import *
|
||||||
from .interactions import *
|
from .interactions import *
|
||||||
|
|||||||
@@ -3,7 +3,16 @@ Stores global variables for the app.
|
|||||||
"""
|
"""
|
||||||
|
|
||||||
# Global variable to store the model
|
# Global variable to store the model
|
||||||
|
from threading import Event
|
||||||
|
|
||||||
|
import time
|
||||||
|
|
||||||
|
|
||||||
MODEL = None
|
MODEL = None
|
||||||
|
MODEL_THREAD_PARAMS = None
|
||||||
|
MODEL_THREAD = None
|
||||||
|
|
||||||
# Global variable to track user activity
|
# Global variable to track user activity
|
||||||
USER_ACTIVE = False
|
LAST_USED = time.time()
|
||||||
|
TIMEOUT = 30 #seconds
|
||||||
|
TRANSCRIBE_ACTIVE = Event()
|
||||||
@@ -3,10 +3,12 @@ This file contains ervery function that will be called when the user interacts w
|
|||||||
UI like pressing a button or uploading a file.
|
UI like pressing a button or uploading a file.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
from math import pi
|
import time
|
||||||
import gradio as gr
|
import gradio as gr
|
||||||
import scraibe.app.global_var as gv
|
import scraibe.app.global_var as gv
|
||||||
from scraibe import Transcript
|
from scraibe import Transcript
|
||||||
|
from scraibe.app.stg import GradioTranscriptionInterface
|
||||||
|
import threading
|
||||||
|
|
||||||
def select_task(choice):
|
def select_task(choice):
|
||||||
# tell the app that it is still in use
|
# tell the app that it is still in use
|
||||||
@@ -86,9 +88,16 @@ def run_scraibe(task,
|
|||||||
|
|
||||||
# get *args which are not None
|
# get *args which are not None
|
||||||
|
|
||||||
pipe = gv.MODEL
|
if gv.MODEL is None and gv.MODEL_THREAD_PARAMS is not None:
|
||||||
|
progress(0, desc='Model was not loaded to conserve resources. Loading model...')
|
||||||
|
time.sleep(1)
|
||||||
|
gv.MODEL_THREAD = threading.Thread(**gv.MODEL_THREAD_PARAMS)
|
||||||
|
gv.MODEL_THREAD.start()
|
||||||
|
gv.MODEL_THREAD.join()
|
||||||
|
|
||||||
progress(0, desc='Starting task...')
|
pipe = GradioTranscriptionInterface()
|
||||||
|
|
||||||
|
progress(0.1, desc='Starting task...')
|
||||||
source = audio1 or audio2 or video1 or video2 or file_in
|
source = audio1 or audio2 or video1 or video2 or file_in
|
||||||
|
|
||||||
if isinstance(source, list):
|
if isinstance(source, list):
|
||||||
|
|||||||
@@ -9,8 +9,6 @@ import scraibe.app.global_var as gv
|
|||||||
from .interactions import *
|
from .interactions import *
|
||||||
from .stg import *
|
from .stg import *
|
||||||
|
|
||||||
from scraibe import Scraibe
|
|
||||||
|
|
||||||
theme = gr.themes.Soft(
|
theme = gr.themes.Soft(
|
||||||
primary_hue="green",
|
primary_hue="green",
|
||||||
secondary_hue='orange',
|
secondary_hue='orange',
|
||||||
@@ -36,10 +34,7 @@ LANGUAGES = [
|
|||||||
CURRENT_PATH = os.path.dirname(os.path.realpath(__file__))
|
CURRENT_PATH = os.path.dirname(os.path.realpath(__file__))
|
||||||
|
|
||||||
|
|
||||||
def gradio_Interface(pipe : Scraibe = None):
|
def gradio_Interface():
|
||||||
|
|
||||||
if pipe is not None:
|
|
||||||
gv.MODEL = GradioTranscriptionInterface(pipe)
|
|
||||||
|
|
||||||
with gr.Blocks(theme=theme,title='ScrAIbe: Automatic Audio Transcription') as demo:
|
with gr.Blocks(theme=theme,title='ScrAIbe: Automatic Audio Transcription') as demo:
|
||||||
|
|
||||||
|
|||||||
@@ -0,0 +1,44 @@
|
|||||||
|
"""
|
||||||
|
This file contains the functions which are related to monitoring the actual app usage.
|
||||||
|
Therefore, the app is to be more efficient in the usage of the resources.
|
||||||
|
By for example, unloading or reloading the model.
|
||||||
|
"""
|
||||||
|
|
||||||
|
import time
|
||||||
|
import gc
|
||||||
|
from typing import Union
|
||||||
|
import torch
|
||||||
|
|
||||||
|
import scraibe.app.global_var as gv
|
||||||
|
from scraibe.autotranscript import Scraibe
|
||||||
|
|
||||||
|
|
||||||
|
def load_model_thread(model : Union[Scraibe, dict] = None):
|
||||||
|
if model is None:
|
||||||
|
gv.MODEL = Scraibe()
|
||||||
|
elif type(model) is Scraibe:
|
||||||
|
gv.MODEL = model
|
||||||
|
elif type(model) is dict:
|
||||||
|
gv.MODEL = Scraibe(**model)
|
||||||
|
else:
|
||||||
|
raise TypeError("model must be of type Scraibe, or dict")
|
||||||
|
|
||||||
|
gv.LAST_USED = time.time()
|
||||||
|
|
||||||
|
# Create a thread to monitor user activity
|
||||||
|
def delete_unused_model():
|
||||||
|
while True:
|
||||||
|
|
||||||
|
_unload_porperty = (not gv.TRANSCRIBE_ACTIVE.is_set() and (time.time() - gv.LAST_USED > gv.TIMEOUT) and gv.MODEL is not None)
|
||||||
|
|
||||||
|
if _unload_porperty:
|
||||||
|
|
||||||
|
del gv.MODEL
|
||||||
|
gv.MODEL = None
|
||||||
|
|
||||||
|
gc.collect()
|
||||||
|
torch.cuda.empty_cache()
|
||||||
|
|
||||||
|
gv.MODEL_THREAD.join()
|
||||||
|
|
||||||
|
time.sleep(int(gv.TIMEOUT/5))
|
||||||
+28
-5
@@ -9,7 +9,8 @@ It makes adds gradio interactions to the scraibe class in the back.
|
|||||||
import json
|
import json
|
||||||
import gradio as gr
|
import gradio as gr
|
||||||
from tqdm import tqdm
|
from tqdm import tqdm
|
||||||
from scraibe import Scraibe
|
|
||||||
|
import scraibe.app.global_var as gv
|
||||||
|
|
||||||
|
|
||||||
class GradioTranscriptionInterface:
|
class GradioTranscriptionInterface:
|
||||||
@@ -17,14 +18,14 @@ class GradioTranscriptionInterface:
|
|||||||
Interface handling the interaction between Gradio UI and the Audio Transcription system.
|
Interface handling the interaction between Gradio UI and the Audio Transcription system.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
def __init__(self, model: Scraibe):
|
def __init__(self):
|
||||||
"""
|
"""
|
||||||
Initializes the GradioTranscriptionInterface with a transcription model.
|
Initializes the GradioTranscriptionInterface with a transcription model.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
model (Scraibe): Model responsible for audio transcription tasks.
|
model (Scraibe): Model responsible for audio transcription tasks.
|
||||||
"""
|
"""
|
||||||
self.model = model
|
self.model = gv.MODEL
|
||||||
|
|
||||||
def auto_transcribe(self, source,
|
def auto_transcribe(self, source,
|
||||||
num_speakers : int,
|
num_speakers : int,
|
||||||
@@ -37,6 +38,8 @@ class GradioTranscriptionInterface:
|
|||||||
tuple: Transcribed text (str), JSON output (dict)
|
tuple: Transcribed text (str), JSON output (dict)
|
||||||
"""
|
"""
|
||||||
|
|
||||||
|
gv.TRANSCRIBE_ACTIVE.set()
|
||||||
|
|
||||||
kwargs = {
|
kwargs = {
|
||||||
"num_speakers": num_speakers if num_speakers != 0 else None,
|
"num_speakers": num_speakers if num_speakers != 0 else None,
|
||||||
"language": language if language != "None" else None,
|
"language": language if language != "None" else None,
|
||||||
@@ -46,9 +49,11 @@ class GradioTranscriptionInterface:
|
|||||||
try:
|
try:
|
||||||
result = self.model.autotranscribe(source, **kwargs)
|
result = self.model.autotranscribe(source, **kwargs)
|
||||||
except ValueError:
|
except ValueError:
|
||||||
|
gv.TRANSCRIBE_ACTIVE.clear()
|
||||||
raise gr.Error("Couldn't detect any speech in the provided audio. \
|
raise gr.Error("Couldn't detect any speech in the provided audio. \
|
||||||
Please try again!")
|
Please try again!")
|
||||||
|
|
||||||
|
gv.TRANSCRIBE_ACTIVE.clear()
|
||||||
return str(result), result.get_json()
|
return str(result), result.get_json()
|
||||||
|
|
||||||
elif isinstance(source, list):
|
elif isinstance(source, list):
|
||||||
@@ -75,9 +80,13 @@ class GradioTranscriptionInterface:
|
|||||||
else:
|
else:
|
||||||
out_dict[source_names[i]] = r.get_dict()
|
out_dict[source_names[i]] = r.get_dict()
|
||||||
|
|
||||||
|
|
||||||
|
gv.TRANSCRIBE_ACTIVE.clear()
|
||||||
|
|
||||||
return out, json.dumps(out_dict, indent=4)
|
return out, json.dumps(out_dict, indent=4)
|
||||||
|
|
||||||
else:
|
else:
|
||||||
|
gv.TRANSCRIBE_ACTIVE.clear()
|
||||||
raise gr.Error("Please provide a valid audio file.")
|
raise gr.Error("Please provide a valid audio file.")
|
||||||
|
|
||||||
|
|
||||||
@@ -88,6 +97,9 @@ class GradioTranscriptionInterface:
|
|||||||
Returns:
|
Returns:
|
||||||
str: Transcribed text.
|
str: Transcribed text.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
|
gv.TRANSCRIBE_ACTIVE.set()
|
||||||
|
|
||||||
kwargs = {
|
kwargs = {
|
||||||
"language": language if language != "None" else None,
|
"language": language if language != "None" else None,
|
||||||
"task": 'translate' if translation == "Yes" else None
|
"task": 'translate' if translation == "Yes" else None
|
||||||
@@ -95,7 +107,7 @@ class GradioTranscriptionInterface:
|
|||||||
|
|
||||||
if isinstance(source, str):
|
if isinstance(source, str):
|
||||||
result = self.model.transcribe(source, **kwargs)
|
result = self.model.transcribe(source, **kwargs)
|
||||||
|
gv.TRANSCRIBE_ACTIVE.clear()
|
||||||
return str(result)
|
return str(result)
|
||||||
|
|
||||||
elif isinstance(source, list):
|
elif isinstance(source, list):
|
||||||
@@ -111,9 +123,12 @@ class GradioTranscriptionInterface:
|
|||||||
out += str(res)
|
out += str(res)
|
||||||
out += "\n\n"
|
out += "\n\n"
|
||||||
|
|
||||||
|
gv.TRANSCRIBE_ACTIVE.clear()
|
||||||
|
|
||||||
return out
|
return out
|
||||||
|
|
||||||
else:
|
else:
|
||||||
|
gv.TRANSCRIBE_ACTIVE.clear()
|
||||||
raise gr.Error("Please provide a valid audio file.")
|
raise gr.Error("Please provide a valid audio file.")
|
||||||
|
|
||||||
def perform_diarisation(self, source, num_speakers):
|
def perform_diarisation(self, source, num_speakers):
|
||||||
@@ -123,6 +138,9 @@ class GradioTranscriptionInterface:
|
|||||||
Returns:
|
Returns:
|
||||||
str: JSON output of diarisation result.
|
str: JSON output of diarisation result.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
|
gv.TRANSCRIBE_ACTIVE.set()
|
||||||
|
|
||||||
kwargs = {
|
kwargs = {
|
||||||
"num_speakers": num_speakers if num_speakers != 0 else None,
|
"num_speakers": num_speakers if num_speakers != 0 else None,
|
||||||
}
|
}
|
||||||
@@ -131,9 +149,10 @@ class GradioTranscriptionInterface:
|
|||||||
try:
|
try:
|
||||||
result = self.model.diarization(source, **kwargs)
|
result = self.model.diarization(source, **kwargs)
|
||||||
except ValueError:
|
except ValueError:
|
||||||
|
gv.TRANSCRIBE_ACTIVE.clear()
|
||||||
raise gr.Error("Couldn't detect any speech in the provided audio. \
|
raise gr.Error("Couldn't detect any speech in the provided audio. \
|
||||||
Please try again!")
|
Please try again!")
|
||||||
|
gv.TRANSCRIBE_ACTIVE.clear()
|
||||||
return json.dumps(result, indent=2)
|
return json.dumps(result, indent=2)
|
||||||
elif isinstance(source, list):
|
elif isinstance(source, list):
|
||||||
source_names = [s.split("/")[-1] for s in source]
|
source_names = [s.split("/")[-1] for s in source]
|
||||||
@@ -142,6 +161,7 @@ class GradioTranscriptionInterface:
|
|||||||
try:
|
try:
|
||||||
res = self.model.diarization(s, **kwargs)
|
res = self.model.diarization(s, **kwargs)
|
||||||
except ValueError:
|
except ValueError:
|
||||||
|
|
||||||
res = f"NO DIARISATION FOUND FOR {s}"
|
res = f"NO DIARISATION FOUND FOR {s}"
|
||||||
gr.Warning(f"Couldn't detect any speech in {s} will skip this file.")
|
gr.Warning(f"Couldn't detect any speech in {s} will skip this file.")
|
||||||
result.append(res)
|
result.append(res)
|
||||||
@@ -151,7 +171,10 @@ class GradioTranscriptionInterface:
|
|||||||
for i, res in enumerate(result):
|
for i, res in enumerate(result):
|
||||||
out[source_names[i]] = res
|
out[source_names[i]] = res
|
||||||
|
|
||||||
|
gv.TRANSCRIBE_ACTIVE.clear()
|
||||||
|
|
||||||
return json.dumps(out, indent=4)
|
return json.dumps(out, indent=4)
|
||||||
|
|
||||||
else:
|
else:
|
||||||
|
gv.TRANSCRIBE_ACTIVE.clear()
|
||||||
gr.Error("Please provide a valid audio file.")
|
gr.Error("Please provide a valid audio file.")
|
||||||
|
|||||||
Reference in New Issue
Block a user