Merge branch 'develop' into pyproject.toml

This commit is contained in:
Schmieder, Jacob
2024-05-21 11:05:55 +00:00
15 changed files with 688 additions and 441 deletions
+1
View File
@@ -2,6 +2,7 @@ tqdm>=4.65.0
numpy>=1.26.4 numpy>=1.26.4
openai-whisper==20231117 openai-whisper==20231117
whisperx~=3.1.3
pyannote.audio~=3.1.1 pyannote.audio~=3.1.1
pyannote.core~=5.0.0 pyannote.core~=5.0.0
-1
View File
@@ -9,4 +9,3 @@ from .misc import *
from .cli import * from .cli import *
from ._version import __version__ from ._version import __version__
+8 -6
View File
@@ -28,6 +28,7 @@ import torch
SAMPLE_RATE = 16000 SAMPLE_RATE = 16000
NORMALIZATION_FACTOR = 32768.0 NORMALIZATION_FACTOR = 32768.0
class AudioProcessor: class AudioProcessor:
""" """
Audio Processor class that leverages PyTorchaudio to provide functionalities Audio Processor class that leverages PyTorchaudio to provide functionalities
@@ -42,7 +43,6 @@ class AudioProcessor:
def __init__(self, waveform: torch.Tensor, sr: int = SAMPLE_RATE, def __init__(self, waveform: torch.Tensor, sr: int = SAMPLE_RATE,
*args, **kwargs) -> None: *args, **kwargs) -> None:
""" """
Initialize the AudioProcessor object. Initialize the AudioProcessor object.
@@ -57,13 +57,14 @@ class AudioProcessor:
ValueError: If the provided sample rate is not of type int. ValueError: If the provided sample rate is not of type int.
""" """
device = kwargs.get("device", "cuda" if torch.cuda.is_available() else "cpu") device = kwargs.get(
"device", "cuda" if torch.cuda.is_available() else "cpu")
self.waveform = waveform.to(device) self.waveform = waveform.to(device)
self.sr = sr self.sr = sr
if not isinstance(self.sr, int): if not isinstance(self.sr, int):
raise ValueError("Sample rate should be a single value of type int," \ raise ValueError("Sample rate should be a single value of type int,"
f"not {len(self.sr)} and type {type(self.sr)}") f"not {len(self.sr)} and type {type(self.sr)}")
@classmethod @classmethod
@@ -84,7 +85,6 @@ class AudioProcessor:
return cls(audio, sr) return cls(audio, sr)
def cut(self, start: float, end: float) -> torch.Tensor: def cut(self, start: float, end: float) -> torch.Tensor:
""" """
Cut a segment from the audio waveform between the specified start and end times. Cut a segment from the audio waveform between the specified start and end times.
@@ -140,9 +140,11 @@ class AudioProcessor:
try: try:
out = run(cmd, capture_output=True, check=True).stdout out = run(cmd, capture_output=True, check=True).stdout
except CalledProcessError as e: except CalledProcessError as e:
raise RuntimeError(f"Failed to load audio: {e.stderr.decode()}") from e raise RuntimeError(
f"Failed to load audio: {e.stderr.decode()}") from e
out = np.frombuffer(out, np.int16).flatten().astype(np.float32) / NORMALIZATION_FACTOR out = np.frombuffer(out, np.int16).flatten().astype(
np.float32) / NORMALIZATION_FACTOR
return out, sr return out, sr
+18 -12
View File
@@ -38,7 +38,7 @@ from tqdm import trange
# Application-Specific Imports # Application-Specific Imports
from .audio import AudioProcessor from .audio import AudioProcessor
from .diarisation import Diariser from .diarisation import Diariser
from .transcriber import Transcriber, whisper from .transcriber import Transcriber, load_transcriber, whisper
from .transcript_exporter import Transcript from .transcript_exporter import Transcript
@@ -62,8 +62,10 @@ class Scraibe:
remove_audio_file: Removes the original audio file to avoid disk space issues or ensure data privacy. remove_audio_file: Removes the original audio file to avoid disk space issues or ensure data privacy.
get_audio_file: Gets an audio file as an AudioProcessor object. get_audio_file: Gets an audio file as an AudioProcessor object.
""" """
def __init__(self, def __init__(self,
whisper_model: Union[bool, str, whisper] = None, whisper_model: Union[bool, str, whisper] = None,
whisper_type: str = "whisper",
dia_model: Union[bool, str, DiarisationType] = None, dia_model: Union[bool, str, DiarisationType] = None,
**kwargs) -> None: **kwargs) -> None:
"""Initializes the Scraibe class. """Initializes the Scraibe class.
@@ -71,6 +73,8 @@ class Scraibe:
Args: Args:
whisper_model (Union[bool, str, whisper], optional): whisper_model (Union[bool, str, whisper], optional):
Path to whisper model or whisper model itself. Path to whisper model or whisper model itself.
whisper_type (str):
Type of whisper model to load. "whisper" or "whisperx".
diarisation_model (Union[bool, str, DiarisationType], optional): diarisation_model (Union[bool, str, DiarisationType], optional):
Path to pyannote diarization model or model itself. Path to pyannote diarization model or model itself.
**kwargs: Additional keyword arguments for whisper **kwargs: Additional keyword arguments for whisper
@@ -82,11 +86,12 @@ class Scraibe:
for autotranscribe. So you can unload the class and reload it again. for autotranscribe. So you can unload the class and reload it again.
""" """
if whisper_model is None: if whisper_model is None:
self.transcriber = Transcriber.load_model("medium", **kwargs) self.transcriber = load_transcriber(
"medium", whisper_type, **kwargs)
elif isinstance(whisper_model, str): elif isinstance(whisper_model, str):
self.transcriber = Transcriber.load_model(whisper_model, **kwargs) self.transcriber = load_transcriber(
whisper_model, whisper_type, **kwargs)
else: else:
self.transcriber = whisper_model self.transcriber = whisper_model
@@ -111,7 +116,6 @@ class Scraibe:
else: else:
self.params = {} self.params = {}
def autotranscribe(self, audio_file: Union[str, torch.Tensor, ndarray], def autotranscribe(self, audio_file: Union[str, torch.Tensor, ndarray],
remove_original: bool = False, remove_original: bool = False,
**kwargs) -> Transcript: **kwargs) -> Transcript:
@@ -149,7 +153,8 @@ class Scraibe:
if not diarisation["segments"]: if not diarisation["segments"]:
print("No segments found. Try to run transcription without diarisation.") print("No segments found. Try to run transcription without diarisation.")
transcript = self.transcriber.transcribe(audio_file.waveform, **kwargs) transcript = self.transcriber.transcribe(
audio_file.waveform, **kwargs)
final_transcript = {0: {"speakers": 'SPEAKER_01', final_transcript = {0: {"speakers": 'SPEAKER_01',
"segments": [0, len(audio_file.waveform)], "segments": [0, len(audio_file.waveform)],
@@ -160,7 +165,8 @@ class Scraibe:
if self.verbose: if self.verbose:
print("Diarisation finished. Starting transcription.") print("Diarisation finished. Starting transcription.")
audio_file.sr = torch.Tensor([audio_file.sr]).to(audio_file.waveform.device) audio_file.sr = torch.Tensor([audio_file.sr]).to(
audio_file.waveform.device)
# Transcribe each segment and store the results # Transcribe each segment and store the results
final_transcript = dict() final_transcript = dict()
@@ -252,11 +258,12 @@ class Scraibe:
_old_model = self.transcriber.model_name _old_model = self.transcriber.model_name
if isinstance(whisper_model, str): if isinstance(whisper_model, str):
self.transcriber = Transcriber.load_model(whisper_model, **kwargs) self.transcriber = load_transcriber(whisper_model, **kwargs)
elif isinstance(whisper_model, Transcriber): elif isinstance(whisper_model, Transcriber):
self.transcriber = whisper_model self.transcriber = whisper_model
else: else:
warn(f"Invalid model type. Please provide a valid model. Fallback to old {_old_model} Model.", RuntimeWarning) warn(
f"Invalid model type. Please provide a valid model. Fallback to old {_old_model} Model.", RuntimeWarning)
return None return None
@@ -278,7 +285,7 @@ class Scraibe:
elif isinstance(dia_model, Diariser): elif isinstance(dia_model, Diariser):
self.diariser = dia_model self.diariser = dia_model
else: else:
warn(f"Invalid model type. Please provide a valid model. Fallback to old Model.", RuntimeWarning) warn("Invalid model type. Please provide a valid model. Fallback to old Model.", RuntimeWarning)
return None return None
@@ -315,7 +322,6 @@ class Scraibe:
os.remove(audio_file) os.remove(audio_file)
print(f"Audiofile {audio_file} removed.") print(f"Audiofile {audio_file} removed.")
@staticmethod @staticmethod
def get_audio_file(audio_file: Union[str, torch.Tensor, ndarray], def get_audio_file(audio_file: Union[str, torch.Tensor, ndarray],
*args, **kwargs) -> AudioProcessor: *args, **kwargs) -> AudioProcessor:
@@ -342,7 +348,7 @@ class Scraibe:
audio_file[1]) audio_file[1])
if not isinstance(audio_file, AudioProcessor): if not isinstance(audio_file, AudioProcessor):
raise ValueError(f'Audiofile must be of type AudioProcessor,' \ raise ValueError(f'Audiofile must be of type AudioProcessor,'
f'not {type(audio_file)}') f'not {type(audio_file)}')
return audio_file return audio_file
+20 -12
View File
@@ -32,7 +32,8 @@ def cli():
if string in str2val: if string in str2val:
return str2val[string] return str2val[string]
else: else:
raise ValueError(f"Expected one of {set(str2val.keys())}, got {string}") raise ValueError(
f"Expected one of {set(str2val.keys())}, got {string}")
parser = ArgumentParser(formatter_class=ArgumentDefaultsHelpFormatter) parser = ArgumentParser(formatter_class=ArgumentDefaultsHelpFormatter)
@@ -42,8 +43,8 @@ def cli():
help="List of audio files to transcribe.") help="List of audio files to transcribe.")
group.add_argument('--start-server', action='store_true', group.add_argument('--start-server', action='store_true',
help='Start the Gradio app.' \ help='Start the Gradio app.'
'If set, all other arguments are ignored' \ 'If set, all other arguments are ignored'
'besides --server-config or --server-kwargs.') 'besides --server-config or --server-kwargs.')
parser.add_argument("--server-config", type=str, default=None, parser.add_argument("--server-config", type=str, default=None,
@@ -89,7 +90,8 @@ def cli():
If set to translate, the output will be translated to English.") If set to translate, the output will be translated to English.")
parser.add_argument("--language", type=str, default=None, parser.add_argument("--language", type=str, default=None,
choices=sorted(LANGUAGES.keys()) + sorted([k.title() for k in TO_LANGUAGE_CODE.keys()]), choices=sorted(
LANGUAGES.keys()) + sorted([k.title() for k in TO_LANGUAGE_CODE.keys()]),
help="Language spoken in the audio. Specify None to perform language detection.") help="Language spoken in the audio. Specify None to perform language detection.")
args = parser.parse_args() args = parser.parse_args()
@@ -131,15 +133,17 @@ def cli():
else: else:
task = "transcribe" task = "transcribe"
out = model.autotranscribe(audio,task = task, language=arg_dict.pop("language"), verbose = arg_dict.pop("verbose_output")) out = model.autotranscribe(audio, task=task, language=arg_dict.pop(
"language"), verbose=arg_dict.pop("verbose_output"))
basename = audio.split("/")[-1].split(".")[0] basename = audio.split("/")[-1].split(".")[0]
print(f'Saving {basename}.{out_format} to {out_folder}') print(f'Saving {basename}.{out_format} to {out_folder}')
out.save(os.path.join(out_folder, f"{basename}.{out_format}")) out.save(os.path.join(
out_folder, f"{basename}.{out_format}"))
elif task == "diarization": elif task == "diarization":
for audio in audio_files: for audio in audio_files:
if arg_dict.pop("verbose_output"): if arg_dict.pop("verbose_output"):
print(f"Verbose not implemented for diarization.") print("Verbose not implemented for diarization.")
out = model.diarization(audio) out = model.diarization(audio)
basename = audio.split("/")[-1].split(".")[0] basename = audio.split("/")[-1].split(".")[0]
@@ -162,25 +166,29 @@ def cli():
with open(path, "w") as f: with open(path, "w") as f:
f.write(out) f.write(out)
else: # unfinished code else: # unfinished code
raise NotImplementedError("Currently not Working") raise NotImplementedError("Currently not Working")
import subprocess import subprocess
import sys import sys
execute_path = os.path.join(os.path.dirname(__file__), "app/app_starter.py") execute_path = os.path.join(
os.path.dirname(__file__), "app/app_starter.py")
config = arg_dict.pop("server_config") config = arg_dict.pop("server_config")
server_kwargs = arg_dict.pop("server_kwargs") server_kwargs = arg_dict.pop("server_kwargs")
if not config: if not config:
subprocess.run([sys.executable, execute_path, f"--server-kwargs={server_kwargs}"]) subprocess.run([sys.executable, execute_path,
f"--server-kwargs={server_kwargs}"])
elif not server_kwargs: elif not server_kwargs:
subprocess.run([sys.executable, execute_path, f"--server-config={config}"]) subprocess.run([sys.executable, execute_path,
f"--server-config={config}"])
elif not config and not server_kwargs: elif not config and not server_kwargs:
subprocess.run([sys.executable, execute_path]) subprocess.run([sys.executable, execute_path])
else: else:
subprocess.run([sys.executable, execute_path, f"--server-config={config}", f"--server-kwargs={server_kwargs}"]) subprocess.run([sys.executable, execute_path,
f"--server-config={config}", f"--server-kwargs={server_kwargs}"])
if __name__ == "__main__": if __name__ == "__main__":
cli() cli()
+18 -14
View File
@@ -37,7 +37,7 @@ from pyannote.audio import Pipeline
from pyannote.audio.pipelines.speaker_diarization import SpeakerDiarization from pyannote.audio.pipelines.speaker_diarization import SpeakerDiarization
from torch import Tensor from torch import Tensor
from torch import device as torch_device from torch import device as torch_device
from torch.cuda import is_available, current_device from torch.cuda import is_available
from huggingface_hub import HfApi from huggingface_hub import HfApi
from huggingface_hub.utils import RepositoryNotFoundError from huggingface_hub.utils import RepositoryNotFoundError
@@ -47,6 +47,7 @@ Annotation = TypeVar('Annotation')
TOKEN_PATH = os.path.join(os.path.dirname( TOKEN_PATH = os.path.join(os.path.dirname(
os.path.realpath(__file__)), '.pyannotetoken') os.path.realpath(__file__)), '.pyannotetoken')
class Diariser: class Diariser:
""" """
Handles the diarization process of an audio file using a pretrained model Handles the diarization process of an audio file using a pretrained model
@@ -132,7 +133,6 @@ class Diariser:
index_start_speaker = i index_start_speaker = i
current_speaker = speaker current_speaker = speaker
if i == len(dia_list) - 1: if i == len(dia_list) - 1:
index_end_speaker = i index_end_speaker = i
@@ -166,8 +166,8 @@ class Diariser:
with open(TOKEN_PATH, 'r', encoding="utf-8") as file: with open(TOKEN_PATH, 'r', encoding="utf-8") as file:
token = file.read() token = file.read()
else: else:
raise ValueError('No token found.' \ raise ValueError('No token found.'
'Please create a token at https://huggingface.co/settings/token' \ 'Please create a token at https://huggingface.co/settings/token'
f'and save it in a file called {TOKEN_PATH}') f'and save it in a file called {TOKEN_PATH}')
return token return token
@@ -193,7 +193,6 @@ class Diariser:
device: str = None, device: str = None,
*args, **kwargs *args, **kwargs
) -> Pipeline: ) -> Pipeline:
""" """
Loads a pretrained model from pyannote.audio, Loads a pretrained model from pyannote.audio,
either from a local cache or some online repository. either from a local cache or some online repository.
@@ -237,16 +236,18 @@ class Diariser:
'deprecated and will be removed in future versions.', 'deprecated and will be removed in future versions.',
category=DeprecationWarning) category=DeprecationWarning)
# list elementes with the ending .bin # list elementes with the ending .bin
bin_files = [f for f in os.listdir(pwd) if f.endswith(".bin")] bin_files = [f for f in os.listdir(
pwd) if f.endswith(".bin")]
if len(bin_files) == 1: if len(bin_files) == 1:
path_to_model = os.path.join(pwd, bin_files[0]) path_to_model = os.path.join(pwd, bin_files[0])
else: else:
warnings.warn("Found more than one .bin file. "\ warnings.warn("Found more than one .bin file. "
"or none. Please specify the path to the model " \ "or none. Please specify the path to the model "
"or setup a huggingface token.") "or setup a huggingface token.")
raise FileNotFoundError raise FileNotFoundError
warnings.warn(f"Found model at {path_to_model} overwriting config file.") warnings.warn(
f"Found model at {path_to_model} overwriting config file.")
config['pipeline']['params']['segmentation'] = path_to_model config['pipeline']['params']['segmentation'] = path_to_model
@@ -270,22 +271,24 @@ class Diariser:
if use_auth_token is None: if use_auth_token is None:
use_auth_token = cls._get_token() use_auth_token = cls._get_token()
else: else:
raise FileNotFoundError(f'No local model or directory found at {model}.') raise FileNotFoundError(
f'No local model or directory found at {model}.')
_model = Pipeline.from_pretrained(model, _model = Pipeline.from_pretrained(model,
use_auth_token=use_auth_token, use_auth_token=use_auth_token,
cache_dir=cache_dir, cache_dir=cache_dir,
hparams_file=hparams_file,) hparams_file=hparams_file,)
if _model is None: if _model is None:
raise ValueError('Unable to load model either from local cache' \ raise ValueError('Unable to load model either from local cache'
'or from huggingface.co models. Please check your token' \ 'or from huggingface.co models. Please check your token'
'or your local model path') 'or your local model path')
# try to move the model to the device # try to move the model to the device
if device is None: if device is None:
device = "cuda" if is_available() else "cpu" device = "cuda" if is_available() else "cpu"
_model = _model.to(torch_device(device)) # torch_device is renamed from torch.device to avoid name conflict # torch_device is renamed from torch.device to avoid name conflict
_model = _model.to(torch_device(device))
return cls(_model) return cls(_model)
@@ -302,7 +305,8 @@ class Diariser:
""" """
_possible_kwargs = SpeakerDiarization.apply.__code__.co_varnames _possible_kwargs = SpeakerDiarization.apply.__code__.co_varnames
diarisation_kwargs = {k: v for k, v in kwargs.items() if k in _possible_kwargs} diarisation_kwargs = {k: v for k,
v in kwargs.items() if k in _possible_kwargs}
return diarisation_kwargs return diarisation_kwargs
+9 -3
View File
@@ -2,6 +2,7 @@ import os
import yaml import yaml
from pyannote.audio.core.model import CACHE_DIR as PYANNOTE_CACHE_DIR from pyannote.audio.core.model import CACHE_DIR as PYANNOTE_CACHE_DIR
from argparse import Action from argparse import Action
from ast import literal_eval
CACHE_DIR = os.getenv( CACHE_DIR = os.getenv(
"AUTOT_CACHE", "AUTOT_CACHE",
@@ -17,6 +18,7 @@ PYANNOTE_DEFAULT_CONFIG = os.path.join(PYANNOTE_DEFAULT_PATH, "config.yaml") \
if os.path.exists(os.path.join(PYANNOTE_DEFAULT_PATH, "config.yaml")) \ if os.path.exists(os.path.join(PYANNOTE_DEFAULT_PATH, "config.yaml")) \
else ('jaikinator/scraibe', 'pyannote/speaker-diarization-3.1') else ('jaikinator/scraibe', 'pyannote/speaker-diarization-3.1')
def config_diarization_yaml(file_path: str, path_to_segmentation: str = None) -> None: def config_diarization_yaml(file_path: str, path_to_segmentation: str = None) -> None:
"""Configure diarization pipeline from a YAML file. """Configure diarization pipeline from a YAML file.
@@ -33,25 +35,29 @@ def config_diarization_yaml(file_path: str, path_to_segmentation: str = None) ->
with open(file_path, "r") as stream: with open(file_path, "r") as stream:
yml = yaml.safe_load(stream) yml = yaml.safe_load(stream)
segmentation_path = path_to_segmentation or os.path.join(PYANNOTE_DEFAULT_PATH, "pytorch_model.bin") segmentation_path = path_to_segmentation or os.path.join(
PYANNOTE_DEFAULT_PATH, "pytorch_model.bin")
yml["pipeline"]["params"]["segmentation"] = segmentation_path yml["pipeline"]["params"]["segmentation"] = segmentation_path
if not os.path.exists(segmentation_path): if not os.path.exists(segmentation_path):
raise FileNotFoundError(f"Segmentation model not found at {segmentation_path}") raise FileNotFoundError(
f"Segmentation model not found at {segmentation_path}")
with open(file_path, "w") as stream: with open(file_path, "w") as stream:
yaml.dump(yml, stream) yaml.dump(yml, stream)
class ParseKwargs(Action): class ParseKwargs(Action):
""" """
Custom argparse action to parse keyword arguments. Custom argparse action to parse keyword arguments.
""" """
def __call__(self, parser, namespace, values, option_string=None): def __call__(self, parser, namespace, values, option_string=None):
setattr(namespace, self.dest, dict()) setattr(namespace, self.dest, dict())
for value in values: for value in values:
key, value = value.split('=') key, value = value.split('=')
try: try:
value = eval(value) value = literal_eval(value)
except: except:
pass pass
getattr(namespace, self.dest)[key] = value getattr(namespace, self.dest)[key] = value
+270 -27
View File
@@ -24,18 +24,22 @@ Usage:
>>> transcriber.save_transcript(transcript, "path/to/save.txt") >>> transcriber.save_transcript(transcript, "path/to/save.txt")
""" """
from whisper import Whisper, load_model from whisper import Whisper
from whisper import load_model as whisper_load_model
from whisperx.asr import WhisperModel
from whisperx import load_model as whisperx_load_model
from typing import TypeVar, Union, Optional from typing import TypeVar, Union, Optional
from torch import Tensor, device from torch import Tensor, device
from torch.cuda import is_available as cuda_is_available
from numpy import ndarray from numpy import ndarray
from inspect import signature
from abc import abstractmethod
import warnings
from .misc import WHISPER_DEFAULT_PATH from .misc import WHISPER_DEFAULT_PATH
whisper = TypeVar('whisper') whisper = TypeVar('whisper')
class Transcriber: class Transcriber:
""" """
Transcriber Class Transcriber Class
@@ -64,6 +68,7 @@ class Transcriber:
The class supports various sizes and versions of Whisper models. Please refer to The class supports various sizes and versions of Whisper models. Please refer to
the load_model method for available options. the load_model method for available options.
""" """
def __init__(self, model: whisper, model_name: str) -> None: def __init__(self, model: whisper, model_name: str) -> None:
""" """
Initialize the Transcriber class with a Whisper model. Initialize the Transcriber class with a Whisper model.
@@ -77,6 +82,102 @@ class Transcriber:
self.model_name = model_name self.model_name = model_name
@abstractmethod
def transcribe(self, audio: Union[str, Tensor, ndarray],
*args, **kwargs) -> str:
"""
Transcribe an audio file.
Args:
audio (Union[str, Tensor, nparray]): The audio file to transcribe.
*args: Additional arguments.
**kwargs: Additional keyword arguments,
such as the language of the audio file.
Returns:
str: The transcript as a string.
"""
pass
@staticmethod
def save_transcript(transcript: str, save_path: str) -> None:
"""
Save a transcript to a file.
Args:
transcript (str): The transcript as a string.
save_path (str): The path to save the transcript.
Returns:
None
"""
with open(save_path, 'w') as f:
f.write(transcript)
print(f'Transcript saved to {save_path}')
@classmethod
@abstractmethod
def load_model(cls,
model: str = "medium",
whisper_type: str = 'whisper',
download_root: str = WHISPER_DEFAULT_PATH,
device: Optional[Union[str, device]] = None,
in_memory: bool = False,
*args, **kwargs
) -> None:
"""
Load whisper model.
Args:
model (str): Whisper model. Available models include:
- 'tiny.en'
- 'tiny'
- 'base.en'
- 'base'
- 'small.en'
- 'small'
- 'medium.en'
- 'medium'
- 'large-v1'
- 'large-v2'
- 'large-v3'
- 'large'
whisper_type (str):
Type of whisper model to load. "whisper" or "whisperx".
download_root (str, optional): Path to download the model.
Defaults to WHISPER_DEFAULT_PATH.
device (Optional[Union[str, torch.device]], optional):
Device to load model on. Defaults to None.
in_memory (bool, optional): Whether to load model in memory.
Defaults to False.
args: Additional arguments only to avoid errors.
kwargs: Additional keyword arguments only to avoid errors.
Returns:
None: abscract method.
"""
pass
@staticmethod
def _get_whisper_kwargs(**kwargs) -> dict:
"""
Get kwargs for whisper model. Ensure that kwargs are valid.
Returns:
dict: Keyword arguments for whisper model.
"""
pass
def __repr__(self) -> str:
return f"Transcriber(model_name={self.model_name}, model={self.model})"
class WhisperTranscriber(Transcriber):
def __init__(self, model: whisper, model_name: str) -> None:
super().__init__(model, model_name)
def transcribe(self, audio: Union[str, Tensor, ndarray], def transcribe(self, audio: Union[str, Tensor, ndarray],
*args, **kwargs) -> str: *args, **kwargs) -> str:
""" """
@@ -100,24 +201,6 @@ class Transcriber:
result = self.model.transcribe(audio, *args, **kwargs) result = self.model.transcribe(audio, *args, **kwargs)
return result["text"] return result["text"]
@staticmethod
def save_transcript(transcript : str , save_path : str) -> None:
"""
Save a transcript to a file.
Args:
transcript (str): The transcript as a string.
save_path (str): The path to save the transcript.
Returns:
None
"""
with open(save_path, 'w') as f:
f.write(transcript)
print(f'Transcript saved to {save_path}')
@classmethod @classmethod
def load_model(cls, def load_model(cls,
model: str = "medium", model: str = "medium",
@@ -125,7 +208,7 @@ class Transcriber:
device: Optional[Union[str, device]] = None, device: Optional[Union[str, device]] = None,
in_memory: bool = False, in_memory: bool = False,
*args, **kwargs *args, **kwargs
) -> 'Transcriber': ) -> 'WhisperTranscriber':
""" """
Load whisper model. Load whisper model.
@@ -158,7 +241,7 @@ class Transcriber:
Transcriber: A Transcriber object initialized with the specified model. Transcriber: A Transcriber object initialized with the specified model.
""" """
_model = load_model(model, download_root=download_root, _model = whisper_load_model(model, download_root=download_root,
device=device, in_memory=in_memory) device=device, in_memory=in_memory)
return cls(_model, model_name=model) return cls(_model, model_name=model)
@@ -171,9 +254,11 @@ class Transcriber:
Returns: Returns:
dict: Keyword arguments for whisper model. dict: Keyword arguments for whisper model.
""" """
_possible_kwargs = Whisper.transcribe.__code__.co_varnames # _possible_kwargs = WhisperModel.transcribe.__code__.co_varnames
_possible_kwargs = signature(Whisper.transcribe).parameters.keys()
whisper_kwargs = {k: v for k, v in kwargs.items() if k in _possible_kwargs} whisper_kwargs = {k: v for k,
v in kwargs.items() if k in _possible_kwargs}
if (task := kwargs.get("task")): if (task := kwargs.get("task")):
whisper_kwargs["task"] = task whisper_kwargs["task"] = task
@@ -184,4 +269,162 @@ class Transcriber:
return whisper_kwargs return whisper_kwargs
def __repr__(self) -> str: def __repr__(self) -> str:
return f"Transcriber(model_name={self.model_name}, model={self.model})" return f"WhisperTranscriber(model_name={self.model_name}, model={self.model})"
class WhisperXTranscriber(Transcriber):
def __init__(self, model: whisper, model_name: str) -> None:
super().__init__(model, model_name)
def transcribe(self, audio: Union[str, Tensor, ndarray],
*args, **kwargs) -> str:
"""
Transcribe an audio file.
Args:
audio (Union[str, Tensor, nparray]): The audio file to transcribe.
*args: Additional arguments.
**kwargs: Additional keyword arguments,
such as the language of the audio file.
Returns:
str: The transcript as a string.
"""
kwargs = self._get_whisper_kwargs(**kwargs)
if isinstance(audio, Tensor):
audio = audio.cpu().numpy()
result = self.model.transcribe(audio, *args, **kwargs)
text = ""
for seg in result['segments']:
text += seg['text']
return text
@classmethod
def load_model(cls,
model: str = "medium",
download_root: str = WHISPER_DEFAULT_PATH,
device: Optional[Union[str, device]] = None,
*args, **kwargs
) -> 'WhisperXTranscriber':
"""
Load whisper model.
Args:
model (str): Whisper model. Available models include:
- 'tiny.en'
- 'tiny'
- 'base.en'
- 'base'
- 'small.en'
- 'small'
- 'medium.en'
- 'medium'
- 'large-v1'
- 'large-v2'
- 'large-v3'
- 'large'
download_root (str, optional): Path to download the model.
Defaults to WHISPER_DEFAULT_PATH.
device (Optional[Union[str, torch.device]], optional):
Device to load model on. Defaults to None.
in_memory (bool, optional): Whether to load model in memory.
Defaults to False.
args: Additional arguments only to avoid errors.
kwargs: Additional keyword arguments only to avoid errors.
Returns:
Transcriber: A Transcriber object initialized with the specified model.
"""
if device is None:
device = "cuda" if cuda_is_available() else "cpu"
if not isinstance(device, str):
device = str(device)
compute_type = kwargs.get('compute_type', 'float16')
if device == 'cpu' and compute_type == 'float16':
warnings.warn(f'Compute type {compute_type} not compatible with '
f'device {device}! Changing compute type to int8.')
compute_type = 'int8'
_model = whisperx_load_model(model, download_root=download_root,
device=device, compute_type=compute_type)
return cls(_model, model_name=model)
@staticmethod
def _get_whisper_kwargs(**kwargs) -> dict:
"""
Get kwargs for whisper model. Ensure that kwargs are valid.
Returns:
dict: Keyword arguments for whisper model.
"""
# _possible_kwargs = WhisperModel.transcribe.__code__.co_varnames
_possible_kwargs = signature(WhisperModel.transcribe).parameters.keys()
whisper_kwargs = {k: v for k,
v in kwargs.items() if k in _possible_kwargs}
if (task := kwargs.get("task")):
whisper_kwargs["task"] = task
if (language := kwargs.get("language")):
whisper_kwargs["language"] = language
return whisper_kwargs
def __repr__(self) -> str:
return f"WhisperXTranscriber(model_name={self.model_name}, model={self.model})"
def load_transcriber(model: str = "medium",
whisper_type: str = 'whisper',
download_root: str = WHISPER_DEFAULT_PATH,
device: Optional[Union[str, device]] = None,
in_memory: bool = False,
*args, **kwargs
) -> Union[WhisperTranscriber, WhisperXTranscriber]:
"""
Load whisper model.
Args:
model (str): Whisper model. Available models include:
- 'tiny.en'
- 'tiny'
- 'base.en'
- 'base'
- 'small.en'
- 'small'
- 'medium.en'
- 'medium'
- 'large-v1'
- 'large-v2'
- 'large-v3'
- 'large'
whisper_type (str):
Type of whisper model to load. "whisper" or "whisperx".
download_root (str, optional): Path to download the model.
Defaults to WHISPER_DEFAULT_PATH.
device (Optional[Union[str, torch.device]], optional):
Device to load model on. Defaults to None.
in_memory (bool, optional): Whether to load model in memory.
Defaults to False.
args: Additional arguments only to avoid errors.
kwargs: Additional keyword arguments only to avoid errors.
Returns:
Union[WhisperTranscriber, WhisperXTranscriber]:
One of the Whisper variants as Transcrbier object initialized with the specified model.
"""
if whisper_type.lower() == 'whisper':
_model = WhisperTranscriber.load_model(
model, download_root, device, in_memory, *args, **kwargs)
return _model
elif whisper_type.lower() == 'whisperx':
_model = WhisperXTranscriber.load_model(
model, download_root, device, *args, **kwargs)
return _model
else:
raise ValueError(f'Model type not recognized, exptected "whisper" '
f'or "whisperx", got {whisper_type}.')
+12 -10
View File
@@ -1,5 +1,6 @@
import json import json
import time import time
from json.decoder import JSONDecodeError
from typing import Union from typing import Union
@@ -8,7 +9,6 @@ from .hallucinations import KNOWN_HALLUCINATIONS
ALPHABET = [*"abcdefghijklmnopqrstuvwxyz"] ALPHABET = [*"abcdefghijklmnopqrstuvwxyz"]
class Transcript: class Transcript:
""" """
Class for storing transcript data, including speaker information and text segments, Class for storing transcript data, including speaker information and text segments,
@@ -49,7 +49,8 @@ class Transcript:
annotations = {} annotations = {}
if args and len(args) != len(self.speakers): if args and len(args) != len(self.speakers):
raise ValueError("Number of speaker names does not match number of speakers") raise ValueError(
"Number of speaker names does not match number of speakers")
if args: if args:
for arg, speaker in zip(args, sorted(self.speakers)): for arg, speaker in zip(args, sorted(self.speakers)):
@@ -58,9 +59,11 @@ class Transcript:
invalid_speakers = set(kwargs.keys()) - set(self.speakers) invalid_speakers = set(kwargs.keys()) - set(self.speakers)
if invalid_speakers: if invalid_speakers:
raise ValueError(f"These keys are not speakers: {', '.join(invalid_speakers)}") raise ValueError(
f"These keys are not speakers: {', '.join(invalid_speakers)}")
annotations.update({key: kwargs[key] for key in self.speakers if key in kwargs}) annotations.update({key: kwargs[key]
for key in self.speakers if key in kwargs})
self.annotation = annotations self.annotation = annotations
@@ -74,8 +77,10 @@ class Transcript:
segments_to_drop = [] segments_to_drop = []
for id in self.transcript: for id in self.transcript:
for snippet in KNOWN_HALLUCINATIONS: for snippet in KNOWN_HALLUCINATIONS:
self.transcript[id]['text']=self.transcript[id]['text'].replace(snippet,'') self.transcript[id]['text'] = self.transcript[id]['text'].replace(
if self.transcript[id]['text'] == '': segments_to_drop.append(id) snippet, '')
if self.transcript[id]['text'] == '':
segments_to_drop.append(id)
for id in segments_to_drop: for id in segments_to_drop:
del self.transcript[id] del self.transcript[id]
@@ -209,7 +214,6 @@ class Transcript:
return fstring return fstring
def to_json(self, path, *args, **kwargs) -> None: def to_json(self, path, *args, **kwargs) -> None:
"""Save transcript as json file """Save transcript as json file
@@ -310,10 +314,8 @@ class Transcript:
else: else:
try: try:
transcript = json.loads(json) transcript = json.loads(json)
except: except (TypeError, JSONDecodeError):
with open(json, "r") as f: with open(json, "r") as f:
transcript = json.load(f) transcript = json.load(f)
return cls(transcript) return cls(transcript)
-31
View File
@@ -3,7 +3,6 @@ from scraibe.audio import AudioProcessor
import torch import torch
DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu") DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu")
TEST_WAVEFORM = torch.sin(torch.randn(160000)).to(DEVICE) TEST_WAVEFORM = torch.sin(torch.randn(160000)).to(DEVICE)
TEST_SR = 16000 TEST_SR = 16000
@@ -25,10 +24,6 @@ def probe_audio_processor():
return AudioProcessor(TEST_WAVEFORM, TEST_SR) return AudioProcessor(TEST_WAVEFORM, TEST_SR)
def test_AudioProcessor_init(probe_audio_processor): def test_AudioProcessor_init(probe_audio_processor):
""" """
Test the initialization of the AudioProcessor class. Test the initialization of the AudioProcessor class.
@@ -53,7 +48,6 @@ def test_AudioProcessor_init(probe_audio_processor):
assert probe_audio_processor.sr == TEST_SR assert probe_audio_processor.sr == TEST_SR
def test_cut(probe_audio_processor): def test_cut(probe_audio_processor):
"""Test the cut function of the AudioProcessor class. """Test the cut function of the AudioProcessor class.
@@ -76,14 +70,6 @@ def test_cut(probe_audio_processor):
# assert AudioProcessor(TEST_WAVEFORM, TEST_SR).cut(start, end).size() == int((end - start) * TEST_SR) # assert AudioProcessor(TEST_WAVEFORM, TEST_SR).cut(start, end).size() == int((end - start) * TEST_SR)
def test_audio_processor_invalid_sr(): def test_audio_processor_invalid_sr():
"""Test the behavior of AudioProcessor when an invalid smaple rate is provided. """Test the behavior of AudioProcessor when an invalid smaple rate is provided.
@@ -108,20 +94,3 @@ def test_audio_processor_SAMPLE_RATE():
""" """
probe_audio_processor = AudioProcessor(TEST_WAVEFORM) probe_audio_processor = AudioProcessor(TEST_WAVEFORM)
assert probe_audio_processor.sr == SAMPLE_RATE assert probe_audio_processor.sr == SAMPLE_RATE
-6
View File
@@ -1,12 +1,8 @@
import pytest import pytest
from scraibe import Scraibe, Diariser, Transcriber, Transcript from scraibe import Scraibe, Diariser, Transcriber, Transcript
from unittest.mock import MagicMock, patch
import os import os
@pytest.fixture @pytest.fixture
def create_scraibe_instance(): def create_scraibe_instance():
if "HF_TOKEN" in os.environ: if "HF_TOKEN" in os.environ:
@@ -15,8 +11,6 @@ def create_scraibe_instance():
return Scraibe() return Scraibe()
def test_scraibe_init(create_scraibe_instance): def test_scraibe_init(create_scraibe_instance):
model = create_scraibe_instance model = create_scraibe_instance
assert isinstance(model.transcriber, Transcriber) assert isinstance(model.transcriber, Transcriber)
+1 -16
View File
@@ -1,8 +1,5 @@
import pytest import pytest
import os from scraibe import Diariser
from unittest import mock
from scraibe import diarisation, Diariser
@pytest.fixture @pytest.fixture
@@ -19,7 +16,6 @@ def diariser_instance():
return Diariser('pyannote') return Diariser('pyannote')
def test_Diariser_init(diariser_instance): def test_Diariser_init(diariser_instance):
"""Test the initialization of the Diariser class. """Test the initialization of the Diariser class.
@@ -34,14 +30,3 @@ def test_Diariser_init(diariser_instance):
None None
""" """
assert diariser_instance.model == 'pyannote' assert diariser_instance.model == 'pyannote'
+39 -11
View File
@@ -1,10 +1,9 @@
import pytest import pytest
from unittest.mock import patch from scraibe import (Transcriber, WhisperTranscriber,
from scraibe import Transcriber WhisperXTranscriber, load_transcriber)
import torch import torch
DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu") DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu")
TEST_WAVEFORM = "Hello World" TEST_WAVEFORM = "Hello World"
@@ -29,12 +28,37 @@ def test_transcriber(mock_load_model, audio_file, expected_transcription):
assert transcription_result == expected_transcription """ assert transcription_result == expected_transcription """
@pytest.fixture
def transcriber_instance():
return Transcriber.load_model('medium')
def test_transcriber_initialization(transcriber_instance): @pytest.fixture
assert isinstance(transcriber_instance, Transcriber) def whisper_instance():
return load_transcriber('medium', whisper_type='whisper')
@pytest.fixture
def whisperx_instance():
return load_transcriber('medium', whisper_type='whisperx')
def test_whisper_base_initialization(whisper_instance):
assert isinstance(whisper_instance, Transcriber)
def test_whisperx_base_initialization(whisperx_instance):
assert isinstance(whisperx_instance, Transcriber)
def test_whisper_transcriber_initialization(whisper_instance):
assert isinstance(whisper_instance, WhisperTranscriber)
def test_whisperx_transcriber_initialization(whisperx_instance):
assert isinstance(whisperx_instance, WhisperXTranscriber)
def test_wrong_transcriber_initialization():
with pytest.raises(ValueError):
load_transcriber('medium', whisper_type='wrong_whisper')
def test_get_whisper_kwargs(): def test_get_whisper_kwargs():
kwargs = {"arg1": 1, "arg3": 3} kwargs = {"arg1": 1, "arg3": 3}
@@ -42,11 +66,15 @@ def test_get_whisper_kwargs():
assert not valid_kwargs == {"arg1": 1, "arg3": 3} assert not valid_kwargs == {"arg1": 1, "arg3": 3}
def test_transcribe(transcriber_instance): def test_whisper_transcribe(whisper_instance):
model = transcriber_instance model = whisper_instance
# mocker.patch.object(transcriber_instance.model, 'transcribe', return_value={'Hello, World !'} ) # mocker.patch.object(transcriber_instance.model, 'transcribe', return_value={'Hello, World !'} )
transcript = model.transcribe('test/audio_test_2.mp4') transcript = model.transcribe('test/audio_test_2.mp4')
assert isinstance(transcript, str) assert isinstance(transcript, str)
def test_whisperx_transcribe(whisperx_instance):
model = whisperx_instance
# mocker.patch.object(transcriber_instance.model, 'transcribe', return_value={'Hello, World !'} )
transcript = model.transcribe('test/audio_test_2.mp4')
assert isinstance(transcript, str)