Make everything work in processes and adding config to customize instance

This commit is contained in:
Jaikinator
2023-12-07 16:22:52 +01:00
parent 32b27442e6
commit 9eb9f5af8d
6 changed files with 241 additions and 96 deletions
+45 -42
View File
@@ -18,19 +18,19 @@ class GradioTranscriptionInterface:
Interface handling the interaction between Gradio UI and the Audio Transcription system.
"""
def __init__(self):
def __init__(self, model):
"""
Initializes the GradioTranscriptionInterface with a transcription model.
Args:
model (Scraibe): Model responsible for audio transcription tasks.
"""
self.model = gv.MODEL
self.model = model
def auto_transcribe(self, source,
def autotranscribe(self, source,
num_speakers : int,
translation : bool,
language : str):
translate : bool,
language : str,*args ,**kwargs):
"""
Shortcut method for the Scraibe task.
@@ -38,22 +38,18 @@ class GradioTranscriptionInterface:
tuple: Transcribed text (str), JSON output (dict)
"""
gv.TRANSCRIBE_ACTIVE.set()
kwargs = {
_kwargs = {
"num_speakers": num_speakers if num_speakers != 0 else None,
"language": language if language != "None" else None,
"task": 'translate' if translation else None
"task": 'translate' if translate else None
}
if isinstance(source, str):
try:
result = self.model.autotranscribe(source, **kwargs)
result = self.model.autotranscribe(source, **_kwargs)
except ValueError:
gv.TRANSCRIBE_ACTIVE.clear()
raise gr.Error("Couldn't detect any speech in the provided audio. \
Please try again!")
gv.TRANSCRIBE_ACTIVE.clear()
return str(result), result.get_json()
elif isinstance(source, list):
@@ -61,7 +57,7 @@ class GradioTranscriptionInterface:
result = []
for s in tqdm(source, total=len(source),desc = "Transcribing audio files"):
try:
res = self.model.autotranscribe(s, **kwargs)
res = self.model.autotranscribe(s, **_kwargs)
except ValueError:
_name = s.split("/")[-1]
res = f"NO TRANSCRIPT FOUND FOR {_name}"
@@ -79,42 +75,36 @@ class GradioTranscriptionInterface:
out_dict[source_names[i]] = r
else:
out_dict[source_names[i]] = r.get_dict()
gv.TRANSCRIBE_ACTIVE.clear()
return out, json.dumps(out_dict, indent=4)
else:
gv.TRANSCRIBE_ACTIVE.clear()
raise gr.Error("Please provide a valid audio file.")
def transcribe(self, source, translation, language):
def transcribe(self, source, translate, language,*args ,**kwargs):
"""
Shortcut method for the Transcribe task.
Returns:
str: Transcribed text.
"""
gv.TRANSCRIBE_ACTIVE.set()
kwargs = {
_kwargs = {
"language": language if language != "None" else None,
"task": 'translate' if translation == "Yes" else None
"task": 'translate' if translate == "Yes" else None
}
if isinstance(source, str):
result = self.model.transcribe(source, **kwargs)
gv.TRANSCRIBE_ACTIVE.clear()
result = self.model.transcribe(source, **_kwargs)
return str(result)
elif isinstance(source, list):
source_names = [s.split("/")[-1] for s in source]
result = []
for s in tqdm(source, total=len(source),desc = "Transcribing audio files"):
res = self.model.transcribe(s, **kwargs)
res = self.model.transcribe(s, **_kwargs)
result.append(res)
out = ''
@@ -123,15 +113,12 @@ class GradioTranscriptionInterface:
out += str(res)
out += "\n\n"
gv.TRANSCRIBE_ACTIVE.clear()
return out
else:
gv.TRANSCRIBE_ACTIVE.clear()
raise gr.Error("Please provide a valid audio file.")
def perform_diarisation(self, source, num_speakers):
def diarisation(self, source, num_speakers, *args ,**kwargs):
"""
Shortcut method for the Diarisation task.
@@ -139,27 +126,24 @@ class GradioTranscriptionInterface:
str: JSON output of diarisation result.
"""
gv.TRANSCRIBE_ACTIVE.set()
kwargs = {
_kwargs = {
"num_speakers": num_speakers if num_speakers != 0 else None,
}
if isinstance(source, str):
try:
result = self.model.diarization(source, **kwargs)
result = self.model.diarization(source, **_kwargs)
except ValueError:
gv.TRANSCRIBE_ACTIVE.clear()
raise gr.Error("Couldn't detect any speech in the provided audio. \
Please try again!")
gv.TRANSCRIBE_ACTIVE.clear()
return json.dumps(result, indent=2)
elif isinstance(source, list):
source_names = [s.split("/")[-1] for s in source]
result = []
for s in tqdm(source, total=len(source),desc = "Performing diarisation"):
try:
res = self.model.diarization(s, **kwargs)
res = self.model.diarization(s, **_kwargs)
except ValueError:
res = f"NO DIARISATION FOUND FOR {s}"
@@ -171,10 +155,29 @@ class GradioTranscriptionInterface:
for i, res in enumerate(result):
out[source_names[i]] = res
gv.TRANSCRIBE_ACTIVE.clear()
return json.dumps(out, indent=4)
else:
gv.TRANSCRIBE_ACTIVE.clear()
gr.Error("Please provide a valid audio file.")
gr.Error("Please provide a valid audio file.")
def get_task_from_str(self, task):
"""
Returns the coresponing task function based on the task string.
params:
task (str): Task string. Can be one of the following:
- 'Auto Transcribe'
- 'Transcribe'
- 'Diarisation'
"""
if task == 'Auto Transcribe':
return self.autotranscribe
elif task == 'Transcribe':
return self.transcribe
elif task == 'Diarisation':
return self.diarisation
else:
raise ValueError("Invalid task string.")