Added WhisperX as possible whisper model.

This commit is contained in:
Marko Henning
2024-05-08 15:49:05 +02:00
parent fee9f0b468
commit 82e26771e0
3 changed files with 224 additions and 34 deletions
+1
View File
@@ -2,6 +2,7 @@ tqdm>=4.65.0
numpy>=1.26.4
openai-whisper==20231117
whisperx~=3.1.3
pyannote.audio~=3.1.1
pyannote.core~=5.0.0
+3 -2
View File
@@ -64,6 +64,7 @@ class Scraibe:
"""
def __init__(self,
whisper_model: Union[bool, str, whisper] = None,
whisper_type: str = "whisper",
dia_model : Union[bool, str, DiarisationType] = None,
**kwargs) -> None:
"""Initializes the Scraibe class.
@@ -84,9 +85,9 @@ class Scraibe:
if whisper_model is None:
self.transcriber = Transcriber.load_model("medium", **kwargs)
self.transcriber = Transcriber.load_model("medium", whisper_type, **kwargs)
elif isinstance(whisper_model, str):
self.transcriber = Transcriber.load_model(whisper_model, **kwargs)
self.transcriber = Transcriber.load_model(whisper_model, whisper_type, **kwargs)
else:
self.transcriber = whisper_model
+216 -28
View File
@@ -24,18 +24,20 @@ Usage:
>>> transcriber.save_transcript(transcript, "path/to/save.txt")
"""
from whisper import Whisper, load_model
from whisper import Whisper
from whisper import load_model as whisper_load_model
from whisperx.asr import WhisperModel
from whisperx import load_model as whisperx_load_model
from typing import TypeVar , Union , Optional
from torch import Tensor, device
from numpy import ndarray
from inspect import getfullargspec
from abc import ABC, abstractmethod
from .misc import WHISPER_DEFAULT_PATH
whisper = TypeVar('whisper')
class Transcriber:
"""
Transcriber Class
@@ -64,7 +66,7 @@ class Transcriber:
The class supports various sizes and versions of Whisper models. Please refer to
the load_model method for available options.
"""
def __init__(self, model: whisper , model_name: str ) -> None:
def __init__(self, model: whisper, model_name: str) -> None:
"""
Initialize the Transcriber class with a Whisper model.
@@ -77,7 +79,113 @@ class Transcriber:
self.model_name = model_name
def transcribe(self, audio : Union[str, Tensor, ndarray] ,
@abstractmethod
def transcribe(self, audio: Union[str, Tensor, ndarray] ,
*args, **kwargs) -> str:
"""
Transcribe an audio file.
Args:
audio (Union[str, Tensor, nparray]): The audio file to transcribe.
*args: Additional arguments.
**kwargs: Additional keyword arguments,
such as the language of the audio file.
Returns:
str: The transcript as a string.
"""
pass
@staticmethod
def save_transcript(transcript : str , save_path : str) -> None:
"""
Save a transcript to a file.
Args:
transcript (str): The transcript as a string.
save_path (str): The path to save the transcript.
Returns:
None
"""
with open(save_path, 'w') as f:
f.write(transcript)
print(f'Transcript saved to {save_path}')
@classmethod
def load_model(cls,
model: str = "medium",
whisper_type: str = 'whisper',
download_root: str = WHISPER_DEFAULT_PATH,
device: Optional[Union[str, device]] = None,
in_memory: bool = False,
*args, **kwargs
) -> 'Transcriber':
"""
Load whisper model.
Args:
model (str): Whisper model. Available models include:
- 'tiny.en'
- 'tiny'
- 'base.en'
- 'base'
- 'small.en'
- 'small'
- 'medium.en'
- 'medium'
- 'large-v1'
- 'large-v2'
- 'large-v3'
- 'large'
download_root (str, optional): Path to download the model.
Defaults to WHISPER_DEFAULT_PATH.
device (Optional[Union[str, torch.device]], optional):
Device to load model on. Defaults to None.
in_memory (bool, optional): Whether to load model in memory.
Defaults to False.
args: Additional arguments only to avoid errors.
kwargs: Additional keyword arguments only to avoid errors.
Returns:
Transcriber: A Transcriber object initialized with the specified model.
"""
if whisper_type.lower() == 'whisper':
_model = WhisperTranscriber.load_model(
model, download_root, device, in_memory, *args, **kwargs)
return _model
elif whisper_type.lower() == 'whisperx':
_model = WhisperXTranscriber.load_model(
model, download_root, device, *args, **kwargs)
return _model
else:
raise ValueError(f'Model type not recognized, exptected "whisper" '
f'or "whisperx", got {whisper_type}.')
pass
@staticmethod
def _get_whisper_kwargs(**kwargs) -> dict:
"""
Get kwargs for whisper model. Ensure that kwargs are valid.
Returns:
dict: Keyword arguments for whisper model.
"""
pass
def __repr__(self) -> str:
return f"Transcriber(model_name={self.model_name}, model={self.model})"
class WhisperTranscriber(Transcriber):
def __init__(self, model: whisper, model_name: str) -> None:
super().__init__(model, model_name)
def transcribe(self, audio: Union[str, Tensor, ndarray],
*args, **kwargs) -> str:
"""
Transcribe an audio file.
@@ -100,24 +208,6 @@ class Transcriber:
result = self.model.transcribe(audio, *args, **kwargs)
return result["text"]
@staticmethod
def save_transcript(transcript : str , save_path : str) -> None:
"""
Save a transcript to a file.
Args:
transcript (str): The transcript as a string.
save_path (str): The path to save the transcript.
Returns:
None
"""
with open(save_path, 'w') as f:
f.write(transcript)
print(f'Transcript saved to {save_path}')
@classmethod
def load_model(cls,
model: str = "medium",
@@ -158,7 +248,7 @@ class Transcriber:
Transcriber: A Transcriber object initialized with the specified model.
"""
_model = load_model(model, download_root=download_root,
_model = whisper_load_model(model, download_root=download_root,
device=device, in_memory=in_memory)
return cls(_model, model_name=model)
@@ -171,7 +261,10 @@ class Transcriber:
Returns:
dict: Keyword arguments for whisper model.
"""
_possible_kwargs = Whisper.transcribe.__code__.co_varnames
# _possible_kwargs = WhisperModel.transcribe.__code__.co_varnames
_args = getfullargspec(Whisper.transcribe).args
_kwargs = getfullargspec(Whisper.transcribe).kwonlyargs
_possible_kwargs = _args + _kwargs
whisper_kwargs = {k: v for k, v in kwargs.items() if k in _possible_kwargs}
@@ -183,5 +276,100 @@ class Transcriber:
return whisper_kwargs
def __repr__(self) -> str:
return f"Transcriber(model_name={self.model_name}, model={self.model})"
class WhisperXTranscriber(Transcriber):
def __init__(self, model: whisper, model_name: str) -> None:
super().__init__(model, model_name)
def transcribe(self, audio: Union[str, Tensor, ndarray],
*args, **kwargs) -> str:
"""
Transcribe an audio file.
Args:
audio (Union[str, Tensor, nparray]): The audio file to transcribe.
*args: Additional arguments.
**kwargs: Additional keyword arguments,
such as the language of the audio file.
Returns:
str: The transcript as a string.
"""
kwargs = self._get_whisper_kwargs(**kwargs)
if isinstance(audio, Tensor):
audio = audio.cpu().numpy()
result = self.model.transcribe(audio, *args, **kwargs)
text = ""
for seg in result['segments']:
text += seg['text']
return text
@classmethod
def load_model(cls,
model: str = "medium",
download_root: str = WHISPER_DEFAULT_PATH,
device: Optional[Union[str, device]] = None,
*args, **kwargs
) -> 'Transcriber':
"""
Load whisper model.
Args:
model (str): Whisper model. Available models include:
- 'tiny.en'
- 'tiny'
- 'base.en'
- 'base'
- 'small.en'
- 'small'
- 'medium.en'
- 'medium'
- 'large-v1'
- 'large-v2'
- 'large-v3'
- 'large'
download_root (str, optional): Path to download the model.
Defaults to WHISPER_DEFAULT_PATH.
device (Optional[Union[str, torch.device]], optional):
Device to load model on. Defaults to None.
in_memory (bool, optional): Whether to load model in memory.
Defaults to False.
args: Additional arguments only to avoid errors.
kwargs: Additional keyword arguments only to avoid errors.
Returns:
Transcriber: A Transcriber object initialized with the specified model.
"""
if not isinstance(device, str):
device = str(device)
_model = whisperx_load_model(model, download_root=download_root,
device=device)
return cls(_model, model_name=model)
@staticmethod
def _get_whisper_kwargs(**kwargs) -> dict:
"""
Get kwargs for whisper model. Ensure that kwargs are valid.
Returns:
dict: Keyword arguments for whisper model.
"""
# _possible_kwargs = WhisperModel.transcribe.__code__.co_varnames
_args = getfullargspec(WhisperModel.transcribe).args
_kwargs = getfullargspec(WhisperModel.transcribe).kwonlyargs
_possible_kwargs = _args + _kwargs
whisper_kwargs = {k: v for k, v in kwargs.items() if k in _possible_kwargs}
if (task := kwargs.get("task")):
whisper_kwargs["task"] = task
if (language := kwargs.get("language")):
whisper_kwargs["language"] = language
return whisper_kwargs