Make everything work in processes and adding config to customize instance

This commit is contained in:
Jaikinator
2023-12-07 16:22:52 +01:00
parent 32b27442e6
commit 9eb9f5af8d
6 changed files with 241 additions and 96 deletions
+68 -23
View File
@@ -4,41 +4,86 @@ Therefore, the app is to be more efficient in the usage of the resources.
By for example, unloading or reloading the model.
"""
import time
import gc
from typing import Union
import multiprocessing
import torch
import signal
import scraibe.app.global_var as gv
from scraibe.autotranscript import Scraibe
from gradio import Warning
from scraibe.autotranscript import Scraibe
from .stg import GradioTranscriptionInterface
def init_worker():
signal.signal(signal.SIGINT, signal.SIG_IGN)
def load_model_thread(model : Union[Scraibe, dict] = None):
if model is None:
gv.MODEL = Scraibe()
elif type(model) is Scraibe:
gv.MODEL = model
elif type(model) is dict:
gv.MODEL = Scraibe(**model)
def clear_queue(queue):
while not queue.empty():
try:
queue.get_nowait()
except queue.Empty:
continue
def model_worker(model_params : Union[Scraibe, dict],
request_queue,
last_active_time,
response_queue,
loaded_event,
running_event,
*args, **kwargs):
loaded_event.set()
if model_params is None:
_model = Scraibe()
elif type(model_params) is Scraibe:
_model = model_params
elif type(model_params) is dict:
_model = Scraibe(**model_params)
else:
raise TypeError("model must be of type Scraibe, or dict")
gv.LAST_USED = time.time()
# Create a thread to monitor user activity
def delete_unused_model():
model = GradioTranscriptionInterface(_model)
while True:
_unload_porperty = (not gv.TRANSCRIBE_ACTIVE.is_set() and (time.time() - gv.LAST_USED > gv.TIMEOUT) and gv.MODEL is not None)
req = request_queue.get()
if _unload_porperty:
if req == "STOP":
del gv.MODEL
gv.MODEL = None
gc.collect()
torch.cuda.empty_cache()
break
elif type(req) is dict:
runner = model.get_task_from_str(req.pop("task"))
running_event.set()
transcription = runner(**req)
running_event.clear()
response_queue.put(transcription)
last_active_time.value = time.time()
else:
raise TypeError("request must be of type dict")
gv.MODEL_THREAD.join()
time.sleep(int(gv.TIMEOUT/5))
del model
torch.cuda.empty_cache()
gc.collect()
clear_queue(request_queue)
clear_queue(response_queue)
loaded_event.clear()
def start_model_worker(model_params, request_queue, last_active_time, response_queue,loaded_event, running_event, *args, **kwargs):
context = multiprocessing.get_context('spawn')
model_process = context.Process(target=model_worker, args=(model_params, request_queue, last_active_time, response_queue,loaded_event, running_event, *args), kwargs=kwargs)
model_process.start()
return model_process
def timer_thread(request_queue, last_active_time,loaded_event, running_event, timeout=30):
while True:
time.sleep(timeout)
if time.time() - last_active_time.value > timeout and loaded_event.is_set() and not running_event.is_set():
print(f"No activity for the last {timeout} seconds. Stopping the model worker.", flush=True)
request_queue.put("STOP")
Warning("Model worker stopped due to inactivity.")