Make everything work in processes and adding config to customize instance
This commit is contained in:
@@ -0,0 +1,48 @@
|
|||||||
|
launch:
|
||||||
|
# The following are the default values for the launch configuration
|
||||||
|
# for more informations look at https://www.gradio.app/docs/interface
|
||||||
|
server_port: 8080
|
||||||
|
server_name: 0.0.0.0
|
||||||
|
inbrowser: true
|
||||||
|
inline: false
|
||||||
|
max-threads: 40
|
||||||
|
quiet: false
|
||||||
|
auth:
|
||||||
|
enabled: false
|
||||||
|
username: admin
|
||||||
|
password: admin
|
||||||
|
auth_message: "Please enter your credentials"
|
||||||
|
show_error : false
|
||||||
|
favicon_path : null
|
||||||
|
ssl_keyfile : null
|
||||||
|
ssl_certfile : null
|
||||||
|
ssl_keyfile_password : null
|
||||||
|
ssl_verify : false
|
||||||
|
quiet : false
|
||||||
|
show_api : false
|
||||||
|
allowed_paths : null
|
||||||
|
blocked_paths : null
|
||||||
|
root_path : null
|
||||||
|
app_kwargs : null
|
||||||
|
state_session_capacity : 1000
|
||||||
|
share_server_address : null
|
||||||
|
share_server_protocol : null
|
||||||
|
share : false
|
||||||
|
debug : false
|
||||||
|
queue:
|
||||||
|
# The following are the default values for the queue configuration
|
||||||
|
# for more informations look at hhttps://www.gradio.app/docs/interface
|
||||||
|
status_update_rate : 'auto'
|
||||||
|
api_open : null
|
||||||
|
max_size : null
|
||||||
|
concurrency_count : null
|
||||||
|
default_concurrency_limit : 'not_set'
|
||||||
|
layout:
|
||||||
|
header: scraibe/app/header.html
|
||||||
|
footer: null
|
||||||
|
logo: scraibe/app/logo.svg
|
||||||
|
model:
|
||||||
|
whisper_model : null
|
||||||
|
dia_model: null
|
||||||
|
advanced:
|
||||||
|
timeout: 300 #seconds e.g. 5 minutes
|
||||||
@@ -3,16 +3,22 @@ Stores global variables for the app.
|
|||||||
"""
|
"""
|
||||||
|
|
||||||
# Global variable to store the model
|
# Global variable to store the model
|
||||||
from threading import Event
|
import multiprocessing
|
||||||
|
import os
|
||||||
import time
|
import time
|
||||||
|
import yaml
|
||||||
|
|
||||||
|
REQUEST_QUEUE = multiprocessing.Queue() # audio file path as string
|
||||||
|
RESPONSE_QUEUE = multiprocessing.Queue() # transcription as string
|
||||||
|
LAST_ACTIVE_TIME = multiprocessing.Value('d', time.time()) # time of last activity
|
||||||
|
LOADED_EVENT = multiprocessing.Event() # model loaded event
|
||||||
|
RUNNING_EVENT = multiprocessing.Event() # model running event
|
||||||
|
|
||||||
MODEL = None
|
MODEL_PARAMS = None # model parameters
|
||||||
MODEL_THREAD_PARAMS = None
|
MODEL_PROCESS = None # model process to handle globally
|
||||||
MODEL_THREAD = None
|
|
||||||
|
|
||||||
# Global variable to track user activity
|
# Global variable to track user activity
|
||||||
LAST_USED = time.time()
|
LAST_USED = time.time()
|
||||||
TIMEOUT = 30 #seconds
|
TIMEOUT = None #seconds
|
||||||
TRANSCRIBE_ACTIVE = Event()
|
|
||||||
|
DEFAULT_APP_CONIFG_PATH = os.path.join(os.path.dirname(os.path.realpath(__file__)), "config.yml")
|
||||||
|
|||||||
+25
-24
@@ -3,12 +3,12 @@ This file contains ervery function that will be called when the user interacts w
|
|||||||
UI like pressing a button or uploading a file.
|
UI like pressing a button or uploading a file.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
|
from re import M
|
||||||
import time
|
import time
|
||||||
import gradio as gr
|
import gradio as gr
|
||||||
import scraibe.app.global_var as gv
|
import scraibe.app.global_var as gv
|
||||||
from scraibe import Transcript
|
from scraibe import Transcript
|
||||||
from scraibe.app.stg import GradioTranscriptionInterface
|
from .multi import start_model_worker
|
||||||
import threading
|
|
||||||
|
|
||||||
def select_task(choice):
|
def select_task(choice):
|
||||||
# tell the app that it is still in use
|
# tell the app that it is still in use
|
||||||
@@ -84,33 +84,37 @@ def run_scraibe(task,
|
|||||||
video1,
|
video1,
|
||||||
video2,
|
video2,
|
||||||
file_in,
|
file_in,
|
||||||
progress = gr.Progress(track_tqdm= True)):
|
progress = gr.Progress(track_tqdm=False)):
|
||||||
|
|
||||||
# get *args which are not None
|
# get *args which are not None
|
||||||
|
if gv.MODEL_PROCESS is None or not gv.MODEL_PROCESS.is_alive():
|
||||||
|
#progress(0.0, desc='Loading model...')
|
||||||
|
gv.MODEL_PROCESS = start_model_worker(gv.MODEL_PARAMS,
|
||||||
|
gv.REQUEST_QUEUE,
|
||||||
|
gv.LAST_ACTIVE_TIME,
|
||||||
|
gv.RESPONSE_QUEUE,
|
||||||
|
gv.LOADED_EVENT,
|
||||||
|
gv.RUNNING_EVENT)
|
||||||
|
|
||||||
if gv.MODEL is None and gv.MODEL_THREAD_PARAMS is not None:
|
# progress(0.1, desc='Starting task...')
|
||||||
progress(0, desc='Model was not loaded to conserve resources. Loading model...')
|
|
||||||
time.sleep(1)
|
|
||||||
gv.MODEL_THREAD = threading.Thread(**gv.MODEL_THREAD_PARAMS)
|
|
||||||
gv.MODEL_THREAD.start()
|
|
||||||
gv.MODEL_THREAD.join()
|
|
||||||
|
|
||||||
pipe = GradioTranscriptionInterface()
|
|
||||||
|
|
||||||
progress(0.1, desc='Starting task...')
|
|
||||||
source = audio1 or audio2 or video1 or video2 or file_in
|
source = audio1 or audio2 or video1 or video2 or file_in
|
||||||
|
|
||||||
if isinstance(source, list):
|
if isinstance(source, list):
|
||||||
source = [s.name for s in source]
|
source = [s.name for s in source]
|
||||||
if len(source) == 1:
|
if len(source) == 1:
|
||||||
source = source[0]
|
source = source[0]
|
||||||
|
|
||||||
|
config = dict(source = source,
|
||||||
|
task = task,
|
||||||
|
num_speakers = num_speakers,
|
||||||
|
translate = translate,
|
||||||
|
language = language)
|
||||||
|
|
||||||
|
gv.REQUEST_QUEUE.put(config)
|
||||||
|
|
||||||
if task == 'Auto Transcribe':
|
if task == 'Auto Transcribe':
|
||||||
|
|
||||||
out_str , out_json = pipe.auto_transcribe(source = source,
|
out_str , out_json = gv.RESPONSE_QUEUE.get()
|
||||||
num_speakers = num_speakers,
|
|
||||||
translation = translate,
|
|
||||||
language = language)
|
|
||||||
|
|
||||||
if isinstance(source, str):
|
if isinstance(source, str):
|
||||||
return (gr.update(value = out_str, visible = True),
|
return (gr.update(value = out_str, visible = True),
|
||||||
@@ -125,9 +129,7 @@ def run_scraibe(task,
|
|||||||
|
|
||||||
elif task == 'Transcribe':
|
elif task == 'Transcribe':
|
||||||
|
|
||||||
out = pipe.transcribe(source = source,
|
out = gv.RESPONSE_QUEUE.get()
|
||||||
translation = translate,
|
|
||||||
language = language)
|
|
||||||
|
|
||||||
return (gr.update(value = out, visible = True),
|
return (gr.update(value = out, visible = True),
|
||||||
gr.update(value = None, visible = False),
|
gr.update(value = None, visible = False),
|
||||||
@@ -136,8 +138,7 @@ def run_scraibe(task,
|
|||||||
|
|
||||||
elif task == 'Diarisation':
|
elif task == 'Diarisation':
|
||||||
|
|
||||||
out = pipe.perform_diarisation(source = source,
|
out = gv.RESPONSE_QUEUE.get()
|
||||||
num_speakers = num_speakers)
|
|
||||||
|
|
||||||
return (gr.update(value = None, visible = False),
|
return (gr.update(value = None, visible = False),
|
||||||
gr.update(value = out, visible = True),
|
gr.update(value = out, visible = True),
|
||||||
|
|||||||
+68
-23
@@ -4,41 +4,86 @@ Therefore, the app is to be more efficient in the usage of the resources.
|
|||||||
By for example, unloading or reloading the model.
|
By for example, unloading or reloading the model.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
import time
|
import time
|
||||||
import gc
|
import gc
|
||||||
from typing import Union
|
from typing import Union
|
||||||
|
import multiprocessing
|
||||||
import torch
|
import torch
|
||||||
|
import signal
|
||||||
|
|
||||||
import scraibe.app.global_var as gv
|
from gradio import Warning
|
||||||
from scraibe.autotranscript import Scraibe
|
from scraibe.autotranscript import Scraibe
|
||||||
|
from .stg import GradioTranscriptionInterface
|
||||||
|
|
||||||
|
def init_worker():
|
||||||
|
signal.signal(signal.SIGINT, signal.SIG_IGN)
|
||||||
|
|
||||||
|
|
||||||
def load_model_thread(model : Union[Scraibe, dict] = None):
|
def clear_queue(queue):
|
||||||
if model is None:
|
while not queue.empty():
|
||||||
gv.MODEL = Scraibe()
|
try:
|
||||||
elif type(model) is Scraibe:
|
queue.get_nowait()
|
||||||
gv.MODEL = model
|
except queue.Empty:
|
||||||
elif type(model) is dict:
|
continue
|
||||||
gv.MODEL = Scraibe(**model)
|
|
||||||
|
def model_worker(model_params : Union[Scraibe, dict],
|
||||||
|
request_queue,
|
||||||
|
last_active_time,
|
||||||
|
response_queue,
|
||||||
|
loaded_event,
|
||||||
|
running_event,
|
||||||
|
*args, **kwargs):
|
||||||
|
|
||||||
|
loaded_event.set()
|
||||||
|
|
||||||
|
if model_params is None:
|
||||||
|
_model = Scraibe()
|
||||||
|
elif type(model_params) is Scraibe:
|
||||||
|
_model = model_params
|
||||||
|
elif type(model_params) is dict:
|
||||||
|
_model = Scraibe(**model_params)
|
||||||
else:
|
else:
|
||||||
raise TypeError("model must be of type Scraibe, or dict")
|
raise TypeError("model must be of type Scraibe, or dict")
|
||||||
|
|
||||||
gv.LAST_USED = time.time()
|
model = GradioTranscriptionInterface(_model)
|
||||||
|
|
||||||
# Create a thread to monitor user activity
|
|
||||||
def delete_unused_model():
|
|
||||||
while True:
|
while True:
|
||||||
|
|
||||||
_unload_porperty = (not gv.TRANSCRIBE_ACTIVE.is_set() and (time.time() - gv.LAST_USED > gv.TIMEOUT) and gv.MODEL is not None)
|
req = request_queue.get()
|
||||||
|
|
||||||
if _unload_porperty:
|
if req == "STOP":
|
||||||
|
|
||||||
del gv.MODEL
|
break
|
||||||
gv.MODEL = None
|
elif type(req) is dict:
|
||||||
|
runner = model.get_task_from_str(req.pop("task"))
|
||||||
gc.collect()
|
running_event.set()
|
||||||
torch.cuda.empty_cache()
|
transcription = runner(**req)
|
||||||
|
running_event.clear()
|
||||||
|
response_queue.put(transcription)
|
||||||
|
last_active_time.value = time.time()
|
||||||
|
else:
|
||||||
|
raise TypeError("request must be of type dict")
|
||||||
|
|
||||||
gv.MODEL_THREAD.join()
|
del model
|
||||||
|
torch.cuda.empty_cache()
|
||||||
time.sleep(int(gv.TIMEOUT/5))
|
gc.collect()
|
||||||
|
clear_queue(request_queue)
|
||||||
|
clear_queue(response_queue)
|
||||||
|
loaded_event.clear()
|
||||||
|
|
||||||
|
def start_model_worker(model_params, request_queue, last_active_time, response_queue,loaded_event, running_event, *args, **kwargs):
|
||||||
|
context = multiprocessing.get_context('spawn')
|
||||||
|
model_process = context.Process(target=model_worker, args=(model_params, request_queue, last_active_time, response_queue,loaded_event, running_event, *args), kwargs=kwargs)
|
||||||
|
model_process.start()
|
||||||
|
return model_process
|
||||||
|
|
||||||
|
def timer_thread(request_queue, last_active_time,loaded_event, running_event, timeout=30):
|
||||||
|
while True:
|
||||||
|
time.sleep(timeout)
|
||||||
|
|
||||||
|
if time.time() - last_active_time.value > timeout and loaded_event.is_set() and not running_event.is_set():
|
||||||
|
print(f"No activity for the last {timeout} seconds. Stopping the model worker.", flush=True)
|
||||||
|
request_queue.put("STOP")
|
||||||
|
Warning("Model worker stopped due to inactivity.")
|
||||||
+45
-42
@@ -18,19 +18,19 @@ class GradioTranscriptionInterface:
|
|||||||
Interface handling the interaction between Gradio UI and the Audio Transcription system.
|
Interface handling the interaction between Gradio UI and the Audio Transcription system.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
def __init__(self):
|
def __init__(self, model):
|
||||||
"""
|
"""
|
||||||
Initializes the GradioTranscriptionInterface with a transcription model.
|
Initializes the GradioTranscriptionInterface with a transcription model.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
model (Scraibe): Model responsible for audio transcription tasks.
|
model (Scraibe): Model responsible for audio transcription tasks.
|
||||||
"""
|
"""
|
||||||
self.model = gv.MODEL
|
self.model = model
|
||||||
|
|
||||||
def auto_transcribe(self, source,
|
def autotranscribe(self, source,
|
||||||
num_speakers : int,
|
num_speakers : int,
|
||||||
translation : bool,
|
translate : bool,
|
||||||
language : str):
|
language : str,*args ,**kwargs):
|
||||||
"""
|
"""
|
||||||
Shortcut method for the Scraibe task.
|
Shortcut method for the Scraibe task.
|
||||||
|
|
||||||
@@ -38,22 +38,18 @@ class GradioTranscriptionInterface:
|
|||||||
tuple: Transcribed text (str), JSON output (dict)
|
tuple: Transcribed text (str), JSON output (dict)
|
||||||
"""
|
"""
|
||||||
|
|
||||||
gv.TRANSCRIBE_ACTIVE.set()
|
_kwargs = {
|
||||||
|
|
||||||
kwargs = {
|
|
||||||
"num_speakers": num_speakers if num_speakers != 0 else None,
|
"num_speakers": num_speakers if num_speakers != 0 else None,
|
||||||
"language": language if language != "None" else None,
|
"language": language if language != "None" else None,
|
||||||
"task": 'translate' if translation else None
|
"task": 'translate' if translate else None
|
||||||
}
|
}
|
||||||
if isinstance(source, str):
|
if isinstance(source, str):
|
||||||
try:
|
try:
|
||||||
result = self.model.autotranscribe(source, **kwargs)
|
result = self.model.autotranscribe(source, **_kwargs)
|
||||||
except ValueError:
|
except ValueError:
|
||||||
gv.TRANSCRIBE_ACTIVE.clear()
|
|
||||||
raise gr.Error("Couldn't detect any speech in the provided audio. \
|
raise gr.Error("Couldn't detect any speech in the provided audio. \
|
||||||
Please try again!")
|
Please try again!")
|
||||||
|
|
||||||
gv.TRANSCRIBE_ACTIVE.clear()
|
|
||||||
return str(result), result.get_json()
|
return str(result), result.get_json()
|
||||||
|
|
||||||
elif isinstance(source, list):
|
elif isinstance(source, list):
|
||||||
@@ -61,7 +57,7 @@ class GradioTranscriptionInterface:
|
|||||||
result = []
|
result = []
|
||||||
for s in tqdm(source, total=len(source),desc = "Transcribing audio files"):
|
for s in tqdm(source, total=len(source),desc = "Transcribing audio files"):
|
||||||
try:
|
try:
|
||||||
res = self.model.autotranscribe(s, **kwargs)
|
res = self.model.autotranscribe(s, **_kwargs)
|
||||||
except ValueError:
|
except ValueError:
|
||||||
_name = s.split("/")[-1]
|
_name = s.split("/")[-1]
|
||||||
res = f"NO TRANSCRIPT FOUND FOR {_name}"
|
res = f"NO TRANSCRIPT FOUND FOR {_name}"
|
||||||
@@ -79,42 +75,36 @@ class GradioTranscriptionInterface:
|
|||||||
out_dict[source_names[i]] = r
|
out_dict[source_names[i]] = r
|
||||||
else:
|
else:
|
||||||
out_dict[source_names[i]] = r.get_dict()
|
out_dict[source_names[i]] = r.get_dict()
|
||||||
|
|
||||||
|
|
||||||
gv.TRANSCRIBE_ACTIVE.clear()
|
|
||||||
|
|
||||||
return out, json.dumps(out_dict, indent=4)
|
return out, json.dumps(out_dict, indent=4)
|
||||||
|
|
||||||
else:
|
else:
|
||||||
gv.TRANSCRIBE_ACTIVE.clear()
|
|
||||||
raise gr.Error("Please provide a valid audio file.")
|
raise gr.Error("Please provide a valid audio file.")
|
||||||
|
|
||||||
|
|
||||||
def transcribe(self, source, translation, language):
|
def transcribe(self, source, translate, language,*args ,**kwargs):
|
||||||
"""
|
"""
|
||||||
Shortcut method for the Transcribe task.
|
Shortcut method for the Transcribe task.
|
||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
str: Transcribed text.
|
str: Transcribed text.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
gv.TRANSCRIBE_ACTIVE.set()
|
_kwargs = {
|
||||||
|
|
||||||
kwargs = {
|
|
||||||
"language": language if language != "None" else None,
|
"language": language if language != "None" else None,
|
||||||
"task": 'translate' if translation == "Yes" else None
|
"task": 'translate' if translate == "Yes" else None
|
||||||
}
|
}
|
||||||
|
|
||||||
if isinstance(source, str):
|
if isinstance(source, str):
|
||||||
result = self.model.transcribe(source, **kwargs)
|
result = self.model.transcribe(source, **_kwargs)
|
||||||
gv.TRANSCRIBE_ACTIVE.clear()
|
|
||||||
return str(result)
|
return str(result)
|
||||||
|
|
||||||
elif isinstance(source, list):
|
elif isinstance(source, list):
|
||||||
source_names = [s.split("/")[-1] for s in source]
|
source_names = [s.split("/")[-1] for s in source]
|
||||||
result = []
|
result = []
|
||||||
for s in tqdm(source, total=len(source),desc = "Transcribing audio files"):
|
for s in tqdm(source, total=len(source),desc = "Transcribing audio files"):
|
||||||
res = self.model.transcribe(s, **kwargs)
|
res = self.model.transcribe(s, **_kwargs)
|
||||||
result.append(res)
|
result.append(res)
|
||||||
|
|
||||||
out = ''
|
out = ''
|
||||||
@@ -123,15 +113,12 @@ class GradioTranscriptionInterface:
|
|||||||
out += str(res)
|
out += str(res)
|
||||||
out += "\n\n"
|
out += "\n\n"
|
||||||
|
|
||||||
gv.TRANSCRIBE_ACTIVE.clear()
|
|
||||||
|
|
||||||
return out
|
return out
|
||||||
|
|
||||||
else:
|
else:
|
||||||
gv.TRANSCRIBE_ACTIVE.clear()
|
|
||||||
raise gr.Error("Please provide a valid audio file.")
|
raise gr.Error("Please provide a valid audio file.")
|
||||||
|
|
||||||
def perform_diarisation(self, source, num_speakers):
|
def diarisation(self, source, num_speakers, *args ,**kwargs):
|
||||||
"""
|
"""
|
||||||
Shortcut method for the Diarisation task.
|
Shortcut method for the Diarisation task.
|
||||||
|
|
||||||
@@ -139,27 +126,24 @@ class GradioTranscriptionInterface:
|
|||||||
str: JSON output of diarisation result.
|
str: JSON output of diarisation result.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
gv.TRANSCRIBE_ACTIVE.set()
|
_kwargs = {
|
||||||
|
|
||||||
kwargs = {
|
|
||||||
"num_speakers": num_speakers if num_speakers != 0 else None,
|
"num_speakers": num_speakers if num_speakers != 0 else None,
|
||||||
}
|
}
|
||||||
|
|
||||||
if isinstance(source, str):
|
if isinstance(source, str):
|
||||||
try:
|
try:
|
||||||
result = self.model.diarization(source, **kwargs)
|
result = self.model.diarization(source, **_kwargs)
|
||||||
except ValueError:
|
except ValueError:
|
||||||
gv.TRANSCRIBE_ACTIVE.clear()
|
|
||||||
raise gr.Error("Couldn't detect any speech in the provided audio. \
|
raise gr.Error("Couldn't detect any speech in the provided audio. \
|
||||||
Please try again!")
|
Please try again!")
|
||||||
gv.TRANSCRIBE_ACTIVE.clear()
|
|
||||||
return json.dumps(result, indent=2)
|
return json.dumps(result, indent=2)
|
||||||
elif isinstance(source, list):
|
elif isinstance(source, list):
|
||||||
source_names = [s.split("/")[-1] for s in source]
|
source_names = [s.split("/")[-1] for s in source]
|
||||||
result = []
|
result = []
|
||||||
for s in tqdm(source, total=len(source),desc = "Performing diarisation"):
|
for s in tqdm(source, total=len(source),desc = "Performing diarisation"):
|
||||||
try:
|
try:
|
||||||
res = self.model.diarization(s, **kwargs)
|
res = self.model.diarization(s, **_kwargs)
|
||||||
except ValueError:
|
except ValueError:
|
||||||
|
|
||||||
res = f"NO DIARISATION FOUND FOR {s}"
|
res = f"NO DIARISATION FOUND FOR {s}"
|
||||||
@@ -171,10 +155,29 @@ class GradioTranscriptionInterface:
|
|||||||
for i, res in enumerate(result):
|
for i, res in enumerate(result):
|
||||||
out[source_names[i]] = res
|
out[source_names[i]] = res
|
||||||
|
|
||||||
gv.TRANSCRIBE_ACTIVE.clear()
|
|
||||||
|
|
||||||
return json.dumps(out, indent=4)
|
return json.dumps(out, indent=4)
|
||||||
|
|
||||||
else:
|
else:
|
||||||
gv.TRANSCRIBE_ACTIVE.clear()
|
gr.Error("Please provide a valid audio file.")
|
||||||
gr.Error("Please provide a valid audio file.")
|
|
||||||
|
def get_task_from_str(self, task):
|
||||||
|
"""
|
||||||
|
Returns the coresponing task function based on the task string.
|
||||||
|
|
||||||
|
params:
|
||||||
|
task (str): Task string. Can be one of the following:
|
||||||
|
- 'Auto Transcribe'
|
||||||
|
- 'Transcribe'
|
||||||
|
- 'Diarisation'
|
||||||
|
"""
|
||||||
|
|
||||||
|
if task == 'Auto Transcribe':
|
||||||
|
return self.autotranscribe
|
||||||
|
elif task == 'Transcribe':
|
||||||
|
return self.transcribe
|
||||||
|
elif task == 'Diarisation':
|
||||||
|
return self.diarisation
|
||||||
|
else:
|
||||||
|
raise ValueError("Invalid task string.")
|
||||||
|
|
||||||
|
|
||||||
|
|||||||
@@ -0,0 +1,42 @@
|
|||||||
|
import scraibe.app.global_var as gv
|
||||||
|
import yaml
|
||||||
|
|
||||||
|
def load_config(original_config_path = gv.DEFAULT_APP_CONIFG_PATH, override_yaml_path=None, **kwargs):
|
||||||
|
|
||||||
|
|
||||||
|
# Load the original configuration
|
||||||
|
with open(original_config_path, 'r') as file:
|
||||||
|
config = yaml.safe_load(file)
|
||||||
|
|
||||||
|
# Override with another YAML file if provided
|
||||||
|
if override_yaml_path:
|
||||||
|
with open(override_yaml_path, 'r') as file:
|
||||||
|
override_config = yaml.safe_load(file)
|
||||||
|
apply_overrides(config, override_config)
|
||||||
|
|
||||||
|
# Apply overrides from kwargs
|
||||||
|
apply_overrides(config, kwargs)
|
||||||
|
|
||||||
|
return config
|
||||||
|
|
||||||
|
def apply_overrides(orig_dict, override_dict):
|
||||||
|
""" Recursively apply overrides to the configuration. """
|
||||||
|
for key, value in override_dict.items():
|
||||||
|
if isinstance(value, dict):
|
||||||
|
# If the value is a dict, apply recursively
|
||||||
|
apply_overrides(orig_dict.get(key, {}), value)
|
||||||
|
else:
|
||||||
|
# If the value is not a dict, search for the key and update
|
||||||
|
if update_nested_key(orig_dict, key, value):
|
||||||
|
continue # Key was found and updated
|
||||||
|
orig_dict[key] = value # Key not found, update at this level
|
||||||
|
|
||||||
|
def update_nested_key(d, key, value):
|
||||||
|
""" Recursively search and update the key in nested dictionary. """
|
||||||
|
if key in d:
|
||||||
|
d[key] = value
|
||||||
|
return True
|
||||||
|
for k, v in d.items():
|
||||||
|
if isinstance(v, dict) and update_nested_key(v, key, value):
|
||||||
|
return True
|
||||||
|
return False
|
||||||
Reference in New Issue
Block a user