diff --git a/scraibe/app/gradio_app.py b/scraibe/app/gradio_app.py index 086db17..6913643 100644 --- a/scraibe/app/gradio_app.py +++ b/scraibe/app/gradio_app.py @@ -32,13 +32,15 @@ Usage: """ + import json +from math import pi import os -import re import gradio as gr import threading from tqdm import tqdm + import time from scraibe import Scraibe, Transcript @@ -226,11 +228,23 @@ class GradioTranscriptionInterface: # Gradio Interface #### -def gradio_Interface(model : Scraibe = None): +def gradio_Interface(model : Scraibe = None, timeout = 1): + """ + Gradio Web interface for audio transcription. + + :param model: Scraibe model, defaults to None + :type model: Scraibe, optional + :param timeout: Time until model is unloaded, defaults to 600 seconds + :type timeout: int, optional + :return: Gradio Interface + :rtype: gradio.Interface + """ if model is None: model = Scraibe() - + + save_model_params = model.params + pipe = GradioTranscriptionInterface(model) def select_task(choice): @@ -314,6 +328,10 @@ def gradio_Interface(model : Scraibe = None): progress = gr.Progress(track_tqdm= True)): # get *args which are not None + if not "model" in locals(): + gr.Warning("Model unloaded due to inactivity. Reloading the model, please wait.") + model = Scraibe(**save_model_params) + pipe = GradioTranscriptionInterface(model) # # tell the app that it is still in use reset_user_activity() @@ -373,21 +391,23 @@ def gradio_Interface(model : Scraibe = None): return gr.update(value = str(trans)),gr.update(value = trans.get_json()) # Create a thread to monitor user activity - def monitor_activity(): + def monitor_activity(model, pipe, timeout=timeout): global USER_ACTIVE while True: - time.sleep(60) # Check user activity every second + time.sleep(timeout) # Check user activity every second with user_active_lock: if not USER_ACTIVE: del model + del pipe print("Model deleted empty memory") + gr.Warning("Model unloaded due to inactivity. Please reload the model to continue.") break - USER_ACTIVE = False + USER_ACTIVE = False # Start the monitoring thread - activity_thread = threading.Thread(target=monitor_activity) + activity_thread = threading.Thread(target=monitor_activity, args=(model, pipe)) activity_thread.daemon = True activity_thread.start() @@ -401,7 +421,7 @@ def gradio_Interface(model : Scraibe = None): header = header.replace("/file=logo.svg", f"/file={CURRENT_PATH}/logo.svg" ) gr.HTML(header, visible= True, show_label=False) - + with gr.Row(): with gr.Column(): @@ -476,8 +496,6 @@ def gradio_Interface(model : Scraibe = None): annotate.click(fn = annotate_output, inputs=[annoation, out_json], outputs=[out_txt, out_json]) - - return demo