diff --git a/pyproject.toml b/pyproject.toml index 098b025..e82881d 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -34,8 +34,8 @@ python = "^3.9" tqdm = "^4.66.5" numpy = "^1.26.4" openai-whisper = "^20231117" -whisperx = "^3.1.5" -"pyannote.audio" = "^3.1.1" +faster-whisper = "^1.0.3" +"pyannote.audio" = "^3.3.1" torch = "^2.3.0" [tool.poetry.group.dev.dependencies] diff --git a/requirements.txt b/requirements.txt index f43514f..6e95c81 100644 --- a/requirements.txt +++ b/requirements.txt @@ -2,9 +2,13 @@ tqdm>=4.66.5 numpy>=1.26.4 openai-whisper==20231117 -whisperx~=3.1.5 +faster-whisper~=1.0.3 -pyannote.audio~=3.1.1 +pyannote.audio~=3.3.1 +pyannote.core~=5.0.0 +pyannote.database~=5.0.1 +pyannote.metrics~=3.2.1 +pyannote.pipeline~=3.0.1 torch>=2.0.0 diff --git a/scraibe/autotranscript.py b/scraibe/autotranscript.py index 7391f1a..43dedc2 100644 --- a/scraibe/autotranscript.py +++ b/scraibe/autotranscript.py @@ -74,7 +74,7 @@ class Scraibe: whisper_model (Union[bool, str, whisper], optional): Path to whisper model or whisper model itself. whisper_type (str): - Type of whisper model to load. "whisper" or "whisperx". + Type of whisper model to load. "whisper" or "faster-whisper". diarisation_model (Union[bool, str, DiarisationType], optional): Path to pyannote diarization model or model itself. **kwargs: Additional keyword arguments for whisper diff --git a/scraibe/cli.py b/scraibe/cli.py index c85e985..df73d1b 100644 --- a/scraibe/cli.py +++ b/scraibe/cli.py @@ -36,8 +36,8 @@ def cli(): help="List of audio files to transcribe.") parser.add_argument("--whisper-type", type=str, default="whisper", - choices=["whisper", "whisperx"], - help="Type of Whisper model to use ('whisper' or 'whisperx').") + choices=["whisper", "faster-whisper"], + help="Type of Whisper model to use ('whisper' or 'faster-whisper').") parser.add_argument("--whisper-model-name", default="medium", help="Name of the Whisper model to use.") diff --git a/scraibe/misc.py b/scraibe/misc.py index 21099fb..106b9e1 100644 --- a/scraibe/misc.py +++ b/scraibe/misc.py @@ -16,7 +16,7 @@ WHISPER_DEFAULT_PATH = os.path.join(CACHE_DIR, "whisper") PYANNOTE_DEFAULT_PATH = os.path.join(CACHE_DIR, "pyannote") PYANNOTE_DEFAULT_CONFIG = os.path.join(PYANNOTE_DEFAULT_PATH, "config.yaml") \ if os.path.exists(os.path.join(PYANNOTE_DEFAULT_PATH, "config.yaml")) \ - else ('jaikinator/scraibe', 'pyannote/speaker-diarization-3.1') + else ('Jaikinator/ScrAIbe', 'pyannote/speaker-diarization-3.1') def config_diarization_yaml(file_path: str, path_to_segmentation: str = None) -> None: diff --git a/scraibe/transcriber.py b/scraibe/transcriber.py index 0301955..abf1ace 100644 --- a/scraibe/transcriber.py +++ b/scraibe/transcriber.py @@ -26,8 +26,9 @@ Usage: from whisper import Whisper from whisper import load_model as whisper_load_model -from whisperx.asr import WhisperModel -from whisperx import load_model as whisperx_load_model +from whisper.tokenizer import TO_LANGUAGE_CODE +from faster_whisper import WhisperModel as FasterWhisperModel +from faster_whisper.tokenizer import _LANGUAGE_CODES as FASTER_WHISPER_LANGUAGE_CODES from typing import TypeVar, Union, Optional from torch import Tensor, device from torch.cuda import is_available as cuda_is_available @@ -145,7 +146,7 @@ class Transcriber: - 'large-v3' - 'large' whisper_type (str): - Type of whisper model to load. "whisper" or "whisperx". + Type of whisper model to load. "whisper" or "faster-whisper". download_root (str, optional): Path to download the model. Defaults to WHISPER_DEFAULT_PATH. device (Optional[Union[str, torch.device]], optional): @@ -272,7 +273,7 @@ class WhisperTranscriber(Transcriber): return f"WhisperTranscriber(model_name={self.model_name}, model={self.model})" -class WhisperXTranscriber(Transcriber): +class FasterWhisperTranscriber(Transcriber): def __init__(self, model: whisper, model_name: str) -> None: super().__init__(model, model_name) @@ -294,10 +295,10 @@ class WhisperXTranscriber(Transcriber): if isinstance(audio, Tensor): audio = audio.cpu().numpy() - result = self.model.transcribe(audio, *args, **kwargs) + result, _ = self.model.transcribe(audio, *args, **kwargs) text = "" - for seg in result['segments']: - text += seg['text'] + for seg in result: + text += seg.text return text @classmethod @@ -306,7 +307,7 @@ class WhisperXTranscriber(Transcriber): download_root: str = WHISPER_DEFAULT_PATH, device: Optional[Union[str, device]] = None, *args, **kwargs - ) -> 'WhisperXTranscriber': + ) -> 'FasterWhisperModel': """ Load whisper model. @@ -347,8 +348,8 @@ class WhisperXTranscriber(Transcriber): warnings.warn(f'Compute type {compute_type} not compatible with ' f'device {device}! Changing compute type to int8.') compute_type = 'int8' - _model = whisperx_load_model(model, download_root=download_root, - device=device, compute_type=compute_type) + _model = FasterWhisperModel(model, download_root=download_root, + device=device, compute_type=compute_type) return cls(_model, model_name=model) @@ -361,7 +362,7 @@ class WhisperXTranscriber(Transcriber): dict: Keyword arguments for whisper model. """ # _possible_kwargs = WhisperModel.transcribe.__code__.co_varnames - _possible_kwargs = signature(WhisperModel.transcribe).parameters.keys() + _possible_kwargs = signature(FasterWhisperModel.transcribe).parameters.keys() whisper_kwargs = {k: v for k, v in kwargs.items() if k in _possible_kwargs} @@ -370,12 +371,42 @@ class WhisperXTranscriber(Transcriber): whisper_kwargs["task"] = task if (language := kwargs.get("language")): + language = FasterWhisperTranscriber.convert_to_language_code(language) whisper_kwargs["language"] = language return whisper_kwargs + @staticmethod + def convert_to_language_code(lang : str) -> str: + """ + Load whisper model. + + Args: + lang (str): language as code or language name + + Returns: + language (str) code of language + """ + + # If the input is already in FASTER_WHISPER_LANGUAGE_CODES, return it directly + if lang in FASTER_WHISPER_LANGUAGE_CODES: + return lang + + # Normalize the input to lowercase + lang = lang.lower() + + # Check if the language name is in the TO_LANGUAGE_CODE mapping + if lang in TO_LANGUAGE_CODE: + return TO_LANGUAGE_CODE[lang] + + # If the language is not recognized, raise a ValueError with the available options + available_codes = ', '.join(FASTER_WHISPER_LANGUAGE_CODES) + raise ValueError(f"Language '{lang}' is not a valid language code or name. " + f"Available language codes are: {available_codes}.") + def __repr__(self) -> str: - return f"WhisperXTranscriber(model_name={self.model_name}, model={self.model})" + return f"FasterWhisperTranscriber(model_name={self.model_name}, model={self.model})" + def load_transcriber(model: str = "medium", @@ -384,7 +415,7 @@ def load_transcriber(model: str = "medium", device: Optional[Union[str, device]] = None, in_memory: bool = False, *args, **kwargs - ) -> Union[WhisperTranscriber, WhisperXTranscriber]: + ) -> Union[WhisperTranscriber, FasterWhisperTranscriber]: """ Load whisper model. @@ -403,28 +434,28 @@ def load_transcriber(model: str = "medium", - 'large-v3' - 'large' whisper_type (str): - Type of whisper model to load. "whisper" or "whisperx". + Type of whisper model to load. "whisper" or "faster-whisper". download_root (str, optional): Path to download the model. Defaults to WHISPER_DEFAULT_PATH. - device (Optional[Union[str, torch.device]], optional): + device (Optional[Union[str, torch.device]], optional): Device to load model on. Defaults to None. - in_memory (bool, optional): Whether to load model in memory. + in_memory (bool, optional): Whether to load model in memory. Defaults to False. args: Additional arguments only to avoid errors. kwargs: Additional keyword arguments only to avoid errors. Returns: - Union[WhisperTranscriber, WhisperXTranscriber]: + Union[WhisperTranscriber, FasterWhisperTranscriber]: One of the Whisper variants as Transcrbier object initialized with the specified model. """ if whisper_type.lower() == 'whisper': _model = WhisperTranscriber.load_model( model, download_root, device, in_memory, *args, **kwargs) return _model - elif whisper_type.lower() == 'whisperx': - _model = WhisperXTranscriber.load_model( + elif whisper_type.lower() == 'faster-whisper': + _model = FasterWhisperTranscriber.load_model( model, download_root, device, *args, **kwargs) return _model else: raise ValueError(f'Model type not recognized, exptected "whisper" ' - f'or "whisperx", got {whisper_type}.') + f'or "faster-whisper", got {whisper_type}.') diff --git a/test/test_transcriber.py b/test/test_transcriber.py index 31765f6..5bfe3cf 100644 --- a/test/test_transcriber.py +++ b/test/test_transcriber.py @@ -1,6 +1,6 @@ import pytest from scraibe import (Transcriber, WhisperTranscriber, - WhisperXTranscriber, load_transcriber) + FasterWhisperTranscriber, load_transcriber) import torch @@ -31,33 +31,33 @@ def test_transcriber(mock_load_model, audio_file, expected_transcription): @pytest.fixture def whisper_instance(): - return load_transcriber('medium', whisper_type='whisper') + return load_transcriber('tiny', whisper_type='whisper') @pytest.fixture -def whisperx_instance(): - return load_transcriber('medium', whisper_type='whisperx') +def faster_whisper_instance(): + return load_transcriber('tiny', whisper_type='faster-whisper') def test_whisper_base_initialization(whisper_instance): assert isinstance(whisper_instance, Transcriber) -def test_whisperx_base_initialization(whisperx_instance): - assert isinstance(whisperx_instance, Transcriber) +def test_faster_whisper_base_initialization(faster_whisper_instance): + assert isinstance(faster_whisper_instance, Transcriber) def test_whisper_transcriber_initialization(whisper_instance): assert isinstance(whisper_instance, WhisperTranscriber) -def test_whisperx_transcriber_initialization(whisperx_instance): - assert isinstance(whisperx_instance, WhisperXTranscriber) +def test_faster_whisper_transcriber_initialization(faster_whisper_instance): + assert isinstance(faster_whisper_instance, FasterWhisperTranscriber) def test_wrong_transcriber_initialization(): with pytest.raises(ValueError): - load_transcriber('medium', whisper_type='wrong_whisper') + load_transcriber('tiny', whisper_type='wrong_whisper') def test_get_whisper_kwargs(): @@ -73,8 +73,8 @@ def test_whisper_transcribe(whisper_instance): assert isinstance(transcript, str) -def test_whisperx_transcribe(whisperx_instance): - model = whisperx_instance +def test_faster_whisper_transcribe(faster_whisper_instance): + model = faster_whisper_instance # mocker.patch.object(transcriber_instance.model, 'transcribe', return_value={'Hello, World !'} ) transcript = model.transcribe('test/audio_test_2.mp4') assert isinstance(transcript, str) diff --git a/tests/test_diarization.py b/tests/test_diarization.py new file mode 100644 index 0000000..f9e81a5 --- /dev/null +++ b/tests/test_diarization.py @@ -0,0 +1,10 @@ +from os import environ + +environ["AUTOT_CACHE"] = "/mnt/disk1/Projekte/ScrAIbe/tests" +# environ["PYANNOTE_CACHE"] = "/mnt/disk1/Projekte/ScrAIbe/tests/pyannote" +# environ["TORCH_HOME"] = "/mnt/disk1/Projekte/ScrAIbe/tests/torch" + +from scraibe import Scraibe + +scraibe = Scraibe(whisper_type = "faster-whisper", whisper_model = "tiny") +print(scraibe.autotranscribe('/mnt/disk1/Projekte/ScrAIbe/test/audio_test_1.mp4')) \ No newline at end of file