diff --git a/pyproject.toml b/pyproject.toml index 8c46bdb..caf02a2 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -34,7 +34,7 @@ python = "^3.9" tqdm = "^4.66.4" numpy = "^1.26.4" openai-whisper = "^20231117" -whisperx = "^3.1.3" +faster-whisper = "^1.0.1" "pyannote.audio" = "^3.1.1" torch = "^2.3.0" diff --git a/requirements.txt b/requirements.txt index f08e2e6..94ee85a 100644 --- a/requirements.txt +++ b/requirements.txt @@ -2,7 +2,7 @@ tqdm>=4.65.0 numpy>=1.26.4 openai-whisper==20231117 -whisperx~=3.1.3 +faster-whisper~=1.0.1 pyannote.audio~=3.1.1 pyannote.core~=5.0.0 diff --git a/scraibe/autotranscript.py b/scraibe/autotranscript.py index 7391f1a..43dedc2 100644 --- a/scraibe/autotranscript.py +++ b/scraibe/autotranscript.py @@ -74,7 +74,7 @@ class Scraibe: whisper_model (Union[bool, str, whisper], optional): Path to whisper model or whisper model itself. whisper_type (str): - Type of whisper model to load. "whisper" or "whisperx". + Type of whisper model to load. "whisper" or "faster-whisper". diarisation_model (Union[bool, str, DiarisationType], optional): Path to pyannote diarization model or model itself. **kwargs: Additional keyword arguments for whisper diff --git a/scraibe/cli.py b/scraibe/cli.py index ee40c8b..a234132 100644 --- a/scraibe/cli.py +++ b/scraibe/cli.py @@ -36,8 +36,8 @@ def cli(): help="List of audio files to transcribe.") parser.add_argument("--whisper-type", type=str, default="whisper", - choices=["whisper", "whisperx"], - help="Type of Whisper model to use ('whisper' or 'whisperx').") + choices=["whisper", "faster-whisper"], + help="Type of Whisper model to use ('whisper' or 'faster-whisper').") parser.add_argument("--whisper-model-name", default="medium", help="Name of the Whisper model to use.") diff --git a/scraibe/misc.py b/scraibe/misc.py index f12335f..56e9f3a 100644 --- a/scraibe/misc.py +++ b/scraibe/misc.py @@ -16,7 +16,7 @@ WHISPER_DEFAULT_PATH = os.path.join(CACHE_DIR, "whisper") PYANNOTE_DEFAULT_PATH = os.path.join(CACHE_DIR, "pyannote") PYANNOTE_DEFAULT_CONFIG = os.path.join(PYANNOTE_DEFAULT_PATH, "config.yaml") \ if os.path.exists(os.path.join(PYANNOTE_DEFAULT_PATH, "config.yaml")) \ - else ('jaikinator/scraibe', 'pyannote/speaker-diarization-3.1') + else ('Jaikinator/ScrAIbe', 'pyannote/speaker-diarization-3.1') def config_diarization_yaml(file_path: str, path_to_segmentation: str = None) -> None: diff --git a/scraibe/transcriber.py b/scraibe/transcriber.py index 0301955..cea7274 100644 --- a/scraibe/transcriber.py +++ b/scraibe/transcriber.py @@ -26,8 +26,7 @@ Usage: from whisper import Whisper from whisper import load_model as whisper_load_model -from whisperx.asr import WhisperModel -from whisperx import load_model as whisperx_load_model +from faster_whisper import WhisperModel as FasterWhisperModel from typing import TypeVar, Union, Optional from torch import Tensor, device from torch.cuda import is_available as cuda_is_available @@ -145,7 +144,7 @@ class Transcriber: - 'large-v3' - 'large' whisper_type (str): - Type of whisper model to load. "whisper" or "whisperx". + Type of whisper model to load. "whisper" or "faster-whisper". download_root (str, optional): Path to download the model. Defaults to WHISPER_DEFAULT_PATH. device (Optional[Union[str, torch.device]], optional): @@ -272,7 +271,7 @@ class WhisperTranscriber(Transcriber): return f"WhisperTranscriber(model_name={self.model_name}, model={self.model})" -class WhisperXTranscriber(Transcriber): +class FasterWhisperTranscriber(Transcriber): def __init__(self, model: whisper, model_name: str) -> None: super().__init__(model, model_name) @@ -294,10 +293,10 @@ class WhisperXTranscriber(Transcriber): if isinstance(audio, Tensor): audio = audio.cpu().numpy() - result = self.model.transcribe(audio, *args, **kwargs) + result, _ = self.model.transcribe(audio, *args, **kwargs) text = "" - for seg in result['segments']: - text += seg['text'] + for seg in result: + text += seg.text return text @classmethod @@ -306,7 +305,7 @@ class WhisperXTranscriber(Transcriber): download_root: str = WHISPER_DEFAULT_PATH, device: Optional[Union[str, device]] = None, *args, **kwargs - ) -> 'WhisperXTranscriber': + ) -> 'FasterWhisperModel': """ Load whisper model. @@ -347,8 +346,8 @@ class WhisperXTranscriber(Transcriber): warnings.warn(f'Compute type {compute_type} not compatible with ' f'device {device}! Changing compute type to int8.') compute_type = 'int8' - _model = whisperx_load_model(model, download_root=download_root, - device=device, compute_type=compute_type) + _model = FasterWhisperModel(model, download_root=download_root, + device=device, compute_type=compute_type) return cls(_model, model_name=model) @@ -361,7 +360,7 @@ class WhisperXTranscriber(Transcriber): dict: Keyword arguments for whisper model. """ # _possible_kwargs = WhisperModel.transcribe.__code__.co_varnames - _possible_kwargs = signature(WhisperModel.transcribe).parameters.keys() + _possible_kwargs = signature(FasterWhisperModel.transcribe).parameters.keys() whisper_kwargs = {k: v for k, v in kwargs.items() if k in _possible_kwargs} @@ -375,7 +374,7 @@ class WhisperXTranscriber(Transcriber): return whisper_kwargs def __repr__(self) -> str: - return f"WhisperXTranscriber(model_name={self.model_name}, model={self.model})" + return f"FasterWhisperTranscriber(model_name={self.model_name}, model={self.model})" def load_transcriber(model: str = "medium", @@ -384,7 +383,7 @@ def load_transcriber(model: str = "medium", device: Optional[Union[str, device]] = None, in_memory: bool = False, *args, **kwargs - ) -> Union[WhisperTranscriber, WhisperXTranscriber]: + ) -> Union[WhisperTranscriber, FasterWhisperTranscriber]: """ Load whisper model. @@ -403,28 +402,28 @@ def load_transcriber(model: str = "medium", - 'large-v3' - 'large' whisper_type (str): - Type of whisper model to load. "whisper" or "whisperx". + Type of whisper model to load. "whisper" or "faster-whisper". download_root (str, optional): Path to download the model. Defaults to WHISPER_DEFAULT_PATH. - device (Optional[Union[str, torch.device]], optional): + device (Optional[Union[str, torch.device]], optional): Device to load model on. Defaults to None. - in_memory (bool, optional): Whether to load model in memory. + in_memory (bool, optional): Whether to load model in memory. Defaults to False. args: Additional arguments only to avoid errors. kwargs: Additional keyword arguments only to avoid errors. Returns: - Union[WhisperTranscriber, WhisperXTranscriber]: + Union[WhisperTranscriber, FasterWhisperTranscriber]: One of the Whisper variants as Transcrbier object initialized with the specified model. """ if whisper_type.lower() == 'whisper': _model = WhisperTranscriber.load_model( model, download_root, device, in_memory, *args, **kwargs) return _model - elif whisper_type.lower() == 'whisperx': - _model = WhisperXTranscriber.load_model( + elif whisper_type.lower() == 'faster-whisper': + _model = FasterWhisperTranscriber.load_model( model, download_root, device, *args, **kwargs) return _model else: raise ValueError(f'Model type not recognized, exptected "whisper" ' - f'or "whisperx", got {whisper_type}.') + f'or "faster-whisper", got {whisper_type}.')