From 82e26771e0e4d32501feba442fec83bf3cb957a6 Mon Sep 17 00:00:00 2001 From: Marko Henning Date: Wed, 8 May 2024 15:49:05 +0200 Subject: [PATCH] Added WhisperX as possible whisper model. --- requirements.txt | 1 + scraibe/autotranscript.py | 5 +- scraibe/transcriber.py | 252 +++++++++++++++++++++++++++++++++----- 3 files changed, 224 insertions(+), 34 deletions(-) diff --git a/requirements.txt b/requirements.txt index 5872774..d1bdccc 100644 --- a/requirements.txt +++ b/requirements.txt @@ -2,6 +2,7 @@ tqdm>=4.65.0 numpy>=1.26.4 openai-whisper==20231117 +whisperx~=3.1.3 pyannote.audio~=3.1.1 pyannote.core~=5.0.0 diff --git a/scraibe/autotranscript.py b/scraibe/autotranscript.py index 7d54ba8..4081638 100644 --- a/scraibe/autotranscript.py +++ b/scraibe/autotranscript.py @@ -64,6 +64,7 @@ class Scraibe: """ def __init__(self, whisper_model: Union[bool, str, whisper] = None, + whisper_type: str = "whisper", dia_model : Union[bool, str, DiarisationType] = None, **kwargs) -> None: """Initializes the Scraibe class. @@ -84,9 +85,9 @@ class Scraibe: if whisper_model is None: - self.transcriber = Transcriber.load_model("medium", **kwargs) + self.transcriber = Transcriber.load_model("medium", whisper_type, **kwargs) elif isinstance(whisper_model, str): - self.transcriber = Transcriber.load_model(whisper_model, **kwargs) + self.transcriber = Transcriber.load_model(whisper_model, whisper_type, **kwargs) else: self.transcriber = whisper_model diff --git a/scraibe/transcriber.py b/scraibe/transcriber.py index 910ea59..365d321 100644 --- a/scraibe/transcriber.py +++ b/scraibe/transcriber.py @@ -24,16 +24,18 @@ Usage: >>> transcriber.save_transcript(transcript, "path/to/save.txt") """ -from whisper import Whisper, load_model +from whisper import Whisper +from whisper import load_model as whisper_load_model +from whisperx.asr import WhisperModel +from whisperx import load_model as whisperx_load_model from typing import TypeVar , Union , Optional from torch import Tensor, device from numpy import ndarray - +from inspect import getfullargspec +from abc import ABC, abstractmethod from .misc import WHISPER_DEFAULT_PATH -whisper = TypeVar('whisper') - - +whisper = TypeVar('whisper') class Transcriber: @@ -64,7 +66,7 @@ class Transcriber: The class supports various sizes and versions of Whisper models. Please refer to the load_model method for available options. """ - def __init__(self, model: whisper , model_name: str ) -> None: + def __init__(self, model: whisper, model_name: str) -> None: """ Initialize the Transcriber class with a Whisper model. @@ -77,7 +79,8 @@ class Transcriber: self.model_name = model_name - def transcribe(self, audio : Union[str, Tensor, ndarray] , + @abstractmethod + def transcribe(self, audio: Union[str, Tensor, ndarray] , *args, **kwargs) -> str: """ Transcribe an audio file. @@ -91,14 +94,7 @@ class Transcriber: Returns: str: The transcript as a string. """ - - kwargs = self._get_whisper_kwargs(**kwargs) - - if not kwargs.get("verbose"): - kwargs["verbose"] = None - - result = self.model.transcribe(audio, *args, **kwargs) - return result["text"] + pass @staticmethod def save_transcript(transcript : str , save_path : str) -> None: @@ -120,12 +116,106 @@ class Transcriber: @classmethod def load_model(cls, - model: str = "medium", - download_root: str = WHISPER_DEFAULT_PATH, - device: Optional[Union[str, device]] = None, - in_memory: bool = False, - *args, **kwargs - ) -> 'Transcriber': + model: str = "medium", + whisper_type: str = 'whisper', + download_root: str = WHISPER_DEFAULT_PATH, + device: Optional[Union[str, device]] = None, + in_memory: bool = False, + *args, **kwargs + ) -> 'Transcriber': + """ + Load whisper model. + + Args: + model (str): Whisper model. Available models include: + - 'tiny.en' + - 'tiny' + - 'base.en' + - 'base' + - 'small.en' + - 'small' + - 'medium.en' + - 'medium' + - 'large-v1' + - 'large-v2' + - 'large-v3' + - 'large' + + download_root (str, optional): Path to download the model. + Defaults to WHISPER_DEFAULT_PATH. + + device (Optional[Union[str, torch.device]], optional): + Device to load model on. Defaults to None. + in_memory (bool, optional): Whether to load model in memory. + Defaults to False. + args: Additional arguments only to avoid errors. + kwargs: Additional keyword arguments only to avoid errors. + + Returns: + Transcriber: A Transcriber object initialized with the specified model. + """ + if whisper_type.lower() == 'whisper': + _model = WhisperTranscriber.load_model( + model, download_root, device, in_memory, *args, **kwargs) + return _model + elif whisper_type.lower() == 'whisperx': + _model = WhisperXTranscriber.load_model( + model, download_root, device, *args, **kwargs) + return _model + else: + raise ValueError(f'Model type not recognized, exptected "whisper" ' + f'or "whisperx", got {whisper_type}.') + pass + + @staticmethod + def _get_whisper_kwargs(**kwargs) -> dict: + """ + Get kwargs for whisper model. Ensure that kwargs are valid. + + Returns: + dict: Keyword arguments for whisper model. + """ + pass + + def __repr__(self) -> str: + return f"Transcriber(model_name={self.model_name}, model={self.model})" + + +class WhisperTranscriber(Transcriber): + def __init__(self, model: whisper, model_name: str) -> None: + super().__init__(model, model_name) + + def transcribe(self, audio: Union[str, Tensor, ndarray], + *args, **kwargs) -> str: + """ + Transcribe an audio file. + + Args: + audio (Union[str, Tensor, nparray]): The audio file to transcribe. + *args: Additional arguments. + **kwargs: Additional keyword arguments, + such as the language of the audio file. + + Returns: + str: The transcript as a string. + """ + + kwargs = self._get_whisper_kwargs(**kwargs) + + if not kwargs.get("verbose"): + kwargs["verbose"] = None + + result = self.model.transcribe(audio, *args, **kwargs) + return result["text"] + + @classmethod + def load_model(cls, + model: str = "medium", + download_root: str = WHISPER_DEFAULT_PATH, + device: Optional[Union[str, device]] = None, + in_memory: bool = False, + *args, **kwargs + ) -> 'Transcriber': """ Load whisper model. @@ -158,8 +248,8 @@ class Transcriber: Transcriber: A Transcriber object initialized with the specified model. """ - _model = load_model(model, download_root=download_root, - device=device, in_memory=in_memory) + _model = whisper_load_model(model, download_root=download_root, + device=device, in_memory=in_memory) return cls(_model, model_name=model) @@ -171,17 +261,115 @@ class Transcriber: Returns: dict: Keyword arguments for whisper model. """ - _possible_kwargs = Whisper.transcribe.__code__.co_varnames - + # _possible_kwargs = WhisperModel.transcribe.__code__.co_varnames + _args = getfullargspec(Whisper.transcribe).args + _kwargs = getfullargspec(Whisper.transcribe).kwonlyargs + _possible_kwargs = _args + _kwargs + whisper_kwargs = {k: v for k, v in kwargs.items() if k in _possible_kwargs} - + if (task := kwargs.get("task")): whisper_kwargs["task"] = task - + if (language := kwargs.get("language")): - whisper_kwargs["language"] = language - + whisper_kwargs["language"] = language + + return whisper_kwargs + + +class WhisperXTranscriber(Transcriber): + def __init__(self, model: whisper, model_name: str) -> None: + super().__init__(model, model_name) + + def transcribe(self, audio: Union[str, Tensor, ndarray], + *args, **kwargs) -> str: + """ + Transcribe an audio file. + + Args: + audio (Union[str, Tensor, nparray]): The audio file to transcribe. + *args: Additional arguments. + **kwargs: Additional keyword arguments, + such as the language of the audio file. + + Returns: + str: The transcript as a string. + """ + kwargs = self._get_whisper_kwargs(**kwargs) + + if isinstance(audio, Tensor): + audio = audio.cpu().numpy() + result = self.model.transcribe(audio, *args, **kwargs) + text = "" + for seg in result['segments']: + text += seg['text'] + return text + + + @classmethod + def load_model(cls, + model: str = "medium", + download_root: str = WHISPER_DEFAULT_PATH, + device: Optional[Union[str, device]] = None, + *args, **kwargs + ) -> 'Transcriber': + """ + Load whisper model. + + Args: + model (str): Whisper model. Available models include: + - 'tiny.en' + - 'tiny' + - 'base.en' + - 'base' + - 'small.en' + - 'small' + - 'medium.en' + - 'medium' + - 'large-v1' + - 'large-v2' + - 'large-v3' + - 'large' + + download_root (str, optional): Path to download the model. + Defaults to WHISPER_DEFAULT_PATH. + + device (Optional[Union[str, torch.device]], optional): + Device to load model on. Defaults to None. + in_memory (bool, optional): Whether to load model in memory. + Defaults to False. + args: Additional arguments only to avoid errors. + kwargs: Additional keyword arguments only to avoid errors. + + Returns: + Transcriber: A Transcriber object initialized with the specified model. + """ + if not isinstance(device, str): + device = str(device) + _model = whisperx_load_model(model, download_root=download_root, + device=device) + + return cls(_model, model_name=model) + + @staticmethod + def _get_whisper_kwargs(**kwargs) -> dict: + """ + Get kwargs for whisper model. Ensure that kwargs are valid. + + Returns: + dict: Keyword arguments for whisper model. + """ + # _possible_kwargs = WhisperModel.transcribe.__code__.co_varnames + _args = getfullargspec(WhisperModel.transcribe).args + _kwargs = getfullargspec(WhisperModel.transcribe).kwonlyargs + _possible_kwargs = _args + _kwargs + + whisper_kwargs = {k: v for k, v in kwargs.items() if k in _possible_kwargs} + + if (task := kwargs.get("task")): + whisper_kwargs["task"] = task + + if (language := kwargs.get("language")): + whisper_kwargs["language"] = language + return whisper_kwargs - - def __repr__(self) -> str: - return f"Transcriber(model_name={self.model_name}, model={self.model})" \ No newline at end of file