diff --git a/scraibe/__init__.py b/scraibe/__init__.py index 233cd4f..eb0cc68 100644 --- a/scraibe/__init__.py +++ b/scraibe/__init__.py @@ -8,5 +8,5 @@ from .version import get_version as _get_version from .misc import * from .cli import * - + __version__ = _get_version() diff --git a/scraibe/audio.py b/scraibe/audio.py index 4d457b6..7fbc6fb 100644 --- a/scraibe/audio.py +++ b/scraibe/audio.py @@ -28,6 +28,7 @@ import torch SAMPLE_RATE = 16000 NORMALIZATION_FACTOR = 32768.0 + class AudioProcessor: """ Audio Processor class that leverages PyTorchaudio to provide functionalities @@ -39,10 +40,9 @@ class AudioProcessor: sr: int The sample rate of the audio. """ - - def __init__(self, waveform: torch.Tensor, sr : int = SAMPLE_RATE, + + def __init__(self, waveform: torch.Tensor, sr: int = SAMPLE_RATE, *args, **kwargs) -> None: - """ Initialize the AudioProcessor object. @@ -56,16 +56,17 @@ class AudioProcessor: Raises: ValueError: If the provided sample rate is not of type int. """ - - device = kwargs.get("device", "cuda" if torch.cuda.is_available() else "cpu") - + + device = kwargs.get( + "device", "cuda" if torch.cuda.is_available() else "cpu") + self.waveform = waveform.to(device) self.sr = sr - + if not isinstance(self.sr, int): - raise ValueError("Sample rate should be a single value of type int," \ + raise ValueError("Sample rate should be a single value of type int," f"not {len(self.sr)} and type {type(self.sr)}") - + @classmethod def from_file(cls, file: str, *args, **kwargs) -> 'AudioProcessor': """ @@ -77,14 +78,13 @@ class AudioProcessor: Returns: AudioProcessor: An instance of the AudioProcessor class containing the loaded audio. """ - - audio, sr = cls.load_audio(file , *args, **kwargs) + + audio, sr = cls.load_audio(file, *args, **kwargs) audio = torch.from_numpy(audio) - + return cls(audio, sr) - - + def cut(self, start: float, end: float) -> torch.Tensor: """ Cut a segment from the audio waveform between the specified start and end times. @@ -96,7 +96,7 @@ class AudioProcessor: Returns: torch.Tensor: The cut waveform segment. """ - + start = int(start * self.sr) if (isinstance(end, float) or isinstance(end, int)) and isinstance(self.sr, int): end = int(np.ceil(end * self.sr)) @@ -140,11 +140,13 @@ class AudioProcessor: try: out = run(cmd, capture_output=True, check=True).stdout except CalledProcessError as e: - raise RuntimeError(f"Failed to load audio: {e.stderr.decode()}") from e + raise RuntimeError( + f"Failed to load audio: {e.stderr.decode()}") from e + + out = np.frombuffer(out, np.int16).flatten().astype( + np.float32) / NORMALIZATION_FACTOR + + return out, sr - out = np.frombuffer(out, np.int16).flatten().astype(np.float32) / NORMALIZATION_FACTOR - - return out , sr - def __repr__(self) -> str: - return f'TorchAudioProcessor(waveform={len(self.waveform)}, sr={int(self.sr)})' \ No newline at end of file + return f'TorchAudioProcessor(waveform={len(self.waveform)}, sr={int(self.sr)})' diff --git a/scraibe/autotranscript.py b/scraibe/autotranscript.py index cf77a62..14d2451 100644 --- a/scraibe/autotranscript.py +++ b/scraibe/autotranscript.py @@ -55,18 +55,19 @@ class Scraibe: Attributes: transcriber (Transcriber): The transcriber object to handle transcription. diariser (Diariser): The diariser object to handle diarization. - + Methods: __init__: Initializes the Scraibe class with appropriate models. transcribe: Transcribes an audio file using the whisper model and pyannote diarization model. remove_audio_file: Removes the original audio file to avoid disk space issues or ensure data privacy. get_audio_file: Gets an audio file as an AudioProcessor object. """ + def __init__(self, - whisper_model: Union[bool, str, whisper] = None, - whisper_type: str = "whisper", - dia_model : Union[bool, str, DiarisationType] = None, - **kwargs) -> None: + whisper_model: Union[bool, str, whisper] = None, + whisper_type: str = "whisper", + dia_model: Union[bool, str, DiarisationType] = None, + **kwargs) -> None: """Initializes the Scraibe class. Args: @@ -84,12 +85,13 @@ class Scraibe: - save_kwargs: If True, the keyword arguments will be saved for autotranscribe. So you can unload the class and reload it again. """ - - + if whisper_model is None: - self.transcriber = Transcriber.load_model("medium", whisper_type, **kwargs) + self.transcriber = Transcriber.load_model( + "medium", whisper_type, **kwargs) elif isinstance(whisper_model, str): - self.transcriber = Transcriber.load_model(whisper_model, whisper_type, **kwargs) + self.transcriber = Transcriber.load_model( + whisper_model, whisper_type, **kwargs) else: self.transcriber = whisper_model @@ -98,26 +100,25 @@ class Scraibe: elif isinstance(dia_model, str): self.diariser = Diariser.load_model(dia_model, **kwargs) else: - self.diariser : Diariser = dia_model + self.diariser: Diariser = dia_model if kwargs.get("verbose"): print("Scraibe initialized all models successfully loaded.") self.verbose = True else: self.verbose = False - + # Save kwargs for autotranscribe if you want to unload the class and load it again. - if kwargs.get('save_setup'): - self.params = dict(whisper_model = whisper_model, - dia_model = dia_model, + if kwargs.get('save_setup'): + self.params = dict(whisper_model=whisper_model, + dia_model=dia_model, **kwargs) else: self.params = {} - - - def autotranscribe(self, audio_file : Union[str, torch.Tensor, ndarray], - remove_original : bool = False, - **kwargs) -> Transcript: + + def autotranscribe(self, audio_file: Union[str, torch.Tensor, ndarray], + remove_original: bool = False, + **kwargs) -> Transcript: """ Transcribes an audio file using the whisper model and pyannote diarization model. @@ -136,60 +137,62 @@ class Scraibe: if kwargs.get("verbose"): self.verbose = kwargs.get("verbose") # Get audio file as an AudioProcessor object - audio_file : AudioProcessor = self.get_audio_file(audio_file) - + audio_file: AudioProcessor = self.get_audio_file(audio_file) + # Prepare waveform and sample rate for diarization dia_audio = { - "waveform" : audio_file.waveform.reshape(1,len(audio_file.waveform)), + "waveform": audio_file.waveform.reshape(1, len(audio_file.waveform)), "sample_rate": audio_file.sr - } + } if self.verbose: print("Starting diarisation.") - + diarisation = self.diariser.diarization(dia_audio, **kwargs) - + if not diarisation["segments"]: print("No segments found. Try to run transcription without diarisation.") - - transcript = self.transcriber.transcribe(audio_file.waveform, **kwargs) - - final_transcript= {0 : {"speakers" : 'SPEAKER_01', - "segments" : [0, len(audio_file.waveform)], - "text" : transcript}} - + + transcript = self.transcriber.transcribe( + audio_file.waveform, **kwargs) + + final_transcript = {0: {"speakers": 'SPEAKER_01', + "segments": [0, len(audio_file.waveform)], + "text": transcript}} + return Transcript(final_transcript) - + if self.verbose: print("Diarisation finished. Starting transcription.") - - audio_file.sr = torch.Tensor([audio_file.sr]).to(audio_file.waveform.device) - + + audio_file.sr = torch.Tensor([audio_file.sr]).to( + audio_file.waveform.device) + # Transcribe each segment and store the results final_transcript = dict() - - for i in trange(len(diarisation["segments"]), desc= "Transcribing", disable = not self.verbose): - + + for i in trange(len(diarisation["segments"]), desc="Transcribing", disable=not self.verbose): + seg = diarisation["segments"][i] - + audio = audio_file.cut(seg[0], seg[1]) - + transcript = self.transcriber.transcribe(audio, **kwargs) - - final_transcript[i] = {"speakers" : diarisation["speakers"][i], - "segments" : seg, - "text" : transcript} - - # Remove original file if needed + + final_transcript[i] = {"speakers": diarisation["speakers"][i], + "segments": seg, + "text": transcript} + + # Remove original file if needed if remove_original: if kwargs.get("shred") is True: self.remove_audio_file(audio_file, shred=True) else: self.remove_audio_file(audio_file, shred=False) - + return Transcript(final_transcript) - def diarization(self, audio_file : Union[str, torch.Tensor, ndarray], + def diarization(self, audio_file: Union[str, torch.Tensor, ndarray], **kwargs) -> dict: """ Perform diarization on an audio file using the pyannote diarization model. @@ -204,24 +207,24 @@ class Scraibe: dict: A dictionary containing the results of the diarization process. """ - + # Get audio file as an AudioProcessor object - audio_file : AudioProcessor = self.get_audio_file(audio_file) - + audio_file: AudioProcessor = self.get_audio_file(audio_file) + # Prepare waveform and sample rate for diarization dia_audio = { - "waveform" : audio_file.waveform.reshape(1,len(audio_file.waveform)), + "waveform": audio_file.waveform.reshape(1, len(audio_file.waveform)), "sample_rate": audio_file.sr - } - + } + print("Starting diarisation.") - + diarisation = self.diariser.diarization(dia_audio, **kwargs) - + return diarisation - - def transcribe(self, audio_file : Union[str, torch.Tensor, ndarray], - **kwargs): + + def transcribe(self, audio_file: Union[str, torch.Tensor, ndarray], + **kwargs): """ Transcribe the provided audio file. @@ -235,11 +238,11 @@ class Scraibe: str: The transcribed text from the audio source. """ - audio_file : AudioProcessor = self.get_audio_file(audio_file) - - return self.transcriber.transcribe(audio_file.waveform, **kwargs) - - def update_transcriber(self, whisper_model : Union[str, whisper], **kwargs) -> None: + audio_file: AudioProcessor = self.get_audio_file(audio_file) + + return self.transcriber.transcribe(audio_file.waveform, **kwargs) + + def update_transcriber(self, whisper_model: Union[str, whisper], **kwargs) -> None: """ Update the transcriber model. @@ -248,22 +251,23 @@ class Scraibe: The new whisper model to use for transcription. **kwargs: Additional keyword arguments for the transcriber model. - + Returns: None """ _old_model = self.transcriber.model_name - + if isinstance(whisper_model, str): self.transcriber = Transcriber.load_model(whisper_model, **kwargs) elif isinstance(whisper_model, Transcriber): self.transcriber = whisper_model else: - warn(f"Invalid model type. Please provide a valid model. Fallback to old {_old_model} Model.", RuntimeWarning) + warn( + f"Invalid model type. Please provide a valid model. Fallback to old {_old_model} Model.", RuntimeWarning) return None - def update_diariser(self, dia_model : Union[str, DiarisationType], **kwargs) -> None: + def update_diariser(self, dia_model: Union[str, DiarisationType], **kwargs) -> None: """ Update the diariser model. @@ -272,7 +276,7 @@ class Scraibe: The new diariser model to use for diarization. **kwargs: Additional keyword arguments for the diariser model. - + Returns: None """ @@ -281,13 +285,13 @@ class Scraibe: elif isinstance(dia_model, Diariser): self.diariser = dia_model else: - warn(f"Invalid model type. Please provide a valid model. Fallback to old Model.", RuntimeWarning) - + warn("Invalid model type. Please provide a valid model. Fallback to old Model.", RuntimeWarning) + return None - + @staticmethod - def remove_audio_file(audio_file : str, - shred : bool = False) -> None: + def remove_audio_file(audio_file: str, + shred: bool = False) -> None: """ Removes the original audio file to avoid disk space issues or ensure data privacy. @@ -298,30 +302,29 @@ class Scraibe: """ if not os.path.exists(audio_file): raise ValueError(f"Audiofile {audio_file} does not exist.") - + if shred: - + warn("Shredding audiofile can take a long time.", RuntimeWarning) - + gen = iglob(f'{audio_file}', recursive=True) cmd = ['shred', '-zvu', '-n', '10', f'{audio_file}'] - + if os.path.isdir(audio_file): raise ValueError(f"Audiofile {audio_file} is a directory.") - + for file in gen: print(f'shredding {file} now\n') - - run(cmd , check=True) + + run(cmd, check=True) else: os.remove(audio_file) print(f"Audiofile {audio_file} removed.") - - + @staticmethod - def get_audio_file(audio_file : Union[str, torch.Tensor, ndarray], - *args, **kwargs) -> AudioProcessor: + def get_audio_file(audio_file: Union[str, torch.Tensor, ndarray], + *args, **kwargs) -> AudioProcessor: """Gets an audio file as TorchAudioProcessor. Args: @@ -334,20 +337,20 @@ class Scraibe: AudioProcessor: An object containing the waveform and sample rate in torch.Tensor format. """ - + if isinstance(audio_file, str): - audio_file = AudioProcessor.from_file(audio_file) - + audio_file = AudioProcessor.from_file(audio_file) + elif isinstance(audio_file, torch.Tensor): audio_file = AudioProcessor(audio_file[0], audio_file[1]) elif isinstance(audio_file, ndarray): audio_file = AudioProcessor(torch.Tensor(audio_file[0]), - audio_file[1]) - + audio_file[1]) + if not isinstance(audio_file, AudioProcessor): - raise ValueError(f'Audiofile must be of type AudioProcessor,' \ - f'not {type(audio_file)}') - + raise ValueError(f'Audiofile must be of type AudioProcessor,' + f'not {type(audio_file)}') + return audio_file def __repr__(self): diff --git a/scraibe/cli.py b/scraibe/cli.py index 7cc7b1d..b6f2c17 100644 --- a/scraibe/cli.py +++ b/scraibe/cli.py @@ -4,7 +4,7 @@ allowing for user interaction to transcribe and diarize audio files. The function includes arguments for specifying the audio files, model paths, output formats, and other options necessary for transcription. """ -import os +import os from argparse import ArgumentParser, ArgumentDefaultsHelpFormatter import json @@ -12,7 +12,7 @@ from .autotranscript import Scraibe from .misc import ParseKwargs -from whisper.tokenizer import LANGUAGES , TO_LANGUAGE_CODE +from whisper.tokenizer import LANGUAGES, TO_LANGUAGE_CODE from torch.cuda import is_available from torch import set_num_threads @@ -26,42 +26,43 @@ def cli(): This function can be executed from the command line to perform transcription tasks, providing a user-friendly way to access the Scraibe class functionalities. """ - + def str2bool(string): str2val = {"True": True, "False": False} if string in str2val: return str2val[string] else: - raise ValueError(f"Expected one of {set(str2val.keys())}, got {string}") + raise ValueError( + f"Expected one of {set(str2val.keys())}, got {string}") - parser = ArgumentParser(formatter_class = ArgumentDefaultsHelpFormatter) + parser = ArgumentParser(formatter_class=ArgumentDefaultsHelpFormatter) group = parser.add_mutually_exclusive_group() - - parser.add_argument("-f","--audio-files", nargs="+", type=str, default=None, + + parser.add_argument("-f", "--audio-files", nargs="+", type=str, default=None, help="List of audio files to transcribe.") - + group.add_argument('--start-server', action='store_true', - help='Start the Gradio app.' \ - 'If set, all other arguments are ignored' \ - 'besides --server-config or --server-kwargs.') - - parser.add_argument("--server-config", type=str, default= None, + help='Start the Gradio app.' + 'If set, all other arguments are ignored' + 'besides --server-config or --server-kwargs.') + + parser.add_argument("--server-config", type=str, default=None, help="Path to the configy.yml file.") - + parser.add_argument('--server-kwargs', nargs='*', action=ParseKwargs, default={}, help='Keyword arguments for the Gradio app.') - + parser.add_argument("--whisper-model-name", default="medium", help="Name of the Whisper model to use.") - parser.add_argument("--whisper-model-directory", type=str, default= None, + parser.add_argument("--whisper-model-directory", type=str, default=None, help="Path to save Whisper model files; defaults to ./models/whisper.") - parser.add_argument("--diarization-directory", type=str, default= None, + parser.add_argument("--diarization-directory", type=str, default=None, help="Path to the diarization model directory.") - parser.add_argument("--hf-token", default= None, type=str, + parser.add_argument("--hf-token", default=None, type=str, help="HuggingFace token for private model download.") parser.add_argument("--inference-device", @@ -82,105 +83,112 @@ def cli(): parser.add_argument("--verbose-output", type=str2bool, default=True, help="Enable or disable progress and debug messages.") - parser.add_argument("--task", type=str, default= 'autotranscribe', # unifinished code + parser.add_argument("--task", type=str, default='autotranscribe', # unifinished code choices=["autotranscribe", "diarization", "autotranscribe+translate", "translate", 'transcribe'], help="Choose to perform transcription, diarization, or translation. \ If set to translate, the output will be translated to English.") parser.add_argument("--language", type=str, default=None, - choices=sorted(LANGUAGES.keys()) + sorted([k.title() for k in TO_LANGUAGE_CODE.keys()]), + choices=sorted( + LANGUAGES.keys()) + sorted([k.title() for k in TO_LANGUAGE_CODE.keys()]), help="Language spoken in the audio. Specify None to perform language detection.") args = parser.parse_args() - + arg_dict = vars(args) - + # configure output out_folder = arg_dict.pop("output_directory") os.makedirs(out_folder, exist_ok=True) out_format = arg_dict.pop("output_format") - - # seup server arg: + + # seup server arg: start_server = arg_dict.pop("start_server") - + task = arg_dict.pop("task") - + if args.num_threads > 0: set_num_threads(arg_dict.pop("num_threads")) - - class_kwargs = {'whisper_model' : arg_dict.pop("whisper_model_name"), + + class_kwargs = {'whisper_model': arg_dict.pop("whisper_model_name"), 'dia_model': arg_dict.pop("diarization_directory"), - 'use_auth_token' : arg_dict.pop("hf_token")} - + 'use_auth_token': arg_dict.pop("hf_token")} + if arg_dict["whisper_model_directory"]: class_kwargs["download_root"] = arg_dict.pop("whisper_model_directory") if not start_server: - + model = Scraibe(**class_kwargs) if arg_dict["audio_files"]: audio_files = arg_dict.pop("audio_files") - + if task == "autotranscribe" or task == "autotranscribe+translate": for audio in audio_files: if task == "autotranscribe+translate": task = "translate" else: task = "transcribe" - - out = model.autotranscribe(audio,task = task, language=arg_dict.pop("language"), verbose = arg_dict.pop("verbose_output")) + + out = model.autotranscribe(audio, task=task, language=arg_dict.pop( + "language"), verbose=arg_dict.pop("verbose_output")) basename = audio.split("/")[-1].split(".")[0] print(f'Saving {basename}.{out_format} to {out_folder}') - out.save(os.path.join(out_folder, f"{basename}.{out_format}")) - + out.save(os.path.join( + out_folder, f"{basename}.{out_format}")) + elif task == "diarization": for audio in audio_files: if arg_dict.pop("verbose_output"): - print(f"Verbose not implemented for diarization.") - + print("Verbose not implemented for diarization.") + out = model.diarization(audio) basename = audio.split("/")[-1].split(".")[0] path = os.path.join(out_folder, f"{basename}.{out_format}") - + print(f'Saving {basename}.{out_format} to {out_folder}') - + with open(path, "w") as f: - json.dump(json.dumps(out, indent= 1), f) + json.dump(json.dumps(out, indent=1), f) elif task == "transcribe" or task == "translate": - + for audio in audio_files: - - out = model.transcribe(audio, task = task, - language= arg_dict.pop("language"), - verbose = arg_dict.pop("verbose_output")) + + out = model.transcribe(audio, task=task, + language=arg_dict.pop("language"), + verbose=arg_dict.pop("verbose_output")) basename = audio.split("/")[-1].split(".")[0] path = os.path.join(out_folder, f"{basename}.{out_format}") with open(path, "w") as f: - f.write(out) - - - else: # unfinished code + f.write(out) + + else: # unfinished code raise NotImplementedError("Currently not Working") import subprocess import sys - - execute_path = os.path.join(os.path.dirname(__file__), "app/app_starter.py") - + + execute_path = os.path.join( + os.path.dirname(__file__), "app/app_starter.py") + config = arg_dict.pop("server_config") server_kwargs = arg_dict.pop("server_kwargs") - + if not config: - subprocess.run([sys.executable, execute_path, f"--server-kwargs={server_kwargs}"]) + subprocess.run([sys.executable, execute_path, + f"--server-kwargs={server_kwargs}"]) elif not server_kwargs: - subprocess.run([sys.executable, execute_path, f"--server-config={config}"]) + subprocess.run([sys.executable, execute_path, + f"--server-config={config}"]) elif not config and not server_kwargs: subprocess.run([sys.executable, execute_path]) else: - subprocess.run([sys.executable, execute_path, f"--server-config={config}", f"--server-kwargs={server_kwargs}"]) + subprocess.run([sys.executable, execute_path, + f"--server-config={config}", f"--server-kwargs={server_kwargs}"]) + if __name__ == "__main__": - cli() \ No newline at end of file + cli() diff --git a/scraibe/diarisation.py b/scraibe/diarisation.py index ade9220..d70df99 100644 --- a/scraibe/diarisation.py +++ b/scraibe/diarisation.py @@ -37,15 +37,16 @@ from pyannote.audio import Pipeline from pyannote.audio.pipelines.speaker_diarization import SpeakerDiarization from torch import Tensor from torch import device as torch_device -from torch.cuda import is_available, current_device +from torch.cuda import is_available from huggingface_hub import HfApi from huggingface_hub.utils import RepositoryNotFoundError from .misc import PYANNOTE_DEFAULT_PATH, PYANNOTE_DEFAULT_CONFIG -Annotation = TypeVar('Annotation') +Annotation = TypeVar('Annotation') TOKEN_PATH = os.path.join(os.path.dirname( - os.path.realpath(__file__)), '.pyannotetoken') + os.path.realpath(__file__)), '.pyannotetoken') + class Diariser: """ @@ -55,12 +56,12 @@ class Diariser: Args: model: The pretrained model to use for diarization. """ - + def __init__(self, model) -> None: self.model = model - def diarization(self, audiofile : Union[str, Tensor, dict] , + def diarization(self, audiofile: Union[str, Tensor, dict], *args, **kwargs) -> Annotation: """ Perform speaker diarization on the provided audio file, @@ -79,15 +80,15 @@ class Diariser: to the diarization process. """ kwargs = self._get_diarisation_kwargs(**kwargs) - - diarization = self.model(audiofile,*args, **kwargs) + + diarization = self.model(audiofile, *args, **kwargs) out = self.format_diarization_output(diarization) return out @staticmethod - def format_diarization_output(dia : Annotation) -> dict: + def format_diarization_output(dia: Annotation) -> dict: """ Formats the raw diarization output into a more usable structure for this project. @@ -99,14 +100,14 @@ class Diariser: as keys and a list of tuples representing segments as values. """ - dia_list = list(dia.itertracks(yield_label=True)) + dia_list = list(dia.itertracks(yield_label=True)) diarization_output = {"speakers": [], "segments": []} normalized_output = [] index_start_speaker = 0 index_end_speaker = 0 current_speaker = str() - + ### # Sometimes two consecutive speakers are the same # This loop removes these duplicates @@ -115,40 +116,39 @@ class Diariser: if len(dia_list) == 1: normalized_output.append([0, 0, dia_list[0][2]]) else: - + for i, (_, _, speaker) in enumerate(dia_list): - + if i == 0: current_speaker = speaker - + if speaker != current_speaker: index_end_speaker = i - 1 normalized_output.append([index_start_speaker, - index_end_speaker, - current_speaker]) + index_end_speaker, + current_speaker]) index_start_speaker = i current_speaker = speaker - if i == len(dia_list) - 1: index_end_speaker = i - - normalized_output.append([index_start_speaker, - index_end_speaker, - current_speaker]) - + + normalized_output.append([index_start_speaker, + index_end_speaker, + current_speaker]) + for outp in normalized_output: - start = dia_list[outp[0]][0].start - end = dia_list[outp[1]][0].end + start = dia_list[outp[0]][0].start + end = dia_list[outp[1]][0].end diarization_output["segments"].append([start, end]) diarization_output["speakers"].append(outp[2]) return diarization_output - + @staticmethod def _get_token(): """ @@ -161,14 +161,14 @@ class Diariser: Returns: str: The Huggingface token. """ - + if os.path.exists(TOKEN_PATH): with open(TOKEN_PATH, 'r', encoding="utf-8") as file: token = file.read() else: - raise ValueError('No token found.' \ - 'Please create a token at https://huggingface.co/settings/token' \ - f'and save it in a file called {TOKEN_PATH}') + raise ValueError('No token found.' + 'Please create a token at https://huggingface.co/settings/token' + f'and save it in a file called {TOKEN_PATH}') return token @staticmethod @@ -182,18 +182,17 @@ class Diariser: """ with open(TOKEN_PATH, 'w', encoding="utf-8") as file: file.write(token) - + @classmethod - def load_model(cls, - model: str = PYANNOTE_DEFAULT_CONFIG, - use_auth_token: str = None, - cache_token: bool = False, - cache_dir: Union[Path, str] = PYANNOTE_DEFAULT_PATH, - hparams_file: Union[str, Path] = None, - device: str = None, - *args, **kwargs - ) -> Pipeline: - + def load_model(cls, + model: str = PYANNOTE_DEFAULT_CONFIG, + use_auth_token: str = None, + cache_token: bool = False, + cache_dir: Union[Path, str] = PYANNOTE_DEFAULT_PATH, + hparams_file: Union[str, Path] = None, + device: str = None, + *args, **kwargs + ) -> Pipeline: """ Loads a pretrained model from pyannote.audio, either from a local cache or some online repository. @@ -237,16 +236,18 @@ class Diariser: 'deprecated and will be removed in future versions.', category=DeprecationWarning) # list elementes with the ending .bin - bin_files = [f for f in os.listdir(pwd) if f.endswith(".bin")] + bin_files = [f for f in os.listdir( + pwd) if f.endswith(".bin")] if len(bin_files) == 1: path_to_model = os.path.join(pwd, bin_files[0]) else: - warnings.warn("Found more than one .bin file. "\ - "or none. Please specify the path to the model " \ - "or setup a huggingface token.") + warnings.warn("Found more than one .bin file. " + "or none. Please specify the path to the model " + "or setup a huggingface token.") raise FileNotFoundError - warnings.warn(f"Found model at {path_to_model} overwriting config file.") + warnings.warn( + f"Found model at {path_to_model} overwriting config file.") config['pipeline']['params']['segmentation'] = path_to_model @@ -270,22 +271,24 @@ class Diariser: if use_auth_token is None: use_auth_token = cls._get_token() else: - raise FileNotFoundError(f'No local model or directory found at {model}.') + raise FileNotFoundError( + f'No local model or directory found at {model}.') _model = Pipeline.from_pretrained(model, use_auth_token=use_auth_token, cache_dir=cache_dir, hparams_file=hparams_file,) if _model is None: - raise ValueError('Unable to load model either from local cache' \ - 'or from huggingface.co models. Please check your token' \ - 'or your local model path') + raise ValueError('Unable to load model either from local cache' + 'or from huggingface.co models. Please check your token' + 'or your local model path') # try to move the model to the device if device is None: device = "cuda" if is_available() else "cpu" - _model = _model.to(torch_device(device)) # torch_device is renamed from torch.device to avoid name conflict + # torch_device is renamed from torch.device to avoid name conflict + _model = _model.to(torch_device(device)) return cls(_model) @@ -302,9 +305,10 @@ class Diariser: """ _possible_kwargs = SpeakerDiarization.apply.__code__.co_varnames - diarisation_kwargs = {k: v for k, v in kwargs.items() if k in _possible_kwargs} - + diarisation_kwargs = {k: v for k, + v in kwargs.items() if k in _possible_kwargs} + return diarisation_kwargs - + def __repr__(self): return f"Diarisation(model={self.model})" diff --git a/scraibe/hallucinations.py b/scraibe/hallucinations.py index a337ec0..249cce5 100644 --- a/scraibe/hallucinations.py +++ b/scraibe/hallucinations.py @@ -1,6 +1,6 @@ # List of known hallucinations - adapted from: # https://github.com/openai/whisper/discussions/928 -KNOWN_HALLUCINATIONS=[ +KNOWN_HALLUCINATIONS = [ # en " www.mooji.org" # nl @@ -73,7 +73,7 @@ KNOWN_HALLUCINATIONS=[ " Sous-titres réalisés para la communauté d'Amara.org" # ln " Sous-titres réalisés para la communauté d'Amara.org" - # pl + # pl " Napisy stworzone przez społeczność Amara.org", " Napisy wykonane przez społeczność Amara.org", " Zdjęcia i napisy stworzone przez społeczność Amara.org", @@ -92,4 +92,4 @@ KNOWN_HALLUCINATIONS=[ # zh "字幕由Amara.org社区提供", "小編字幕由Amara.org社區提供" -] \ No newline at end of file +] diff --git a/scraibe/misc.py b/scraibe/misc.py index c1d5484..e8837d3 100644 --- a/scraibe/misc.py +++ b/scraibe/misc.py @@ -14,8 +14,9 @@ if CACHE_DIR != PYANNOTE_CACHE_DIR: WHISPER_DEFAULT_PATH = os.path.join(CACHE_DIR, "whisper") PYANNOTE_DEFAULT_PATH = os.path.join(CACHE_DIR, "pyannote") PYANNOTE_DEFAULT_CONFIG = os.path.join(PYANNOTE_DEFAULT_PATH, "config.yaml") \ - if os.path.exists(os.path.join(PYANNOTE_DEFAULT_PATH, "config.yaml")) \ - else ('jaikinator/scraibe', 'pyannote/speaker-diarization-3.1') + if os.path.exists(os.path.join(PYANNOTE_DEFAULT_PATH, "config.yaml")) \ + else ('jaikinator/scraibe', 'pyannote/speaker-diarization-3.1') + def config_diarization_yaml(file_path: str, path_to_segmentation: str = None) -> None: """Configure diarization pipeline from a YAML file. @@ -33,25 +34,29 @@ def config_diarization_yaml(file_path: str, path_to_segmentation: str = None) -> with open(file_path, "r") as stream: yml = yaml.safe_load(stream) - segmentation_path = path_to_segmentation or os.path.join(PYANNOTE_DEFAULT_PATH, "pytorch_model.bin") + segmentation_path = path_to_segmentation or os.path.join( + PYANNOTE_DEFAULT_PATH, "pytorch_model.bin") yml["pipeline"]["params"]["segmentation"] = segmentation_path if not os.path.exists(segmentation_path): - raise FileNotFoundError(f"Segmentation model not found at {segmentation_path}") + raise FileNotFoundError( + f"Segmentation model not found at {segmentation_path}") with open(file_path, "w") as stream: yaml.dump(yml, stream) + class ParseKwargs(Action): """ Custom argparse action to parse keyword arguments. """ + def __call__(self, parser, namespace, values, option_string=None): setattr(namespace, self.dest, dict()) for value in values: key, value = value.split('=') try: - value = eval(value) + value = ast.literal_eval(value) except: pass - getattr(namespace, self.dest)[key] = value \ No newline at end of file + getattr(namespace, self.dest)[key] = value diff --git a/scraibe/transcriber.py b/scraibe/transcriber.py index 977cd94..8802cf6 100644 --- a/scraibe/transcriber.py +++ b/scraibe/transcriber.py @@ -28,11 +28,11 @@ from whisper import Whisper from whisper import load_model as whisper_load_model from whisperx.asr import WhisperModel from whisperx import load_model as whisperx_load_model -from typing import TypeVar , Union , Optional +from typing import TypeVar, Union, Optional from torch import Tensor, device from numpy import ndarray from inspect import getfullargspec -from abc import ABC, abstractmethod +from abc import abstractmethod from .misc import WHISPER_DEFAULT_PATH whisper = TypeVar('whisper') @@ -66,6 +66,7 @@ class Transcriber: The class supports various sizes and versions of Whisper models. Please refer to the load_model method for available options. """ + def __init__(self, model: whisper, model_name: str) -> None: """ Initialize the Transcriber class with a Whisper model. @@ -74,13 +75,13 @@ class Transcriber: model (whisper): The Whisper model to use for transcription. model_name (str): The name of the model. """ - + self.model = model - + self.model_name = model_name @abstractmethod - def transcribe(self, audio: Union[str, Tensor, ndarray] , + def transcribe(self, audio: Union[str, Tensor, ndarray], *args, **kwargs) -> str: """ Transcribe an audio file. @@ -95,9 +96,9 @@ class Transcriber: str: The transcript as a string. """ pass - + @staticmethod - def save_transcript(transcript : str , save_path : str) -> None: + def save_transcript(transcript: str, save_path: str) -> None: """ Save a transcript to a file. @@ -111,7 +112,7 @@ class Transcriber: with open(save_path, 'w') as f: f.write(transcript) - + print(f'Transcript saved to {save_path}') @classmethod @@ -176,10 +177,10 @@ class Transcriber: dict: Keyword arguments for whisper model. """ pass - + def __repr__(self) -> str: return f"Transcriber(model_name={self.model_name}, model={self.model})" - + class WhisperTranscriber(Transcriber): def __init__(self, model: whisper, model_name: str) -> None: @@ -233,10 +234,10 @@ class WhisperTranscriber(Transcriber): - 'large-v2' - 'large-v3' - 'large' - + download_root (str, optional): Path to download the model. Defaults to WHISPER_DEFAULT_PATH. - + device (Optional[Union[str, torch.device]], optional): Device to load model on. Defaults to None. in_memory (bool, optional): Whether to load model in memory. @@ -266,7 +267,8 @@ class WhisperTranscriber(Transcriber): _kwargs = getfullargspec(Whisper.transcribe).kwonlyargs _possible_kwargs = _args + _kwargs - whisper_kwargs = {k: v for k, v in kwargs.items() if k in _possible_kwargs} + whisper_kwargs = {k: v for k, + v in kwargs.items() if k in _possible_kwargs} if (task := kwargs.get("task")): whisper_kwargs["task"] = task @@ -280,7 +282,7 @@ class WhisperTranscriber(Transcriber): class WhisperXTranscriber(Transcriber): def __init__(self, model: whisper, model_name: str) -> None: super().__init__(model, model_name) - + def transcribe(self, audio: Union[str, Tensor, ndarray], *args, **kwargs) -> str: """ @@ -296,7 +298,7 @@ class WhisperXTranscriber(Transcriber): str: The transcript as a string. """ kwargs = self._get_whisper_kwargs(**kwargs) - + if isinstance(audio, Tensor): audio = audio.cpu().numpy() result = self.model.transcribe(audio, *args, **kwargs) @@ -304,8 +306,7 @@ class WhisperXTranscriber(Transcriber): for seg in result['segments']: text += seg['text'] return text - - + @classmethod def load_model(cls, model: str = "medium", @@ -330,10 +331,10 @@ class WhisperXTranscriber(Transcriber): - 'large-v2' - 'large-v3' - 'large' - + download_root (str, optional): Path to download the model. Defaults to WHISPER_DEFAULT_PATH. - + device (Optional[Union[str, torch.device]], optional): Device to load model on. Defaults to None. in_memory (bool, optional): Whether to load model in memory. @@ -364,7 +365,8 @@ class WhisperXTranscriber(Transcriber): _kwargs = getfullargspec(WhisperModel.transcribe).kwonlyargs _possible_kwargs = _args + _kwargs - whisper_kwargs = {k: v for k, v in kwargs.items() if k in _possible_kwargs} + whisper_kwargs = {k: v for k, + v in kwargs.items() if k in _possible_kwargs} if (task := kwargs.get("task")): whisper_kwargs["task"] = task diff --git a/scraibe/transcript_exporter.py b/scraibe/transcript_exporter.py index 1ce43d4..5222d58 100644 --- a/scraibe/transcript_exporter.py +++ b/scraibe/transcript_exporter.py @@ -1,5 +1,6 @@ import json import time +from json.decoder import JSONDecodeError from typing import Union @@ -8,13 +9,12 @@ from .hallucinations import KNOWN_HALLUCINATIONS ALPHABET = [*"abcdefghijklmnopqrstuvwxyz"] - class Transcript: """ Class for storing transcript data, including speaker information and text segments, and exporting it to various file formats such as JSON, HTML, and LaTeX. """ - + def __init__(self, transcript: dict) -> None: """ Initializes the Transcript object with the given transcript data. @@ -30,7 +30,7 @@ class Transcript: self.speakers = self._extract_speakers() self.segments = self._extract_segments() self.annotation = {} - + def annotate(self, *args, **kwargs) -> dict: """ Annotates the transcript to associate specific names with speakers. @@ -46,36 +46,41 @@ class Transcript: ValueError: If the number of speaker names does not match the number of speakers, or if an unknown speaker is found. """ - + annotations = {} if args and len(args) != len(self.speakers): - raise ValueError("Number of speaker names does not match number of speakers") - + raise ValueError( + "Number of speaker names does not match number of speakers") + if args: for arg, speaker in zip(args, sorted(self.speakers)): - + annotations[speaker] = arg - + invalid_speakers = set(kwargs.keys()) - set(self.speakers) if invalid_speakers: - raise ValueError(f"These keys are not speakers: {', '.join(invalid_speakers)}") + raise ValueError( + f"These keys are not speakers: {', '.join(invalid_speakers)}") - annotations.update({key: kwargs[key] for key in self.speakers if key in kwargs}) + annotations.update({key: kwargs[key] + for key in self.speakers if key in kwargs}) self.annotation = annotations - + return self - + def _remove_hallucinations(self) -> None: """ Removes all occurances of known hallucinations from all segments of the transcript. Segments that are identical to empty strings afterwards are removed from the transcript. """ - segments_to_drop=[] + segments_to_drop = [] for id in self.transcript: for snippet in KNOWN_HALLUCINATIONS: - self.transcript[id]['text']=self.transcript[id]['text'].replace(snippet,'') - if self.transcript[id]['text'] == '': segments_to_drop.append(id) + self.transcript[id]['text'] = self.transcript[id]['text'].replace( + snippet, '') + if self.transcript[id]['text'] == '': + segments_to_drop.append(id) for id in segments_to_drop: del self.transcript[id] @@ -87,9 +92,9 @@ class Transcript: Returns: list: List of unique speaker names in the transcript. """ - + return list(set([self.transcript[id]["speakers"] for id in self.transcript])) - + def _extract_segments(self) -> list: """ Extracts all the text segments from the transcript. @@ -109,23 +114,23 @@ class Transcript: time stamps for each segment. """ fstring = "" - + for _id in self.transcript: seq = self.transcript[_id] - + if self.annotation: speaker = self.annotation[seq["speakers"]] else: speaker = seq["speakers"] - + segm = seq["segments"] - sseg = time.strftime("%H:%M:%S",time.gmtime(segm[0])) - eseg = time.strftime("%H:%M:%S",time.gmtime(segm[1])) - + sseg = time.strftime("%H:%M:%S", time.gmtime(segm[0])) + eseg = time.strftime("%H:%M:%S", time.gmtime(segm[1])) + fstring += f"{speaker} ({sseg} ; {eseg}):\t{seq['text']}\n" - + return fstring - + def __repr__(self) -> str: """Return a string representation of the Transcript object. @@ -133,8 +138,8 @@ class Transcript: str: A string that provides an informative description of the object. """ return f"Transcript(speakers = {self.speakers},"\ - f"segments = {self.segments}, annotation = {self.annotation})" - + f"segments = {self.segments}, annotation = {self.annotation})" + def get_dict(self) -> dict: """ Get transcript as dict @@ -142,10 +147,10 @@ class Transcript: :return: transcript as dict :rtype: dict """ - + return self.transcript - - def get_json(self, *args, use_annotation : bool = True, **kwargs) -> str: + + def get_json(self, *args, use_annotation: bool = True, **kwargs) -> str: """ Get transcript as json string :return: transcript as json string @@ -153,14 +158,14 @@ class Transcript: """ if "indent" not in kwargs: kwargs["indent"] = 3 - + if use_annotation and self.annotation: for _id in self.transcript: seq = self.transcript[_id] seq["speakers"] = self.annotation[seq["speakers"]] - + return json.dumps(self.transcript, *args, **kwargs) - + def get_html(self) -> str: """ Get transcript as html string @@ -171,9 +176,9 @@ class Transcript: html = "
" + self.__str__().replace("\n", "
") + "