diff --git a/requirements.txt b/requirements.txt index 5872774..d1bdccc 100644 --- a/requirements.txt +++ b/requirements.txt @@ -2,6 +2,7 @@ tqdm>=4.65.0 numpy>=1.26.4 openai-whisper==20231117 +whisperx~=3.1.3 pyannote.audio~=3.1.1 pyannote.core~=5.0.0 diff --git a/scraibe/__init__.py b/scraibe/__init__.py index 4338879..399023a 100644 --- a/scraibe/__init__.py +++ b/scraibe/__init__.py @@ -8,5 +8,4 @@ from .misc import * from .cli import * -from ._version import __version__ - +from ._version import __version__ diff --git a/scraibe/audio.py b/scraibe/audio.py index 4d457b6..7fbc6fb 100644 --- a/scraibe/audio.py +++ b/scraibe/audio.py @@ -28,6 +28,7 @@ import torch SAMPLE_RATE = 16000 NORMALIZATION_FACTOR = 32768.0 + class AudioProcessor: """ Audio Processor class that leverages PyTorchaudio to provide functionalities @@ -39,10 +40,9 @@ class AudioProcessor: sr: int The sample rate of the audio. """ - - def __init__(self, waveform: torch.Tensor, sr : int = SAMPLE_RATE, + + def __init__(self, waveform: torch.Tensor, sr: int = SAMPLE_RATE, *args, **kwargs) -> None: - """ Initialize the AudioProcessor object. @@ -56,16 +56,17 @@ class AudioProcessor: Raises: ValueError: If the provided sample rate is not of type int. """ - - device = kwargs.get("device", "cuda" if torch.cuda.is_available() else "cpu") - + + device = kwargs.get( + "device", "cuda" if torch.cuda.is_available() else "cpu") + self.waveform = waveform.to(device) self.sr = sr - + if not isinstance(self.sr, int): - raise ValueError("Sample rate should be a single value of type int," \ + raise ValueError("Sample rate should be a single value of type int," f"not {len(self.sr)} and type {type(self.sr)}") - + @classmethod def from_file(cls, file: str, *args, **kwargs) -> 'AudioProcessor': """ @@ -77,14 +78,13 @@ class AudioProcessor: Returns: AudioProcessor: An instance of the AudioProcessor class containing the loaded audio. """ - - audio, sr = cls.load_audio(file , *args, **kwargs) + + audio, sr = cls.load_audio(file, *args, **kwargs) audio = torch.from_numpy(audio) - + return cls(audio, sr) - - + def cut(self, start: float, end: float) -> torch.Tensor: """ Cut a segment from the audio waveform between the specified start and end times. @@ -96,7 +96,7 @@ class AudioProcessor: Returns: torch.Tensor: The cut waveform segment. """ - + start = int(start * self.sr) if (isinstance(end, float) or isinstance(end, int)) and isinstance(self.sr, int): end = int(np.ceil(end * self.sr)) @@ -140,11 +140,13 @@ class AudioProcessor: try: out = run(cmd, capture_output=True, check=True).stdout except CalledProcessError as e: - raise RuntimeError(f"Failed to load audio: {e.stderr.decode()}") from e + raise RuntimeError( + f"Failed to load audio: {e.stderr.decode()}") from e + + out = np.frombuffer(out, np.int16).flatten().astype( + np.float32) / NORMALIZATION_FACTOR + + return out, sr - out = np.frombuffer(out, np.int16).flatten().astype(np.float32) / NORMALIZATION_FACTOR - - return out , sr - def __repr__(self) -> str: - return f'TorchAudioProcessor(waveform={len(self.waveform)}, sr={int(self.sr)})' \ No newline at end of file + return f'TorchAudioProcessor(waveform={len(self.waveform)}, sr={int(self.sr)})' diff --git a/scraibe/autotranscript.py b/scraibe/autotranscript.py index 7d54ba8..7391f1a 100644 --- a/scraibe/autotranscript.py +++ b/scraibe/autotranscript.py @@ -38,7 +38,7 @@ from tqdm import trange # Application-Specific Imports from .audio import AudioProcessor from .diarisation import Diariser -from .transcriber import Transcriber, whisper +from .transcriber import Transcriber, load_transcriber, whisper from .transcript_exporter import Transcript @@ -55,22 +55,26 @@ class Scraibe: Attributes: transcriber (Transcriber): The transcriber object to handle transcription. diariser (Diariser): The diariser object to handle diarization. - + Methods: __init__: Initializes the Scraibe class with appropriate models. transcribe: Transcribes an audio file using the whisper model and pyannote diarization model. remove_audio_file: Removes the original audio file to avoid disk space issues or ensure data privacy. get_audio_file: Gets an audio file as an AudioProcessor object. """ + def __init__(self, - whisper_model: Union[bool, str, whisper] = None, - dia_model : Union[bool, str, DiarisationType] = None, - **kwargs) -> None: + whisper_model: Union[bool, str, whisper] = None, + whisper_type: str = "whisper", + dia_model: Union[bool, str, DiarisationType] = None, + **kwargs) -> None: """Initializes the Scraibe class. Args: whisper_model (Union[bool, str, whisper], optional): Path to whisper model or whisper model itself. + whisper_type (str): + Type of whisper model to load. "whisper" or "whisperx". diarisation_model (Union[bool, str, DiarisationType], optional): Path to pyannote diarization model or model itself. **kwargs: Additional keyword arguments for whisper @@ -81,12 +85,13 @@ class Scraibe: - save_kwargs: If True, the keyword arguments will be saved for autotranscribe. So you can unload the class and reload it again. """ - - + if whisper_model is None: - self.transcriber = Transcriber.load_model("medium", **kwargs) + self.transcriber = load_transcriber( + "medium", whisper_type, **kwargs) elif isinstance(whisper_model, str): - self.transcriber = Transcriber.load_model(whisper_model, **kwargs) + self.transcriber = load_transcriber( + whisper_model, whisper_type, **kwargs) else: self.transcriber = whisper_model @@ -95,26 +100,25 @@ class Scraibe: elif isinstance(dia_model, str): self.diariser = Diariser.load_model(dia_model, **kwargs) else: - self.diariser : Diariser = dia_model + self.diariser: Diariser = dia_model if kwargs.get("verbose"): print("Scraibe initialized all models successfully loaded.") self.verbose = True else: self.verbose = False - + # Save kwargs for autotranscribe if you want to unload the class and load it again. - if kwargs.get('save_setup'): - self.params = dict(whisper_model = whisper_model, - dia_model = dia_model, + if kwargs.get('save_setup'): + self.params = dict(whisper_model=whisper_model, + dia_model=dia_model, **kwargs) else: self.params = {} - - - def autotranscribe(self, audio_file : Union[str, torch.Tensor, ndarray], - remove_original : bool = False, - **kwargs) -> Transcript: + + def autotranscribe(self, audio_file: Union[str, torch.Tensor, ndarray], + remove_original: bool = False, + **kwargs) -> Transcript: """ Transcribes an audio file using the whisper model and pyannote diarization model. @@ -133,60 +137,62 @@ class Scraibe: if kwargs.get("verbose"): self.verbose = kwargs.get("verbose") # Get audio file as an AudioProcessor object - audio_file : AudioProcessor = self.get_audio_file(audio_file) - + audio_file: AudioProcessor = self.get_audio_file(audio_file) + # Prepare waveform and sample rate for diarization dia_audio = { - "waveform" : audio_file.waveform.reshape(1,len(audio_file.waveform)), + "waveform": audio_file.waveform.reshape(1, len(audio_file.waveform)), "sample_rate": audio_file.sr - } + } if self.verbose: print("Starting diarisation.") - + diarisation = self.diariser.diarization(dia_audio, **kwargs) - + if not diarisation["segments"]: print("No segments found. Try to run transcription without diarisation.") - - transcript = self.transcriber.transcribe(audio_file.waveform, **kwargs) - - final_transcript= {0 : {"speakers" : 'SPEAKER_01', - "segments" : [0, len(audio_file.waveform)], - "text" : transcript}} - + + transcript = self.transcriber.transcribe( + audio_file.waveform, **kwargs) + + final_transcript = {0: {"speakers": 'SPEAKER_01', + "segments": [0, len(audio_file.waveform)], + "text": transcript}} + return Transcript(final_transcript) - + if self.verbose: print("Diarisation finished. Starting transcription.") - - audio_file.sr = torch.Tensor([audio_file.sr]).to(audio_file.waveform.device) - + + audio_file.sr = torch.Tensor([audio_file.sr]).to( + audio_file.waveform.device) + # Transcribe each segment and store the results final_transcript = dict() - - for i in trange(len(diarisation["segments"]), desc= "Transcribing", disable = not self.verbose): - + + for i in trange(len(diarisation["segments"]), desc="Transcribing", disable=not self.verbose): + seg = diarisation["segments"][i] - + audio = audio_file.cut(seg[0], seg[1]) - + transcript = self.transcriber.transcribe(audio, **kwargs) - - final_transcript[i] = {"speakers" : diarisation["speakers"][i], - "segments" : seg, - "text" : transcript} - - # Remove original file if needed + + final_transcript[i] = {"speakers": diarisation["speakers"][i], + "segments": seg, + "text": transcript} + + # Remove original file if needed if remove_original: if kwargs.get("shred") is True: self.remove_audio_file(audio_file, shred=True) else: self.remove_audio_file(audio_file, shred=False) - + return Transcript(final_transcript) - def diarization(self, audio_file : Union[str, torch.Tensor, ndarray], + def diarization(self, audio_file: Union[str, torch.Tensor, ndarray], **kwargs) -> dict: """ Perform diarization on an audio file using the pyannote diarization model. @@ -201,24 +207,24 @@ class Scraibe: dict: A dictionary containing the results of the diarization process. """ - + # Get audio file as an AudioProcessor object - audio_file : AudioProcessor = self.get_audio_file(audio_file) - + audio_file: AudioProcessor = self.get_audio_file(audio_file) + # Prepare waveform and sample rate for diarization dia_audio = { - "waveform" : audio_file.waveform.reshape(1,len(audio_file.waveform)), + "waveform": audio_file.waveform.reshape(1, len(audio_file.waveform)), "sample_rate": audio_file.sr - } - + } + print("Starting diarisation.") - + diarisation = self.diariser.diarization(dia_audio, **kwargs) - + return diarisation - - def transcribe(self, audio_file : Union[str, torch.Tensor, ndarray], - **kwargs): + + def transcribe(self, audio_file: Union[str, torch.Tensor, ndarray], + **kwargs): """ Transcribe the provided audio file. @@ -232,11 +238,11 @@ class Scraibe: str: The transcribed text from the audio source. """ - audio_file : AudioProcessor = self.get_audio_file(audio_file) - - return self.transcriber.transcribe(audio_file.waveform, **kwargs) - - def update_transcriber(self, whisper_model : Union[str, whisper], **kwargs) -> None: + audio_file: AudioProcessor = self.get_audio_file(audio_file) + + return self.transcriber.transcribe(audio_file.waveform, **kwargs) + + def update_transcriber(self, whisper_model: Union[str, whisper], **kwargs) -> None: """ Update the transcriber model. @@ -245,22 +251,23 @@ class Scraibe: The new whisper model to use for transcription. **kwargs: Additional keyword arguments for the transcriber model. - + Returns: None """ _old_model = self.transcriber.model_name - + if isinstance(whisper_model, str): - self.transcriber = Transcriber.load_model(whisper_model, **kwargs) + self.transcriber = load_transcriber(whisper_model, **kwargs) elif isinstance(whisper_model, Transcriber): self.transcriber = whisper_model else: - warn(f"Invalid model type. Please provide a valid model. Fallback to old {_old_model} Model.", RuntimeWarning) + warn( + f"Invalid model type. Please provide a valid model. Fallback to old {_old_model} Model.", RuntimeWarning) return None - def update_diariser(self, dia_model : Union[str, DiarisationType], **kwargs) -> None: + def update_diariser(self, dia_model: Union[str, DiarisationType], **kwargs) -> None: """ Update the diariser model. @@ -269,7 +276,7 @@ class Scraibe: The new diariser model to use for diarization. **kwargs: Additional keyword arguments for the diariser model. - + Returns: None """ @@ -278,13 +285,13 @@ class Scraibe: elif isinstance(dia_model, Diariser): self.diariser = dia_model else: - warn(f"Invalid model type. Please provide a valid model. Fallback to old Model.", RuntimeWarning) - + warn("Invalid model type. Please provide a valid model. Fallback to old Model.", RuntimeWarning) + return None - + @staticmethod - def remove_audio_file(audio_file : str, - shred : bool = False) -> None: + def remove_audio_file(audio_file: str, + shred: bool = False) -> None: """ Removes the original audio file to avoid disk space issues or ensure data privacy. @@ -295,30 +302,29 @@ class Scraibe: """ if not os.path.exists(audio_file): raise ValueError(f"Audiofile {audio_file} does not exist.") - + if shred: - + warn("Shredding audiofile can take a long time.", RuntimeWarning) - + gen = iglob(f'{audio_file}', recursive=True) cmd = ['shred', '-zvu', '-n', '10', f'{audio_file}'] - + if os.path.isdir(audio_file): raise ValueError(f"Audiofile {audio_file} is a directory.") - + for file in gen: print(f'shredding {file} now\n') - - run(cmd , check=True) + + run(cmd, check=True) else: os.remove(audio_file) print(f"Audiofile {audio_file} removed.") - - + @staticmethod - def get_audio_file(audio_file : Union[str, torch.Tensor, ndarray], - *args, **kwargs) -> AudioProcessor: + def get_audio_file(audio_file: Union[str, torch.Tensor, ndarray], + *args, **kwargs) -> AudioProcessor: """Gets an audio file as TorchAudioProcessor. Args: @@ -331,20 +337,20 @@ class Scraibe: AudioProcessor: An object containing the waveform and sample rate in torch.Tensor format. """ - + if isinstance(audio_file, str): - audio_file = AudioProcessor.from_file(audio_file) - + audio_file = AudioProcessor.from_file(audio_file) + elif isinstance(audio_file, torch.Tensor): audio_file = AudioProcessor(audio_file[0], audio_file[1]) elif isinstance(audio_file, ndarray): audio_file = AudioProcessor(torch.Tensor(audio_file[0]), - audio_file[1]) - + audio_file[1]) + if not isinstance(audio_file, AudioProcessor): - raise ValueError(f'Audiofile must be of type AudioProcessor,' \ - f'not {type(audio_file)}') - + raise ValueError(f'Audiofile must be of type AudioProcessor,' + f'not {type(audio_file)}') + return audio_file def __repr__(self): diff --git a/scraibe/cli.py b/scraibe/cli.py index 7cc7b1d..b6f2c17 100644 --- a/scraibe/cli.py +++ b/scraibe/cli.py @@ -4,7 +4,7 @@ allowing for user interaction to transcribe and diarize audio files. The function includes arguments for specifying the audio files, model paths, output formats, and other options necessary for transcription. """ -import os +import os from argparse import ArgumentParser, ArgumentDefaultsHelpFormatter import json @@ -12,7 +12,7 @@ from .autotranscript import Scraibe from .misc import ParseKwargs -from whisper.tokenizer import LANGUAGES , TO_LANGUAGE_CODE +from whisper.tokenizer import LANGUAGES, TO_LANGUAGE_CODE from torch.cuda import is_available from torch import set_num_threads @@ -26,42 +26,43 @@ def cli(): This function can be executed from the command line to perform transcription tasks, providing a user-friendly way to access the Scraibe class functionalities. """ - + def str2bool(string): str2val = {"True": True, "False": False} if string in str2val: return str2val[string] else: - raise ValueError(f"Expected one of {set(str2val.keys())}, got {string}") + raise ValueError( + f"Expected one of {set(str2val.keys())}, got {string}") - parser = ArgumentParser(formatter_class = ArgumentDefaultsHelpFormatter) + parser = ArgumentParser(formatter_class=ArgumentDefaultsHelpFormatter) group = parser.add_mutually_exclusive_group() - - parser.add_argument("-f","--audio-files", nargs="+", type=str, default=None, + + parser.add_argument("-f", "--audio-files", nargs="+", type=str, default=None, help="List of audio files to transcribe.") - + group.add_argument('--start-server', action='store_true', - help='Start the Gradio app.' \ - 'If set, all other arguments are ignored' \ - 'besides --server-config or --server-kwargs.') - - parser.add_argument("--server-config", type=str, default= None, + help='Start the Gradio app.' + 'If set, all other arguments are ignored' + 'besides --server-config or --server-kwargs.') + + parser.add_argument("--server-config", type=str, default=None, help="Path to the configy.yml file.") - + parser.add_argument('--server-kwargs', nargs='*', action=ParseKwargs, default={}, help='Keyword arguments for the Gradio app.') - + parser.add_argument("--whisper-model-name", default="medium", help="Name of the Whisper model to use.") - parser.add_argument("--whisper-model-directory", type=str, default= None, + parser.add_argument("--whisper-model-directory", type=str, default=None, help="Path to save Whisper model files; defaults to ./models/whisper.") - parser.add_argument("--diarization-directory", type=str, default= None, + parser.add_argument("--diarization-directory", type=str, default=None, help="Path to the diarization model directory.") - parser.add_argument("--hf-token", default= None, type=str, + parser.add_argument("--hf-token", default=None, type=str, help="HuggingFace token for private model download.") parser.add_argument("--inference-device", @@ -82,105 +83,112 @@ def cli(): parser.add_argument("--verbose-output", type=str2bool, default=True, help="Enable or disable progress and debug messages.") - parser.add_argument("--task", type=str, default= 'autotranscribe', # unifinished code + parser.add_argument("--task", type=str, default='autotranscribe', # unifinished code choices=["autotranscribe", "diarization", "autotranscribe+translate", "translate", 'transcribe'], help="Choose to perform transcription, diarization, or translation. \ If set to translate, the output will be translated to English.") parser.add_argument("--language", type=str, default=None, - choices=sorted(LANGUAGES.keys()) + sorted([k.title() for k in TO_LANGUAGE_CODE.keys()]), + choices=sorted( + LANGUAGES.keys()) + sorted([k.title() for k in TO_LANGUAGE_CODE.keys()]), help="Language spoken in the audio. Specify None to perform language detection.") args = parser.parse_args() - + arg_dict = vars(args) - + # configure output out_folder = arg_dict.pop("output_directory") os.makedirs(out_folder, exist_ok=True) out_format = arg_dict.pop("output_format") - - # seup server arg: + + # seup server arg: start_server = arg_dict.pop("start_server") - + task = arg_dict.pop("task") - + if args.num_threads > 0: set_num_threads(arg_dict.pop("num_threads")) - - class_kwargs = {'whisper_model' : arg_dict.pop("whisper_model_name"), + + class_kwargs = {'whisper_model': arg_dict.pop("whisper_model_name"), 'dia_model': arg_dict.pop("diarization_directory"), - 'use_auth_token' : arg_dict.pop("hf_token")} - + 'use_auth_token': arg_dict.pop("hf_token")} + if arg_dict["whisper_model_directory"]: class_kwargs["download_root"] = arg_dict.pop("whisper_model_directory") if not start_server: - + model = Scraibe(**class_kwargs) if arg_dict["audio_files"]: audio_files = arg_dict.pop("audio_files") - + if task == "autotranscribe" or task == "autotranscribe+translate": for audio in audio_files: if task == "autotranscribe+translate": task = "translate" else: task = "transcribe" - - out = model.autotranscribe(audio,task = task, language=arg_dict.pop("language"), verbose = arg_dict.pop("verbose_output")) + + out = model.autotranscribe(audio, task=task, language=arg_dict.pop( + "language"), verbose=arg_dict.pop("verbose_output")) basename = audio.split("/")[-1].split(".")[0] print(f'Saving {basename}.{out_format} to {out_folder}') - out.save(os.path.join(out_folder, f"{basename}.{out_format}")) - + out.save(os.path.join( + out_folder, f"{basename}.{out_format}")) + elif task == "diarization": for audio in audio_files: if arg_dict.pop("verbose_output"): - print(f"Verbose not implemented for diarization.") - + print("Verbose not implemented for diarization.") + out = model.diarization(audio) basename = audio.split("/")[-1].split(".")[0] path = os.path.join(out_folder, f"{basename}.{out_format}") - + print(f'Saving {basename}.{out_format} to {out_folder}') - + with open(path, "w") as f: - json.dump(json.dumps(out, indent= 1), f) + json.dump(json.dumps(out, indent=1), f) elif task == "transcribe" or task == "translate": - + for audio in audio_files: - - out = model.transcribe(audio, task = task, - language= arg_dict.pop("language"), - verbose = arg_dict.pop("verbose_output")) + + out = model.transcribe(audio, task=task, + language=arg_dict.pop("language"), + verbose=arg_dict.pop("verbose_output")) basename = audio.split("/")[-1].split(".")[0] path = os.path.join(out_folder, f"{basename}.{out_format}") with open(path, "w") as f: - f.write(out) - - - else: # unfinished code + f.write(out) + + else: # unfinished code raise NotImplementedError("Currently not Working") import subprocess import sys - - execute_path = os.path.join(os.path.dirname(__file__), "app/app_starter.py") - + + execute_path = os.path.join( + os.path.dirname(__file__), "app/app_starter.py") + config = arg_dict.pop("server_config") server_kwargs = arg_dict.pop("server_kwargs") - + if not config: - subprocess.run([sys.executable, execute_path, f"--server-kwargs={server_kwargs}"]) + subprocess.run([sys.executable, execute_path, + f"--server-kwargs={server_kwargs}"]) elif not server_kwargs: - subprocess.run([sys.executable, execute_path, f"--server-config={config}"]) + subprocess.run([sys.executable, execute_path, + f"--server-config={config}"]) elif not config and not server_kwargs: subprocess.run([sys.executable, execute_path]) else: - subprocess.run([sys.executable, execute_path, f"--server-config={config}", f"--server-kwargs={server_kwargs}"]) + subprocess.run([sys.executable, execute_path, + f"--server-config={config}", f"--server-kwargs={server_kwargs}"]) + if __name__ == "__main__": - cli() \ No newline at end of file + cli() diff --git a/scraibe/diarisation.py b/scraibe/diarisation.py index ade9220..d70df99 100644 --- a/scraibe/diarisation.py +++ b/scraibe/diarisation.py @@ -37,15 +37,16 @@ from pyannote.audio import Pipeline from pyannote.audio.pipelines.speaker_diarization import SpeakerDiarization from torch import Tensor from torch import device as torch_device -from torch.cuda import is_available, current_device +from torch.cuda import is_available from huggingface_hub import HfApi from huggingface_hub.utils import RepositoryNotFoundError from .misc import PYANNOTE_DEFAULT_PATH, PYANNOTE_DEFAULT_CONFIG -Annotation = TypeVar('Annotation') +Annotation = TypeVar('Annotation') TOKEN_PATH = os.path.join(os.path.dirname( - os.path.realpath(__file__)), '.pyannotetoken') + os.path.realpath(__file__)), '.pyannotetoken') + class Diariser: """ @@ -55,12 +56,12 @@ class Diariser: Args: model: The pretrained model to use for diarization. """ - + def __init__(self, model) -> None: self.model = model - def diarization(self, audiofile : Union[str, Tensor, dict] , + def diarization(self, audiofile: Union[str, Tensor, dict], *args, **kwargs) -> Annotation: """ Perform speaker diarization on the provided audio file, @@ -79,15 +80,15 @@ class Diariser: to the diarization process. """ kwargs = self._get_diarisation_kwargs(**kwargs) - - diarization = self.model(audiofile,*args, **kwargs) + + diarization = self.model(audiofile, *args, **kwargs) out = self.format_diarization_output(diarization) return out @staticmethod - def format_diarization_output(dia : Annotation) -> dict: + def format_diarization_output(dia: Annotation) -> dict: """ Formats the raw diarization output into a more usable structure for this project. @@ -99,14 +100,14 @@ class Diariser: as keys and a list of tuples representing segments as values. """ - dia_list = list(dia.itertracks(yield_label=True)) + dia_list = list(dia.itertracks(yield_label=True)) diarization_output = {"speakers": [], "segments": []} normalized_output = [] index_start_speaker = 0 index_end_speaker = 0 current_speaker = str() - + ### # Sometimes two consecutive speakers are the same # This loop removes these duplicates @@ -115,40 +116,39 @@ class Diariser: if len(dia_list) == 1: normalized_output.append([0, 0, dia_list[0][2]]) else: - + for i, (_, _, speaker) in enumerate(dia_list): - + if i == 0: current_speaker = speaker - + if speaker != current_speaker: index_end_speaker = i - 1 normalized_output.append([index_start_speaker, - index_end_speaker, - current_speaker]) + index_end_speaker, + current_speaker]) index_start_speaker = i current_speaker = speaker - if i == len(dia_list) - 1: index_end_speaker = i - - normalized_output.append([index_start_speaker, - index_end_speaker, - current_speaker]) - + + normalized_output.append([index_start_speaker, + index_end_speaker, + current_speaker]) + for outp in normalized_output: - start = dia_list[outp[0]][0].start - end = dia_list[outp[1]][0].end + start = dia_list[outp[0]][0].start + end = dia_list[outp[1]][0].end diarization_output["segments"].append([start, end]) diarization_output["speakers"].append(outp[2]) return diarization_output - + @staticmethod def _get_token(): """ @@ -161,14 +161,14 @@ class Diariser: Returns: str: The Huggingface token. """ - + if os.path.exists(TOKEN_PATH): with open(TOKEN_PATH, 'r', encoding="utf-8") as file: token = file.read() else: - raise ValueError('No token found.' \ - 'Please create a token at https://huggingface.co/settings/token' \ - f'and save it in a file called {TOKEN_PATH}') + raise ValueError('No token found.' + 'Please create a token at https://huggingface.co/settings/token' + f'and save it in a file called {TOKEN_PATH}') return token @staticmethod @@ -182,18 +182,17 @@ class Diariser: """ with open(TOKEN_PATH, 'w', encoding="utf-8") as file: file.write(token) - + @classmethod - def load_model(cls, - model: str = PYANNOTE_DEFAULT_CONFIG, - use_auth_token: str = None, - cache_token: bool = False, - cache_dir: Union[Path, str] = PYANNOTE_DEFAULT_PATH, - hparams_file: Union[str, Path] = None, - device: str = None, - *args, **kwargs - ) -> Pipeline: - + def load_model(cls, + model: str = PYANNOTE_DEFAULT_CONFIG, + use_auth_token: str = None, + cache_token: bool = False, + cache_dir: Union[Path, str] = PYANNOTE_DEFAULT_PATH, + hparams_file: Union[str, Path] = None, + device: str = None, + *args, **kwargs + ) -> Pipeline: """ Loads a pretrained model from pyannote.audio, either from a local cache or some online repository. @@ -237,16 +236,18 @@ class Diariser: 'deprecated and will be removed in future versions.', category=DeprecationWarning) # list elementes with the ending .bin - bin_files = [f for f in os.listdir(pwd) if f.endswith(".bin")] + bin_files = [f for f in os.listdir( + pwd) if f.endswith(".bin")] if len(bin_files) == 1: path_to_model = os.path.join(pwd, bin_files[0]) else: - warnings.warn("Found more than one .bin file. "\ - "or none. Please specify the path to the model " \ - "or setup a huggingface token.") + warnings.warn("Found more than one .bin file. " + "or none. Please specify the path to the model " + "or setup a huggingface token.") raise FileNotFoundError - warnings.warn(f"Found model at {path_to_model} overwriting config file.") + warnings.warn( + f"Found model at {path_to_model} overwriting config file.") config['pipeline']['params']['segmentation'] = path_to_model @@ -270,22 +271,24 @@ class Diariser: if use_auth_token is None: use_auth_token = cls._get_token() else: - raise FileNotFoundError(f'No local model or directory found at {model}.') + raise FileNotFoundError( + f'No local model or directory found at {model}.') _model = Pipeline.from_pretrained(model, use_auth_token=use_auth_token, cache_dir=cache_dir, hparams_file=hparams_file,) if _model is None: - raise ValueError('Unable to load model either from local cache' \ - 'or from huggingface.co models. Please check your token' \ - 'or your local model path') + raise ValueError('Unable to load model either from local cache' + 'or from huggingface.co models. Please check your token' + 'or your local model path') # try to move the model to the device if device is None: device = "cuda" if is_available() else "cpu" - _model = _model.to(torch_device(device)) # torch_device is renamed from torch.device to avoid name conflict + # torch_device is renamed from torch.device to avoid name conflict + _model = _model.to(torch_device(device)) return cls(_model) @@ -302,9 +305,10 @@ class Diariser: """ _possible_kwargs = SpeakerDiarization.apply.__code__.co_varnames - diarisation_kwargs = {k: v for k, v in kwargs.items() if k in _possible_kwargs} - + diarisation_kwargs = {k: v for k, + v in kwargs.items() if k in _possible_kwargs} + return diarisation_kwargs - + def __repr__(self): return f"Diarisation(model={self.model})" diff --git a/scraibe/hallucinations.py b/scraibe/hallucinations.py index a337ec0..249cce5 100644 --- a/scraibe/hallucinations.py +++ b/scraibe/hallucinations.py @@ -1,6 +1,6 @@ # List of known hallucinations - adapted from: # https://github.com/openai/whisper/discussions/928 -KNOWN_HALLUCINATIONS=[ +KNOWN_HALLUCINATIONS = [ # en " www.mooji.org" # nl @@ -73,7 +73,7 @@ KNOWN_HALLUCINATIONS=[ " Sous-titres réalisés para la communauté d'Amara.org" # ln " Sous-titres réalisés para la communauté d'Amara.org" - # pl + # pl " Napisy stworzone przez społeczność Amara.org", " Napisy wykonane przez społeczność Amara.org", " Zdjęcia i napisy stworzone przez społeczność Amara.org", @@ -92,4 +92,4 @@ KNOWN_HALLUCINATIONS=[ # zh "字幕由Amara.org社区提供", "小編字幕由Amara.org社區提供" -] \ No newline at end of file +] diff --git a/scraibe/misc.py b/scraibe/misc.py index c1d5484..f12335f 100644 --- a/scraibe/misc.py +++ b/scraibe/misc.py @@ -2,6 +2,7 @@ import os import yaml from pyannote.audio.core.model import CACHE_DIR as PYANNOTE_CACHE_DIR from argparse import Action +from ast import literal_eval CACHE_DIR = os.getenv( "AUTOT_CACHE", @@ -14,8 +15,9 @@ if CACHE_DIR != PYANNOTE_CACHE_DIR: WHISPER_DEFAULT_PATH = os.path.join(CACHE_DIR, "whisper") PYANNOTE_DEFAULT_PATH = os.path.join(CACHE_DIR, "pyannote") PYANNOTE_DEFAULT_CONFIG = os.path.join(PYANNOTE_DEFAULT_PATH, "config.yaml") \ - if os.path.exists(os.path.join(PYANNOTE_DEFAULT_PATH, "config.yaml")) \ - else ('jaikinator/scraibe', 'pyannote/speaker-diarization-3.1') + if os.path.exists(os.path.join(PYANNOTE_DEFAULT_PATH, "config.yaml")) \ + else ('jaikinator/scraibe', 'pyannote/speaker-diarization-3.1') + def config_diarization_yaml(file_path: str, path_to_segmentation: str = None) -> None: """Configure diarization pipeline from a YAML file. @@ -33,25 +35,29 @@ def config_diarization_yaml(file_path: str, path_to_segmentation: str = None) -> with open(file_path, "r") as stream: yml = yaml.safe_load(stream) - segmentation_path = path_to_segmentation or os.path.join(PYANNOTE_DEFAULT_PATH, "pytorch_model.bin") + segmentation_path = path_to_segmentation or os.path.join( + PYANNOTE_DEFAULT_PATH, "pytorch_model.bin") yml["pipeline"]["params"]["segmentation"] = segmentation_path if not os.path.exists(segmentation_path): - raise FileNotFoundError(f"Segmentation model not found at {segmentation_path}") + raise FileNotFoundError( + f"Segmentation model not found at {segmentation_path}") with open(file_path, "w") as stream: yaml.dump(yml, stream) + class ParseKwargs(Action): """ Custom argparse action to parse keyword arguments. """ + def __call__(self, parser, namespace, values, option_string=None): setattr(namespace, self.dest, dict()) for value in values: key, value = value.split('=') try: - value = eval(value) + value = literal_eval(value) except: pass - getattr(namespace, self.dest)[key] = value \ No newline at end of file + getattr(namespace, self.dest)[key] = value diff --git a/scraibe/transcriber.py b/scraibe/transcriber.py index 910ea59..0301955 100644 --- a/scraibe/transcriber.py +++ b/scraibe/transcriber.py @@ -24,16 +24,20 @@ Usage: >>> transcriber.save_transcript(transcript, "path/to/save.txt") """ -from whisper import Whisper, load_model -from typing import TypeVar , Union , Optional +from whisper import Whisper +from whisper import load_model as whisper_load_model +from whisperx.asr import WhisperModel +from whisperx import load_model as whisperx_load_model +from typing import TypeVar, Union, Optional from torch import Tensor, device +from torch.cuda import is_available as cuda_is_available from numpy import ndarray - +from inspect import signature +from abc import abstractmethod +import warnings from .misc import WHISPER_DEFAULT_PATH -whisper = TypeVar('whisper') - - +whisper = TypeVar('whisper') class Transcriber: @@ -64,7 +68,8 @@ class Transcriber: The class supports various sizes and versions of Whisper models. Please refer to the load_model method for available options. """ - def __init__(self, model: whisper , model_name: str ) -> None: + + def __init__(self, model: whisper, model_name: str) -> None: """ Initialize the Transcriber class with a Whisper model. @@ -72,12 +77,13 @@ class Transcriber: model (whisper): The Whisper model to use for transcription. model_name (str): The name of the model. """ - + self.model = model - + self.model_name = model_name - def transcribe(self, audio : Union[str, Tensor, ndarray] , + @abstractmethod + def transcribe(self, audio: Union[str, Tensor, ndarray], *args, **kwargs) -> str: """ Transcribe an audio file. @@ -91,17 +97,10 @@ class Transcriber: Returns: str: The transcript as a string. """ - - kwargs = self._get_whisper_kwargs(**kwargs) - - if not kwargs.get("verbose"): - kwargs["verbose"] = None + pass - result = self.model.transcribe(audio, *args, **kwargs) - return result["text"] - @staticmethod - def save_transcript(transcript : str , save_path : str) -> None: + def save_transcript(transcript: str, save_path: str) -> None: """ Save a transcript to a file. @@ -115,17 +114,19 @@ class Transcriber: with open(save_path, 'w') as f: f.write(transcript) - + print(f'Transcript saved to {save_path}') @classmethod + @abstractmethod def load_model(cls, - model: str = "medium", - download_root: str = WHISPER_DEFAULT_PATH, - device: Optional[Union[str, device]] = None, - in_memory: bool = False, - *args, **kwargs - ) -> 'Transcriber': + model: str = "medium", + whisper_type: str = 'whisper', + download_root: str = WHISPER_DEFAULT_PATH, + device: Optional[Union[str, device]] = None, + in_memory: bool = False, + *args, **kwargs + ) -> None: """ Load whisper model. @@ -143,10 +144,92 @@ class Transcriber: - 'large-v2' - 'large-v3' - 'large' - + whisper_type (str): + Type of whisper model to load. "whisper" or "whisperx". download_root (str, optional): Path to download the model. Defaults to WHISPER_DEFAULT_PATH. - + device (Optional[Union[str, torch.device]], optional): + Device to load model on. Defaults to None. + in_memory (bool, optional): Whether to load model in memory. + Defaults to False. + args: Additional arguments only to avoid errors. + kwargs: Additional keyword arguments only to avoid errors. + + Returns: + None: abscract method. + """ + pass + + @staticmethod + def _get_whisper_kwargs(**kwargs) -> dict: + """ + Get kwargs for whisper model. Ensure that kwargs are valid. + + Returns: + dict: Keyword arguments for whisper model. + """ + pass + + def __repr__(self) -> str: + return f"Transcriber(model_name={self.model_name}, model={self.model})" + + +class WhisperTranscriber(Transcriber): + def __init__(self, model: whisper, model_name: str) -> None: + super().__init__(model, model_name) + + def transcribe(self, audio: Union[str, Tensor, ndarray], + *args, **kwargs) -> str: + """ + Transcribe an audio file. + + Args: + audio (Union[str, Tensor, nparray]): The audio file to transcribe. + *args: Additional arguments. + **kwargs: Additional keyword arguments, + such as the language of the audio file. + + Returns: + str: The transcript as a string. + """ + + kwargs = self._get_whisper_kwargs(**kwargs) + + if not kwargs.get("verbose"): + kwargs["verbose"] = None + + result = self.model.transcribe(audio, *args, **kwargs) + return result["text"] + + @classmethod + def load_model(cls, + model: str = "medium", + download_root: str = WHISPER_DEFAULT_PATH, + device: Optional[Union[str, device]] = None, + in_memory: bool = False, + *args, **kwargs + ) -> 'WhisperTranscriber': + """ + Load whisper model. + + Args: + model (str): Whisper model. Available models include: + - 'tiny.en' + - 'tiny' + - 'base.en' + - 'base' + - 'small.en' + - 'small' + - 'medium.en' + - 'medium' + - 'large-v1' + - 'large-v2' + - 'large-v3' + - 'large' + + download_root (str, optional): Path to download the model. + Defaults to WHISPER_DEFAULT_PATH. + device (Optional[Union[str, torch.device]], optional): Device to load model on. Defaults to None. in_memory (bool, optional): Whether to load model in memory. @@ -158,8 +241,8 @@ class Transcriber: Transcriber: A Transcriber object initialized with the specified model. """ - _model = load_model(model, download_root=download_root, - device=device, in_memory=in_memory) + _model = whisper_load_model(model, download_root=download_root, + device=device, in_memory=in_memory) return cls(_model, model_name=model) @@ -171,17 +254,177 @@ class Transcriber: Returns: dict: Keyword arguments for whisper model. """ - _possible_kwargs = Whisper.transcribe.__code__.co_varnames - - whisper_kwargs = {k: v for k, v in kwargs.items() if k in _possible_kwargs} - + # _possible_kwargs = WhisperModel.transcribe.__code__.co_varnames + _possible_kwargs = signature(Whisper.transcribe).parameters.keys() + + whisper_kwargs = {k: v for k, + v in kwargs.items() if k in _possible_kwargs} + if (task := kwargs.get("task")): whisper_kwargs["task"] = task - + if (language := kwargs.get("language")): - whisper_kwargs["language"] = language - + whisper_kwargs["language"] = language + return whisper_kwargs - + def __repr__(self) -> str: - return f"Transcriber(model_name={self.model_name}, model={self.model})" \ No newline at end of file + return f"WhisperTranscriber(model_name={self.model_name}, model={self.model})" + + +class WhisperXTranscriber(Transcriber): + def __init__(self, model: whisper, model_name: str) -> None: + super().__init__(model, model_name) + + def transcribe(self, audio: Union[str, Tensor, ndarray], + *args, **kwargs) -> str: + """ + Transcribe an audio file. + + Args: + audio (Union[str, Tensor, nparray]): The audio file to transcribe. + *args: Additional arguments. + **kwargs: Additional keyword arguments, + such as the language of the audio file. + + Returns: + str: The transcript as a string. + """ + kwargs = self._get_whisper_kwargs(**kwargs) + + if isinstance(audio, Tensor): + audio = audio.cpu().numpy() + result = self.model.transcribe(audio, *args, **kwargs) + text = "" + for seg in result['segments']: + text += seg['text'] + return text + + @classmethod + def load_model(cls, + model: str = "medium", + download_root: str = WHISPER_DEFAULT_PATH, + device: Optional[Union[str, device]] = None, + *args, **kwargs + ) -> 'WhisperXTranscriber': + """ + Load whisper model. + + Args: + model (str): Whisper model. Available models include: + - 'tiny.en' + - 'tiny' + - 'base.en' + - 'base' + - 'small.en' + - 'small' + - 'medium.en' + - 'medium' + - 'large-v1' + - 'large-v2' + - 'large-v3' + - 'large' + + download_root (str, optional): Path to download the model. + Defaults to WHISPER_DEFAULT_PATH. + + device (Optional[Union[str, torch.device]], optional): + Device to load model on. Defaults to None. + in_memory (bool, optional): Whether to load model in memory. + Defaults to False. + args: Additional arguments only to avoid errors. + kwargs: Additional keyword arguments only to avoid errors. + + Returns: + Transcriber: A Transcriber object initialized with the specified model. + """ + if device is None: + device = "cuda" if cuda_is_available() else "cpu" + if not isinstance(device, str): + device = str(device) + compute_type = kwargs.get('compute_type', 'float16') + if device == 'cpu' and compute_type == 'float16': + warnings.warn(f'Compute type {compute_type} not compatible with ' + f'device {device}! Changing compute type to int8.') + compute_type = 'int8' + _model = whisperx_load_model(model, download_root=download_root, + device=device, compute_type=compute_type) + + return cls(_model, model_name=model) + + @staticmethod + def _get_whisper_kwargs(**kwargs) -> dict: + """ + Get kwargs for whisper model. Ensure that kwargs are valid. + + Returns: + dict: Keyword arguments for whisper model. + """ + # _possible_kwargs = WhisperModel.transcribe.__code__.co_varnames + _possible_kwargs = signature(WhisperModel.transcribe).parameters.keys() + + whisper_kwargs = {k: v for k, + v in kwargs.items() if k in _possible_kwargs} + + if (task := kwargs.get("task")): + whisper_kwargs["task"] = task + + if (language := kwargs.get("language")): + whisper_kwargs["language"] = language + + return whisper_kwargs + + def __repr__(self) -> str: + return f"WhisperXTranscriber(model_name={self.model_name}, model={self.model})" + + +def load_transcriber(model: str = "medium", + whisper_type: str = 'whisper', + download_root: str = WHISPER_DEFAULT_PATH, + device: Optional[Union[str, device]] = None, + in_memory: bool = False, + *args, **kwargs + ) -> Union[WhisperTranscriber, WhisperXTranscriber]: + """ + Load whisper model. + + Args: + model (str): Whisper model. Available models include: + - 'tiny.en' + - 'tiny' + - 'base.en' + - 'base' + - 'small.en' + - 'small' + - 'medium.en' + - 'medium' + - 'large-v1' + - 'large-v2' + - 'large-v3' + - 'large' + whisper_type (str): + Type of whisper model to load. "whisper" or "whisperx". + download_root (str, optional): Path to download the model. + Defaults to WHISPER_DEFAULT_PATH. + device (Optional[Union[str, torch.device]], optional): + Device to load model on. Defaults to None. + in_memory (bool, optional): Whether to load model in memory. + Defaults to False. + args: Additional arguments only to avoid errors. + kwargs: Additional keyword arguments only to avoid errors. + + Returns: + Union[WhisperTranscriber, WhisperXTranscriber]: + One of the Whisper variants as Transcrbier object initialized with the specified model. + """ + if whisper_type.lower() == 'whisper': + _model = WhisperTranscriber.load_model( + model, download_root, device, in_memory, *args, **kwargs) + return _model + elif whisper_type.lower() == 'whisperx': + _model = WhisperXTranscriber.load_model( + model, download_root, device, *args, **kwargs) + return _model + else: + raise ValueError(f'Model type not recognized, exptected "whisper" ' + f'or "whisperx", got {whisper_type}.') diff --git a/scraibe/transcript_exporter.py b/scraibe/transcript_exporter.py index 1ce43d4..5222d58 100644 --- a/scraibe/transcript_exporter.py +++ b/scraibe/transcript_exporter.py @@ -1,5 +1,6 @@ import json import time +from json.decoder import JSONDecodeError from typing import Union @@ -8,13 +9,12 @@ from .hallucinations import KNOWN_HALLUCINATIONS ALPHABET = [*"abcdefghijklmnopqrstuvwxyz"] - class Transcript: """ Class for storing transcript data, including speaker information and text segments, and exporting it to various file formats such as JSON, HTML, and LaTeX. """ - + def __init__(self, transcript: dict) -> None: """ Initializes the Transcript object with the given transcript data. @@ -30,7 +30,7 @@ class Transcript: self.speakers = self._extract_speakers() self.segments = self._extract_segments() self.annotation = {} - + def annotate(self, *args, **kwargs) -> dict: """ Annotates the transcript to associate specific names with speakers. @@ -46,36 +46,41 @@ class Transcript: ValueError: If the number of speaker names does not match the number of speakers, or if an unknown speaker is found. """ - + annotations = {} if args and len(args) != len(self.speakers): - raise ValueError("Number of speaker names does not match number of speakers") - + raise ValueError( + "Number of speaker names does not match number of speakers") + if args: for arg, speaker in zip(args, sorted(self.speakers)): - + annotations[speaker] = arg - + invalid_speakers = set(kwargs.keys()) - set(self.speakers) if invalid_speakers: - raise ValueError(f"These keys are not speakers: {', '.join(invalid_speakers)}") + raise ValueError( + f"These keys are not speakers: {', '.join(invalid_speakers)}") - annotations.update({key: kwargs[key] for key in self.speakers if key in kwargs}) + annotations.update({key: kwargs[key] + for key in self.speakers if key in kwargs}) self.annotation = annotations - + return self - + def _remove_hallucinations(self) -> None: """ Removes all occurances of known hallucinations from all segments of the transcript. Segments that are identical to empty strings afterwards are removed from the transcript. """ - segments_to_drop=[] + segments_to_drop = [] for id in self.transcript: for snippet in KNOWN_HALLUCINATIONS: - self.transcript[id]['text']=self.transcript[id]['text'].replace(snippet,'') - if self.transcript[id]['text'] == '': segments_to_drop.append(id) + self.transcript[id]['text'] = self.transcript[id]['text'].replace( + snippet, '') + if self.transcript[id]['text'] == '': + segments_to_drop.append(id) for id in segments_to_drop: del self.transcript[id] @@ -87,9 +92,9 @@ class Transcript: Returns: list: List of unique speaker names in the transcript. """ - + return list(set([self.transcript[id]["speakers"] for id in self.transcript])) - + def _extract_segments(self) -> list: """ Extracts all the text segments from the transcript. @@ -109,23 +114,23 @@ class Transcript: time stamps for each segment. """ fstring = "" - + for _id in self.transcript: seq = self.transcript[_id] - + if self.annotation: speaker = self.annotation[seq["speakers"]] else: speaker = seq["speakers"] - + segm = seq["segments"] - sseg = time.strftime("%H:%M:%S",time.gmtime(segm[0])) - eseg = time.strftime("%H:%M:%S",time.gmtime(segm[1])) - + sseg = time.strftime("%H:%M:%S", time.gmtime(segm[0])) + eseg = time.strftime("%H:%M:%S", time.gmtime(segm[1])) + fstring += f"{speaker} ({sseg} ; {eseg}):\t{seq['text']}\n" - + return fstring - + def __repr__(self) -> str: """Return a string representation of the Transcript object. @@ -133,8 +138,8 @@ class Transcript: str: A string that provides an informative description of the object. """ return f"Transcript(speakers = {self.speakers},"\ - f"segments = {self.segments}, annotation = {self.annotation})" - + f"segments = {self.segments}, annotation = {self.annotation})" + def get_dict(self) -> dict: """ Get transcript as dict @@ -142,10 +147,10 @@ class Transcript: :return: transcript as dict :rtype: dict """ - + return self.transcript - - def get_json(self, *args, use_annotation : bool = True, **kwargs) -> str: + + def get_json(self, *args, use_annotation: bool = True, **kwargs) -> str: """ Get transcript as json string :return: transcript as json string @@ -153,14 +158,14 @@ class Transcript: """ if "indent" not in kwargs: kwargs["indent"] = 3 - + if use_annotation and self.annotation: for _id in self.transcript: seq = self.transcript[_id] seq["speakers"] = self.annotation[seq["speakers"]] - + return json.dumps(self.transcript, *args, **kwargs) - + def get_html(self) -> str: """ Get transcript as html string @@ -171,9 +176,9 @@ class Transcript: html = "

" + self.__str__().replace("\n", "
") + "

" html = "" + html + "" html = html.replace("\t", "    ") - - return html - + + return html + def get_md(self) -> str: """Get transcript as Markdown string, using HTML formatting. @@ -181,7 +186,7 @@ class Transcript: str: Transcript as a Markdown string. """ return self.get_html() - + def get_tex(self) -> str: """Get transcript as LaTeX string. If no annotations are present, the speakers will be annotated with the first letters of the alphabet. @@ -192,43 +197,42 @@ class Transcript: if not self.annotation: self.annotate(*ALPHABET[:len(self.speakers)]) - - fstring ="\\begin{drama}" - + + fstring = "\\begin{drama}" + for speaker in self.speakers: - - fstring += "\n\t\\Character{"+ str(self.annotation[speaker]) + "}" \ - "{"+ str(self.annotation[speaker]) + "}" - + + fstring += "\n\t\\Character{" + str(self.annotation[speaker]) + "}" \ + "{" + str(self.annotation[speaker]) + "}" + for id in self.transcript: seq = self.transcript[id] speaker = self.annotation[seq["speakers"]] fstring += f"\n\\{speaker}speaks:\n{seq['text']}" - + fstring += "\n\\end{drama}" - + return fstring - - - def to_json(self,path, *args, **kwargs) -> None: + + def to_json(self, path, *args, **kwargs) -> None: """Save transcript as json file - + Args: path (str): path to save file """ with open(path, "w") as f: json.dump(self.transcript, f, *args, **kwargs) - + def to_txt(self, path: str) -> None: """Save transcript as a LaTeX file (placeholder function, implementation needed). Args: path (str): Path to save the LaTeX file. """ - + with open(path, "w") as f: f.write(self.__str__()) - + def to_md(self, path: str) -> None: """Get transcript as Markdown string, using HTML formatting. @@ -236,7 +240,7 @@ class Transcript: str: Transcript as a Markdown string. """ return self.to_html(path) - + def to_html(self, path: str) -> None: """ Save transcript as html file @@ -244,10 +248,10 @@ class Transcript: :param path: path to save file :type path: str """ - + with open(path, "w") as file: file.write(self.get_html()) - + def to_tex(self, path: str) -> None: """Save transcript as a LaTeX file (placeholder function, implementation needed). @@ -255,7 +259,7 @@ class Transcript: path (str): Path to save the LaTeX file. """ pass - + def to_pdf(self, path: str) -> None: """Save transcript as a PDF file (placeholder function, implementation needed). @@ -263,7 +267,7 @@ class Transcript: path (str): Path to save the PDF file. """ pass - + def save(self, path: str, *args, **kwargs) -> None: """Save transcript to file with the given path and file format. @@ -279,7 +283,7 @@ class Transcript: Raises: ValueError: If the file format specified in the path is unknown. """ - + if path.endswith(".json"): self.to_json(path, *args, **kwargs) elif path.endswith(".txt"): @@ -294,7 +298,7 @@ class Transcript: self.to_pdf(path, *args, **kwargs) else: raise ValueError("Unknown file format") - + @classmethod def from_json(cls, json: Union[dict, str]) -> "Transcript": """Load transcript from json file @@ -310,10 +314,8 @@ class Transcript: else: try: transcript = json.loads(json) - except: + except (TypeError, JSONDecodeError): with open(json, "r") as f: transcript = json.load(f) - - return cls(transcript) - \ No newline at end of file + return cls(transcript) diff --git a/source/conf.py b/source/conf.py index ba51ab3..43fe803 100644 --- a/source/conf.py +++ b/source/conf.py @@ -31,16 +31,16 @@ release = '0.1.1' # extensions coming with Sphinx (named 'sphinx.ext.*') or your custom # ones. extensions = ['sphinx.ext.autodoc', - 'sphinx.ext.doctest', - 'sphinx.ext.intersphinx', - 'sphinx.ext.todo', - 'sphinx.ext.coverage', - 'sphinx.ext.mathjax', - 'sphinx.ext.ifconfig', - 'sphinx.ext.viewcode', - 'sphinx.ext.githubpages', - 'sphinx.ext.napoleon', - 'myst_parser'] + 'sphinx.ext.doctest', + 'sphinx.ext.intersphinx', + 'sphinx.ext.todo', + 'sphinx.ext.coverage', + 'sphinx.ext.mathjax', + 'sphinx.ext.ifconfig', + 'sphinx.ext.viewcode', + 'sphinx.ext.githubpages', + 'sphinx.ext.napoleon', + 'myst_parser'] # Napoleon settings napoleon_google_docstring = True diff --git a/test/test_audio.py b/test/test_audio.py index 311a472..aee6cb3 100644 --- a/test/test_audio.py +++ b/test/test_audio.py @@ -3,7 +3,6 @@ from scraibe.audio import AudioProcessor import torch - DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu") TEST_WAVEFORM = torch.sin(torch.randn(160000)).to(DEVICE) TEST_SR = 16000 @@ -14,21 +13,17 @@ NORMALIZATION_FACTOR = 32768 @pytest.fixture def probe_audio_processor(): """Fixture for creating an instance of the AudioProcessor class with test waveform and sample rate. - + This fixture is used to create an instance of the AudioProcessor class with a predfined test waveform and sample rate (TEST_SR). It returns the instantiated AudioProcessor , which can bes used as a dependency in other test functions. Returns: AudioProcessor (obj): An instance of the AudioProcessor class with the test waveform and sample rate. - """ + """ return AudioProcessor(TEST_WAVEFORM, TEST_SR) - - - - def test_AudioProcessor_init(probe_audio_processor): """ Test the initialization of the AudioProcessor class. @@ -43,20 +38,19 @@ def test_AudioProcessor_init(probe_audio_processor): Returns: None - - - """ + + + """ assert isinstance(probe_audio_processor, AudioProcessor) assert probe_audio_processor.waveform.device == TEST_WAVEFORM.device assert torch.equal(probe_audio_processor.waveform, TEST_WAVEFORM) assert probe_audio_processor.sr == TEST_SR - def test_cut(probe_audio_processor): """Test the cut function of the AudioProcessor class. - + This test verifies that the cut function correctly extracts a segment of audio data from the waveform, given start and end indices. It checks whether the size of the extracted segment matches the expected size based on the provided start and end indices and the sample rate. @@ -65,63 +59,38 @@ def test_cut(probe_audio_processor): None - """ - + """ + start = 4 end = 7 trimmed_waveform = probe_audio_processor.cut(start, end) expected_size = int((end - start) * TEST_SR) real_size = trimmed_waveform.size(0) assert real_size == expected_size - #assert AudioProcessor(TEST_WAVEFORM, TEST_SR).cut(start, end).size() == int((end - start) * TEST_SR) - - - - - - - - + # assert AudioProcessor(TEST_WAVEFORM, TEST_SR).cut(start, end).size() == int((end - start) * TEST_SR) def test_audio_processor_invalid_sr(): """Test the behavior of AudioProcessor when an invalid smaple rate is provided. - + This test verifies that the AudioProcessor constructor raises a ValueError when an invalid sample rate is provided. It uses the pytest.raises context manager to check if the ValueError is raised when initializing an AudioProcessor object with an invalid sample rate. Returns: None - """ + """ with pytest.raises(ValueError): - AudioProcessor(TEST_WAVEFORM, [44100,48000]) + AudioProcessor(TEST_WAVEFORM, [44100, 48000]) def test_audio_processor_SAMPLE_RATE(): """Test the default sample rate of the AudioProcessor class. - + This test verifies that the default sample rate of the AudioProcessor class matches the expected value defined by the constant SAMPLE_RATE. It instantiates an AudioProcessor object with a test waveform and checks whether the sample rate attribute (sr) of the AudioProcessor object equals the predefined constant SAMPLE_RATE. Returns: None - """ + """ probe_audio_processor = AudioProcessor(TEST_WAVEFORM) - assert probe_audio_processor.sr == SAMPLE_RATE - - - - - - - - - - - - - - - - - + assert probe_audio_processor.sr == SAMPLE_RATE diff --git a/test/test_autotranscript.py b/test/test_autotranscript.py index edbe0f7..78442b3 100644 --- a/test/test_autotranscript.py +++ b/test/test_autotranscript.py @@ -1,20 +1,14 @@ import pytest from scraibe import Scraibe, Diariser, Transcriber, Transcript -from unittest.mock import MagicMock, patch import os - - - @pytest.fixture def create_scraibe_instance(): if "HF_TOKEN" in os.environ: - return Scraibe(use_auth_token=os.environ["HF_TOKEN"] ) + return Scraibe(use_auth_token=os.environ["HF_TOKEN"]) else: return Scraibe() - - def test_scraibe_init(create_scraibe_instance): @@ -47,7 +41,7 @@ def test_scraibe_transcribe(create_scraibe_instance): model.remove_audio_file("non_existing_audio_file") model.remove_audio_file("audio_test_2.mp4") - assert not os.path.exists("audio_test_2.mp4") """ + assert not os.path.exists("audio_test_2.mp4") """ """ def test_get_audio_file(create_scraibe_instance): diff --git a/test/test_diarisation.py b/test/test_diarisation.py index d1d26f3..01431be 100644 --- a/test/test_diarisation.py +++ b/test/test_diarisation.py @@ -1,8 +1,5 @@ import pytest -import os -from unittest import mock -from scraibe import diarisation, Diariser - +from scraibe import Diariser @pytest.fixture @@ -15,11 +12,10 @@ def diariser_instance(): Returns: Diariser(Obj): An instance of the Diariser class with a mocked token. """ - #with mock.patch.object(Diariser, '_get_token', return_value = 'HF_TOKEN' ): + # with mock.patch.object(Diariser, '_get_token', return_value = 'HF_TOKEN' ): return Diariser('pyannote') - def test_Diariser_init(diariser_instance): """Test the initialization of the Diariser class. @@ -30,18 +26,7 @@ def test_Diariser_init(diariser_instance): Args: diariser_instance (obj): instance of the Diariser class - Returns: + Returns: None - """ + """ assert diariser_instance.model == 'pyannote' - - - - - - - - - - - diff --git a/test/test_transcriber.py b/test/test_transcriber.py index 3a4a0dc..31765f6 100644 --- a/test/test_transcriber.py +++ b/test/test_transcriber.py @@ -1,27 +1,26 @@ import pytest -from unittest.mock import patch -from scraibe import Transcriber +from scraibe import (Transcriber, WhisperTranscriber, + WhisperXTranscriber, load_transcriber) import torch - DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu") TEST_WAVEFORM = "Hello World" -""" +""" @pytest.mark.parametrize("audio_file, expected_transcription",[("path_to_test_audiofile", "test_transcription")] ) @patch("scraibe.Transcriber.load_model") def test_transcriber(mock_load_model, audio_file, expected_transcription): - + Args: mock_load_model (_type_): _description_ audio_file (_type_): _description_ expected_transcription (_type_): _description_ - + mock_model = mock_load_model.return_value - mock_model.transcribe.return_value ={"text": expected_transcription} + mock_model.transcribe.return_value ={"text": expected_transcription} transcriber = Transcriber.load_model(model="medium") @@ -29,24 +28,53 @@ def test_transcriber(mock_load_model, audio_file, expected_transcription): assert transcription_result == expected_transcription """ -@pytest.fixture -def transcriber_instance(): - return Transcriber.load_model('medium') -def test_transcriber_initialization(transcriber_instance): - assert isinstance(transcriber_instance, Transcriber) +@pytest.fixture +def whisper_instance(): + return load_transcriber('medium', whisper_type='whisper') + + +@pytest.fixture +def whisperx_instance(): + return load_transcriber('medium', whisper_type='whisperx') + + +def test_whisper_base_initialization(whisper_instance): + assert isinstance(whisper_instance, Transcriber) + + +def test_whisperx_base_initialization(whisperx_instance): + assert isinstance(whisperx_instance, Transcriber) + + +def test_whisper_transcriber_initialization(whisper_instance): + assert isinstance(whisper_instance, WhisperTranscriber) + + +def test_whisperx_transcriber_initialization(whisperx_instance): + assert isinstance(whisperx_instance, WhisperXTranscriber) + + +def test_wrong_transcriber_initialization(): + with pytest.raises(ValueError): + load_transcriber('medium', whisper_type='wrong_whisper') + def test_get_whisper_kwargs(): - kwargs = {"arg1": 1, "arg3": 3} + kwargs = {"arg1": 1, "arg3": 3} valid_kwargs = Transcriber._get_whisper_kwargs(**kwargs) - assert not valid_kwargs == {"arg1": 1, "arg3": 3} + assert not valid_kwargs == {"arg1": 1, "arg3": 3} -def test_transcribe(transcriber_instance): - model = transcriber_instance - #mocker.patch.object(transcriber_instance.model, 'transcribe', return_value={'Hello, World !'} ) +def test_whisper_transcribe(whisper_instance): + model = whisper_instance + # mocker.patch.object(transcriber_instance.model, 'transcribe', return_value={'Hello, World !'} ) transcript = model.transcribe('test/audio_test_2.mp4') assert isinstance(transcript, str) - +def test_whisperx_transcribe(whisperx_instance): + model = whisperx_instance + # mocker.patch.object(transcriber_instance.model, 'transcribe', return_value={'Hello, World !'} ) + transcript = model.transcribe('test/audio_test_2.mp4') + assert isinstance(transcript, str)