make gradio working with treads

This commit is contained in:
Jaikinator
2023-11-25 15:17:12 +01:00
parent bbb2c848e3
commit 93e5ce15f9
6 changed files with 101 additions and 21 deletions
+31 -8
View File
@@ -9,7 +9,8 @@ It makes adds gradio interactions to the scraibe class in the back.
import json
import gradio as gr
from tqdm import tqdm
from scraibe import Scraibe
import scraibe.app.global_var as gv
class GradioTranscriptionInterface:
@@ -17,14 +18,14 @@ class GradioTranscriptionInterface:
Interface handling the interaction between Gradio UI and the Audio Transcription system.
"""
def __init__(self, model: Scraibe):
def __init__(self):
"""
Initializes the GradioTranscriptionInterface with a transcription model.
Args:
model (Scraibe): Model responsible for audio transcription tasks.
"""
self.model = model
self.model = gv.MODEL
def auto_transcribe(self, source,
num_speakers : int,
@@ -37,6 +38,8 @@ class GradioTranscriptionInterface:
tuple: Transcribed text (str), JSON output (dict)
"""
gv.TRANSCRIBE_ACTIVE.set()
kwargs = {
"num_speakers": num_speakers if num_speakers != 0 else None,
"language": language if language != "None" else None,
@@ -46,9 +49,11 @@ class GradioTranscriptionInterface:
try:
result = self.model.autotranscribe(source, **kwargs)
except ValueError:
gv.TRANSCRIBE_ACTIVE.clear()
raise gr.Error("Couldn't detect any speech in the provided audio. \
Please try again!")
gv.TRANSCRIBE_ACTIVE.clear()
return str(result), result.get_json()
elif isinstance(source, list):
@@ -74,10 +79,14 @@ class GradioTranscriptionInterface:
out_dict[source_names[i]] = r
else:
out_dict[source_names[i]] = r.get_dict()
gv.TRANSCRIBE_ACTIVE.clear()
return out, json.dumps(out_dict, indent=4)
else:
gv.TRANSCRIBE_ACTIVE.clear()
raise gr.Error("Please provide a valid audio file.")
@@ -88,14 +97,17 @@ class GradioTranscriptionInterface:
Returns:
str: Transcribed text.
"""
gv.TRANSCRIBE_ACTIVE.set()
kwargs = {
"language": language if language != "None" else None,
"task": 'translate' if translation == "Yes" else None
}
if isinstance(source, str):
result = self.model.transcribe(source, **kwargs)
gv.TRANSCRIBE_ACTIVE.clear()
return str(result)
elif isinstance(source, list):
@@ -111,9 +123,12 @@ class GradioTranscriptionInterface:
out += str(res)
out += "\n\n"
gv.TRANSCRIBE_ACTIVE.clear()
return out
else:
gv.TRANSCRIBE_ACTIVE.clear()
raise gr.Error("Please provide a valid audio file.")
def perform_diarisation(self, source, num_speakers):
@@ -123,6 +138,9 @@ class GradioTranscriptionInterface:
Returns:
str: JSON output of diarisation result.
"""
gv.TRANSCRIBE_ACTIVE.set()
kwargs = {
"num_speakers": num_speakers if num_speakers != 0 else None,
}
@@ -131,9 +149,10 @@ class GradioTranscriptionInterface:
try:
result = self.model.diarization(source, **kwargs)
except ValueError:
gv.TRANSCRIBE_ACTIVE.clear()
raise gr.Error("Couldn't detect any speech in the provided audio. \
Please try again!")
gv.TRANSCRIBE_ACTIVE.clear()
return json.dumps(result, indent=2)
elif isinstance(source, list):
source_names = [s.split("/")[-1] for s in source]
@@ -142,6 +161,7 @@ class GradioTranscriptionInterface:
try:
res = self.model.diarization(s, **kwargs)
except ValueError:
res = f"NO DIARISATION FOUND FOR {s}"
gr.Warning(f"Couldn't detect any speech in {s} will skip this file.")
result.append(res)
@@ -150,8 +170,11 @@ class GradioTranscriptionInterface:
for i, res in enumerate(result):
out[source_names[i]] = res
gv.TRANSCRIBE_ACTIVE.clear()
return json.dumps(out, indent=4)
else:
gv.TRANSCRIBE_ACTIVE.clear()
gr.Error("Please provide a valid audio file.")