Added WhisperX as possible whisper model.

This commit is contained in:
Marko Henning
2024-05-08 15:49:05 +02:00
parent fee9f0b468
commit 82e26771e0
3 changed files with 224 additions and 34 deletions
+1
View File
@@ -2,6 +2,7 @@ tqdm>=4.65.0
numpy>=1.26.4 numpy>=1.26.4
openai-whisper==20231117 openai-whisper==20231117
whisperx~=3.1.3
pyannote.audio~=3.1.1 pyannote.audio~=3.1.1
pyannote.core~=5.0.0 pyannote.core~=5.0.0
+3 -2
View File
@@ -64,6 +64,7 @@ class Scraibe:
""" """
def __init__(self, def __init__(self,
whisper_model: Union[bool, str, whisper] = None, whisper_model: Union[bool, str, whisper] = None,
whisper_type: str = "whisper",
dia_model : Union[bool, str, DiarisationType] = None, dia_model : Union[bool, str, DiarisationType] = None,
**kwargs) -> None: **kwargs) -> None:
"""Initializes the Scraibe class. """Initializes the Scraibe class.
@@ -84,9 +85,9 @@ class Scraibe:
if whisper_model is None: if whisper_model is None:
self.transcriber = Transcriber.load_model("medium", **kwargs) self.transcriber = Transcriber.load_model("medium", whisper_type, **kwargs)
elif isinstance(whisper_model, str): elif isinstance(whisper_model, str):
self.transcriber = Transcriber.load_model(whisper_model, **kwargs) self.transcriber = Transcriber.load_model(whisper_model, whisper_type, **kwargs)
else: else:
self.transcriber = whisper_model self.transcriber = whisper_model
+220 -32
View File
@@ -24,16 +24,18 @@ Usage:
>>> transcriber.save_transcript(transcript, "path/to/save.txt") >>> transcriber.save_transcript(transcript, "path/to/save.txt")
""" """
from whisper import Whisper, load_model from whisper import Whisper
from whisper import load_model as whisper_load_model
from whisperx.asr import WhisperModel
from whisperx import load_model as whisperx_load_model
from typing import TypeVar , Union , Optional from typing import TypeVar , Union , Optional
from torch import Tensor, device from torch import Tensor, device
from numpy import ndarray from numpy import ndarray
from inspect import getfullargspec
from abc import ABC, abstractmethod
from .misc import WHISPER_DEFAULT_PATH from .misc import WHISPER_DEFAULT_PATH
whisper = TypeVar('whisper') whisper = TypeVar('whisper')
class Transcriber: class Transcriber:
@@ -64,7 +66,7 @@ class Transcriber:
The class supports various sizes and versions of Whisper models. Please refer to The class supports various sizes and versions of Whisper models. Please refer to
the load_model method for available options. the load_model method for available options.
""" """
def __init__(self, model: whisper , model_name: str ) -> None: def __init__(self, model: whisper, model_name: str) -> None:
""" """
Initialize the Transcriber class with a Whisper model. Initialize the Transcriber class with a Whisper model.
@@ -77,7 +79,8 @@ class Transcriber:
self.model_name = model_name self.model_name = model_name
def transcribe(self, audio : Union[str, Tensor, ndarray] , @abstractmethod
def transcribe(self, audio: Union[str, Tensor, ndarray] ,
*args, **kwargs) -> str: *args, **kwargs) -> str:
""" """
Transcribe an audio file. Transcribe an audio file.
@@ -91,14 +94,7 @@ class Transcriber:
Returns: Returns:
str: The transcript as a string. str: The transcript as a string.
""" """
pass
kwargs = self._get_whisper_kwargs(**kwargs)
if not kwargs.get("verbose"):
kwargs["verbose"] = None
result = self.model.transcribe(audio, *args, **kwargs)
return result["text"]
@staticmethod @staticmethod
def save_transcript(transcript : str , save_path : str) -> None: def save_transcript(transcript : str , save_path : str) -> None:
@@ -120,12 +116,106 @@ class Transcriber:
@classmethod @classmethod
def load_model(cls, def load_model(cls,
model: str = "medium", model: str = "medium",
download_root: str = WHISPER_DEFAULT_PATH, whisper_type: str = 'whisper',
device: Optional[Union[str, device]] = None, download_root: str = WHISPER_DEFAULT_PATH,
in_memory: bool = False, device: Optional[Union[str, device]] = None,
*args, **kwargs in_memory: bool = False,
) -> 'Transcriber': *args, **kwargs
) -> 'Transcriber':
"""
Load whisper model.
Args:
model (str): Whisper model. Available models include:
- 'tiny.en'
- 'tiny'
- 'base.en'
- 'base'
- 'small.en'
- 'small'
- 'medium.en'
- 'medium'
- 'large-v1'
- 'large-v2'
- 'large-v3'
- 'large'
download_root (str, optional): Path to download the model.
Defaults to WHISPER_DEFAULT_PATH.
device (Optional[Union[str, torch.device]], optional):
Device to load model on. Defaults to None.
in_memory (bool, optional): Whether to load model in memory.
Defaults to False.
args: Additional arguments only to avoid errors.
kwargs: Additional keyword arguments only to avoid errors.
Returns:
Transcriber: A Transcriber object initialized with the specified model.
"""
if whisper_type.lower() == 'whisper':
_model = WhisperTranscriber.load_model(
model, download_root, device, in_memory, *args, **kwargs)
return _model
elif whisper_type.lower() == 'whisperx':
_model = WhisperXTranscriber.load_model(
model, download_root, device, *args, **kwargs)
return _model
else:
raise ValueError(f'Model type not recognized, exptected "whisper" '
f'or "whisperx", got {whisper_type}.')
pass
@staticmethod
def _get_whisper_kwargs(**kwargs) -> dict:
"""
Get kwargs for whisper model. Ensure that kwargs are valid.
Returns:
dict: Keyword arguments for whisper model.
"""
pass
def __repr__(self) -> str:
return f"Transcriber(model_name={self.model_name}, model={self.model})"
class WhisperTranscriber(Transcriber):
def __init__(self, model: whisper, model_name: str) -> None:
super().__init__(model, model_name)
def transcribe(self, audio: Union[str, Tensor, ndarray],
*args, **kwargs) -> str:
"""
Transcribe an audio file.
Args:
audio (Union[str, Tensor, nparray]): The audio file to transcribe.
*args: Additional arguments.
**kwargs: Additional keyword arguments,
such as the language of the audio file.
Returns:
str: The transcript as a string.
"""
kwargs = self._get_whisper_kwargs(**kwargs)
if not kwargs.get("verbose"):
kwargs["verbose"] = None
result = self.model.transcribe(audio, *args, **kwargs)
return result["text"]
@classmethod
def load_model(cls,
model: str = "medium",
download_root: str = WHISPER_DEFAULT_PATH,
device: Optional[Union[str, device]] = None,
in_memory: bool = False,
*args, **kwargs
) -> 'Transcriber':
""" """
Load whisper model. Load whisper model.
@@ -158,8 +248,8 @@ class Transcriber:
Transcriber: A Transcriber object initialized with the specified model. Transcriber: A Transcriber object initialized with the specified model.
""" """
_model = load_model(model, download_root=download_root, _model = whisper_load_model(model, download_root=download_root,
device=device, in_memory=in_memory) device=device, in_memory=in_memory)
return cls(_model, model_name=model) return cls(_model, model_name=model)
@@ -171,17 +261,115 @@ class Transcriber:
Returns: Returns:
dict: Keyword arguments for whisper model. dict: Keyword arguments for whisper model.
""" """
_possible_kwargs = Whisper.transcribe.__code__.co_varnames # _possible_kwargs = WhisperModel.transcribe.__code__.co_varnames
_args = getfullargspec(Whisper.transcribe).args
_kwargs = getfullargspec(Whisper.transcribe).kwonlyargs
_possible_kwargs = _args + _kwargs
whisper_kwargs = {k: v for k, v in kwargs.items() if k in _possible_kwargs} whisper_kwargs = {k: v for k, v in kwargs.items() if k in _possible_kwargs}
if (task := kwargs.get("task")): if (task := kwargs.get("task")):
whisper_kwargs["task"] = task whisper_kwargs["task"] = task
if (language := kwargs.get("language")): if (language := kwargs.get("language")):
whisper_kwargs["language"] = language whisper_kwargs["language"] = language
return whisper_kwargs
class WhisperXTranscriber(Transcriber):
def __init__(self, model: whisper, model_name: str) -> None:
super().__init__(model, model_name)
def transcribe(self, audio: Union[str, Tensor, ndarray],
*args, **kwargs) -> str:
"""
Transcribe an audio file.
Args:
audio (Union[str, Tensor, nparray]): The audio file to transcribe.
*args: Additional arguments.
**kwargs: Additional keyword arguments,
such as the language of the audio file.
Returns:
str: The transcript as a string.
"""
kwargs = self._get_whisper_kwargs(**kwargs)
if isinstance(audio, Tensor):
audio = audio.cpu().numpy()
result = self.model.transcribe(audio, *args, **kwargs)
text = ""
for seg in result['segments']:
text += seg['text']
return text
@classmethod
def load_model(cls,
model: str = "medium",
download_root: str = WHISPER_DEFAULT_PATH,
device: Optional[Union[str, device]] = None,
*args, **kwargs
) -> 'Transcriber':
"""
Load whisper model.
Args:
model (str): Whisper model. Available models include:
- 'tiny.en'
- 'tiny'
- 'base.en'
- 'base'
- 'small.en'
- 'small'
- 'medium.en'
- 'medium'
- 'large-v1'
- 'large-v2'
- 'large-v3'
- 'large'
download_root (str, optional): Path to download the model.
Defaults to WHISPER_DEFAULT_PATH.
device (Optional[Union[str, torch.device]], optional):
Device to load model on. Defaults to None.
in_memory (bool, optional): Whether to load model in memory.
Defaults to False.
args: Additional arguments only to avoid errors.
kwargs: Additional keyword arguments only to avoid errors.
Returns:
Transcriber: A Transcriber object initialized with the specified model.
"""
if not isinstance(device, str):
device = str(device)
_model = whisperx_load_model(model, download_root=download_root,
device=device)
return cls(_model, model_name=model)
@staticmethod
def _get_whisper_kwargs(**kwargs) -> dict:
"""
Get kwargs for whisper model. Ensure that kwargs are valid.
Returns:
dict: Keyword arguments for whisper model.
"""
# _possible_kwargs = WhisperModel.transcribe.__code__.co_varnames
_args = getfullargspec(WhisperModel.transcribe).args
_kwargs = getfullargspec(WhisperModel.transcribe).kwonlyargs
_possible_kwargs = _args + _kwargs
whisper_kwargs = {k: v for k, v in kwargs.items() if k in _possible_kwargs}
if (task := kwargs.get("task")):
whisper_kwargs["task"] = task
if (language := kwargs.get("language")):
whisper_kwargs["language"] = language
return whisper_kwargs return whisper_kwargs
def __repr__(self) -> str:
return f"Transcriber(model_name={self.model_name}, model={self.model})"