Make everything work in processes and adding config to customize instance
This commit is contained in:
+68
-23
@@ -4,41 +4,86 @@ Therefore, the app is to be more efficient in the usage of the resources.
|
||||
By for example, unloading or reloading the model.
|
||||
"""
|
||||
|
||||
|
||||
|
||||
import time
|
||||
import gc
|
||||
from typing import Union
|
||||
import multiprocessing
|
||||
import torch
|
||||
import signal
|
||||
|
||||
import scraibe.app.global_var as gv
|
||||
from scraibe.autotranscript import Scraibe
|
||||
from gradio import Warning
|
||||
from scraibe.autotranscript import Scraibe
|
||||
from .stg import GradioTranscriptionInterface
|
||||
|
||||
def init_worker():
|
||||
signal.signal(signal.SIGINT, signal.SIG_IGN)
|
||||
|
||||
|
||||
def load_model_thread(model : Union[Scraibe, dict] = None):
|
||||
if model is None:
|
||||
gv.MODEL = Scraibe()
|
||||
elif type(model) is Scraibe:
|
||||
gv.MODEL = model
|
||||
elif type(model) is dict:
|
||||
gv.MODEL = Scraibe(**model)
|
||||
def clear_queue(queue):
|
||||
while not queue.empty():
|
||||
try:
|
||||
queue.get_nowait()
|
||||
except queue.Empty:
|
||||
continue
|
||||
|
||||
def model_worker(model_params : Union[Scraibe, dict],
|
||||
request_queue,
|
||||
last_active_time,
|
||||
response_queue,
|
||||
loaded_event,
|
||||
running_event,
|
||||
*args, **kwargs):
|
||||
|
||||
loaded_event.set()
|
||||
|
||||
if model_params is None:
|
||||
_model = Scraibe()
|
||||
elif type(model_params) is Scraibe:
|
||||
_model = model_params
|
||||
elif type(model_params) is dict:
|
||||
_model = Scraibe(**model_params)
|
||||
else:
|
||||
raise TypeError("model must be of type Scraibe, or dict")
|
||||
|
||||
gv.LAST_USED = time.time()
|
||||
|
||||
# Create a thread to monitor user activity
|
||||
def delete_unused_model():
|
||||
model = GradioTranscriptionInterface(_model)
|
||||
|
||||
while True:
|
||||
|
||||
_unload_porperty = (not gv.TRANSCRIBE_ACTIVE.is_set() and (time.time() - gv.LAST_USED > gv.TIMEOUT) and gv.MODEL is not None)
|
||||
req = request_queue.get()
|
||||
|
||||
if _unload_porperty:
|
||||
if req == "STOP":
|
||||
|
||||
del gv.MODEL
|
||||
gv.MODEL = None
|
||||
|
||||
gc.collect()
|
||||
torch.cuda.empty_cache()
|
||||
break
|
||||
elif type(req) is dict:
|
||||
runner = model.get_task_from_str(req.pop("task"))
|
||||
running_event.set()
|
||||
transcription = runner(**req)
|
||||
running_event.clear()
|
||||
response_queue.put(transcription)
|
||||
last_active_time.value = time.time()
|
||||
else:
|
||||
raise TypeError("request must be of type dict")
|
||||
|
||||
gv.MODEL_THREAD.join()
|
||||
|
||||
time.sleep(int(gv.TIMEOUT/5))
|
||||
del model
|
||||
torch.cuda.empty_cache()
|
||||
gc.collect()
|
||||
clear_queue(request_queue)
|
||||
clear_queue(response_queue)
|
||||
loaded_event.clear()
|
||||
|
||||
def start_model_worker(model_params, request_queue, last_active_time, response_queue,loaded_event, running_event, *args, **kwargs):
|
||||
context = multiprocessing.get_context('spawn')
|
||||
model_process = context.Process(target=model_worker, args=(model_params, request_queue, last_active_time, response_queue,loaded_event, running_event, *args), kwargs=kwargs)
|
||||
model_process.start()
|
||||
return model_process
|
||||
|
||||
def timer_thread(request_queue, last_active_time,loaded_event, running_event, timeout=30):
|
||||
while True:
|
||||
time.sleep(timeout)
|
||||
|
||||
if time.time() - last_active_time.value > timeout and loaded_event.is_set() and not running_event.is_set():
|
||||
print(f"No activity for the last {timeout} seconds. Stopping the model worker.", flush=True)
|
||||
request_queue.put("STOP")
|
||||
Warning("Model worker stopped due to inactivity.")
|
||||
Reference in New Issue
Block a user