diff --git a/README.md b/README.md index 999dba3..8ffe9d1 100644 --- a/README.md +++ b/README.md @@ -1 +1,47 @@ -# transcriptor \ No newline at end of file + +# `AutoTranscript`: Fully Automated Transcription using AI + +`AutoTranscript` is a [PyTorch](https://pytorch.org/) based interface for. To enable fully auomated Transcription using AI models containing speaker diarization models: + +- [whisper](https://github.com/openai/whisper): an a general-purpose speech recognition model +- [payannote-audio](https://github.com/pyannote/pyannote-audio) an open-source toolkit for speaker diarization + +Therefore `AutoTranscript` can be used as a Commandline Interface a Webserver or as a Python API. + +## Setup: +For this Project, Python 3.9 were [PyTorch](https://pytorch.org/) version 1.11.0 + +The following command will pull and install the latest commit from this repository, along with its Python dependencies. + + pip install https://github.com/JSchmie/autotranscript.git + +## Example Python usage + +```python +from autotranscript import AutoTranscribe + +model = AutoTranscribe() + +text = model.transcribe("audio.wav") + +print(f"Transcription: \n{text}") + +``` + +## Command-line usage + +If you not want to control the optimization using python, you also can use the Command-line: + + autotranscript audio.wav + +Run the following to view all available options: + + autotranscript -h + + +## License + +## Citation + + + diff --git a/app.py b/app.py new file mode 100644 index 0000000..3645d79 --- /dev/null +++ b/app.py @@ -0,0 +1,101 @@ +from dash import Dash, dcc, html, dash_table, Input, Output, State, callback + +import base64 +from autotranscript.app.qtfaststart import process +from autotranscript import AutoTranscribe +import io +import subprocess as sp +import numpy as np +from autotranscript.audio import SAMPLE_RATE + +# Setup auto-transcript +autot = AutoTranscribe() # whisper_model="tiny", whisper_kwargs={"local" : False} + +# Setup FFmpeg +PROBLEMATIC_FILE_TYPES : tuple = "mov","mp4","m4a","3gp","3g2","mj2" + + +# Setup Dash +external_stylesheets = ['https://codepen.io/chriddyp/pen/bWLwgP.css'] + +app = Dash(__name__, external_stylesheets=external_stylesheets) + +app.layout = html.Div([ + dcc.Upload( + id='upload-data', + children=html.Div([ + 'Drag and Drop or ', + html.A('Select Files') + ]), + style={ + 'width': '100%', + 'height': '60px', + 'lineHeight': '60px', + 'borderWidth': '1px', + 'borderStyle': 'dashed', + 'borderRadius': '5px', + 'textAlign': 'center', + 'margin': '10px' + }, + # Allow multiple files to be uploaded + multiple=True + ), + html.Div(id='output-data-upload'), +]) + +def parse_contents(contents, filename, date): + content_type, content_string = contents.split(',') + + decoded = base64.b64decode(content_string) + file = io.BytesIO(decoded).read() + + if filename.endswith(PROBLEMATIC_FILE_TYPES): + # mp4 and other files need to be processed with qtfaststart + # since theire metadata is at the end of the file + # and we need it at the beginning + file = process(file) + + cmd = [ + "ffmpeg", + "-nostdin", + "-threads", "0", + "-i",'pipe:', + "-f", "s16le", + '-hide_banner', + '-loglevel', 'error', + "-c", "copy", + "-vn", + "-ac", "1", + "-acodec", "pcm_s16le", + "-ar", str(SAMPLE_RATE), + "-" + ] + + proc = sp.Popen(cmd, stdout=sp.PIPE, stdin=sp.PIPE) + + out = proc.communicate(input=file)[0] + out = np.frombuffer(out, np.int16).flatten().astype(np.float32) / 32768.0 + out = np.array([out, SAMPLE_RATE]) + + transcript = str(autot.transcribe(out)) + + return html.Div([ + html.H5(f"File Name: {filename} \n" \ + "Transcript: \n" + ), + html.P(transcript) + ]) + +@callback(Output('output-data-upload', 'children'), + Input('upload-data', 'contents'), + State('upload-data', 'filename'), + State('upload-data', 'last_modified')) +def update_output(list_of_contents, list_of_names, list_of_dates): + if list_of_contents is not None: + children = [ + parse_contents(c, n, d) for c, n, d in + zip(list_of_contents, list_of_names, list_of_dates)] + return children + +if __name__ == '__main__': + app.run_server() diff --git a/autotranscript/__init__.py b/autotranscript/__init__.py index 13f245b..20bcc93 100644 --- a/autotranscript/__init__.py +++ b/autotranscript/__init__.py @@ -1,4 +1,10 @@ -from autotranscript.__main__ import * -from autotranscript.version import get_version as _get_version - -__version__ = _get_version() \ No newline at end of file +from .autotranscript import * +from .app.qtfaststart import * +from .transcriber import * +from .audio import * +from .transcript_exporter import * +from .diarisation import * +from .version import get_version as _get_version +from .misc import * + +__version__ = _get_version() diff --git a/autotranscript/__main__.py b/autotranscript/__main__.py deleted file mode 100644 index 19d5145..0000000 --- a/autotranscript/__main__.py +++ /dev/null @@ -1,497 +0,0 @@ - -import whisper -from time import time, sleep -import os -import glob -import re -import shutil -import sys -from tqdm import tqdm - -from typing import Union -from pydub import AudioSegment - -from pyannote.audio import Pipeline - -class AudioProcessor: - def __init__(self, audio_file:str): - self.audio_file_path = audio_file - self.audio_file = AudioSegment.from_file(audio_file, format=audio_file.split('.')[-1]) - - self.audiofilename = audio_file.split('/')[-1][:-4] - self.coreaudiofile = audio_file.split('/')[-1][:-4] - self.audiofilefolder = os.path.dirname(audio_file) - self.audio_file_type = audio_file.split('.')[-1] - - - - def convert_audio(self, savefolder: str = "", savename: str = "", type: str = "wav", remove_orginal: bool = True): - """ - Convert video file or other audio files to mp3 file, ensures that the audio file is in the correct format for the - Whisper model - :param file: path to audio or video file - :param remove_orginal: remove original file - :return: mp3 file path - """ - print(f'Converting {self.audiofilename} to .{type} file') - - if savefolder == "": - savefolder = self.audiofilefolder - - if savename == "": - savename = self.coreaudiofile + f'.{type}' - else: - savename = savename + f'.{type}' - - savepath = os.path.join(savefolder, savename) - - self.audio_file.export(savepath, format=type) - - print(f'Converted {self.audiofilename} to {type}') - - if remove_orginal: - os.remove(self.audio_file_path) - print(f'File {self.audio_file_path} removed') - - self.audio_file_path = savepath - self.audio_file = AudioSegment.from_file(savepath, format=type) - - return self - - def to_mp3(self, savefolder: str = "", savename: str = "", remove_orginal: bool = True): - """ - Convert audio file to mp3 file - :param file: audio file - :param remove_orginal: remove original file - :return: mp3 file path - """ - return self.convert_audio(savefolder = savefolder, savename = savename, type="mp3", remove_orginal=remove_orginal) - - def to_wav(self, savefolder: str = "", savename: str = "", remove_orginal: bool = True): - """ - Convert audio file to wav file - :param file: audio file - :param remove_orginal: remove original file - :return: wav file path - """ - return self.convert_audio(savefolder = savefolder, savename = savename,type="wav", remove_orginal=remove_orginal) - - def slower_mp3(self, savefolder: str = "", savename: str = "", speed: float = 0.75, type: str = "mp3"): - """ - Slow down mp3 file - :param file: mp3 file - :param speed: speed - :return: None - """ - if savefolder == "": - savefolder = self.audiofilefolder - else: - savefolder = savefolder - - sound = self.audio_file - slow_sound = sound._spawn(sound.raw_data, overrides={ - "frame_rate": int(sound.frame_rate * speed) - }) - - speedstr = str(speed).replace('.', '') - - file_out = self.coreaudiofile + f'_{speedstr}.{type}' - - save_path = os.path.join(savefolder, file_out) - - slow_sound.export(save_path, format=type) - - return slow_sound - -class WhisperTranscription: - def __init__(self, audio_file: str , model, language: str = "German"): - - self.audio_file = audio_file - self.model = model - self.language = language - - def transcribe(self, language:str = "German"): - """ - Transcribe audio file - - language: language of the audio file - :return: transcript as string - """ - - audiofilename = self.audio_file.split('/')[-1] - #print(f'Start transcribing Audio file: {audiofilename}') - - _stime = time() - result = self.model.transcribe(self.audio_file, language=self.language) - - #print(f'Transcription finished in {time() - _stime} seconds') - - self.transcript = result - - return result["text"] - - def save_transcript(self, transcript:str = "", savefolder : str = "", savename: str = ""): - """ - Save transcript to file - :param transcript: transcript as string - :param savefolder: folder to save transcript - :param savename: name of the transcript file - :return: None - """ - if savefolder == "": - savefolder = os.path.dirname(self.audio_file) - else: - savefolder = savefolder - - if savename == "": - savename = self.audio_file.split('/')[-1][:-4] + '.txt' - else: - savename = savename - - if transcript == "": - transcript = self.transcript["text"] - - savepath = os.path.join(savefolder, savename) - - with open(savepath, 'w') as f: - f.write(transcript) - - print(f'Transcript saved to {savepath}') - -class Diarisation(AudioProcessor): - def __init__(self, audio_file: str, model,**kwargs): - - super().__init__(audio_file=audio_file) - - self.model = model - - - def diarization(self, *args, **kwargs): - - if "num_speakers" in kwargs: - num_speakers = kwargs['num_speakers'] - kwargs.pop('num_speakers') - else: - num_speakers = 2 - - audiofilename = self.coreaudiofile - - print(f'Start diarization of audio file: {self.audiofilename}') - - _stime = time() - - diarization = self.model(self.audio_file_path, num_speakers=num_speakers) - - print(f'Diarization finished in {time() - _stime} seconds') - self.diarization = diarization - - return diarization - - def format_diarization_output(self, *args, **kwargs): - """ - Format diarization output to a list of tuples - :param args: - :param kwargs: - :return: dict with speaker names as keys and list of tuples as values and list of different speakers - """ - - diarization_output = {"speakers": [], "segments": []} - - if not hasattr(self, 'diarization'): - # ensure diarization is run before formatting - self.diarization = self.diarization() - - - for segment, _, speaker in self.diarization.itertracks(yield_label=True): - diarization_output["speakers"].append(speaker) - diarization_output["segments"].append(segment) - - normalized_output = [] - index_start_speaker = 0 - index_end_speaker = 0 - current_speaker = str() - - for i, speaker in enumerate(diarization_output["speakers"]): - - if i == 0: - current_speaker = speaker - - if speaker != current_speaker: - - index_end_speaker = i - 1 - - normalized_output.append([index_start_speaker, index_end_speaker, current_speaker]) - - index_start_speaker = i - current_speaker = speaker - - if i == len(diarization_output["speakers"]) - 1: - - index_end_speaker = i - normalized_output.append([index_start_speaker, index_end_speaker, current_speaker]) - - - self.normalized_output = normalized_output - self.diarization_output = diarization_output - - return diarization_output,normalized_output - - def create_temporary_wav(self,savefolder: str = "", savename: str = "", *args, **kwargs): - """ - Create temporary wav file for diarization - :param savefolder: folder to save the temporary wav file - :param savename: name of the temporary wav file prefix - :param audiofile: audio file - :return: temporary wav file - """ - - - if savefolder == "": - folder = '.temp' - if not os.path.exists(folder): - os.makedirs(folder) - else: - folder = savefolder - - folder = os.path.realpath(folder) - - if savename == "": - savename = self.coreaudiofile + '.wav' - else: - savename = savename - - - if not os.path.exists(folder): - os.makedirs(folder) - - if not hasattr(self, 'normalized_output') or not hasattr(self, 'diarization_output'): - self.format_diarization_output() - - - speaker = set(self.diarization_output["speakers"]) - num_speak_iter = [0 for _ in range(len(speaker))] - - for count, outp in enumerate(self.normalized_output): - start = self.diarization_output["segments"][outp[0]].start - end = self.diarization_output["segments"][outp[1]].end - - print("start: ", start) - print("end: ", end) - - start_milliseconds = start * 1000 - end_milliseconds = end * 1000 - - print("start_milliseconds: ", start_milliseconds) - print("end_milliseconds: ", end_milliseconds) - - print("cut audio") - - cut_audio = self.audio_file[start_milliseconds:end_milliseconds] - - print("save audio") - print(f".temp/{count}_speaker_" + str(outp[2]) + ".wav") - cut_audio.export(f".temp/{count}_speaker_" + str(outp[2]) + ".wav", format="wav") - - return os.path.realpath(folder) - - def __repr__(self): - return f"Diarization(audiofile={self.audiofile}, model={self.model}, language={self.language})" - def __str__(self): - return f"Diarization(audiofile={self.audiofile}, model={self.model}, language={self.language})" - - -class AutoTranscribe: - def __init__(self, audiofile: Union[str, bool, list] = None, - model: str = "medium", - language: str = "German", - diarisation: bool = False, - audioinput: str = "audiofiles", - transcriptionout: str = "transcriptions", - *args, **kwargs): - """ - AutoTranscribe - :param audiofile: audio file or list of audio files to transcribe - :param model: model name (default: medium) - :param language: language (default: German) - :param diarisation: diarisation (default: False) - """ - if audiofile is None: - audiofile = os.listdir(audioinput) # get all audio files in audioinput folder - audiofile = [os.path.realpath(os.path.join(audioinput, file)) for file in audiofile]# add path to audio files - - self.audiofile = audiofile - self.language = language - self.diarisation = diarisation - if diarisation: - print("Diarisation is enabled") - print("Load Diarisation model") - self.diarisation_model = Pipeline.from_pretrained("pyannote/speaker-diarization", - use_auth_token = self._get_token()) - print("Load Diarisation model done") - - print(f"Load Whisper model {model}") - self.model = whisper.load_model(model) - print(f"Load Whisper model {model} done") - - self.currentpath, \ - self.audiopath, \ - self.transcriptionpath, \ - self.audiofiles = self.create_folder_structure(audioinput, transcriptionout) # create folder structure - - - - def transcribe(self, *args, **kwargs): - - if isinstance(self.audiofile, str): - for i in range(len(self.audiofiles)): - if self.audiofile in self.audiofiles[i]: - self.audiofile = [self.audiofiles[i]] - break - - audiolist = self.audiofile - - elif isinstance(self.audiofile, list): - audiolist = self.audiofile - else: - audiolist = self.audiofiles - - if not set(audiolist).issubset(set(self.audiofiles)): - raise ValueError(f"Audio file {self.audiofile} not found in {self.audiopath}") - - - for audiofile in audiolist: - _start = time() - if not "/" in audiofile: - audiofile = os.path.join(self.audiopath, audiofile) - - if not self.check_if_already_transcribed (audiofile): - - audio = AudioProcessor(audiofile) - - if not audiofile.endswith('wav'): - audio = audio.to_wav() - self.audiofile = audio.audio_file_path - audiofile = audio.audio_file_path - - if "speed" in kwargs: - speed = kwargs['speed'] - kwargs.pop('speed') - - print('Creating slower version of the audio file with speed {}'.format(speed)) - slower_audio = os.path.join(self.transcriptionpath, 'slower_version') - if not os.path.exists(slower_audio): - os.makedirs(slower_audio) - audio.slower_mp3(savefolder=slower_audio,speed=speed) - - if not self.diarisation: - WhisperTranscription(audiofile, self.model, self.language - ).save_transcript(savefolder = self.transcriptionpath) - - else: - print("Start diarisation") - dia = Diarisation(audiofile, self.diarisation_model) - - if 'num_speakers' in kwargs: - num_speakers = kwargs['num_speakers'] - kwargs.pop('num_speakers') - dia.diarization(num_speakers=num_speakers) - else: - dia.diarization() - - temppath = dia.create_temporary_wav() - temppath_dict, _ = dia.format_diarization_output() - speakers = list(set(temppath_dict["speakers"])) - - - fstring = "\\begin{drama}" - - for speaker in speakers: - speaker = speaker.replace("SPEAKER_", "") - fstring += "\n\t\Character{S"+ str(speaker) + "}{S" + str(speaker) + "}" - - - files = glob.glob(temppath + "/*.wav") - - # Sort files according to the digits included in the filename - files = sorted(files, key=lambda x: float(re.findall("(\d+)", x)[0])) - - for file in tqdm(files): - - Whisper = WhisperTranscription(file, self.model, self.language).transcribe() - - for s in speakers: - if s in file: - s = s.replace("SPEAKER_", "") - fstring += f"\n\S{s}speaks: \n {Whisper}" - - fstring += "\n\end{drama}" - - print(fstring) - - with open(os.path.join(self.transcriptionpath, - os.path.basename(audiofile).split('.')[0] + '.tex'), 'w') as f: - f.write(fstring) - - print("Remove temporary files") - shutil.rmtree(temppath) - - print(f"Transcription of {audiofile} done in total of {time() - _start} seconds") - - def create_folder_structure(self, audiopath: str, transcriptionout: str): - """ - Create folder structure for audio and transcription files - - :return: currentpath, audiopath, transcriptionpath, audiofiles - """ - currentpath = os.path.dirname(sys.argv[0]) # get executable path - - if not os.path.exists(os.path.join(currentpath, audiopath)): - print('Creating audiofiles folder') - os.makedirs(os.path.join(currentpath, audiopath)) - if not os.path.exists(os.path.join(currentpath, transcriptionout)): - print('Creating transcription folder') - os.makedirs(os.path.join(currentpath, transcriptionout)) - - audiopath = os.path.join(currentpath, audiopath) # path to audio files - transcriptionpath = os.path.join(currentpath, transcriptionout) # path to transcription files - - - _audiofiles = os.listdir(audiopath) # list of audio files - audiofiles = [] - for i in _audiofiles: - audiofiles.append(os.path.join(audiopath, i)) - - return currentpath, audiopath, transcriptionpath, audiofiles - - def check_if_already_transcribed (self, filename: str): - """ - Check if all audio files are already transcribed - :param filename: audio file name - :return: bool - """ - purefilename = filename.split('/')[-1][:-4] - _files = os.listdir(self.transcriptionpath) - for i,f in enumerate(_files): - _files[i] = f[:-4] - - if purefilename in _files: - print(f'File {purefilename[:-4]} already transcribed') - return True - else: - return False - @classmethod - def _get_token(self): - # check ig .pyannotetoken.txt exists - path = os.path.join(os.path.dirname(os.path.realpath(__file__)), '.pyannotetoken') - if os.path.exists(path): - with open(path, 'r') as f: - token = f.read() - else: - raise ValueError('No token found. Please create a token at https://huggingface.co/settings/token' - ' and save it in a file called .pyannotetoken.txt') - return token - - def __repr__(self): - return f"AutoTranscribe(audiofile={self.audiofile}, model={self.model}, language={self.language}, diarisation={self.diarisation})" - def __call__(self, *args, **kwargs): - return self.transcribe(*args, **kwargs) diff --git a/autotranscript/__pycache__/__init__.cpython-39.pyc b/autotranscript/__pycache__/__init__.cpython-39.pyc deleted file mode 100644 index 04235a5..0000000 Binary files a/autotranscript/__pycache__/__init__.cpython-39.pyc and /dev/null differ diff --git a/autotranscript/__pycache__/__main__.cpython-39.pyc b/autotranscript/__pycache__/__main__.cpython-39.pyc deleted file mode 100644 index d64ee0a..0000000 Binary files a/autotranscript/__pycache__/__main__.cpython-39.pyc and /dev/null differ diff --git a/autotranscript/app/__init__.py b/autotranscript/app/__init__.py new file mode 100644 index 0000000..c61a882 --- /dev/null +++ b/autotranscript/app/__init__.py @@ -0,0 +1 @@ +from .qtfaststart import * \ No newline at end of file diff --git a/autotranscript/app/qtfaststart.py b/autotranscript/app/qtfaststart.py new file mode 100644 index 0000000..e57eb20 --- /dev/null +++ b/autotranscript/app/qtfaststart.py @@ -0,0 +1,319 @@ +""" +This file contains a modified version of qtfaststart by qtfaststart +https://github.com/danielgtaylor/qtfaststart/tree/master + +All credit goes to the original author. +Copyright (C) 2008 - 2013 Daniel G. Taylor +Permission is hereby granted, free of charge, to any person obtaining a copy of this +software and associated documentation files (the "Software"), +to deal in the Software without restriction, including without limitation the rights to +use, copy, modify, merge, publish, distribute, sublicense, and/or sell copies of the +Software, and to permit persons to whom the Software is furnished to do so, +subject to the following conditions: + +The above copyright notice and this permission notice shall be included in all copies +or substantial portions of the Software. + +THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR IMPLIED, +INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, +FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. +IN NO EVENT SHALL THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, +DAMAGES OR OTHER LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, +ARISING FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS +IN THE SOFTWARE. +""" + +import logging +import os +import struct +import collections +import io + +# define error classes +class FastStartException(Exception): + """ + Raised when something bad happens during processing. + """ + pass + +class FastStartSetupError(FastStartException): + """ + Rasised when asked to process a file that does not need processing + """ + pass + +class MalformedFileError(FastStartException): + """ + Raised when the input file is setup in an unexpected way + """ + pass + +class UnsupportedFormatError(FastStartException): + """ + Raised when a movie file is recognized as a format not supported. + """ + pass + +# define constants +CHUNK_SIZE = 8192 + +log = logging.getLogger("qtfaststart") + +# Older versions of Python require this to be defined +if not hasattr(os, 'SEEK_CUR'): + os.SEEK_CUR = 1 + +Atom = collections.namedtuple('Atom', 'name position size') + +def read_atom(datastream): + """ + Read an atom and return a tuple of (size, type) where size is the size + in bytes (including the 8 bytes already read) and type is a "fourcc" + like "ftyp" or "moov". + """ + size, type = struct.unpack(">L4s", datastream.read(8)) + type = type.decode('ascii') + return size, type + + +def _read_atom_ex(datastream): + """ + Read an Atom from datastream + """ + pos = datastream.tell() + atom_size, atom_type = read_atom(datastream) + if atom_size == 1: + atom_size, = struct.unpack(">Q", datastream.read(8)) + return Atom(atom_type, pos, atom_size) + + +def get_index(datastream): + """ + Return an index of top level atoms, their absolute byte-position in the + file and their size in a list: + + index = [ + ("ftyp", 0, 24), + ("moov", 25, 2658), + ("free", 2683, 8), + ... + ] + + The tuple elements will be in the order that they appear in the file. + """ + log.debug("Getting index of top level atoms...") + + index = list(_read_atoms(datastream)) + _ensure_valid_index(index) + + return index + + +def _read_atoms(datastream): + """ + Read atoms until an error occurs + """ + while datastream: + try: + atom = _read_atom_ex(datastream) + log.debug("%s: %s" % (atom.name, atom.size)) + except: + break + + yield atom + + if atom.size == 0: + if atom.name == "mdat": + # Some files may end in mdat with no size set, which generally + # means to seek to the end of the file. We can just stop indexing + # as no more entries will be found! + break + else: + # Weird, but just continue to try to find more atoms + continue + + datastream.seek(atom.position + atom.size) + + +def _ensure_valid_index(index): + """ + Ensure the minimum viable atoms are present in the index. + + Raise FastStartException if not. + """ + top_level_atoms = set([item.name for item in index]) + for key in ["moov", "mdat"]: + if key not in top_level_atoms: + log.error("%s atom not found, is this a valid MOV/MP4 file?" % key) + raise FastStartException() + + +def find_atoms(size, datastream): + """ + Compatibilty interface for _find_atoms_ex + """ + fake_parent = Atom('fake', datastream.tell()-8, size+8) + for atom in _find_atoms_ex(fake_parent, datastream): + yield atom.name + + +def _find_atoms_ex(parent_atom, datastream): + """ + Yield either "stco" or "co64" Atoms from datastream. + datastream will be 8 bytes into the stco or co64 atom when the value + is yielded. + + It is assumed that datastream will be at the end of the atom after + the value has been yielded and processed. + + parent_atom is the parent atom, a 'moov' or other ancestor of CO + atoms in the datastream. + """ + stop = parent_atom.position + parent_atom.size + + while datastream.tell() < stop: + try: + atom = _read_atom_ex(datastream) + except: + log.exception("Error reading next atom!") + raise FastStartException() + + if atom.name in ["trak", "mdia", "minf", "stbl"]: + # Known ancestor atom of stco or co64, search within it! + for res in _find_atoms_ex(atom, datastream): + yield res + elif atom.name in ["stco", "co64"]: + yield atom + else: + # Ignore this atom, seek to the end of it. + datastream.seek(atom.position + atom.size) + + +def process(infilename, limit=float('inf')): + """ + Convert a Quicktime/MP4 file for streaming by moving the metadata to + the front of the file. This method writes a new file. + + If limit is set to something other than zero it will be used as the + number of bytes to write of the atoms following the moov atom. This + is very useful to create a small sample of a file with full headers, + which can then be used in bug reports and such. + """ + if isinstance(infilename, str): + datastream = open(infilename, "rb") + elif isinstance(infilename, bytes): + datastream = io.BytesIO(infilename) + else: + raise TypeError("infilename must be a filename, bytes or file-like object") + # Get the top level atom index + index = get_index(datastream) + + mdat_pos = 999999 + free_size = 0 + + # Make sure moov occurs AFTER mdat, otherwise no need to run! + for atom in index: + # The atoms are guaranteed to exist from get_index above! + if atom.name == "moov": + moov_atom = atom + moov_pos = atom.position + elif atom.name == "mdat": + mdat_pos = atom.position + elif atom.name == "free" and atom.position < mdat_pos: + # This free atom is before the mdat! + free_size += atom.size + log.info("Removing free atom at %d (%d bytes)" % (atom.position, atom.size)) + elif atom.name == "\x00\x00\x00\x00" and atom.position < mdat_pos: + # This is some strange zero atom with incorrect size + free_size += 8 + log.info("Removing strange zero atom at %s (8 bytes)" % atom.position) + + # Offset to shift positions + offset = moov_atom.size - free_size + + if moov_pos < mdat_pos: + # moov appears to be in the proper place, don't shift by moov size + offset -= moov_atom.size + if not free_size: + # No free atoms and moov is correct, we are done! + log.error("This file appears to already be setup for streaming!") + # Stupid hack to retrun the non-processed file: + if isinstance(infilename, str): + return open(infilename, "rb").read() + elif isinstance(infilename, bytes): + return io.BytesIO(infilename).read() + + # Read and fix moov + moov = _patch_moov(datastream, moov_atom, offset) + + log.info("Writing output...") + outfile = b'' + + # Write ftype + for atom in index: + if atom.name == "ftyp": + log.debug("Writing ftyp... (%d bytes)" % atom.size) + datastream.seek(atom.position) + outfile += datastream.read(atom.size) + + # Write moov + _bytes = moov.getvalue() + log.debug("Writing moov... (%d bytes)" % len(_bytes)) + outfile += _bytes + + # Write the rest + atoms = [item for item in index if item.name not in ["ftyp", "moov", "free"]] + for atom in atoms: + log.debug("Writing %s... (%d bytes)" % (atom.name, atom.size)) + datastream.seek(atom.position) + + # for compatability, allow '0' to mean no limit + cur_limit = limit or float('inf') + cur_limit = min(cur_limit, atom.size) + + for chunk in get_chunks(datastream, CHUNK_SIZE, cur_limit): + outfile += chunk + + return outfile + + +def _patch_moov(datastream, atom, offset): + datastream.seek(atom.position) + moov = io.BytesIO(datastream.read(atom.size)) + + # reload the atom from the fixed stream + atom = _read_atom_ex(moov) + + for atom in _find_atoms_ex(atom, moov): + # Read either 32-bit or 64-bit offsets + ctype, csize = dict( + stco=('L', 4), + co64=('Q', 8), + )[atom.name] + + # Get number of entries + version, entry_count = struct.unpack(">2L", moov.read(8)) + + log.info("Patching %s with %d entries" % (atom.name, entry_count)) + + entries_pos = moov.tell() + + struct_fmt = ">%(entry_count)s%(ctype)s" % vars() + + # Read entries + entries = struct.unpack(struct_fmt, moov.read(csize * entry_count)) + + # Patch and write entries + offset_entries = [entry + offset for entry in entries] + moov.seek(entries_pos) + moov.write(struct.pack(struct_fmt, *offset_entries)) + return moov + +def get_chunks(stream, chunk_size, limit): + remaining = limit + while remaining: + chunk = stream.read(min(remaining, chunk_size)) + if not chunk: + return + remaining -= len(chunk) + yield chunk diff --git a/autotranscript/audio.py b/autotranscript/audio.py new file mode 100644 index 0000000..04feb1d --- /dev/null +++ b/autotranscript/audio.py @@ -0,0 +1,147 @@ +""" +Audio Processor Module +======================= + +This module provides the AudioProcessor class, utilizing PyTorchaudio for handling audio files. +It includes functionalities to load, cut, and manage audio waveforms, offering efficient and +flexible audio processing. + +Available Classes: +- AudioProcessor: Processes audio waveforms and provides methods for loading, + cutting, and handling audio. + +Usage: + from .audio_import AudioProcessor + + processor = AudioProcessor.from_file("path/to/audiofile.wav") + cut_waveform = processor.cut(start=1.0, end=5.0) + +Constants: +- SAMPLE_RATE (int): Default sample rate for processing. +- NORMALIZATION_FACTOR (float): Normalization factor for audio waveform. +""" + +from subprocess import CalledProcessError, run +import numpy as np +import torch + +SAMPLE_RATE = 16000 +NORMALIZATION_FACTOR = 32768.0 + +class AudioProcessor: + """ + Audio Processor class that leverages PyTorchaudio to provide functionalities + for loading, cutting, and handling audio waveforms. + + Attributes: + waveform: torch.Tensor + The audio waveform tensor. + sr: int + The sample rate of the audio. + """ + + def __init__(self, waveform: torch.Tensor, sr : int = SAMPLE_RATE, + *args, **kwargs) -> None: + + """ + Initialize the AudioProcessor object. + + Args: + waveform (torch.Tensor): The audio waveform tensor. + sr (int, optional): The sample rate of the audio. Defaults to SAMPLE_RATE. + args: Additional arguments. + kwargs: Additional keyword arguments, e.g., device to use for processing. + If CUDA is available, it defaults to CUDA. + + Raises: + ValueError: If the provided sample rate is not of type int. + """ + + device = kwargs.get("device", "cuda" if torch.cuda.is_available() else "cpu") + + self.waveform = waveform.to(device) + self.sr = sr + + if not isinstance(self.sr, int): + raise ValueError("Sample rate should be a single value of type int," \ + f"not {len(self.sr)} and type {type(self.sr)}") + + @classmethod + def from_file(cls, file: str, *args, **kwargs) -> 'AudioProcessor': + """ + Create an AudioProcessor instance from an audio file. + + Args: + file (str): The audio file path. + + Returns: + AudioProcessor: An instance of the AudioProcessor class containing the loaded audio. + """ + + audio, sr = cls.load_audio(file , *args, **kwargs) + + audio = torch.from_numpy(audio) + + return cls(audio, sr) + + + def cut(self, start: float, end: float) -> torch.Tensor: + """ + Cut a segment from the audio waveform between the specified start and end times. + + Args: + start (float): Start time in seconds. + end (float): End time in seconds. + + Returns: + torch.Tensor: The cut waveform segment. + """ + + start = int(start * self.sr) + end = int(torch.ceil(end * self.sr)) + return self.waveform[start:end] + + @staticmethod + def load_audio(file: str, sr: int = SAMPLE_RATE): + """ + Open an audio file and read it as a mono waveform, resampling if necessary. + This method ensures compatibility with pyannote.audio + and requires the ffmpeg CLI in PATH. + + Args: + file (str): The audio file to open. + sr (int, optional): The desired sample rate. Defaults to SAMPLE_RATE. + + Returns: + tuple: A NumPy array containing the audio waveform in float32 dtype + and the sample rate. + + Raises: + RuntimeError: If failed to load audio. + """ + # This launches a subprocess to decode audio while down-mixing + # and resampling as necessary. Requires the ffmpeg CLI in PATH. + # fmt: off + cmd = [ + "ffmpeg", + "-nostdin", + "-threads", "0", + "-i", file, + "-f", "s16le", + "-ac", "1", + "-acodec", "pcm_s16le", + "-ar", str(sr), + "-" + ] + # fmt: on + try: + out = run(cmd, capture_output=True, check=True).stdout + except CalledProcessError as e: + raise RuntimeError(f"Failed to load audio: {e.stderr.decode()}") from e + + out = np.frombuffer(out, np.int16).flatten().astype(np.float32) / NORMALIZATION_FACTOR + + return out , sr + + def __repr__(self) -> str: + return f'TorchAudioProcessor(waveform={len(self.waveform)}, sr={int(self.sr)})' \ No newline at end of file diff --git a/autotranscript/autotranscript.py b/autotranscript/autotranscript.py new file mode 100644 index 0000000..e053d6a --- /dev/null +++ b/autotranscript/autotranscript.py @@ -0,0 +1,360 @@ +""" +AutoTranscribe Class +-------------------- + +This class serves as the core of the transcription system, responsible for handling +transcription and diarization of audio files. It leverages pretrained models for +speech-to-text (such as Whisper) and speaker diarization (such as pyannote.audio), +providing an accessible interface for audio processing tasks such as transcription, +speaker separation, and timestamping. + +By encapsulating the complexities of underlying models, it allows for straightforward +integration into various applications, ranging from transcription services to voice assistants. + +Available Classes: +- AutoTranscribe: Main class for performing transcription and diarization. + Includes methods for loading models, processing audio files, + and formatting the transcription output. + +Usage: + from .autotranscribe import AutoTranscribe + + model = AutoTranscribe(whisper_model="path/to/whisper/model", dia_model="path/to/diarisation/model") + transcript = model.transcribe("path/to/audiofile.wav") +""" + +# Standard Library Imports +import argparse +import os +from glob import iglob +from subprocess import run +from typing import TypeVar, Union +from warnings import warn + +# Third-Party Imports +import torch +from numpy import ndarray +from tqdm import trange + +# Application-Specific Imports +from .audio import AudioProcessor +from .diarisation import Diariser +from .transcriber import Transcriber, whisper +from .transcript_exporter import Transcript + +DiarisationType = TypeVar('DiarisationType') + + +class AutoTranscribe: + """ + AutoTranscribe is a class responsible for managing the transcription and diarization of audio files. + It serves as the core of the transcription system, incorporating pretrained models + for speech-to-text (such as Whisper) and speaker diarization (such as pyannote.audio), + allowing for comprehensive audio processing. + + Attributes: + transcriber (Transcriber): The transcriber object to handle transcription. + diariser (Diariser): The diariser object to handle diarization. + + Methods: + __init__: Initializes the AutoTranscribe class with appropriate models. + transcribe: Transcribes an audio file using the whisper model and pyannote diarization model. + remove_audio_file: Removes the original audio file to avoid disk space issues or ensure data privacy. + get_audio_file: Gets an audio file as an AudioProcessor object. + """ + def __init__(self, + whisper_model: Union[bool, str, whisper] = None, + dia_model : Union[bool, str, DiarisationType] = None, + **kwargs) -> None: + """Initializes the AutoTranscribe class. + + Args: + whisper_model (Union[bool, str, whisper], optional): + Path to whisper model or whisper model itself. + diarisation_model (Union[bool, str, DiarisationType], optional): + Path to pyannote diarization model or model itself. + **kwargs: Additional keyword arguments for whisper + and pyannote diarization models. + """ + + if whisper_model is None: + self.transcriber = Transcriber.load_model("medium") + elif isinstance(whisper_model, str): + self.transcriber = Transcriber.load_model(whisper_model, **kwargs) + else: + self.transcriber = whisper_model + + if dia_model is None: + self.diariser = Diariser.load_model() + elif isinstance(dia_model, str): + self.diariser = Diariser.load_model(dia_model, **kwargs) + else: + self.diariser = dia_model + + print("AutoTranscribe initialized all models successfully loaded.") + + def transcribe(self, audio_file : Union[str, torch.Tensor, ndarray], + remove_original : bool = False, + **kwargs) -> Transcript: + """ + Transcribes an audio file using the whisper model and pyannote diarization model. + + Args: + audio_file (Union[str, torch.Tensor, ndarray]): + Path to audio file or a tensor representing the audio. + remove_original (bool, optional): If True, the original audio file will + be removed after transcription. + *args: Additional positional arguments for diarization and transcription. + **kwargs: Additional keyword arguments for diarization and transcription. + + Returns: + Transcript: A Transcript object containing the transcription, + which can be exported to different formats. + """ + + # Get audio file as an AudioProcessor object + audio_file = self.get_audio_file(audio_file) + + # Prepare waveform and sample rate for diarization + dia_audio = { + "waveform" : audio_file.waveform.reshape(1,len(audio_file.waveform)), + "sample_rate": audio_file.sr + } + + print("Starting diarisation.") + + diarisation = self.diariser.diarization(dia_audio, **kwargs) + + if not diarisation["segments"]: + warn("No segments found. Try to run transcription without diarisation.") + transcript = self.transcriber.transcribe(audio_file.waveform, **kwargs) + + final_transcript= {"speakers" : ["speaker01"], + "segments" : [0, len(audio_file.waveform)], + "text" : transcript} + + return Transcript(final_transcript) + + + print("Diarisation finished. Starting transcription.") + + audio_file.sr = torch.Tensor([audio_file.sr]).to(audio_file.waveform.device) + + # Transcribe each segment and store the results + final_transcript = dict() + + for i in trange(len(diarisation["segments"]), desc= "Transcribing"): + + seg = diarisation["segments"][i] + + audio = audio_file.cut(seg[0], seg[1]) + + transcript = self.transcriber.transcribe(audio, **kwargs) + + final_transcript[i] = {"speakers" : diarisation["speakers"][i], + "segments" : seg, + "text" : transcript} + + # Remove original file if needed + if remove_original: + if kwargs.get("shred") is True: + self.remove_audio_file(audio_file, shred=True) + else: + self.remove_audio_file(audio_file, shred=False) + + return Transcript(final_transcript) + + @staticmethod + def remove_audio_file(audio_file : str, + shred : bool = False) -> None: + """ + Removes the original audio file to avoid disk space issues or ensure data privacy. + + Args: + audio_file_path (str): Path to the audio file. + shred (bool, optional): If True, the audio file will be shredded, + not just removed. + """ + if not os.path.exists(audio_file): + raise ValueError(f"Audiofile {audio_file} does not exist.") + + if shred: + + warn("Shredding audiofile can take a long time.", RuntimeWarning) + + gen = iglob(f'{audio_file}', recursive=True) + cmd = ['shred', '-zvu', '-n', '10', f'{audio_file}'] + + if os.path.isdir(audio_file): + raise ValueError(f"Audiofile {audio_file} is a directory.") + + for file in gen: + print(f'shredding {file} now\n') + + run(cmd , check=True) + + else: + os.remove(audio_file) + print(f"Audiofile {audio_file} removed.") + + + + @staticmethod + def get_audio_file(audio_file : Union[str, torch.Tensor, ndarray], + *args, **kwargs) -> AudioProcessor: + """Gets an audio file as TorchAudioProcessor. + + Args: + audio_file (Union[str, torch.Tensor, ndarray]): Path to the audio file or + a tensor representing the audio. + *args: Additional positional arguments. + **kwargs: Additional keyword arguments. + + Returns: + AudioProcessor: An object containing the waveform and sample rate in + torch.Tensor format. + """ + + if isinstance(audio_file, str): + audio_file = AudioProcessor.from_file(audio_file) + + elif isinstance(audio_file, torch.Tensor): + audio_file = AudioProcessor(audio_file[0], audio_file[1]) + elif isinstance(audio_file, ndarray): + audio_file = AudioProcessor(torch.Tensor(audio_file[0]), + audio_file[1]) + + if not isinstance(audio_file, AudioProcessor): + raise ValueError(f'Audiofile must be of type AudioProcessor,' \ + f'not {type(audio_file)}') + return audio_file + + +def cli(): + """ + Command-Line Interface (CLI) for the AutoTranscribe class, allowing for user interaction to transcribe + and diarize audio files. The function includes arguments for specifying the audio files, model paths, + output formats, and other options necessary for transcription. + + This function can be executed from the command line to perform transcription tasks, providing a + user-friendly way to access the AutoTranscribe class functionalities. + """ + from whisper import available_models + from whisper.utils import get_writer + from whisper.tokenizer import LANGUAGES , TO_LANGUAGE_CODE + from .transcriber import WHISPER_DEFAULT_PATH + from .diarisation import PYANNOTE_DEFAULT_PATH + + def str2bool(string): + str2val = {"True": True, "False": False} + if string in str2val: + return str2val[string] + else: + raise ValueError(f"Expected one of {set(str2val.keys())}, got {string}") + + parser = argparse.ArgumentParser(formatter_class=argparse.ArgumentDefaultsHelpFormatter) + + parser.add_argument("-f","--audio_files", nargs="+", type=str, + help="List of audio files to transcribe.") + + parser.add_argument('--start_server', action='store_true', + help='Start the Gradio app.') + + parser.add_argument("--whisper_model_name", default="medium", + help="Name of the Whisper model to use.") + + parser.add_argument("--whisper_model_directory", type=str, default=WHISPER_DEFAULT_PATH, + help="Path to save Whisper model files; defaults to ./models/whisper.") + + parser.add_argument("--diarization_directory", type=str, default=PYANNOTE_DEFAULT_PATH, + help="Path to the diarization model directory.") + + parser.add_argument("--huggingface_token", default="", type=str, + help="HuggingFace token for private model download.") + + parser.add_argument("--allow_download", type=str2bool, default=False, + help="Allow model download if not found locally.") + + parser.add_argument("--inference_device", + default="cuda" if torch.cuda.is_available() else "cpu", + help="Device to use for PyTorch inference.") + + parser.add_argument("--num_threads", type=int, default=0, + help="Number of threads used by torch for CPU inference; overrides MKL_NUM_THREADS/OMP_NUM_THREADS.") + + parser.add_argument("--output_directory", "-o", type=str, default=".", + help="Directory to save the transcription outputs.") + + parser.add_argument("--output_format", "-f", type=str, default="txt", + choices=["txt", "json", "md", "html"], + help="Format of the output file; defaults to txt.") + + parser.add_argument("--verbose_output", type=str2bool, default=True, + help="Enable or disable progress and debug messages.") + + parser.add_argument("--transcription_task", type=str, default="transcribe", + choices=["transcribe", "diarize", "wtranscribe"], + help="Choose to perform transcription, diarization, or Whisper transcription.") + + parser.add_argument("--spoken_language", type=str, default=None, + choices=sorted(LANGUAGES.keys()) + sorted([k.title() for k in TO_LANGUAGE_CODE.keys()]), + help="Language spoken in the audio. Specify None to perform language detection.") + + args = parser.parse_args() + + output_directory = args.output_directory + num_threads = args.num_threads + whisper_model_directory = args.whisper_model_directory + allow_download = args.allow_download + inference_device = args.inference_device + whisper_model_name = args.whisper_model_name + diarization_directory = args.diarization_directory + huggingface_token = args.huggingface_token + transcription_task = args.transcription_task + audio_files = args.audio_files + spoken_language = args.spoken_language + output_format = args.output_format + start_server = args.start_server + + os.makedirs(output_directory, exist_ok=True) + + if num_threads > 0: + torch.set_num_threads(num_threads) + + whisper_kwargs = { + "download_root": whisper_model_directory, + "local": allow_download, + "device": inference_device + } + + diarisation_kwargs = { + "local": allow_download, + "token": huggingface_token + } + + model = AutoTranscribe(whisper_model=whisper_model_name, + whisper_kwargs=whisper_kwargs, + dia_model=diarization_directory, + dia_kwargs=diarisation_kwargs) + + if transcription_task == "transcribe": + for audio in audio_files: + out = model.transcribe(audio, language=spoken_language) + basename = audio.split("/")[-1].split(".")[0] + spath = f"{output_directory}/{basename}.{output_format}" + out.save(spath) + + # ... include other tasks here ... + elif transcription_task == "diarize": + # diarize code here + pass + elif transcription_task == "wtranscribe": + # wtranscribe code here + pass + + if start_server: + from .gradio_app import gradio_app + gradio_app(model) + +if __name__ == "__main__": + cli() \ No newline at end of file diff --git a/autotranscript/diarisation.py b/autotranscript/diarisation.py new file mode 100644 index 0000000..5cf60ce --- /dev/null +++ b/autotranscript/diarisation.py @@ -0,0 +1,239 @@ +""" +Diarisation Class +------------------ + +This class serves as the heart of the speaker diarization system, responsible for identifying +and segmenting individual speakers from a given audio file. It leverages a pretrained model +from pyannote.audio, providing an accessible interface for audio processing tasks such as +speaker separation, and timestamping. + +By encapsulating the complexities of the underlying model, it allows for straightforward +integration into various applications, ranging from transcription services to voice assistants. + +Available Classes: +- Diariser: Main class for performing speaker diarization. + Includes methods for loading models, processing audio files, + and formatting the diarization output. + +Constants: +- TOKEN_PATH (str): Path to the Pyannote token. +- PYANNOTE_DEFAULT_PATH (str): Default path to Pyannote models. +- PYANNOTE_DEFAULT_CONFIG (str): Default configuration for Pyannote models. + +Usage: + from .diarisation import Diariser + + model = Diariser.load_model(model="path/to/model/config.yaml") + diarisation_output = model.diarization("path/to/audiofile.wav") +""" + +import os +from pathlib import Path +from typing import TypeVar, Union + +from pyannote.audio import Pipeline +from pyannote.audio.pipelines.speaker_diarization import SpeakerDiarization +from torch import Tensor + +from .misc import PYANNOTE_DEFAULT_PATH, PYANNOTE_DEFAULT_CONFIG +Annotation = TypeVar('Annotation') + +TOKEN_PATH = os.path.join(os.path.dirname( + os.path.realpath(__file__)), '.pyannotetoken') + +class Diariser: + """ + Handles the diarization process of an audio file using a pretrained model + from pyannote.audio. Diarization is the task of determining "who spoke when." + + Args: + model: The pretrained model to use for diarization. + """ + + def __init__(self, model) -> None: + + self.model = model + + def diarization(self, audiofile : Union[str, Tensor, dict] , + *args, **kwargs) -> Annotation: + """ + Perform speaker diarization on the provided audio file, + effectively separating different speakers + and providing a timestamp for each segment. + + Args: + audiofile: The path to the audio file or a torch.Tensor + containing the audio data. + args: Additional arguments for the diarization model. + kwargs: Additional keyword arguments for the diarization model. + + Returns: + dict: A dictionary containing speaker names, + segments, and other information related + to the diarization process. + """ + kwargs = self._get_diarisation_kwargs(**kwargs) + + diarization = self.model(audiofile,*args, **kwargs) + + out = self.format_diarization_output(diarization) + + return out + + @staticmethod + def format_diarization_output(dia : Annotation) -> dict: + """ + Formats the raw diarization output into a more usable structure for this project. + + Args: + dia: Raw diarization output. + + Returns: + dict: A structured representation of the diarization, with speaker names + as keys and a list of tuples representing segments as values. + """ + + dia_list = list(dia.itertracks(yield_label=True)) + diarization_output = {"speakers": [], "segments": []} + + normalized_output = [] + index_start_speaker = 0 + index_end_speaker = 0 + current_speaker = str() + + ### + # Sometimes two consecutive speakers are the same + # This loop removes these duplicates + ### + + if len(dia_list) == 1: + normalized_output.append([0, 0, dia_list[0][2]]) + else: + + for i, (_, _, speaker) in enumerate(dia_list): + if i == 0: + current_speaker = speaker + + if speaker != current_speaker: + + index_end_speaker = i - 1 + + normalized_output.append([index_start_speaker, + index_end_speaker, + current_speaker]) + + index_start_speaker = i + current_speaker = speaker + + if i == len(diarization_output["speakers"]) - 1: + + index_end_speaker = i + normalized_output.append([index_start_speaker, + index_end_speaker, + current_speaker]) + + for outp in normalized_output: + start = dia_list[outp[0]][0].start + end = dia_list[outp[1]][0].end + + diarization_output["segments"].append([start, end]) + diarization_output["speakers"].append(outp[2]) + return diarization_output + + @staticmethod + def _get_token(): + """ + Retrieves the Huggingface token from a local file. This token is required + for accessing certain online resources. + + Raises: + ValueError: If the token is not found. + + Returns: + str: The Huggingface token. + """ + + if os.path.exists(TOKEN_PATH): + with open(TOKEN_PATH, 'r', encoding="utf-8") as file: + token = file.read() + else: + raise ValueError('No token found.' \ + 'Please create a token at https://huggingface.co/settings/token' \ + f'and save it in a file called {TOKEN_PATH}') + return token + + @staticmethod + def _save_token(token): + """ + Saves the provided Huggingface token to a local file. This facilitates future + access to online resources without needing to repeatedly authenticate. + + Args: + token: The Huggingface token to save. + """ + with open(TOKEN_PATH, 'w', encoding="utf-8") as file: + file.write(token) + + @classmethod + def load_model(cls, + model: str = PYANNOTE_DEFAULT_CONFIG, + token: str = None, + cache_token: bool = False, + cache_dir: Union[Path, str] = PYANNOTE_DEFAULT_PATH, + hparams_file: Union[str, Path] = None + ) -> Pipeline: + + """ + Loads a pretrained model from pyannote.audio, + either from a local cache or online repository. + + Args: + model: Path or identifier for the pyannote model. + default: /models/pyannote/speaker_diarization/config.yaml + token: Optional HUGGINGFACE_TOKEN for authenticated access. + cache_token: Whether to cache the token locally for future use. + cache_dir: Directory for caching models. + hparams_file: Path to a YAML file containing hyperparameters. + + Returns: + Pipeline: A pyannote.audio Pipeline object, encapsulating the loaded model. + """ + + if cache_token and token is not None: + cls._save_token(token) + + if not os.path.exists(model) and token is None: + token = cls._get_token() + model = 'pyannote/speaker-diarization' + + _model = Pipeline.from_pretrained(model, + use_auth_token = token, + cache_dir = cache_dir, + hparams_file = hparams_file,) + + if _model is None: + raise ValueError('Unable to load model either from local cache' \ + 'or from huggingface.co models. Please check your token' \ + 'or your local model path') + + return cls(_model) + + @staticmethod + def _get_diarisation_kwargs(**kwargs) -> dict: + """ + Validates and extracts the keyword arguments for the pyannote diarization model. + + Ensures that the provided keyword arguments match the expected parameters, + filtering out any invalid or unnecessary arguments. + + Returns: + dict: A dictionary containing the validated keyword arguments. + """ + _possible_kwargs = SpeakerDiarization.apply.__code__.co_varnames + + diarisation_kwargs = {k: v for k, v in kwargs.items() if k in _possible_kwargs} + + return diarisation_kwargs + + def __repr__(self): + return f"Diarisation(model={self.model})" diff --git a/autotranscript/misc.py b/autotranscript/misc.py new file mode 100644 index 0000000..399fcbb --- /dev/null +++ b/autotranscript/misc.py @@ -0,0 +1,41 @@ +import os +import yaml +from pyannote.audio.core.model import CACHE_DIR as PYANNOTE_CACHE_DIR + +CACHE_DIR = os.getenv( + "AUTOT_CACHE", + os.path.expanduser("~/.cache/torch/models"), +) + +if CACHE_DIR != PYANNOTE_CACHE_DIR: + os.environ["PYANNOTE_CACHE"] = os.path.join(CACHE_DIR, "pyannote") + +WHISPER_DEFAULT_PATH = os.path.join(CACHE_DIR, "whisper") +PYANNOTE_DEFAULT_PATH = os.path.join(CACHE_DIR, "pyannote") +PYANNOTE_DEFAULT_CONFIG = os.path.join(PYANNOTE_DEFAULT_PATH, "config.yaml") + + +def config_diarization_yaml(file_path: str, path_to_segmentation: str = None) -> None: + """Configure diarization pipeline from a YAML file. + + This function updates the YAML file to use the given segmentation model + offline, and avoids manual file manipulation. + + Args: + file_path (str): Path to the YAML file. + path_to_segmentation (str, optional): Optional path to the segmentation model. + + Raises: + FileNotFoundError: If the segmentation model file is not found. + """ + with open(file_path, "r") as stream: + yml = yaml.safe_load(stream) + + segmentation_path = path_to_segmentation or os.path.join(PYANNOTE_DEFAULT_PATH, "pytorch_model.bin") + yml["pipeline"]["params"]["segmentation"] = segmentation_path + + if not os.path.exists(segmentation_path): + raise FileNotFoundError(f"Segmentation model not found at {segmentation_path}") + + with open(file_path, "w") as stream: + yaml.dump(yml, stream) diff --git a/autotranscript/transcriber.py b/autotranscript/transcriber.py new file mode 100644 index 0000000..81787da --- /dev/null +++ b/autotranscript/transcriber.py @@ -0,0 +1,173 @@ +""" +Transcriber Module +------------------ + +This module provides the Transcriber class, a comprehensive tool for working with Whisper models. +The Transcriber class offers functionalities such as loading different Whisper models, transcribing audio files, +and saving transcriptions to text files. It acts as an interface between various Whisper models and the user, +simplifying the process of audio transcription. + +Main Features: + - Loading different sizes and versions of Whisper models. + - Transcribing audio in various formats including str, Tensor, and nparray. + - Saving the transcriptions to the specified paths. + - Adaptable to various language specifications. + - Options to control the verbosity of the transcription process. + +Constants: + WHISPER_DEFAULT_PATH: Default path for downloading and loading Whisper models. + +Usage: + >>> from your_package import Transcriber + >>> transcriber = Transcriber.load_model(model="medium") + >>> transcript = transcriber.transcribe(audio="path/to/audio.wav") + >>> transcriber.save_transcript(transcript, "path/to/save.txt") +""" + +from whisper import Whisper, load_model +from typing import TypeVar , Union , Optional +from torch import Tensor, device +from numpy import ndarray + + +from .misc import WHISPER_DEFAULT_PATH +whisper = TypeVar('whisper') + + + + +class Transcriber: + """ + Transcriber Class + ----------------- + + The Transcriber class serves as a wrapper around Whisper models for efficient audio + transcription. By encapsulating the intricacies of loading models, processing audio, + and saving transcripts, it offers an easy-to-use interface + for users to transcribe audio files. + + Attributes: + model (whisper): The Whisper model used for transcription. + + Methods: + transcribe: Transcribes the given audio file. + save_transcript: Saves the transcript to a file. + load_model: Loads a specific Whisper model. + _get_whisper_kwargs: Private method to get valid keyword arguments for the whisper model. + + Examples: + >>> transcriber = Transcriber.load_model(model="medium") + >>> transcript = transcriber.transcribe(audio="path/to/audio.wav") + >>> transcriber.save_transcript(transcript, "path/to/save.txt") + + Note: + The class supports various sizes and versions of Whisper models. Please refer to + the load_model method for available options. + """ + def __init__(self, model: whisper ) -> None: + """ + Initialize the Transcriber class with a Whisper model. + + Args: + model (whisper): The Whisper model to use for transcription. + """ + self.model = model + + def transcribe(self, audio : Union[str, Tensor, ndarray] , + *args, **kwargs) -> str: + """ + Transcribe an audio file. + + Args: + audio (Union[str, Tensor, nparray]): The audio file to transcribe. + *args: Additional arguments. + **kwargs: Additional keyword arguments, + such as the language of the audio file. + + Returns: + str: The transcript as a string. + """ + + kwargs = self._get_whisper_kwargs(**kwargs) + + if "verbose" not in kwargs: + kwargs["verbose"] = False + + result = self.model.transcribe(audio, *args, **kwargs) + return result["text"] + + @staticmethod + def save_transcript(transcript : str , save_path : str) -> None: + """ + Save a transcript to a file. + + Args: + transcript (str): The transcript as a string. + save_path (str): The path to save the transcript. + + Returns: + None + """ + + with open(save_path, 'w') as f: + f.write(transcript) + + print(f'Transcript saved to {save_path}') + + @classmethod + def load_model(cls, + model: str = "medium", + download_root: str = WHISPER_DEFAULT_PATH, + device: Optional[Union[str, device]] = None, + in_memory: bool = False, + ) -> 'Transcriber': + """ + Load whisper model. + + Args: + model (str): Whisper model. Available models include: + - 'tiny.en' + - 'tiny' + - 'base.en' + - 'base' + - 'small.en' + - 'small' + - 'medium.en' + - 'medium' + - 'large-v1' + - 'large-v2' + - 'large' + + download_root (str, optional): Path to download the model. + Defaults to WHISPER_DEFAULT_PATH. + + device (Optional[Union[str, torch.device]], optional): + Device to load model on. Defaults to None. + in_memory (bool, optional): Whether to load model in memory. + Defaults to False. + + Returns: + Transcriber: A Transcriber object initialized with the specified model. + """ + + _model = load_model(model, download_root=download_root, + device=device, in_memory=in_memory) + + return cls(_model) + + @staticmethod + def _get_whisper_kwargs(**kwargs) -> dict: + """ + Get kwargs for whisper model. Ensure that kwargs are valid. + + Returns: + dict: Keyword arguments for whisper model. + """ + _possible_kwargs = Whisper.transcribe.__code__.co_varnames + + whisper_kwargs = {k: v for k, v in kwargs.items() if k in _possible_kwargs} + + return whisper_kwargs + + def __repr__(self) -> str: + return f"Transcriber(model={self.model})" \ No newline at end of file diff --git a/autotranscript/transcript_exporter.py b/autotranscript/transcript_exporter.py new file mode 100644 index 0000000..9262be6 --- /dev/null +++ b/autotranscript/transcript_exporter.py @@ -0,0 +1,268 @@ +import json +import time + +ALPHABET = [*"abcdefghijklmnopqrstuvwxyz"] + + +class Transcript: + """ + Class for storing transcript data, including speaker information and text segments, + and exporting it to various file formats such as JSON, HTML, and LaTeX. + """ + + def __init__(self, transcript: dict) -> None: + """ + Initializes the Transcript object with the given transcript data. + + Args: + transcript (dict): A dictionary containing the formatted transcript string. + Keys should correspond to segment IDs, and values should + contain speaker and segment information. + """ + self.transcript = transcript + self.speakers = self._extract_speakers() + self.segments = self._extract_segments() + self.annotation = {} + + def annotate(self, *args, **kwargs) -> dict: + """ + Annotates the transcript to associate specific names with speakers. + + Args: + args (list): List of speaker names. These will be mapped sequentially to the speakers. + kwargs (dict): Dictionary with speaker names as keys and list of segments as values. + + Returns: + dict: Dictionary with speaker names as keys and the corresponding annotation as values. + + Raises: + ValueError: If the number of speaker names does not match the number + of speakers, or if an unknown speaker is found. + """ + + annotations = {} + if args and len(args) != len(self.speakers): + raise ValueError("Number of speaker names does not match number of speakers") + + if args: + for arg, speaker in zip(args, self.speakers): + annotations[speaker] = arg + + invalid_speakers = set(kwargs.keys()) - set(self.speakers) + if invalid_speakers: + raise ValueError(f"These keys are not speakers: {', '.join(invalid_speakers)}") + + annotations.update({key: kwargs[key] for key in self.speakers if key in kwargs}) + + self.annotation = annotations + return annotations + + def _extract_speakers(self) -> list: + """ + Extracts the unique speaker names from the transcript. + + Returns: + list: List of unique speaker names in the transcript. + """ + + return list(set([self.transcript[id]["speakers"] for id in self.transcript])) + + def _extract_segments(self) -> list: + """ + Extracts all the text segments from the transcript. + + Returns: + list: List of segments, where each segment is represented + by the starting and ending times. + """ + return [self.transcript[id]["segments"] for id in self.transcript] + + def __str__(self) -> str: + """ + Converts the transcript to a string representation. + + Returns: + str: String representation of the transcript, including speaker names and + time stamps for each segment. + """ + fstring = "" + + for _id in self.transcript: + seq = self.transcript[_id] + + if self.annotation: + speaker = self.annotation[seq["speakers"]] + else: + speaker = seq["speakers"] + + segm = seq["segments"] + sseg = time.strftime("%H:%M:%S",time.gmtime(segm[0])) + eseg = time.strftime("%H:%M:%S",time.gmtime(segm[1])) + + fstring += f"{speaker} ({sseg} ; {eseg}): {seq['text']}\n" + return fstring + + def __repr__(self) -> str: + """Return a string representation of the Transcript object. + + Returns: + str: A string that provides an informative description of the object. + """ + return f"Transcript(speakers = {self.speakers},"\ + f"segments = {self.segments}, annotation = {self.annotation})" + + def get_dict(self) -> dict: + """ + Get transcript as dict + + :return: transcript as dict + :rtype: dict + """ + + return self.transcript + + def get_json(self, *args, **kwargs) -> str: + """ + Get transcript as json string + :return: transcript as json string + :rtype: str + """ + if "indent" not in kwargs: + kwargs["indent"] = 4 + return json.dumps(self.transcript, *args, **kwargs) + + def get_html(self) -> str: + """ + Get transcript as html string + + :return: transcript as html string + :rtype: str + """ + html = "

" + self.__str__().replace("\n", "
") + "

" + html = "" + html + "" + html = html.replace("\t", "    ") + + return html + + def get_md(self) -> str: + """Get transcript as Markdown string, using HTML formatting. + + Returns: + str: Transcript as a Markdown string. + """ + return self.get_html() + + def get_tex(self) -> str: + """Get transcript as LaTeX string. If no annotations are present, the speakers will + be annotated with the first letters of the alphabet. + + Returns: + str: Transcript as LaTeX string. + """ + if not self.annotation: + + self.annotate(*ALPHABET[:len(self.speakers)]) + + fstring ="\\begin{drama}" + + for speaker in self.speakers: + + fstring += "\n\t\\Character{"+ str(self.annotation[speaker]) + "}" \ + "{"+ str(self.annotation[speaker]) + "}" + + for id in self.transcript: + seq = self.transcript[id] + speaker = self.annotation[seq["speakers"]] + fstring += f"\n\\{speaker}speaks:\n{seq['text']}" + + fstring += "\n\\end{drama}" + + return fstring + + + def to_json(self,path, *args, **kwargs) -> None: + """Save transcript as json file + + Args: + path (str): path to save file + """ + with open(path, "w") as f: + json.dump(self.transcript, f, *args, **kwargs) + + def to_txt(self, path: str) -> None: + """Save transcript as a LaTeX file (placeholder function, implementation needed). + + Args: + path (str): Path to save the LaTeX file. + """ + + with open(path, "w") as f: + f.write(self.__str__()) + + def to_md(self, path: str) -> None: + """Get transcript as Markdown string, using HTML formatting. + + Returns: + str: Transcript as a Markdown string. + """ + return self.to_html(path) + + def to_html(self, path: str) -> None: + """ + Save transcript as html file + + :param path: path to save file + :type path: str + """ + + with open(path, "w") as file: + file.write(self.get_html()) + + def to_tex(self, path: str) -> None: + """Save transcript as a LaTeX file (placeholder function, implementation needed). + + Args: + path (str): Path to save the LaTeX file. + """ + pass + + def to_pdf(self, path: str) -> None: + """Save transcript as a PDF file (placeholder function, implementation needed). + + Args: + path (str): Path to save the PDF file. + """ + pass + + def save(self, path: str, *args, **kwargs) -> None: + """Save transcript to file with the given path and file format. + + This method can save the transcript in various formats including JSON, TXT, + MD, HTML, TEX, and PDF. The file format is determined by the extension of + the path. + + Args: + path (str): Path to save the file, including the desired file extension. + *args: Additional positional arguments to be passed to the specific save methods. + **kwargs: Additional keyword arguments to be passed to the specific save methods. + + Raises: + ValueError: If the file format specified in the path is unknown. + """ + + if path.endswith(".json"): + self.to_json(path, *args, **kwargs) + elif path.endswith(".txt"): + self.to_txt(path, *args, **kwargs) + elif path.endswith(".md"): + self.to_md(path, *args, **kwargs) + elif path.endswith(".html"): + self.to_html(path, *args, **kwargs) + elif path.endswith(".tex"): + self.to_tex(path, *args, **kwargs) + elif path.endswith(".pdf"): + self.to_pdf(path, *args, **kwargs) + else: + raise ValueError("Unknown file format") + + \ No newline at end of file diff --git a/autotranscript/version.py b/autotranscript/version.py index 0a1a41e..0a3730e 100644 --- a/autotranscript/version.py +++ b/autotranscript/version.py @@ -1,8 +1,8 @@ import os import subprocess as sp -MAJOR = 1 -MINOR = 0 +MAJOR = 0 +MINOR = 1 MICRO = 0 MICRO_POST = 0 ISRELEASED = False diff --git a/gradio_app.py b/gradio_app.py new file mode 100644 index 0000000..321f8bc --- /dev/null +++ b/gradio_app.py @@ -0,0 +1,65 @@ +from autotranscript import AutoTranscribe +import gradio as gr + +LANGUAGES = [ + "Afrikaans", "Arabic", "Armenian", "Azerbaijani", "Belarusian", + "Bosnian", "Bulgarian", "Catalan", "Chinese", "Croatian", + "Czech", "Danish", "Dutch", "English", "Estonian", + "Finnish", "French", "Galician", "German", "Greek", + "Hebrew", "Hindi", "Hungarian", "Icelandic", "Indonesian", + "Italian", "Japanese", "Kannada", "Kazakh", "Korean", + "Latvian", "Lithuanian", "Macedonian", "Malay", "Marathi", + "Maori", "Nepali", "Norwegian", "Persian", "Polish", + "Portuguese", "Romanian", "Russian", "Serbian", "Slovak", + "Slovenian", "Spanish", "Swahili", "Swedish", "Tagalog", + "Tamil", "Thai", "Turkish", "Ukrainian", "Urdu", + "Vietnamese", "Welsh" +] + + +def gradio_server(model : AutoTranscribe): + + def transcribe(audio, microphone, number_of_speakers, language): + kwargs = {} + if number_of_speakers != 0: + kwargs["num_speakers"] = number_of_speakers + if language != "None": + kwargs["language"] = language + + if audio is not None: + out = model.transcribe(audio, **kwargs) + elif microphone is not None: + out = model.transcribe(microphone , **kwargs) + else: + out = "Please upload an audio file or record one." + + + return str(out) + + gr.Interface( + fn=transcribe, + inputs=[ + gr.Audio(source= "upload", type="filepath", label="Upload Your Audio File", interactive=True), + gr.Audio(source= "microphone", type="filepath", label="Record Your Audio", interactive=True), + gr.Number(value=0, label= "Number of speakers", + info = "Number of speakers in the audio file. If you don't know, leave it at 0."), + # gr.Number(value=0, label= "Minimal number of speakers", + # info = "Minimal number of speakers in the audio file. If you don't know or you have specified Numspeakers, leave it at 0."), + gr.Dropdown(LANGUAGES, + label="Languages", default="None", + info="Language of the audio file. If you don't know, leave it at None.") + ], + outputs=[ + "text" + ], + title="Audio Transcription", + thumbnail = "Logo_KIDA.png", + description="Upload an audio file to transcribe its content. Powered by AutoTranscribe!", + theme="soft", # Example of a more modern theme + ).launch(share=True) + + +if __name__ == "__main__": + + model = AutoTranscribe() + gradio_server(model) \ No newline at end of file diff --git a/requirements.txt b/requirements.txt index 619d0c4..b81b23c 100644 --- a/requirements.txt +++ b/requirements.txt @@ -1,152 +1,17 @@ -absl-py==1.3.0 -aiohttp==3.8.3 -aiosignal==1.3.1 -alembic==1.9.1 -antlr4-python3-runtime==4.9.3 -appdirs==1.4.4 -asteroid-filterbanks==0.4.0 -async-timeout==4.0.2 -attrs==22.2.0 -audioread==3.0.0 -autopage==0.5.1 -backports.cached-property==1.0.2 -brotlipy==0.7.0 -cachetools==5.2.0 -certifi==2023.5.7 -cffi==1.15.1 -charset-normalizer==2.1.1 -click==8.1.3 -cliff==4.1.0 -cmaes==0.9.0 -cmake==3.26.4 -cmd2==2.4.2 -colorama==0.4.6 -colorlog==6.7.0 -commonmark==0.9.1 -contourpy==1.0.6 -cryptography==39.0.1 -cycler==0.11.0 -decorator==4.4.2 -docopt==0.6.2 -einops==0.3.2 -ffmpeg-python==0.2.0 -filelock==3.8.0 -flit_core==3.8.0 -fonttools==4.38.0 -frozenlist==1.3.3 -fsspec==2022.11.0 -future==0.18.2 -google-auth==2.15.0 -google-auth-oauthlib==0.4.6 -greenlet==2.0.1 -grpcio==1.51.1 -hmmlearn==0.2.8 -huggingface-hub==0.11.0 -HyperPyYAML==1.1.0 -idna==3.4 -imageio==2.23.0 -imageio-ffmpeg==0.4.7 -importlib-metadata==4.13.0 -joblib==1.2.0 -julius==0.2.7 -kiwisolver==1.4.4 -librosa==0.9.2 -lit==16.0.5.post0 -llvmlite==0.39.1 -Mako==1.2.4 -Markdown==3.4.1 -MarkupSafe==2.1.1 -matplotlib==3.6.2 -mkl-fft==1.3.1 -mkl-random==1.2.2 -mkl-service==2.4.0 -more-itertools==9.0.0 -moviepy==1.0.3 -mpmath==1.2.1 -multidict==6.0.4 -networkx==2.8.8 -numba==0.56.4 -numpy==1.23.5 -oauthlib==3.2.2 -omegaconf==2.3.0 openai-whisper==20230314 -optuna==3.0.5 -packaging==21.3 -pandas==1.5.2 -pbr==5.11.0 -Pillow==9.4.0 -pip==23.0.1 -pooch==1.6.0 -prettytable==3.5.0 -primePy==1.3 -proglog==0.1.10 -protobuf==3.20.1 -pyannote.audio==2.1.1 -pyannote.core==4.5 -pyannote.database==4.1.3 -pyannote.metrics==3.2.1 -pyannote.pipeline==2.3 -pyasn1==0.4.8 -pyasn1-modules==0.2.8 -pycparser==2.21 -pyDeprecate==0.3.2 -pydub==0.25.1 -Pygments==2.13.0 -pyOpenSSL==23.0.0 -pyparsing==3.0.9 -pyperclip==1.8.2 -PySocks==1.7.1 -python-dateutil==2.8.2 -pytorch-lightning==1.6.5 -pytorch-metric-learning==1.6.3 -pytz==2022.7 -PyYAML==6.0 -regex==2022.10.31 -requests==2.28.1 -requests-oauthlib==1.3.1 -resampy==0.4.2 -rich==12.6.0 -rsa==4.9 -ruamel.yaml==0.17.21 -ruamel.yaml.clib==0.2.7 -scikit-learn==1.2.0 -scipy==1.8.1 -semantic-version==2.10.0 -semver==2.13.0 -sentencepiece==0.1.97 -setuptools==65.6.3 -setuptools-rust==1.5.2 -shellingham==1.5.0 -simplejson==3.18.0 -singledispatchmethod==1.0 -six==1.16.0 -sortedcontainers==2.4.0 -SoundFile==0.10.3.post1 -speechbrain==0.5.13 -SQLAlchemy==1.4.45 -stevedore==4.1.1 -sympy==1.11.1 -tabulate==0.9.0 -tensorboard==2.11.0 -tensorboard-data-server==0.6.1 -tensorboard-plugin-wit==1.8.1 -threadpoolctl==3.1.0 -tiktoken==0.3.1 -tokenizers==0.13.2 -torch==1.11.0 -torch-audiomentations==0.11.0 -torch-pitch-shift==1.2.2 -torchaudio==0.11.0 -torchmetrics==0.11.0 -torchvision==0.12.0 -tqdm==4.65.0 -transformers==4.24.0 -triton==2.0.0 -typer==0.7.0 -typing_extensions==4.4.0 -urllib3==1.26.15 -wcwidth==0.2.5 -Werkzeug==2.2.2 -wheel==0.38.4 -yarl==1.8.2 -zipp==3.11.0 + +pyannote.audio~=2.1.1 +pyannote.core~=4.5 +pyannote.database~=4.1.3 +pyannote.metrics~=3.2.1 +pyannote.pipeline~=2.3 + +setuptools~=65.6.3 +setuptools-rust~=1.5.2 + +tqdm>=4.65.0 + +#optional: +#dash~=2.10.2 + + diff --git a/setup.py b/setup.py index d6884d3..e7da608 100644 --- a/setup.py +++ b/setup.py @@ -3,7 +3,7 @@ import os from setuptools import setup, find_packages module_name = "autotranscript" -github_url = "https://github.com/Jaikinator/transcriptor" +github_url = "https://github.com/JSchmie/autotranscript" file_dir = os.path.dirname(os.path.realpath(__file__)) absdir = lambda p: os.path.join(file_dir, p) @@ -15,24 +15,28 @@ version = {"__file__": verfile} with open(verfile, "r") as fp: exec(fp.read(), version) + ############### setup ############### -build_version = "OPTB_BUILD" in os.environ +build_version = "AUTOTRANSCRIPT_BUILD" in os.environ -setup( - name=module_name, - version=version["get_version"](build_version), - packages=find_packages(), - python_requires="~=3.9", - readme="README.md", - install_requires = [str(r) for r in pkg_resources.parse_requirements( - open(os.path.join(os.path.dirname(__file__), "requirements.txt")) - ) - ], - url= github_url, - license='', - author='Jacob Schmieder', - author_email='', - description='Transcription tool for audio files based on Whisper', - #entry_points={'console_scripts': ['autotranscript = autotranscript.__main__:main']} -) +if __name__ == "__main__": + + setup( + name=module_name, + version=version["get_version"](build_version), + packages=find_packages(), + python_requires="~=3.9", + readme="README.md", + install_requires = [str(r) for r in pkg_resources.parse_requirements( + open(os.path.join(os.path.dirname(__file__), "requirements.txt")) + ) + ], + url= github_url, + license='', + author='Jacob Schmieder', + author_email='', + description='Transcription tool for audio files based on Whisper and Pyannote', + entry_points={'console_scripts': + ['autotranscript = autotranscript.autotranscript:cli']} + ) diff --git a/test_autotranscript.py b/test_autotranscript.py new file mode 100644 index 0000000..8f745a0 --- /dev/null +++ b/test_autotranscript.py @@ -0,0 +1,120 @@ +import pytest +from autotranscript import Transcriber +from unittest.mock import patch, mock_open +import os + +def test_load_pyannote_model(): + """ + Test load_pyannote_test + """ + from pyannote.audio.pipelines.speaker_diarization import SpeakerDiarization + from pyannote.audio import Pipeline + + pipeline = Pipeline.from_pretrained("models/pyannote/speaker_diarization/config.yaml") + assert isinstance(pipeline, SpeakerDiarization) + +# Test Transcribtion class + + +@pytest.fixture +def transcriber(): + """ + Prepare Transcriber for testing + Returns: Transcriber Object + """ + + return Transcriber.load_model("medium", local=True) + + +def test_Transcriber_init(transcriber): + """ + Test Transcriber initialization with a whisper model + """ + + assert isinstance(transcriber, Transcriber) + +def test_transcription(transcriber): + """ + Test transcription + """ + + transcript = transcriber.transcribe("tests/test.wav") + assert isinstance(transcript, str) + +def test_save_transcript_to_file(transcriber): + """ + Test save_transcript_to_file + """ + transcript = transcriber.transcribe("tests/test.wav") + + Transcriber.save_transcript(transcript, "tests/output.txt") + + assert os.path.exists("tests/output.txt") + + os.remove("tests/output.txt") + +# Test Diaraization class + +from autotranscript import Diariser + +@pytest.fixture +def diarisation(): + """ + Prepare Diarisation for testing + Returns: Diarisation Object + """ + + return Diariser.load_model("models/pyannote/speaker_diarization/config.yaml", local=True) + +def test_Diarisation_init(diarisation): + """ + Test Diarisation initialization with a pyannote model + """ + + assert isinstance(diarisation, Diariser) + +def test_diarisation(diarisation): + """ + Test diarisation + """ + + diarisation = diarisation.diarization("tests/test.wav") + assert isinstance(diarisation, dict) + +# Test AudioProcessor + +from autotranscript import AudioProcessor , TorchAudioProcessor + + +def test_AudioProcessor_init(): + """ + Test AudioProcessor initialization + """ + audio = AudioProcessor("tests/test.wav") + assert isinstance(audio, AudioProcessor) + +def test_AudioProcessor_convert(): + """ + Test AudioProcessor convert + """ + audio = AudioProcessor("tests/test.wav") + audio.convert_audio("tests/test.mp3", format="mp3") + assert os.path.exists("tests/test.mp3") + +def test_TorchAudioProcessor_from_file(): + """ + Test TorchAudioProcessor initialization + """ + audio = TorchAudioProcessor.from_file("tests/test.wav") + + assert isinstance(audio, TorchAudioProcessor) + + os.remove("tests/test.mp3") + + +def test_TorchAudioProcessor_from_ffmpeg(): + """ + Test TorchAudioProcessor initialization + """ + audio = TorchAudioProcessor.from_ffmpeg("tests/test.wav") + assert isinstance(audio, TorchAudioProcessor) diff --git a/transcribe.py b/transcribe.py index e7c62fa..73d8838 100644 --- a/transcribe.py +++ b/transcribe.py @@ -1,3 +1,38 @@ +# import os +# import sys +# import traceback + +# class TracePrints(object): +# def __init__(self): +# self.stdout = sys.stdout +# def write(self, s): +# self.stdout.write("Writing %r\n" % s) +# traceback.print_stack(file=self.stdout) + +# sys.stdout = TracePrints() + +# os.environ["PYANNOTE_CACHE"] = os.path.expanduser("~/PycharmProjects/autotranscript/autotranscript/models/pyannote") +# import os + +# os.environ['TRANSFORMERS_CACHE'] = os.path.expanduser("~/PycharmProjects/autotranscript/autotranscript/models") +# os.environ['HF_HOME'] = os.path.expanduser("~/PycharmProjects/autotranscript/autotranscript/models") + + from autotranscript import AutoTranscribe -AutoTranscribe(diarisation=True).transcribe() +model = AutoTranscribe() + +text = model.transcribe("test.mp4") + +print("Transcription:\n") +print(text) + + +# from autotranscript.misc import * +# import os + +# print(os.path.exists(CACHE_DIR)) +# print(os.path.exists(WHISPER_DEFAULT_PATH)) +# print(os.path.exists(PYANNOTE_DEFAULT_PATH)) + +# print(os.path.exists(PYANNOTE_DEFAULT_CONFIG))