diff --git a/.dockerignore b/.dockerignore new file mode 100644 index 0000000..1155cba --- /dev/null +++ b/.dockerignore @@ -0,0 +1,6 @@ +scraibe/*__pycache__ +scraibe/app/*__pycache__ +scraibe/.pyannotetoken +.git +.gitignore +.github diff --git a/.gitignore b/.gitignore new file mode 100644 index 0000000..18c7986 --- /dev/null +++ b/.gitignore @@ -0,0 +1,6 @@ +transcibe.py +scraibe/*__pycache__ +scraibe/app/*__pycache__ +scraibe/.pyannotetoken + + diff --git a/scraibe/__init__.py b/scraibe/__init__.py index a3a2b17..233cd4f 100644 --- a/scraibe/__init__.py +++ b/scraibe/__init__.py @@ -7,9 +7,6 @@ from .diarisation import * from .version import get_version as _get_version from .misc import * -from .app.gradio_app import * -from .app.qtfaststart import * - from .cli import * __version__ = _get_version() diff --git a/scraibe/app/__init__.py b/scraibe/app/__init__.py deleted file mode 100644 index dc00e7a..0000000 --- a/scraibe/app/__init__.py +++ /dev/null @@ -1,2 +0,0 @@ -from .qtfaststart import * -from .gradio_app import * \ No newline at end of file diff --git a/scraibe/app/gradio_app.py b/scraibe/app/gradio_app.py deleted file mode 100644 index cf80b7e..0000000 --- a/scraibe/app/gradio_app.py +++ /dev/null @@ -1,441 +0,0 @@ -""" -Gradio Audio Transcription App. --------------------------------- - -This module provides an interface to transcribe audio files using the -Scraibe model. Users can either upload an audio file or record their speech -live for transcription. The application supports multiple languages and provides -options to specify the number of speakers and the language of the audio. - -Attributes: - LANGUAGES (list): A list of supported languages for transcription. - -Usage: - Run this script to start the Gradio web interface for audio transcription. - -""" - -""" -Gradio Audio Transcription App. --------------------------------- - -This module provides an interface to transcribe audio files using the -Scraibe model. Users can either upload an audio file or record their speech -live for transcription. The application supports multiple languages and provides -options to specify the number of speakers and the language of the audio. - -Attributes: - LANGUAGES (list): A list of supported languages for transcription. - -Usage: - Run this script to start the Gradio web interface for audio transcription. - -""" - -import json -import os - -import gradio as gr -from tqdm import tqdm - -from scraibe import Scraibe, Transcript - -theme = gr.themes.Soft( - primary_hue="green", - secondary_hue='orange', - neutral_hue="gray", -) - -LANGUAGES = [ - "Afrikaans", "Arabic", "Armenian", "Azerbaijani", "Belarusian", - "Bosnian", "Bulgarian", "Catalan", "Chinese", "Croatian", - "Czech", "Danish", "Dutch", "English", "Estonian", - "Finnish", "French", "Galician", "German", "Greek", - "Hebrew", "Hindi", "Hungarian", "Icelandic", "Indonesian", - "Italian", "Japanese", "Kannada", "Kazakh", "Korean", - "Latvian", "Lithuanian", "Macedonian", "Malay", "Marathi", - "Maori", "Nepali", "Norwegian", "Persian", "Polish", - "Portuguese", "Romanian", "Russian", "Serbian", "Slovak", - "Slovenian", "Spanish", "Swahili", "Swedish", "Tagalog", - "Tamil", "Thai", "Turkish", "Ukrainian", "Urdu", - "Vietnamese", "Welsh" -] - -CURRENT_PATH = os.path.dirname(os.path.realpath(__file__)) - -class GradioTranscriptionInterface: - """ - Interface handling the interaction between Gradio UI and the Audio Transcription system. - """ - - def __init__(self, model: Scraibe): - """ - Initializes the GradioTranscriptionInterface with a transcription model. - - Args: - model (Scraibe): Model responsible for audio transcription tasks. - """ - self.model = model - - def auto_transcribe(self, source, - num_speakers : int, - translation : bool, - language : str): - """ - Shortcut method for the Scraibe task. - - Returns: - tuple: Transcribed text (str), JSON output (dict) - """ - - kwargs = { - "num_speakers": num_speakers if num_speakers != 0 else None, - "language": language if language != "None" else None, - "task": 'translate' if translation else None - } - if isinstance(source, str): - try: - result = self.model.autotranscribe(source, **kwargs) - except ValueError: - raise gr.Error("Couldn't detect any speech in the provided audio. \ - Please try again!") - - return str(result), result.get_json() - - elif isinstance(source, list): - source_names = [s.split("/")[-1] for s in source] - result = [] - for s in tqdm(source, total=len(source),desc = "Transcribing audio files"): - try: - res = self.model.autotranscribe(s, **kwargs) - except ValueError: - _name = s.split("/")[-1] - res = f"NO TRANSCRIPT FOUND FOR {_name}" - gr.Warning(f"Couldn't detect any speech in {_name} will skip this file.") - result.append(res) - - out = '' - out_dict = {} - for i, r in enumerate(result): - out += f"TRANSCRIPT FOR {source_names[i]}:\n\n" - out += str(r) - out += "\n\n" - - if isinstance(r, str): - out_dict[source_names[i]] = r - else: - out_dict[source_names[i]] = r.get_dict() - - return out, json.dumps(out_dict, indent=4) - - else: - raise gr.Error("Please provide a valid audio file.") - - - def transcribe(self, source, translation, language): - """ - Shortcut method for the Transcribe task. - - Returns: - str: Transcribed text. - """ - kwargs = { - "language": language if language != "None" else None, - "task": 'translate' if translation == "Yes" else None - } - - if isinstance(source, str): - result = self.model.transcribe(source, **kwargs) - - return str(result) - - elif isinstance(source, list): - source_names = [s.split("/")[-1] for s in source] - result = [] - for s in tqdm(source, total=len(source),desc = "Transcribing audio files"): - res = self.model.transcribe(s, **kwargs) - result.append(res) - - out = '' - for i, res in enumerate(result): - out += f"TRANSCRIPT FOR {source_names[i]}:\n\n" - out += str(res) - out += "\n\n" - - return out - - else: - raise gr.Error("Please provide a valid audio file.") - - def perform_diarisation(self, source, num_speakers): - """ - Shortcut method for the Diarisation task. - - Returns: - str: JSON output of diarisation result. - """ - kwargs = { - "num_speakers": num_speakers if num_speakers != 0 else None, - } - - if isinstance(source, str): - try: - result = self.model.diarization(source, **kwargs) - except ValueError: - raise gr.Error("Couldn't detect any speech in the provided audio. \ - Please try again!") - - return json.dumps(result, indent=2) - elif isinstance(source, list): - source_names = [s.split("/")[-1] for s in source] - result = [] - for s in tqdm(source, total=len(source),desc = "Performing diarisation"): - try: - res = self.model.diarization(s, **kwargs) - except ValueError: - res = f"NO DIARISATION FOUND FOR {s}" - gr.Warning(f"Couldn't detect any speech in {s} will skip this file.") - result.append(res) - - out = {} - - for i, res in enumerate(result): - out[source_names[i]] = res - - return json.dumps(out, indent=4) - - else: - gr.Error("Please provide a valid audio file.") - - -#### -# Gradio Interface -#### - -def gradio_Interface(model : Scraibe = None): - - if model is None: - model = Scraibe() - - pipe = GradioTranscriptionInterface(model) - - def select_task(choice): - if choice == 'Auto Transcribe': - - return (gr.update(visible = True), - gr.update(visible = True), - gr.update(visible = True)) - - - elif choice == 'Transcribe': - - return (gr.update(visible = False), - gr.update(visible = True), - gr.update(visible = True)) - - - elif choice == 'Diarisation': - - return (gr.update(visible = True), - gr.update(visible = False), - gr.update(visible = False)) - - def select_origin(choice): - if choice == "Upload Audio": - - return (gr.update(visible = True), - gr.update(visible = False, value = None), - gr.update(visible = False, value = None), - gr.update(visible = False, value = None), - gr.update(visible = False, value = None)) - - elif choice == "Record Audio": - - return (gr.update(visible = False, value = None), - gr.update(visible = True), - gr.update(visible = False, value = None), - gr.update(visible = False, value = None), - gr.update(visible = False, value = None)) - - elif choice == "Upload Video": - - return (gr.update(visible = False, value = None), - gr.update(visible = False, value = None), - gr.update(visible = True), - gr.update(visible = False, value = None), - gr.update(visible = False, value = None)) - - elif choice == "Record Video": - - return (gr.update(visible = False, value = None), - gr.update(visible = False, value = None), - gr.update(visible = False, value = None), - gr.update(visible = True), - gr.update(visible = False, value = None)) - - elif choice == "File or Files": - - return (gr.update(visible = False, value = None), - gr.update(visible = False, value = None), - gr.update(visible = False, value = None), - gr.update(visible = False, value = None), - gr.update(visible = True)) - - def run_scribe(task, - num_speakers, - translate, - language, - audio1, - audio2, - video1, - video2, - file_in, - progress = gr.Progress(track_tqdm= True)): - # get *args which are not None - progress(0, desc='Starting task...') - source = audio1 or audio2 or video1 or video2 or file_in - - if isinstance(source, list): - source = [s.name for s in source] - if len(source) == 1: - source = source[0] - - if task == 'Auto Transcribe': - - out_str , out_json = pipe.auto_transcribe(source = source, - num_speakers = num_speakers, - translation = translate, - language = language) - - if isinstance(source, str): - return (gr.update(value = out_str, visible = True), - gr.update(value = out_json, visible = True), - gr.update(visible = True), - gr.update(visible = True)) - else: - return (gr.update(value = out_str, visible = True), - gr.update(value = out_json, visible = True), - gr.update(visible = False), - gr.update(visible = False)) - - elif task == 'Transcribe': - - out = pipe.transcribe(source = source, - translation = translate, - language = language) - - return (gr.update(value = out, visible = True), - gr.update(value = None, visible = False), - gr.update(visible = False), - gr.update(visible = False)) - - elif task == 'Diarisation': - - out = pipe.perform_diarisation(source = source, - num_speakers = num_speakers) - - return (gr.update(value = None, visible = False), - gr.update(value = out, visible = True), - gr.update(visible = False), - gr.update(visible = False)) - - def annotate_output(annoation : str, out_json : dict): - # get *args which are not None - - trans = Transcript.from_json(out_json) - trans = trans.annotate(*annoation.split(",")) - - return gr.update(value = str(trans)),gr.update(value = trans.get_json()) - - - with gr.Blocks(theme=theme,title='ScrAIbe: Automatic Audio Transcription') as demo: - - # Define components - hname = os.path.join(CURRENT_PATH, "header.html") - header = open(hname, "r").read() - - # ugly hack to get the logo to work - header = header.replace("/file=logo.svg", f"/file={CURRENT_PATH}/logo.svg" ) - - gr.HTML(header, visible= True, show_label=False) - - with gr.Row(): - - with gr.Column(): - - task = gr.Radio(["Auto Transcribe", "Transcribe", "Diarisation"], label="Task", - value= 'Auto Transcribe') - - num_speakers = gr.Number(value=0, label= "Number of speakers (optional)", - info = "Number of speakers in the audio file. If you don't know,\ - leave it at 0.", visible= True) - - translate = gr.Checkbox(label="Translation", choices=[True, False], value = False, - info="Select 'Yes' to have the output translated into English.", - visible= True) - - language = gr.Dropdown(LANGUAGES, - label="Language (optional)", value = "None", - info="Language of the audio file. If you don't know,\ - leave it at None.", visible= True) - - input = gr.Radio(["Upload Audio", "Record Audio", "Upload Video","Record Video" - ,"File or Files"], label="Input Type", value="Upload Audio") - - audio1 = gr.Audio(source="upload", type="filepath", label="Upload Audio", - interactive= True, visible= True) - audio2 = gr.Audio(source="microphone", label="Record Audio", type="filepath", - interactive= True, visible= False) - video1 = gr.Video(source="upload", type="filepath", label="Upload Video", - interactive= True, visible= False) - video2 = gr.Video(source="webcam", label="Record Video", type="filepath",include_audio= True, - interactive= True, visible= False) - file_in = gr.Files(label="Upload File or Files", interactive= True, visible= False) - - submit = gr.Button() - - with gr.Column(): - - out_txt = gr.Textbox(label="Output", - visible= True, show_copy_button=True) - - out_json = gr.JSON(label="JSON Output", - visible= False, show_copy_button=True) - - annoation = gr.Textbox(label="Name your speaker's", - info= "Please provide a list of the speakers arranged \ - in the order in which they appear in the input. Use comma ',' \ - as a seperator. Be aware that the first name is given \ - to SPEAKER_00 the second to SPEAKER_01 and so on.", - visible= False, interactive= True) - - annotate = gr.Button(value="Annotate", visible= False, interactive= True) - - # Define usage of components - input.change(fn=select_origin, inputs=[input], - outputs=[audio1, audio2, video1, video2, file_in]) - - task.change(fn=select_task, inputs=[task], - outputs=[num_speakers, translate, language]) - - translate.change(fn= lambda x : gr.update(value = x), - inputs=[translate], outputs=[translate]) - num_speakers.change(fn= lambda x : gr.update(value = x), - inputs=[num_speakers], outputs=[num_speakers]) - language.change(fn= lambda x : gr.update(value = x), - inputs=[language], outputs=[language]) - - submit.click(fn = run_scribe, - inputs=[task, num_speakers, translate, language, audio1, - audio2, video1, video2, file_in], - outputs=[out_txt, out_json, annoation, annotate]) - - annotate.click(fn = annotate_output, inputs=[annoation, out_json], - outputs=[out_txt, out_json]) - - return demo - - -if __name__ == "__main__": - - gradio_Interface().queue().launch() \ No newline at end of file diff --git a/scraibe/app/header.html b/scraibe/app/header.html deleted file mode 100644 index 4b12136..0000000 --- a/scraibe/app/header.html +++ /dev/null @@ -1,66 +0,0 @@ - - - - - -
-

ScrAIbe

-
- - - -
-
-
-

- Upload, record, or provide a video with audio for transcription. Our toolkit is designed to transcribe content from multiple languages accurately. The integrated speaker diarisation feature identifies different speakers, ensuring a smooth transcription experience. For optimal results, indicate the number of speakers and the original language of the content. -

-

What would you like to do next?

-
diff --git a/scraibe/app/logo.svg b/scraibe/app/logo.svg deleted file mode 100644 index 54d12d7..0000000 --- a/scraibe/app/logo.svg +++ /dev/null @@ -1,37 +0,0 @@ - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - diff --git a/scraibe/app/qtfaststart.py b/scraibe/app/qtfaststart.py deleted file mode 100644 index e57eb20..0000000 --- a/scraibe/app/qtfaststart.py +++ /dev/null @@ -1,319 +0,0 @@ -""" -This file contains a modified version of qtfaststart by qtfaststart -https://github.com/danielgtaylor/qtfaststart/tree/master - -All credit goes to the original author. -Copyright (C) 2008 - 2013 Daniel G. Taylor -Permission is hereby granted, free of charge, to any person obtaining a copy of this -software and associated documentation files (the "Software"), -to deal in the Software without restriction, including without limitation the rights to -use, copy, modify, merge, publish, distribute, sublicense, and/or sell copies of the -Software, and to permit persons to whom the Software is furnished to do so, -subject to the following conditions: - -The above copyright notice and this permission notice shall be included in all copies -or substantial portions of the Software. - -THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR IMPLIED, -INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, -FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. -IN NO EVENT SHALL THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, -DAMAGES OR OTHER LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, -ARISING FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS -IN THE SOFTWARE. -""" - -import logging -import os -import struct -import collections -import io - -# define error classes -class FastStartException(Exception): - """ - Raised when something bad happens during processing. - """ - pass - -class FastStartSetupError(FastStartException): - """ - Rasised when asked to process a file that does not need processing - """ - pass - -class MalformedFileError(FastStartException): - """ - Raised when the input file is setup in an unexpected way - """ - pass - -class UnsupportedFormatError(FastStartException): - """ - Raised when a movie file is recognized as a format not supported. - """ - pass - -# define constants -CHUNK_SIZE = 8192 - -log = logging.getLogger("qtfaststart") - -# Older versions of Python require this to be defined -if not hasattr(os, 'SEEK_CUR'): - os.SEEK_CUR = 1 - -Atom = collections.namedtuple('Atom', 'name position size') - -def read_atom(datastream): - """ - Read an atom and return a tuple of (size, type) where size is the size - in bytes (including the 8 bytes already read) and type is a "fourcc" - like "ftyp" or "moov". - """ - size, type = struct.unpack(">L4s", datastream.read(8)) - type = type.decode('ascii') - return size, type - - -def _read_atom_ex(datastream): - """ - Read an Atom from datastream - """ - pos = datastream.tell() - atom_size, atom_type = read_atom(datastream) - if atom_size == 1: - atom_size, = struct.unpack(">Q", datastream.read(8)) - return Atom(atom_type, pos, atom_size) - - -def get_index(datastream): - """ - Return an index of top level atoms, their absolute byte-position in the - file and their size in a list: - - index = [ - ("ftyp", 0, 24), - ("moov", 25, 2658), - ("free", 2683, 8), - ... - ] - - The tuple elements will be in the order that they appear in the file. - """ - log.debug("Getting index of top level atoms...") - - index = list(_read_atoms(datastream)) - _ensure_valid_index(index) - - return index - - -def _read_atoms(datastream): - """ - Read atoms until an error occurs - """ - while datastream: - try: - atom = _read_atom_ex(datastream) - log.debug("%s: %s" % (atom.name, atom.size)) - except: - break - - yield atom - - if atom.size == 0: - if atom.name == "mdat": - # Some files may end in mdat with no size set, which generally - # means to seek to the end of the file. We can just stop indexing - # as no more entries will be found! - break - else: - # Weird, but just continue to try to find more atoms - continue - - datastream.seek(atom.position + atom.size) - - -def _ensure_valid_index(index): - """ - Ensure the minimum viable atoms are present in the index. - - Raise FastStartException if not. - """ - top_level_atoms = set([item.name for item in index]) - for key in ["moov", "mdat"]: - if key not in top_level_atoms: - log.error("%s atom not found, is this a valid MOV/MP4 file?" % key) - raise FastStartException() - - -def find_atoms(size, datastream): - """ - Compatibilty interface for _find_atoms_ex - """ - fake_parent = Atom('fake', datastream.tell()-8, size+8) - for atom in _find_atoms_ex(fake_parent, datastream): - yield atom.name - - -def _find_atoms_ex(parent_atom, datastream): - """ - Yield either "stco" or "co64" Atoms from datastream. - datastream will be 8 bytes into the stco or co64 atom when the value - is yielded. - - It is assumed that datastream will be at the end of the atom after - the value has been yielded and processed. - - parent_atom is the parent atom, a 'moov' or other ancestor of CO - atoms in the datastream. - """ - stop = parent_atom.position + parent_atom.size - - while datastream.tell() < stop: - try: - atom = _read_atom_ex(datastream) - except: - log.exception("Error reading next atom!") - raise FastStartException() - - if atom.name in ["trak", "mdia", "minf", "stbl"]: - # Known ancestor atom of stco or co64, search within it! - for res in _find_atoms_ex(atom, datastream): - yield res - elif atom.name in ["stco", "co64"]: - yield atom - else: - # Ignore this atom, seek to the end of it. - datastream.seek(atom.position + atom.size) - - -def process(infilename, limit=float('inf')): - """ - Convert a Quicktime/MP4 file for streaming by moving the metadata to - the front of the file. This method writes a new file. - - If limit is set to something other than zero it will be used as the - number of bytes to write of the atoms following the moov atom. This - is very useful to create a small sample of a file with full headers, - which can then be used in bug reports and such. - """ - if isinstance(infilename, str): - datastream = open(infilename, "rb") - elif isinstance(infilename, bytes): - datastream = io.BytesIO(infilename) - else: - raise TypeError("infilename must be a filename, bytes or file-like object") - # Get the top level atom index - index = get_index(datastream) - - mdat_pos = 999999 - free_size = 0 - - # Make sure moov occurs AFTER mdat, otherwise no need to run! - for atom in index: - # The atoms are guaranteed to exist from get_index above! - if atom.name == "moov": - moov_atom = atom - moov_pos = atom.position - elif atom.name == "mdat": - mdat_pos = atom.position - elif atom.name == "free" and atom.position < mdat_pos: - # This free atom is before the mdat! - free_size += atom.size - log.info("Removing free atom at %d (%d bytes)" % (atom.position, atom.size)) - elif atom.name == "\x00\x00\x00\x00" and atom.position < mdat_pos: - # This is some strange zero atom with incorrect size - free_size += 8 - log.info("Removing strange zero atom at %s (8 bytes)" % atom.position) - - # Offset to shift positions - offset = moov_atom.size - free_size - - if moov_pos < mdat_pos: - # moov appears to be in the proper place, don't shift by moov size - offset -= moov_atom.size - if not free_size: - # No free atoms and moov is correct, we are done! - log.error("This file appears to already be setup for streaming!") - # Stupid hack to retrun the non-processed file: - if isinstance(infilename, str): - return open(infilename, "rb").read() - elif isinstance(infilename, bytes): - return io.BytesIO(infilename).read() - - # Read and fix moov - moov = _patch_moov(datastream, moov_atom, offset) - - log.info("Writing output...") - outfile = b'' - - # Write ftype - for atom in index: - if atom.name == "ftyp": - log.debug("Writing ftyp... (%d bytes)" % atom.size) - datastream.seek(atom.position) - outfile += datastream.read(atom.size) - - # Write moov - _bytes = moov.getvalue() - log.debug("Writing moov... (%d bytes)" % len(_bytes)) - outfile += _bytes - - # Write the rest - atoms = [item for item in index if item.name not in ["ftyp", "moov", "free"]] - for atom in atoms: - log.debug("Writing %s... (%d bytes)" % (atom.name, atom.size)) - datastream.seek(atom.position) - - # for compatability, allow '0' to mean no limit - cur_limit = limit or float('inf') - cur_limit = min(cur_limit, atom.size) - - for chunk in get_chunks(datastream, CHUNK_SIZE, cur_limit): - outfile += chunk - - return outfile - - -def _patch_moov(datastream, atom, offset): - datastream.seek(atom.position) - moov = io.BytesIO(datastream.read(atom.size)) - - # reload the atom from the fixed stream - atom = _read_atom_ex(moov) - - for atom in _find_atoms_ex(atom, moov): - # Read either 32-bit or 64-bit offsets - ctype, csize = dict( - stco=('L', 4), - co64=('Q', 8), - )[atom.name] - - # Get number of entries - version, entry_count = struct.unpack(">2L", moov.read(8)) - - log.info("Patching %s with %d entries" % (atom.name, entry_count)) - - entries_pos = moov.tell() - - struct_fmt = ">%(entry_count)s%(ctype)s" % vars() - - # Read entries - entries = struct.unpack(struct_fmt, moov.read(csize * entry_count)) - - # Patch and write entries - offset_entries = [entry + offset for entry in entries] - moov.seek(entries_pos) - moov.write(struct.pack(struct_fmt, *offset_entries)) - return moov - -def get_chunks(stream, chunk_size, limit): - remaining = limit - while remaining: - chunk = stream.read(min(remaining, chunk_size)) - if not chunk: - return - remaining -= len(chunk) - yield chunk diff --git a/scraibe/autotranscript.py b/scraibe/autotranscript.py index b3545e4..2664e3f 100644 --- a/scraibe/autotranscript.py +++ b/scraibe/autotranscript.py @@ -75,6 +75,11 @@ class Scraibe: Path to pyannote diarization model or model itself. **kwargs: Additional keyword arguments for whisper and pyannote diarization models. + e.g.: + + - verbose: If True, the class will print additional information. + - save_kwargs: If True, the keyword arguments will be saved + for autotranscribe. So you can unload the class and reload it again. """ @@ -98,6 +103,15 @@ class Scraibe: else: self.verbose = False + # Save kwargs for autotranscribe if you want to unload the class and load it again. + if kwargs.get('save_setup'): + self.params = dict(whisper_model = whisper_model, + dia_model = dia_model, + **kwargs) + else: + self.params = {} + + def autotranscribe(self, audio_file : Union[str, torch.Tensor, ndarray], remove_original : bool = False, **kwargs) -> Transcript: diff --git a/scraibe/cli.py b/scraibe/cli.py index b05da92..7cc7b1d 100644 --- a/scraibe/cli.py +++ b/scraibe/cli.py @@ -9,7 +9,8 @@ from argparse import ArgumentParser, ArgumentDefaultsHelpFormatter import json from .autotranscript import Scraibe -from .app.gradio_app import gradio_Interface +from .misc import ParseKwargs + from whisper.tokenizer import LANGUAGES , TO_LANGUAGE_CODE from torch.cuda import is_available @@ -41,13 +42,15 @@ def cli(): help="List of audio files to transcribe.") group.add_argument('--start-server', action='store_true', - help='Start the Gradio app.') + help='Start the Gradio app.' \ + 'If set, all other arguments are ignored' \ + 'besides --server-config or --server-kwargs.') - parser.add_argument("--port", type=int, default= None, - help="Port to run the Gradio app on. Defaults to 7860.") + parser.add_argument("--server-config", type=str, default= None, + help="Path to the configy.yml file.") - parser.add_argument("--server-name", type=str, default= None, - help="Name of the Gradio app. If empty 127.0.0.1 or 0.0.0.0 will be used.") + parser.add_argument('--server-kwargs', nargs='*', action=ParseKwargs, default={}, + help='Keyword arguments for the Gradio app.') parser.add_argument("--whisper-model-name", default="medium", help="Name of the Whisper model to use.") @@ -66,7 +69,8 @@ def cli(): help="Device to use for PyTorch inference.") parser.add_argument("--num-threads", type=int, default=0, - help="Number of threads used by torch for CPU inference; overrides MKL_NUM_THREADS/OMP_NUM_THREADS.") + help="Number of threads used by torch for CPU inference; '\ + 'overrides MKL_NUM_THREADS/OMP_NUM_THREADS.") parser.add_argument("--output-directory", "-o", type=str, default=".", help="Directory to save the transcription outputs.") @@ -113,55 +117,70 @@ def cli(): if arg_dict["whisper_model_directory"]: class_kwargs["download_root"] = arg_dict.pop("whisper_model_directory") - model = Scraibe(**class_kwargs) - - - if arg_dict["audio_files"]: - audio_files = arg_dict.pop("audio_files") + if not start_server: - if task == "autotranscribe" or task == "autotranscribe+translate": - for audio in audio_files: - if task == "autotranscribe+translate": - task = "translate" - else: - task = "transcribe" - - out = model.autotranscribe(audio,task = task, language=arg_dict.pop("language"), verbose = arg_dict.pop("verbose_output")) - basename = audio.split("/")[-1].split(".")[0] - print(f'Saving {basename}.{out_format} to {out_folder}') - out.save(os.path.join(out_folder, f"{basename}.{out_format}")) - - elif task == "diarization": - for audio in audio_files: - if arg_dict.pop("verbose_output"): - print(f"Verbose not implemented for diarization.") - - out = model.diarization(audio) - basename = audio.split("/")[-1].split(".")[0] - path = os.path.join(out_folder, f"{basename}.{out_format}") - - print(f'Saving {basename}.{out_format} to {out_folder}') - - with open(path, "w") as f: - json.dump(json.dumps(out, indent= 1), f) + model = Scraibe(**class_kwargs) - elif task == "transcribe" or task == "translate": + if arg_dict["audio_files"]: + audio_files = arg_dict.pop("audio_files") - for audio in audio_files: - - out = model.transcribe(audio, task = task, - language= arg_dict.pop("language"), - verbose = arg_dict.pop("verbose_output")) - basename = audio.split("/")[-1].split(".")[0] - path = os.path.join(out_folder, f"{basename}.{out_format}") - with open(path, "w") as f: - f.write(out) + if task == "autotranscribe" or task == "autotranscribe+translate": + for audio in audio_files: + if task == "autotranscribe+translate": + task = "translate" + else: + task = "transcribe" + + out = model.autotranscribe(audio,task = task, language=arg_dict.pop("language"), verbose = arg_dict.pop("verbose_output")) + basename = audio.split("/")[-1].split(".")[0] + print(f'Saving {basename}.{out_format} to {out_folder}') + out.save(os.path.join(out_folder, f"{basename}.{out_format}")) + + elif task == "diarization": + for audio in audio_files: + if arg_dict.pop("verbose_output"): + print(f"Verbose not implemented for diarization.") + + out = model.diarization(audio) + basename = audio.split("/")[-1].split(".")[0] + path = os.path.join(out_folder, f"{basename}.{out_format}") + + print(f'Saving {basename}.{out_format} to {out_folder}') + + with open(path, "w") as f: + json.dump(json.dumps(out, indent= 1), f) + + elif task == "transcribe" or task == "translate": + for audio in audio_files: - if start_server: # unfinished code + out = model.transcribe(audio, task = task, + language= arg_dict.pop("language"), + verbose = arg_dict.pop("verbose_output")) + basename = audio.split("/")[-1].split(".")[0] + path = os.path.join(out_folder, f"{basename}.{out_format}") + with open(path, "w") as f: + f.write(out) + + + else: # unfinished code + raise NotImplementedError("Currently not Working") + import subprocess + import sys - gradio_Interface(model).queue().launch(server_port=args.port, server_name=args.server_name) + execute_path = os.path.join(os.path.dirname(__file__), "app/app_starter.py") + config = arg_dict.pop("server_config") + server_kwargs = arg_dict.pop("server_kwargs") + + if not config: + subprocess.run([sys.executable, execute_path, f"--server-kwargs={server_kwargs}"]) + elif not server_kwargs: + subprocess.run([sys.executable, execute_path, f"--server-config={config}"]) + elif not config and not server_kwargs: + subprocess.run([sys.executable, execute_path]) + else: + subprocess.run([sys.executable, execute_path, f"--server-config={config}", f"--server-kwargs={server_kwargs}"]) if __name__ == "__main__": cli() \ No newline at end of file diff --git a/scraibe/diarisation.py b/scraibe/diarisation.py index 1a33817..0f0e14a 100644 --- a/scraibe/diarisation.py +++ b/scraibe/diarisation.py @@ -27,7 +27,9 @@ Usage: diarisation_output = model.diarization("path/to/audiofile.wav") """ +import warnings import os +import yaml from pathlib import Path from typing import TypeVar, Union @@ -215,7 +217,42 @@ class Diariser: if not os.path.exists(model) and use_auth_token is None: use_auth_token = cls._get_token() + + elif os.path.exists(model) and not use_auth_token: + # check if model can be found locally nearby the config file + with open(model, 'r') as file: + config = yaml.safe_load(file) + + path_to_model = config['pipeline']['params']['segmentation'] + if not os.path.exists(path_to_model): + warnings.warn(f"Model not found at {path_to_model}. "\ + "Trying to find it nearby the config file.") + + pwd = model.split("/")[:-1] + pwd = "/".join(pwd) + + path_to_model = os.path.join(pwd, "pytorch_model.bin") + + if not os.path.exists(path_to_model): + warnings.warn(f"Model not found at {path_to_model}. \ + 'Trying to find it nearby .bin files instead.") + # list elementes with the ending .bin + bin_files = [f for f in os.listdir(pwd) if f.endswith(".bin")] + if len(bin_files) == 1: + path_to_model = os.path.join(pwd, bin_files[0]) + else: + warnings.warn("Found more than one .bin file. "\ + "or none. Please specify the path to the model " \ + "or setup a huggingface token.") + + warnings.warn(f"Found model at {path_to_model} overwriting config file.") + + config['pipeline']['params']['segmentation'] = path_to_model + + with open(model, 'w') as file: + yaml.dump(config, file) + _model = Pipeline.from_pretrained(model, use_auth_token = use_auth_token, cache_dir = cache_dir, diff --git a/scraibe/misc.py b/scraibe/misc.py index c912478..992e40c 100644 --- a/scraibe/misc.py +++ b/scraibe/misc.py @@ -1,6 +1,7 @@ import os import yaml from pyannote.audio.core.model import CACHE_DIR as PYANNOTE_CACHE_DIR +from argparse import Action CACHE_DIR = os.getenv( "AUTOT_CACHE", @@ -40,3 +41,17 @@ def config_diarization_yaml(file_path: str, path_to_segmentation: str = None) -> with open(file_path, "w") as stream: yaml.dump(yml, stream) + +class ParseKwargs(Action): + """ + Custom argparse action to parse keyword arguments. + """ + def __call__(self, parser, namespace, values, option_string=None): + setattr(namespace, self.dest, dict()) + for value in values: + key, value = value.split('=') + try: + value = eval(value) + except: + pass + getattr(namespace, self.dest)[key] = value \ No newline at end of file diff --git a/setup.py b/setup.py index 64d30b9..1e2c641 100644 --- a/setup.py +++ b/setup.py @@ -1,4 +1,3 @@ -from calendar import c import pkg_resources import os from setuptools import setup, find_packages @@ -21,6 +20,8 @@ with open(verfile, "r") as fp: build_version = "SCRAIBE_BUILD" in os.environ +version["ISRELEASED"] = True if "ISRELEASED" in os.environ else False + if __name__ == "__main__": setup( @@ -53,7 +54,7 @@ if __name__ == "__main__": keywords = ['transcription', 'speech recognition', 'whisper', 'pyannote', 'audio', 'ScrAIbe', 'scraibe', 'speech-to-text', 'speech-to-text transcription', 'speech-to-text recognition', 'voice-to-speech'], - package_data={'scraibe.app' : ["*.html", "*.svg"]}, + package_data={'scraibe.app' : ["*.html", "*.svg","*.yml"]}, entry_points={'console_scripts': ['scraibe = scraibe.cli:cli']}