Merge branch 'develop' into pyproject.toml

This commit is contained in:
Schmieder, Jacob
2024-05-21 11:05:55 +00:00
15 changed files with 688 additions and 441 deletions
+1
View File
@@ -2,6 +2,7 @@ tqdm>=4.65.0
numpy>=1.26.4 numpy>=1.26.4
openai-whisper==20231117 openai-whisper==20231117
whisperx~=3.1.3
pyannote.audio~=3.1.1 pyannote.audio~=3.1.1
pyannote.core~=5.0.0 pyannote.core~=5.0.0
-1
View File
@@ -9,4 +9,3 @@ from .misc import *
from .cli import * from .cli import *
from ._version import __version__ from ._version import __version__
+11 -9
View File
@@ -28,6 +28,7 @@ import torch
SAMPLE_RATE = 16000 SAMPLE_RATE = 16000
NORMALIZATION_FACTOR = 32768.0 NORMALIZATION_FACTOR = 32768.0
class AudioProcessor: class AudioProcessor:
""" """
Audio Processor class that leverages PyTorchaudio to provide functionalities Audio Processor class that leverages PyTorchaudio to provide functionalities
@@ -40,9 +41,8 @@ class AudioProcessor:
The sample rate of the audio. The sample rate of the audio.
""" """
def __init__(self, waveform: torch.Tensor, sr : int = SAMPLE_RATE, def __init__(self, waveform: torch.Tensor, sr: int = SAMPLE_RATE,
*args, **kwargs) -> None: *args, **kwargs) -> None:
""" """
Initialize the AudioProcessor object. Initialize the AudioProcessor object.
@@ -57,13 +57,14 @@ class AudioProcessor:
ValueError: If the provided sample rate is not of type int. ValueError: If the provided sample rate is not of type int.
""" """
device = kwargs.get("device", "cuda" if torch.cuda.is_available() else "cpu") device = kwargs.get(
"device", "cuda" if torch.cuda.is_available() else "cpu")
self.waveform = waveform.to(device) self.waveform = waveform.to(device)
self.sr = sr self.sr = sr
if not isinstance(self.sr, int): if not isinstance(self.sr, int):
raise ValueError("Sample rate should be a single value of type int," \ raise ValueError("Sample rate should be a single value of type int,"
f"not {len(self.sr)} and type {type(self.sr)}") f"not {len(self.sr)} and type {type(self.sr)}")
@classmethod @classmethod
@@ -78,13 +79,12 @@ class AudioProcessor:
AudioProcessor: An instance of the AudioProcessor class containing the loaded audio. AudioProcessor: An instance of the AudioProcessor class containing the loaded audio.
""" """
audio, sr = cls.load_audio(file , *args, **kwargs) audio, sr = cls.load_audio(file, *args, **kwargs)
audio = torch.from_numpy(audio) audio = torch.from_numpy(audio)
return cls(audio, sr) return cls(audio, sr)
def cut(self, start: float, end: float) -> torch.Tensor: def cut(self, start: float, end: float) -> torch.Tensor:
""" """
Cut a segment from the audio waveform between the specified start and end times. Cut a segment from the audio waveform between the specified start and end times.
@@ -140,11 +140,13 @@ class AudioProcessor:
try: try:
out = run(cmd, capture_output=True, check=True).stdout out = run(cmd, capture_output=True, check=True).stdout
except CalledProcessError as e: except CalledProcessError as e:
raise RuntimeError(f"Failed to load audio: {e.stderr.decode()}") from e raise RuntimeError(
f"Failed to load audio: {e.stderr.decode()}") from e
out = np.frombuffer(out, np.int16).flatten().astype(np.float32) / NORMALIZATION_FACTOR out = np.frombuffer(out, np.int16).flatten().astype(
np.float32) / NORMALIZATION_FACTOR
return out , sr return out, sr
def __repr__(self) -> str: def __repr__(self) -> str:
return f'TorchAudioProcessor(waveform={len(self.waveform)}, sr={int(self.sr)})' return f'TorchAudioProcessor(waveform={len(self.waveform)}, sr={int(self.sr)})'
+52 -46
View File
@@ -38,7 +38,7 @@ from tqdm import trange
# Application-Specific Imports # Application-Specific Imports
from .audio import AudioProcessor from .audio import AudioProcessor
from .diarisation import Diariser from .diarisation import Diariser
from .transcriber import Transcriber, whisper from .transcriber import Transcriber, load_transcriber, whisper
from .transcript_exporter import Transcript from .transcript_exporter import Transcript
@@ -62,15 +62,19 @@ class Scraibe:
remove_audio_file: Removes the original audio file to avoid disk space issues or ensure data privacy. remove_audio_file: Removes the original audio file to avoid disk space issues or ensure data privacy.
get_audio_file: Gets an audio file as an AudioProcessor object. get_audio_file: Gets an audio file as an AudioProcessor object.
""" """
def __init__(self, def __init__(self,
whisper_model: Union[bool, str, whisper] = None, whisper_model: Union[bool, str, whisper] = None,
dia_model : Union[bool, str, DiarisationType] = None, whisper_type: str = "whisper",
**kwargs) -> None: dia_model: Union[bool, str, DiarisationType] = None,
**kwargs) -> None:
"""Initializes the Scraibe class. """Initializes the Scraibe class.
Args: Args:
whisper_model (Union[bool, str, whisper], optional): whisper_model (Union[bool, str, whisper], optional):
Path to whisper model or whisper model itself. Path to whisper model or whisper model itself.
whisper_type (str):
Type of whisper model to load. "whisper" or "whisperx".
diarisation_model (Union[bool, str, DiarisationType], optional): diarisation_model (Union[bool, str, DiarisationType], optional):
Path to pyannote diarization model or model itself. Path to pyannote diarization model or model itself.
**kwargs: Additional keyword arguments for whisper **kwargs: Additional keyword arguments for whisper
@@ -82,11 +86,12 @@ class Scraibe:
for autotranscribe. So you can unload the class and reload it again. for autotranscribe. So you can unload the class and reload it again.
""" """
if whisper_model is None: if whisper_model is None:
self.transcriber = Transcriber.load_model("medium", **kwargs) self.transcriber = load_transcriber(
"medium", whisper_type, **kwargs)
elif isinstance(whisper_model, str): elif isinstance(whisper_model, str):
self.transcriber = Transcriber.load_model(whisper_model, **kwargs) self.transcriber = load_transcriber(
whisper_model, whisper_type, **kwargs)
else: else:
self.transcriber = whisper_model self.transcriber = whisper_model
@@ -95,7 +100,7 @@ class Scraibe:
elif isinstance(dia_model, str): elif isinstance(dia_model, str):
self.diariser = Diariser.load_model(dia_model, **kwargs) self.diariser = Diariser.load_model(dia_model, **kwargs)
else: else:
self.diariser : Diariser = dia_model self.diariser: Diariser = dia_model
if kwargs.get("verbose"): if kwargs.get("verbose"):
print("Scraibe initialized all models successfully loaded.") print("Scraibe initialized all models successfully loaded.")
@@ -105,16 +110,15 @@ class Scraibe:
# Save kwargs for autotranscribe if you want to unload the class and load it again. # Save kwargs for autotranscribe if you want to unload the class and load it again.
if kwargs.get('save_setup'): if kwargs.get('save_setup'):
self.params = dict(whisper_model = whisper_model, self.params = dict(whisper_model=whisper_model,
dia_model = dia_model, dia_model=dia_model,
**kwargs) **kwargs)
else: else:
self.params = {} self.params = {}
def autotranscribe(self, audio_file: Union[str, torch.Tensor, ndarray],
def autotranscribe(self, audio_file : Union[str, torch.Tensor, ndarray], remove_original: bool = False,
remove_original : bool = False, **kwargs) -> Transcript:
**kwargs) -> Transcript:
""" """
Transcribes an audio file using the whisper model and pyannote diarization model. Transcribes an audio file using the whisper model and pyannote diarization model.
@@ -133,13 +137,13 @@ class Scraibe:
if kwargs.get("verbose"): if kwargs.get("verbose"):
self.verbose = kwargs.get("verbose") self.verbose = kwargs.get("verbose")
# Get audio file as an AudioProcessor object # Get audio file as an AudioProcessor object
audio_file : AudioProcessor = self.get_audio_file(audio_file) audio_file: AudioProcessor = self.get_audio_file(audio_file)
# Prepare waveform and sample rate for diarization # Prepare waveform and sample rate for diarization
dia_audio = { dia_audio = {
"waveform" : audio_file.waveform.reshape(1,len(audio_file.waveform)), "waveform": audio_file.waveform.reshape(1, len(audio_file.waveform)),
"sample_rate": audio_file.sr "sample_rate": audio_file.sr
} }
if self.verbose: if self.verbose:
print("Starting diarisation.") print("Starting diarisation.")
@@ -149,23 +153,25 @@ class Scraibe:
if not diarisation["segments"]: if not diarisation["segments"]:
print("No segments found. Try to run transcription without diarisation.") print("No segments found. Try to run transcription without diarisation.")
transcript = self.transcriber.transcribe(audio_file.waveform, **kwargs) transcript = self.transcriber.transcribe(
audio_file.waveform, **kwargs)
final_transcript= {0 : {"speakers" : 'SPEAKER_01', final_transcript = {0: {"speakers": 'SPEAKER_01',
"segments" : [0, len(audio_file.waveform)], "segments": [0, len(audio_file.waveform)],
"text" : transcript}} "text": transcript}}
return Transcript(final_transcript) return Transcript(final_transcript)
if self.verbose: if self.verbose:
print("Diarisation finished. Starting transcription.") print("Diarisation finished. Starting transcription.")
audio_file.sr = torch.Tensor([audio_file.sr]).to(audio_file.waveform.device) audio_file.sr = torch.Tensor([audio_file.sr]).to(
audio_file.waveform.device)
# Transcribe each segment and store the results # Transcribe each segment and store the results
final_transcript = dict() final_transcript = dict()
for i in trange(len(diarisation["segments"]), desc= "Transcribing", disable = not self.verbose): for i in trange(len(diarisation["segments"]), desc="Transcribing", disable=not self.verbose):
seg = diarisation["segments"][i] seg = diarisation["segments"][i]
@@ -173,9 +179,9 @@ class Scraibe:
transcript = self.transcriber.transcribe(audio, **kwargs) transcript = self.transcriber.transcribe(audio, **kwargs)
final_transcript[i] = {"speakers" : diarisation["speakers"][i], final_transcript[i] = {"speakers": diarisation["speakers"][i],
"segments" : seg, "segments": seg,
"text" : transcript} "text": transcript}
# Remove original file if needed # Remove original file if needed
if remove_original: if remove_original:
@@ -186,7 +192,7 @@ class Scraibe:
return Transcript(final_transcript) return Transcript(final_transcript)
def diarization(self, audio_file : Union[str, torch.Tensor, ndarray], def diarization(self, audio_file: Union[str, torch.Tensor, ndarray],
**kwargs) -> dict: **kwargs) -> dict:
""" """
Perform diarization on an audio file using the pyannote diarization model. Perform diarization on an audio file using the pyannote diarization model.
@@ -203,13 +209,13 @@ class Scraibe:
""" """
# Get audio file as an AudioProcessor object # Get audio file as an AudioProcessor object
audio_file : AudioProcessor = self.get_audio_file(audio_file) audio_file: AudioProcessor = self.get_audio_file(audio_file)
# Prepare waveform and sample rate for diarization # Prepare waveform and sample rate for diarization
dia_audio = { dia_audio = {
"waveform" : audio_file.waveform.reshape(1,len(audio_file.waveform)), "waveform": audio_file.waveform.reshape(1, len(audio_file.waveform)),
"sample_rate": audio_file.sr "sample_rate": audio_file.sr
} }
print("Starting diarisation.") print("Starting diarisation.")
@@ -217,8 +223,8 @@ class Scraibe:
return diarisation return diarisation
def transcribe(self, audio_file : Union[str, torch.Tensor, ndarray], def transcribe(self, audio_file: Union[str, torch.Tensor, ndarray],
**kwargs): **kwargs):
""" """
Transcribe the provided audio file. Transcribe the provided audio file.
@@ -232,11 +238,11 @@ class Scraibe:
str: str:
The transcribed text from the audio source. The transcribed text from the audio source.
""" """
audio_file : AudioProcessor = self.get_audio_file(audio_file) audio_file: AudioProcessor = self.get_audio_file(audio_file)
return self.transcriber.transcribe(audio_file.waveform, **kwargs) return self.transcriber.transcribe(audio_file.waveform, **kwargs)
def update_transcriber(self, whisper_model : Union[str, whisper], **kwargs) -> None: def update_transcriber(self, whisper_model: Union[str, whisper], **kwargs) -> None:
""" """
Update the transcriber model. Update the transcriber model.
@@ -252,15 +258,16 @@ class Scraibe:
_old_model = self.transcriber.model_name _old_model = self.transcriber.model_name
if isinstance(whisper_model, str): if isinstance(whisper_model, str):
self.transcriber = Transcriber.load_model(whisper_model, **kwargs) self.transcriber = load_transcriber(whisper_model, **kwargs)
elif isinstance(whisper_model, Transcriber): elif isinstance(whisper_model, Transcriber):
self.transcriber = whisper_model self.transcriber = whisper_model
else: else:
warn(f"Invalid model type. Please provide a valid model. Fallback to old {_old_model} Model.", RuntimeWarning) warn(
f"Invalid model type. Please provide a valid model. Fallback to old {_old_model} Model.", RuntimeWarning)
return None return None
def update_diariser(self, dia_model : Union[str, DiarisationType], **kwargs) -> None: def update_diariser(self, dia_model: Union[str, DiarisationType], **kwargs) -> None:
""" """
Update the diariser model. Update the diariser model.
@@ -278,13 +285,13 @@ class Scraibe:
elif isinstance(dia_model, Diariser): elif isinstance(dia_model, Diariser):
self.diariser = dia_model self.diariser = dia_model
else: else:
warn(f"Invalid model type. Please provide a valid model. Fallback to old Model.", RuntimeWarning) warn("Invalid model type. Please provide a valid model. Fallback to old Model.", RuntimeWarning)
return None return None
@staticmethod @staticmethod
def remove_audio_file(audio_file : str, def remove_audio_file(audio_file: str,
shred : bool = False) -> None: shred: bool = False) -> None:
""" """
Removes the original audio file to avoid disk space issues or ensure data privacy. Removes the original audio file to avoid disk space issues or ensure data privacy.
@@ -309,16 +316,15 @@ class Scraibe:
for file in gen: for file in gen:
print(f'shredding {file} now\n') print(f'shredding {file} now\n')
run(cmd , check=True) run(cmd, check=True)
else: else:
os.remove(audio_file) os.remove(audio_file)
print(f"Audiofile {audio_file} removed.") print(f"Audiofile {audio_file} removed.")
@staticmethod @staticmethod
def get_audio_file(audio_file : Union[str, torch.Tensor, ndarray], def get_audio_file(audio_file: Union[str, torch.Tensor, ndarray],
*args, **kwargs) -> AudioProcessor: *args, **kwargs) -> AudioProcessor:
"""Gets an audio file as TorchAudioProcessor. """Gets an audio file as TorchAudioProcessor.
Args: Args:
@@ -339,10 +345,10 @@ class Scraibe:
audio_file = AudioProcessor(audio_file[0], audio_file[1]) audio_file = AudioProcessor(audio_file[0], audio_file[1])
elif isinstance(audio_file, ndarray): elif isinstance(audio_file, ndarray):
audio_file = AudioProcessor(torch.Tensor(audio_file[0]), audio_file = AudioProcessor(torch.Tensor(audio_file[0]),
audio_file[1]) audio_file[1])
if not isinstance(audio_file, AudioProcessor): if not isinstance(audio_file, AudioProcessor):
raise ValueError(f'Audiofile must be of type AudioProcessor,' \ raise ValueError(f'Audiofile must be of type AudioProcessor,'
f'not {type(audio_file)}') f'not {type(audio_file)}')
return audio_file return audio_file
+36 -28
View File
@@ -12,7 +12,7 @@ from .autotranscript import Scraibe
from .misc import ParseKwargs from .misc import ParseKwargs
from whisper.tokenizer import LANGUAGES , TO_LANGUAGE_CODE from whisper.tokenizer import LANGUAGES, TO_LANGUAGE_CODE
from torch.cuda import is_available from torch.cuda import is_available
from torch import set_num_threads from torch import set_num_threads
@@ -32,21 +32,22 @@ def cli():
if string in str2val: if string in str2val:
return str2val[string] return str2val[string]
else: else:
raise ValueError(f"Expected one of {set(str2val.keys())}, got {string}") raise ValueError(
f"Expected one of {set(str2val.keys())}, got {string}")
parser = ArgumentParser(formatter_class = ArgumentDefaultsHelpFormatter) parser = ArgumentParser(formatter_class=ArgumentDefaultsHelpFormatter)
group = parser.add_mutually_exclusive_group() group = parser.add_mutually_exclusive_group()
parser.add_argument("-f","--audio-files", nargs="+", type=str, default=None, parser.add_argument("-f", "--audio-files", nargs="+", type=str, default=None,
help="List of audio files to transcribe.") help="List of audio files to transcribe.")
group.add_argument('--start-server', action='store_true', group.add_argument('--start-server', action='store_true',
help='Start the Gradio app.' \ help='Start the Gradio app.'
'If set, all other arguments are ignored' \ 'If set, all other arguments are ignored'
'besides --server-config or --server-kwargs.') 'besides --server-config or --server-kwargs.')
parser.add_argument("--server-config", type=str, default= None, parser.add_argument("--server-config", type=str, default=None,
help="Path to the configy.yml file.") help="Path to the configy.yml file.")
parser.add_argument('--server-kwargs', nargs='*', action=ParseKwargs, default={}, parser.add_argument('--server-kwargs', nargs='*', action=ParseKwargs, default={},
@@ -55,13 +56,13 @@ def cli():
parser.add_argument("--whisper-model-name", default="medium", parser.add_argument("--whisper-model-name", default="medium",
help="Name of the Whisper model to use.") help="Name of the Whisper model to use.")
parser.add_argument("--whisper-model-directory", type=str, default= None, parser.add_argument("--whisper-model-directory", type=str, default=None,
help="Path to save Whisper model files; defaults to ./models/whisper.") help="Path to save Whisper model files; defaults to ./models/whisper.")
parser.add_argument("--diarization-directory", type=str, default= None, parser.add_argument("--diarization-directory", type=str, default=None,
help="Path to the diarization model directory.") help="Path to the diarization model directory.")
parser.add_argument("--hf-token", default= None, type=str, parser.add_argument("--hf-token", default=None, type=str,
help="HuggingFace token for private model download.") help="HuggingFace token for private model download.")
parser.add_argument("--inference-device", parser.add_argument("--inference-device",
@@ -82,14 +83,15 @@ def cli():
parser.add_argument("--verbose-output", type=str2bool, default=True, parser.add_argument("--verbose-output", type=str2bool, default=True,
help="Enable or disable progress and debug messages.") help="Enable or disable progress and debug messages.")
parser.add_argument("--task", type=str, default= 'autotranscribe', # unifinished code parser.add_argument("--task", type=str, default='autotranscribe', # unifinished code
choices=["autotranscribe", "diarization", choices=["autotranscribe", "diarization",
"autotranscribe+translate", "translate", 'transcribe'], "autotranscribe+translate", "translate", 'transcribe'],
help="Choose to perform transcription, diarization, or translation. \ help="Choose to perform transcription, diarization, or translation. \
If set to translate, the output will be translated to English.") If set to translate, the output will be translated to English.")
parser.add_argument("--language", type=str, default=None, parser.add_argument("--language", type=str, default=None,
choices=sorted(LANGUAGES.keys()) + sorted([k.title() for k in TO_LANGUAGE_CODE.keys()]), choices=sorted(
LANGUAGES.keys()) + sorted([k.title() for k in TO_LANGUAGE_CODE.keys()]),
help="Language spoken in the audio. Specify None to perform language detection.") help="Language spoken in the audio. Specify None to perform language detection.")
args = parser.parse_args() args = parser.parse_args()
@@ -110,9 +112,9 @@ def cli():
if args.num_threads > 0: if args.num_threads > 0:
set_num_threads(arg_dict.pop("num_threads")) set_num_threads(arg_dict.pop("num_threads"))
class_kwargs = {'whisper_model' : arg_dict.pop("whisper_model_name"), class_kwargs = {'whisper_model': arg_dict.pop("whisper_model_name"),
'dia_model': arg_dict.pop("diarization_directory"), 'dia_model': arg_dict.pop("diarization_directory"),
'use_auth_token' : arg_dict.pop("hf_token")} 'use_auth_token': arg_dict.pop("hf_token")}
if arg_dict["whisper_model_directory"]: if arg_dict["whisper_model_directory"]:
class_kwargs["download_root"] = arg_dict.pop("whisper_model_directory") class_kwargs["download_root"] = arg_dict.pop("whisper_model_directory")
@@ -131,15 +133,17 @@ def cli():
else: else:
task = "transcribe" task = "transcribe"
out = model.autotranscribe(audio,task = task, language=arg_dict.pop("language"), verbose = arg_dict.pop("verbose_output")) out = model.autotranscribe(audio, task=task, language=arg_dict.pop(
"language"), verbose=arg_dict.pop("verbose_output"))
basename = audio.split("/")[-1].split(".")[0] basename = audio.split("/")[-1].split(".")[0]
print(f'Saving {basename}.{out_format} to {out_folder}') print(f'Saving {basename}.{out_format} to {out_folder}')
out.save(os.path.join(out_folder, f"{basename}.{out_format}")) out.save(os.path.join(
out_folder, f"{basename}.{out_format}"))
elif task == "diarization": elif task == "diarization":
for audio in audio_files: for audio in audio_files:
if arg_dict.pop("verbose_output"): if arg_dict.pop("verbose_output"):
print(f"Verbose not implemented for diarization.") print("Verbose not implemented for diarization.")
out = model.diarization(audio) out = model.diarization(audio)
basename = audio.split("/")[-1].split(".")[0] basename = audio.split("/")[-1].split(".")[0]
@@ -148,39 +152,43 @@ def cli():
print(f'Saving {basename}.{out_format} to {out_folder}') print(f'Saving {basename}.{out_format} to {out_folder}')
with open(path, "w") as f: with open(path, "w") as f:
json.dump(json.dumps(out, indent= 1), f) json.dump(json.dumps(out, indent=1), f)
elif task == "transcribe" or task == "translate": elif task == "transcribe" or task == "translate":
for audio in audio_files: for audio in audio_files:
out = model.transcribe(audio, task = task, out = model.transcribe(audio, task=task,
language= arg_dict.pop("language"), language=arg_dict.pop("language"),
verbose = arg_dict.pop("verbose_output")) verbose=arg_dict.pop("verbose_output"))
basename = audio.split("/")[-1].split(".")[0] basename = audio.split("/")[-1].split(".")[0]
path = os.path.join(out_folder, f"{basename}.{out_format}") path = os.path.join(out_folder, f"{basename}.{out_format}")
with open(path, "w") as f: with open(path, "w") as f:
f.write(out) f.write(out)
else: # unfinished code
else: # unfinished code
raise NotImplementedError("Currently not Working") raise NotImplementedError("Currently not Working")
import subprocess import subprocess
import sys import sys
execute_path = os.path.join(os.path.dirname(__file__), "app/app_starter.py") execute_path = os.path.join(
os.path.dirname(__file__), "app/app_starter.py")
config = arg_dict.pop("server_config") config = arg_dict.pop("server_config")
server_kwargs = arg_dict.pop("server_kwargs") server_kwargs = arg_dict.pop("server_kwargs")
if not config: if not config:
subprocess.run([sys.executable, execute_path, f"--server-kwargs={server_kwargs}"]) subprocess.run([sys.executable, execute_path,
f"--server-kwargs={server_kwargs}"])
elif not server_kwargs: elif not server_kwargs:
subprocess.run([sys.executable, execute_path, f"--server-config={config}"]) subprocess.run([sys.executable, execute_path,
f"--server-config={config}"])
elif not config and not server_kwargs: elif not config and not server_kwargs:
subprocess.run([sys.executable, execute_path]) subprocess.run([sys.executable, execute_path])
else: else:
subprocess.run([sys.executable, execute_path, f"--server-config={config}", f"--server-kwargs={server_kwargs}"]) subprocess.run([sys.executable, execute_path,
f"--server-config={config}", f"--server-kwargs={server_kwargs}"])
if __name__ == "__main__": if __name__ == "__main__":
cli() cli()
+40 -36
View File
@@ -37,7 +37,7 @@ from pyannote.audio import Pipeline
from pyannote.audio.pipelines.speaker_diarization import SpeakerDiarization from pyannote.audio.pipelines.speaker_diarization import SpeakerDiarization
from torch import Tensor from torch import Tensor
from torch import device as torch_device from torch import device as torch_device
from torch.cuda import is_available, current_device from torch.cuda import is_available
from huggingface_hub import HfApi from huggingface_hub import HfApi
from huggingface_hub.utils import RepositoryNotFoundError from huggingface_hub.utils import RepositoryNotFoundError
@@ -45,7 +45,8 @@ from .misc import PYANNOTE_DEFAULT_PATH, PYANNOTE_DEFAULT_CONFIG
Annotation = TypeVar('Annotation') Annotation = TypeVar('Annotation')
TOKEN_PATH = os.path.join(os.path.dirname( TOKEN_PATH = os.path.join(os.path.dirname(
os.path.realpath(__file__)), '.pyannotetoken') os.path.realpath(__file__)), '.pyannotetoken')
class Diariser: class Diariser:
""" """
@@ -60,7 +61,7 @@ class Diariser:
self.model = model self.model = model
def diarization(self, audiofile : Union[str, Tensor, dict] , def diarization(self, audiofile: Union[str, Tensor, dict],
*args, **kwargs) -> Annotation: *args, **kwargs) -> Annotation:
""" """
Perform speaker diarization on the provided audio file, Perform speaker diarization on the provided audio file,
@@ -80,14 +81,14 @@ class Diariser:
""" """
kwargs = self._get_diarisation_kwargs(**kwargs) kwargs = self._get_diarisation_kwargs(**kwargs)
diarization = self.model(audiofile,*args, **kwargs) diarization = self.model(audiofile, *args, **kwargs)
out = self.format_diarization_output(diarization) out = self.format_diarization_output(diarization)
return out return out
@staticmethod @staticmethod
def format_diarization_output(dia : Annotation) -> dict: def format_diarization_output(dia: Annotation) -> dict:
""" """
Formats the raw diarization output into a more usable structure for this project. Formats the raw diarization output into a more usable structure for this project.
@@ -99,7 +100,7 @@ class Diariser:
as keys and a list of tuples representing segments as values. as keys and a list of tuples representing segments as values.
""" """
dia_list = list(dia.itertracks(yield_label=True)) dia_list = list(dia.itertracks(yield_label=True))
diarization_output = {"speakers": [], "segments": []} diarization_output = {"speakers": [], "segments": []}
normalized_output = [] normalized_output = []
@@ -126,24 +127,23 @@ class Diariser:
index_end_speaker = i - 1 index_end_speaker = i - 1
normalized_output.append([index_start_speaker, normalized_output.append([index_start_speaker,
index_end_speaker, index_end_speaker,
current_speaker]) current_speaker])
index_start_speaker = i index_start_speaker = i
current_speaker = speaker current_speaker = speaker
if i == len(dia_list) - 1: if i == len(dia_list) - 1:
index_end_speaker = i index_end_speaker = i
normalized_output.append([index_start_speaker, normalized_output.append([index_start_speaker,
index_end_speaker, index_end_speaker,
current_speaker]) current_speaker])
for outp in normalized_output: for outp in normalized_output:
start = dia_list[outp[0]][0].start start = dia_list[outp[0]][0].start
end = dia_list[outp[1]][0].end end = dia_list[outp[1]][0].end
diarization_output["segments"].append([start, end]) diarization_output["segments"].append([start, end])
diarization_output["speakers"].append(outp[2]) diarization_output["speakers"].append(outp[2])
@@ -166,9 +166,9 @@ class Diariser:
with open(TOKEN_PATH, 'r', encoding="utf-8") as file: with open(TOKEN_PATH, 'r', encoding="utf-8") as file:
token = file.read() token = file.read()
else: else:
raise ValueError('No token found.' \ raise ValueError('No token found.'
'Please create a token at https://huggingface.co/settings/token' \ 'Please create a token at https://huggingface.co/settings/token'
f'and save it in a file called {TOKEN_PATH}') f'and save it in a file called {TOKEN_PATH}')
return token return token
@staticmethod @staticmethod
@@ -185,15 +185,14 @@ class Diariser:
@classmethod @classmethod
def load_model(cls, def load_model(cls,
model: str = PYANNOTE_DEFAULT_CONFIG, model: str = PYANNOTE_DEFAULT_CONFIG,
use_auth_token: str = None, use_auth_token: str = None,
cache_token: bool = False, cache_token: bool = False,
cache_dir: Union[Path, str] = PYANNOTE_DEFAULT_PATH, cache_dir: Union[Path, str] = PYANNOTE_DEFAULT_PATH,
hparams_file: Union[str, Path] = None, hparams_file: Union[str, Path] = None,
device: str = None, device: str = None,
*args, **kwargs *args, **kwargs
) -> Pipeline: ) -> Pipeline:
""" """
Loads a pretrained model from pyannote.audio, Loads a pretrained model from pyannote.audio,
either from a local cache or some online repository. either from a local cache or some online repository.
@@ -237,16 +236,18 @@ class Diariser:
'deprecated and will be removed in future versions.', 'deprecated and will be removed in future versions.',
category=DeprecationWarning) category=DeprecationWarning)
# list elementes with the ending .bin # list elementes with the ending .bin
bin_files = [f for f in os.listdir(pwd) if f.endswith(".bin")] bin_files = [f for f in os.listdir(
pwd) if f.endswith(".bin")]
if len(bin_files) == 1: if len(bin_files) == 1:
path_to_model = os.path.join(pwd, bin_files[0]) path_to_model = os.path.join(pwd, bin_files[0])
else: else:
warnings.warn("Found more than one .bin file. "\ warnings.warn("Found more than one .bin file. "
"or none. Please specify the path to the model " \ "or none. Please specify the path to the model "
"or setup a huggingface token.") "or setup a huggingface token.")
raise FileNotFoundError raise FileNotFoundError
warnings.warn(f"Found model at {path_to_model} overwriting config file.") warnings.warn(
f"Found model at {path_to_model} overwriting config file.")
config['pipeline']['params']['segmentation'] = path_to_model config['pipeline']['params']['segmentation'] = path_to_model
@@ -270,22 +271,24 @@ class Diariser:
if use_auth_token is None: if use_auth_token is None:
use_auth_token = cls._get_token() use_auth_token = cls._get_token()
else: else:
raise FileNotFoundError(f'No local model or directory found at {model}.') raise FileNotFoundError(
f'No local model or directory found at {model}.')
_model = Pipeline.from_pretrained(model, _model = Pipeline.from_pretrained(model,
use_auth_token=use_auth_token, use_auth_token=use_auth_token,
cache_dir=cache_dir, cache_dir=cache_dir,
hparams_file=hparams_file,) hparams_file=hparams_file,)
if _model is None: if _model is None:
raise ValueError('Unable to load model either from local cache' \ raise ValueError('Unable to load model either from local cache'
'or from huggingface.co models. Please check your token' \ 'or from huggingface.co models. Please check your token'
'or your local model path') 'or your local model path')
# try to move the model to the device # try to move the model to the device
if device is None: if device is None:
device = "cuda" if is_available() else "cpu" device = "cuda" if is_available() else "cpu"
_model = _model.to(torch_device(device)) # torch_device is renamed from torch.device to avoid name conflict # torch_device is renamed from torch.device to avoid name conflict
_model = _model.to(torch_device(device))
return cls(_model) return cls(_model)
@@ -302,7 +305,8 @@ class Diariser:
""" """
_possible_kwargs = SpeakerDiarization.apply.__code__.co_varnames _possible_kwargs = SpeakerDiarization.apply.__code__.co_varnames
diarisation_kwargs = {k: v for k, v in kwargs.items() if k in _possible_kwargs} diarisation_kwargs = {k: v for k,
v in kwargs.items() if k in _possible_kwargs}
return diarisation_kwargs return diarisation_kwargs
+1 -1
View File
@@ -1,6 +1,6 @@
# List of known hallucinations - adapted from: # List of known hallucinations - adapted from:
# https://github.com/openai/whisper/discussions/928 # https://github.com/openai/whisper/discussions/928
KNOWN_HALLUCINATIONS=[ KNOWN_HALLUCINATIONS = [
# en # en
" www.mooji.org" " www.mooji.org"
# nl # nl
+11 -5
View File
@@ -2,6 +2,7 @@ import os
import yaml import yaml
from pyannote.audio.core.model import CACHE_DIR as PYANNOTE_CACHE_DIR from pyannote.audio.core.model import CACHE_DIR as PYANNOTE_CACHE_DIR
from argparse import Action from argparse import Action
from ast import literal_eval
CACHE_DIR = os.getenv( CACHE_DIR = os.getenv(
"AUTOT_CACHE", "AUTOT_CACHE",
@@ -14,8 +15,9 @@ if CACHE_DIR != PYANNOTE_CACHE_DIR:
WHISPER_DEFAULT_PATH = os.path.join(CACHE_DIR, "whisper") WHISPER_DEFAULT_PATH = os.path.join(CACHE_DIR, "whisper")
PYANNOTE_DEFAULT_PATH = os.path.join(CACHE_DIR, "pyannote") PYANNOTE_DEFAULT_PATH = os.path.join(CACHE_DIR, "pyannote")
PYANNOTE_DEFAULT_CONFIG = os.path.join(PYANNOTE_DEFAULT_PATH, "config.yaml") \ PYANNOTE_DEFAULT_CONFIG = os.path.join(PYANNOTE_DEFAULT_PATH, "config.yaml") \
if os.path.exists(os.path.join(PYANNOTE_DEFAULT_PATH, "config.yaml")) \ if os.path.exists(os.path.join(PYANNOTE_DEFAULT_PATH, "config.yaml")) \
else ('jaikinator/scraibe', 'pyannote/speaker-diarization-3.1') else ('jaikinator/scraibe', 'pyannote/speaker-diarization-3.1')
def config_diarization_yaml(file_path: str, path_to_segmentation: str = None) -> None: def config_diarization_yaml(file_path: str, path_to_segmentation: str = None) -> None:
"""Configure diarization pipeline from a YAML file. """Configure diarization pipeline from a YAML file.
@@ -33,25 +35,29 @@ def config_diarization_yaml(file_path: str, path_to_segmentation: str = None) ->
with open(file_path, "r") as stream: with open(file_path, "r") as stream:
yml = yaml.safe_load(stream) yml = yaml.safe_load(stream)
segmentation_path = path_to_segmentation or os.path.join(PYANNOTE_DEFAULT_PATH, "pytorch_model.bin") segmentation_path = path_to_segmentation or os.path.join(
PYANNOTE_DEFAULT_PATH, "pytorch_model.bin")
yml["pipeline"]["params"]["segmentation"] = segmentation_path yml["pipeline"]["params"]["segmentation"] = segmentation_path
if not os.path.exists(segmentation_path): if not os.path.exists(segmentation_path):
raise FileNotFoundError(f"Segmentation model not found at {segmentation_path}") raise FileNotFoundError(
f"Segmentation model not found at {segmentation_path}")
with open(file_path, "w") as stream: with open(file_path, "w") as stream:
yaml.dump(yml, stream) yaml.dump(yml, stream)
class ParseKwargs(Action): class ParseKwargs(Action):
""" """
Custom argparse action to parse keyword arguments. Custom argparse action to parse keyword arguments.
""" """
def __call__(self, parser, namespace, values, option_string=None): def __call__(self, parser, namespace, values, option_string=None):
setattr(namespace, self.dest, dict()) setattr(namespace, self.dest, dict())
for value in values: for value in values:
key, value = value.split('=') key, value = value.split('=')
try: try:
value = eval(value) value = literal_eval(value)
except: except:
pass pass
getattr(namespace, self.dest)[key] = value getattr(namespace, self.dest)[key] = value
+279 -36
View File
@@ -24,18 +24,22 @@ Usage:
>>> transcriber.save_transcript(transcript, "path/to/save.txt") >>> transcriber.save_transcript(transcript, "path/to/save.txt")
""" """
from whisper import Whisper, load_model from whisper import Whisper
from typing import TypeVar , Union , Optional from whisper import load_model as whisper_load_model
from whisperx.asr import WhisperModel
from whisperx import load_model as whisperx_load_model
from typing import TypeVar, Union, Optional
from torch import Tensor, device from torch import Tensor, device
from torch.cuda import is_available as cuda_is_available
from numpy import ndarray from numpy import ndarray
from inspect import signature
from abc import abstractmethod
import warnings
from .misc import WHISPER_DEFAULT_PATH from .misc import WHISPER_DEFAULT_PATH
whisper = TypeVar('whisper') whisper = TypeVar('whisper')
class Transcriber: class Transcriber:
""" """
Transcriber Class Transcriber Class
@@ -64,7 +68,8 @@ class Transcriber:
The class supports various sizes and versions of Whisper models. Please refer to The class supports various sizes and versions of Whisper models. Please refer to
the load_model method for available options. the load_model method for available options.
""" """
def __init__(self, model: whisper , model_name: str ) -> None:
def __init__(self, model: whisper, model_name: str) -> None:
""" """
Initialize the Transcriber class with a Whisper model. Initialize the Transcriber class with a Whisper model.
@@ -77,7 +82,103 @@ class Transcriber:
self.model_name = model_name self.model_name = model_name
def transcribe(self, audio : Union[str, Tensor, ndarray] , @abstractmethod
def transcribe(self, audio: Union[str, Tensor, ndarray],
*args, **kwargs) -> str:
"""
Transcribe an audio file.
Args:
audio (Union[str, Tensor, nparray]): The audio file to transcribe.
*args: Additional arguments.
**kwargs: Additional keyword arguments,
such as the language of the audio file.
Returns:
str: The transcript as a string.
"""
pass
@staticmethod
def save_transcript(transcript: str, save_path: str) -> None:
"""
Save a transcript to a file.
Args:
transcript (str): The transcript as a string.
save_path (str): The path to save the transcript.
Returns:
None
"""
with open(save_path, 'w') as f:
f.write(transcript)
print(f'Transcript saved to {save_path}')
@classmethod
@abstractmethod
def load_model(cls,
model: str = "medium",
whisper_type: str = 'whisper',
download_root: str = WHISPER_DEFAULT_PATH,
device: Optional[Union[str, device]] = None,
in_memory: bool = False,
*args, **kwargs
) -> None:
"""
Load whisper model.
Args:
model (str): Whisper model. Available models include:
- 'tiny.en'
- 'tiny'
- 'base.en'
- 'base'
- 'small.en'
- 'small'
- 'medium.en'
- 'medium'
- 'large-v1'
- 'large-v2'
- 'large-v3'
- 'large'
whisper_type (str):
Type of whisper model to load. "whisper" or "whisperx".
download_root (str, optional): Path to download the model.
Defaults to WHISPER_DEFAULT_PATH.
device (Optional[Union[str, torch.device]], optional):
Device to load model on. Defaults to None.
in_memory (bool, optional): Whether to load model in memory.
Defaults to False.
args: Additional arguments only to avoid errors.
kwargs: Additional keyword arguments only to avoid errors.
Returns:
None: abscract method.
"""
pass
@staticmethod
def _get_whisper_kwargs(**kwargs) -> dict:
"""
Get kwargs for whisper model. Ensure that kwargs are valid.
Returns:
dict: Keyword arguments for whisper model.
"""
pass
def __repr__(self) -> str:
return f"Transcriber(model_name={self.model_name}, model={self.model})"
class WhisperTranscriber(Transcriber):
def __init__(self, model: whisper, model_name: str) -> None:
super().__init__(model, model_name)
def transcribe(self, audio: Union[str, Tensor, ndarray],
*args, **kwargs) -> str: *args, **kwargs) -> str:
""" """
Transcribe an audio file. Transcribe an audio file.
@@ -100,32 +201,14 @@ class Transcriber:
result = self.model.transcribe(audio, *args, **kwargs) result = self.model.transcribe(audio, *args, **kwargs)
return result["text"] return result["text"]
@staticmethod
def save_transcript(transcript : str , save_path : str) -> None:
"""
Save a transcript to a file.
Args:
transcript (str): The transcript as a string.
save_path (str): The path to save the transcript.
Returns:
None
"""
with open(save_path, 'w') as f:
f.write(transcript)
print(f'Transcript saved to {save_path}')
@classmethod @classmethod
def load_model(cls, def load_model(cls,
model: str = "medium", model: str = "medium",
download_root: str = WHISPER_DEFAULT_PATH, download_root: str = WHISPER_DEFAULT_PATH,
device: Optional[Union[str, device]] = None, device: Optional[Union[str, device]] = None,
in_memory: bool = False, in_memory: bool = False,
*args, **kwargs *args, **kwargs
) -> 'Transcriber': ) -> 'WhisperTranscriber':
""" """
Load whisper model. Load whisper model.
@@ -158,8 +241,8 @@ class Transcriber:
Transcriber: A Transcriber object initialized with the specified model. Transcriber: A Transcriber object initialized with the specified model.
""" """
_model = load_model(model, download_root=download_root, _model = whisper_load_model(model, download_root=download_root,
device=device, in_memory=in_memory) device=device, in_memory=in_memory)
return cls(_model, model_name=model) return cls(_model, model_name=model)
@@ -171,9 +254,11 @@ class Transcriber:
Returns: Returns:
dict: Keyword arguments for whisper model. dict: Keyword arguments for whisper model.
""" """
_possible_kwargs = Whisper.transcribe.__code__.co_varnames # _possible_kwargs = WhisperModel.transcribe.__code__.co_varnames
_possible_kwargs = signature(Whisper.transcribe).parameters.keys()
whisper_kwargs = {k: v for k, v in kwargs.items() if k in _possible_kwargs} whisper_kwargs = {k: v for k,
v in kwargs.items() if k in _possible_kwargs}
if (task := kwargs.get("task")): if (task := kwargs.get("task")):
whisper_kwargs["task"] = task whisper_kwargs["task"] = task
@@ -184,4 +269,162 @@ class Transcriber:
return whisper_kwargs return whisper_kwargs
def __repr__(self) -> str: def __repr__(self) -> str:
return f"Transcriber(model_name={self.model_name}, model={self.model})" return f"WhisperTranscriber(model_name={self.model_name}, model={self.model})"
class WhisperXTranscriber(Transcriber):
def __init__(self, model: whisper, model_name: str) -> None:
super().__init__(model, model_name)
def transcribe(self, audio: Union[str, Tensor, ndarray],
*args, **kwargs) -> str:
"""
Transcribe an audio file.
Args:
audio (Union[str, Tensor, nparray]): The audio file to transcribe.
*args: Additional arguments.
**kwargs: Additional keyword arguments,
such as the language of the audio file.
Returns:
str: The transcript as a string.
"""
kwargs = self._get_whisper_kwargs(**kwargs)
if isinstance(audio, Tensor):
audio = audio.cpu().numpy()
result = self.model.transcribe(audio, *args, **kwargs)
text = ""
for seg in result['segments']:
text += seg['text']
return text
@classmethod
def load_model(cls,
model: str = "medium",
download_root: str = WHISPER_DEFAULT_PATH,
device: Optional[Union[str, device]] = None,
*args, **kwargs
) -> 'WhisperXTranscriber':
"""
Load whisper model.
Args:
model (str): Whisper model. Available models include:
- 'tiny.en'
- 'tiny'
- 'base.en'
- 'base'
- 'small.en'
- 'small'
- 'medium.en'
- 'medium'
- 'large-v1'
- 'large-v2'
- 'large-v3'
- 'large'
download_root (str, optional): Path to download the model.
Defaults to WHISPER_DEFAULT_PATH.
device (Optional[Union[str, torch.device]], optional):
Device to load model on. Defaults to None.
in_memory (bool, optional): Whether to load model in memory.
Defaults to False.
args: Additional arguments only to avoid errors.
kwargs: Additional keyword arguments only to avoid errors.
Returns:
Transcriber: A Transcriber object initialized with the specified model.
"""
if device is None:
device = "cuda" if cuda_is_available() else "cpu"
if not isinstance(device, str):
device = str(device)
compute_type = kwargs.get('compute_type', 'float16')
if device == 'cpu' and compute_type == 'float16':
warnings.warn(f'Compute type {compute_type} not compatible with '
f'device {device}! Changing compute type to int8.')
compute_type = 'int8'
_model = whisperx_load_model(model, download_root=download_root,
device=device, compute_type=compute_type)
return cls(_model, model_name=model)
@staticmethod
def _get_whisper_kwargs(**kwargs) -> dict:
"""
Get kwargs for whisper model. Ensure that kwargs are valid.
Returns:
dict: Keyword arguments for whisper model.
"""
# _possible_kwargs = WhisperModel.transcribe.__code__.co_varnames
_possible_kwargs = signature(WhisperModel.transcribe).parameters.keys()
whisper_kwargs = {k: v for k,
v in kwargs.items() if k in _possible_kwargs}
if (task := kwargs.get("task")):
whisper_kwargs["task"] = task
if (language := kwargs.get("language")):
whisper_kwargs["language"] = language
return whisper_kwargs
def __repr__(self) -> str:
return f"WhisperXTranscriber(model_name={self.model_name}, model={self.model})"
def load_transcriber(model: str = "medium",
whisper_type: str = 'whisper',
download_root: str = WHISPER_DEFAULT_PATH,
device: Optional[Union[str, device]] = None,
in_memory: bool = False,
*args, **kwargs
) -> Union[WhisperTranscriber, WhisperXTranscriber]:
"""
Load whisper model.
Args:
model (str): Whisper model. Available models include:
- 'tiny.en'
- 'tiny'
- 'base.en'
- 'base'
- 'small.en'
- 'small'
- 'medium.en'
- 'medium'
- 'large-v1'
- 'large-v2'
- 'large-v3'
- 'large'
whisper_type (str):
Type of whisper model to load. "whisper" or "whisperx".
download_root (str, optional): Path to download the model.
Defaults to WHISPER_DEFAULT_PATH.
device (Optional[Union[str, torch.device]], optional):
Device to load model on. Defaults to None.
in_memory (bool, optional): Whether to load model in memory.
Defaults to False.
args: Additional arguments only to avoid errors.
kwargs: Additional keyword arguments only to avoid errors.
Returns:
Union[WhisperTranscriber, WhisperXTranscriber]:
One of the Whisper variants as Transcrbier object initialized with the specified model.
"""
if whisper_type.lower() == 'whisper':
_model = WhisperTranscriber.load_model(
model, download_root, device, in_memory, *args, **kwargs)
return _model
elif whisper_type.lower() == 'whisperx':
_model = WhisperXTranscriber.load_model(
model, download_root, device, *args, **kwargs)
return _model
else:
raise ValueError(f'Model type not recognized, exptected "whisper" '
f'or "whisperx", got {whisper_type}.')
+21 -19
View File
@@ -1,5 +1,6 @@
import json import json
import time import time
from json.decoder import JSONDecodeError
from typing import Union from typing import Union
@@ -8,7 +9,6 @@ from .hallucinations import KNOWN_HALLUCINATIONS
ALPHABET = [*"abcdefghijklmnopqrstuvwxyz"] ALPHABET = [*"abcdefghijklmnopqrstuvwxyz"]
class Transcript: class Transcript:
""" """
Class for storing transcript data, including speaker information and text segments, Class for storing transcript data, including speaker information and text segments,
@@ -49,7 +49,8 @@ class Transcript:
annotations = {} annotations = {}
if args and len(args) != len(self.speakers): if args and len(args) != len(self.speakers):
raise ValueError("Number of speaker names does not match number of speakers") raise ValueError(
"Number of speaker names does not match number of speakers")
if args: if args:
for arg, speaker in zip(args, sorted(self.speakers)): for arg, speaker in zip(args, sorted(self.speakers)):
@@ -58,9 +59,11 @@ class Transcript:
invalid_speakers = set(kwargs.keys()) - set(self.speakers) invalid_speakers = set(kwargs.keys()) - set(self.speakers)
if invalid_speakers: if invalid_speakers:
raise ValueError(f"These keys are not speakers: {', '.join(invalid_speakers)}") raise ValueError(
f"These keys are not speakers: {', '.join(invalid_speakers)}")
annotations.update({key: kwargs[key] for key in self.speakers if key in kwargs}) annotations.update({key: kwargs[key]
for key in self.speakers if key in kwargs})
self.annotation = annotations self.annotation = annotations
@@ -71,11 +74,13 @@ class Transcript:
Removes all occurances of known hallucinations from all segments of the transcript. Removes all occurances of known hallucinations from all segments of the transcript.
Segments that are identical to empty strings afterwards are removed from the transcript. Segments that are identical to empty strings afterwards are removed from the transcript.
""" """
segments_to_drop=[] segments_to_drop = []
for id in self.transcript: for id in self.transcript:
for snippet in KNOWN_HALLUCINATIONS: for snippet in KNOWN_HALLUCINATIONS:
self.transcript[id]['text']=self.transcript[id]['text'].replace(snippet,'') self.transcript[id]['text'] = self.transcript[id]['text'].replace(
if self.transcript[id]['text'] == '': segments_to_drop.append(id) snippet, '')
if self.transcript[id]['text'] == '':
segments_to_drop.append(id)
for id in segments_to_drop: for id in segments_to_drop:
del self.transcript[id] del self.transcript[id]
@@ -119,8 +124,8 @@ class Transcript:
speaker = seq["speakers"] speaker = seq["speakers"]
segm = seq["segments"] segm = seq["segments"]
sseg = time.strftime("%H:%M:%S",time.gmtime(segm[0])) sseg = time.strftime("%H:%M:%S", time.gmtime(segm[0]))
eseg = time.strftime("%H:%M:%S",time.gmtime(segm[1])) eseg = time.strftime("%H:%M:%S", time.gmtime(segm[1]))
fstring += f"{speaker} ({sseg} ; {eseg}):\t{seq['text']}\n" fstring += f"{speaker} ({sseg} ; {eseg}):\t{seq['text']}\n"
@@ -133,7 +138,7 @@ class Transcript:
str: A string that provides an informative description of the object. str: A string that provides an informative description of the object.
""" """
return f"Transcript(speakers = {self.speakers},"\ return f"Transcript(speakers = {self.speakers},"\
f"segments = {self.segments}, annotation = {self.annotation})" f"segments = {self.segments}, annotation = {self.annotation})"
def get_dict(self) -> dict: def get_dict(self) -> dict:
""" """
@@ -145,7 +150,7 @@ class Transcript:
return self.transcript return self.transcript
def get_json(self, *args, use_annotation : bool = True, **kwargs) -> str: def get_json(self, *args, use_annotation: bool = True, **kwargs) -> str:
""" """
Get transcript as json string Get transcript as json string
:return: transcript as json string :return: transcript as json string
@@ -193,12 +198,12 @@ class Transcript:
self.annotate(*ALPHABET[:len(self.speakers)]) self.annotate(*ALPHABET[:len(self.speakers)])
fstring ="\\begin{drama}" fstring = "\\begin{drama}"
for speaker in self.speakers: for speaker in self.speakers:
fstring += "\n\t\\Character{"+ str(self.annotation[speaker]) + "}" \ fstring += "\n\t\\Character{" + str(self.annotation[speaker]) + "}" \
"{"+ str(self.annotation[speaker]) + "}" "{" + str(self.annotation[speaker]) + "}"
for id in self.transcript: for id in self.transcript:
seq = self.transcript[id] seq = self.transcript[id]
@@ -209,8 +214,7 @@ class Transcript:
return fstring return fstring
def to_json(self, path, *args, **kwargs) -> None:
def to_json(self,path, *args, **kwargs) -> None:
"""Save transcript as json file """Save transcript as json file
Args: Args:
@@ -310,10 +314,8 @@ class Transcript:
else: else:
try: try:
transcript = json.loads(json) transcript = json.loads(json)
except: except (TypeError, JSONDecodeError):
with open(json, "r") as f: with open(json, "r") as f:
transcript = json.load(f) transcript = json.load(f)
return cls(transcript) return cls(transcript)
+10 -10
View File
@@ -31,16 +31,16 @@ release = '0.1.1'
# extensions coming with Sphinx (named 'sphinx.ext.*') or your custom # extensions coming with Sphinx (named 'sphinx.ext.*') or your custom
# ones. # ones.
extensions = ['sphinx.ext.autodoc', extensions = ['sphinx.ext.autodoc',
'sphinx.ext.doctest', 'sphinx.ext.doctest',
'sphinx.ext.intersphinx', 'sphinx.ext.intersphinx',
'sphinx.ext.todo', 'sphinx.ext.todo',
'sphinx.ext.coverage', 'sphinx.ext.coverage',
'sphinx.ext.mathjax', 'sphinx.ext.mathjax',
'sphinx.ext.ifconfig', 'sphinx.ext.ifconfig',
'sphinx.ext.viewcode', 'sphinx.ext.viewcode',
'sphinx.ext.githubpages', 'sphinx.ext.githubpages',
'sphinx.ext.napoleon', 'sphinx.ext.napoleon',
'myst_parser'] 'myst_parser']
# Napoleon settings # Napoleon settings
napoleon_google_docstring = True napoleon_google_docstring = True
+2 -33
View File
@@ -3,7 +3,6 @@ from scraibe.audio import AudioProcessor
import torch import torch
DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu") DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu")
TEST_WAVEFORM = torch.sin(torch.randn(160000)).to(DEVICE) TEST_WAVEFORM = torch.sin(torch.randn(160000)).to(DEVICE)
TEST_SR = 16000 TEST_SR = 16000
@@ -25,10 +24,6 @@ def probe_audio_processor():
return AudioProcessor(TEST_WAVEFORM, TEST_SR) return AudioProcessor(TEST_WAVEFORM, TEST_SR)
def test_AudioProcessor_init(probe_audio_processor): def test_AudioProcessor_init(probe_audio_processor):
""" """
Test the initialization of the AudioProcessor class. Test the initialization of the AudioProcessor class.
@@ -53,7 +48,6 @@ def test_AudioProcessor_init(probe_audio_processor):
assert probe_audio_processor.sr == TEST_SR assert probe_audio_processor.sr == TEST_SR
def test_cut(probe_audio_processor): def test_cut(probe_audio_processor):
"""Test the cut function of the AudioProcessor class. """Test the cut function of the AudioProcessor class.
@@ -73,15 +67,7 @@ def test_cut(probe_audio_processor):
expected_size = int((end - start) * TEST_SR) expected_size = int((end - start) * TEST_SR)
real_size = trimmed_waveform.size(0) real_size = trimmed_waveform.size(0)
assert real_size == expected_size assert real_size == expected_size
#assert AudioProcessor(TEST_WAVEFORM, TEST_SR).cut(start, end).size() == int((end - start) * TEST_SR) # assert AudioProcessor(TEST_WAVEFORM, TEST_SR).cut(start, end).size() == int((end - start) * TEST_SR)
def test_audio_processor_invalid_sr(): def test_audio_processor_invalid_sr():
@@ -94,7 +80,7 @@ def test_audio_processor_invalid_sr():
None None
""" """
with pytest.raises(ValueError): with pytest.raises(ValueError):
AudioProcessor(TEST_WAVEFORM, [44100,48000]) AudioProcessor(TEST_WAVEFORM, [44100, 48000])
def test_audio_processor_SAMPLE_RATE(): def test_audio_processor_SAMPLE_RATE():
@@ -108,20 +94,3 @@ def test_audio_processor_SAMPLE_RATE():
""" """
probe_audio_processor = AudioProcessor(TEST_WAVEFORM) probe_audio_processor = AudioProcessor(TEST_WAVEFORM)
assert probe_audio_processor.sr == SAMPLE_RATE assert probe_audio_processor.sr == SAMPLE_RATE
+1 -7
View File
@@ -1,22 +1,16 @@
import pytest import pytest
from scraibe import Scraibe, Diariser, Transcriber, Transcript from scraibe import Scraibe, Diariser, Transcriber, Transcript
from unittest.mock import MagicMock, patch
import os import os
@pytest.fixture @pytest.fixture
def create_scraibe_instance(): def create_scraibe_instance():
if "HF_TOKEN" in os.environ: if "HF_TOKEN" in os.environ:
return Scraibe(use_auth_token=os.environ["HF_TOKEN"] ) return Scraibe(use_auth_token=os.environ["HF_TOKEN"])
else: else:
return Scraibe() return Scraibe()
def test_scraibe_init(create_scraibe_instance): def test_scraibe_init(create_scraibe_instance):
model = create_scraibe_instance model = create_scraibe_instance
assert isinstance(model.transcriber, Transcriber) assert isinstance(model.transcriber, Transcriber)
+2 -17
View File
@@ -1,8 +1,5 @@
import pytest import pytest
import os from scraibe import Diariser
from unittest import mock
from scraibe import diarisation, Diariser
@pytest.fixture @pytest.fixture
@@ -15,11 +12,10 @@ def diariser_instance():
Returns: Returns:
Diariser(Obj): An instance of the Diariser class with a mocked token. Diariser(Obj): An instance of the Diariser class with a mocked token.
""" """
#with mock.patch.object(Diariser, '_get_token', return_value = 'HF_TOKEN' ): # with mock.patch.object(Diariser, '_get_token', return_value = 'HF_TOKEN' ):
return Diariser('pyannote') return Diariser('pyannote')
def test_Diariser_init(diariser_instance): def test_Diariser_init(diariser_instance):
"""Test the initialization of the Diariser class. """Test the initialization of the Diariser class.
@@ -34,14 +30,3 @@ def test_Diariser_init(diariser_instance):
None None
""" """
assert diariser_instance.model == 'pyannote' assert diariser_instance.model == 'pyannote'
+40 -12
View File
@@ -1,10 +1,9 @@
import pytest import pytest
from unittest.mock import patch from scraibe import (Transcriber, WhisperTranscriber,
from scraibe import Transcriber WhisperXTranscriber, load_transcriber)
import torch import torch
DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu") DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu")
TEST_WAVEFORM = "Hello World" TEST_WAVEFORM = "Hello World"
@@ -29,12 +28,37 @@ def test_transcriber(mock_load_model, audio_file, expected_transcription):
assert transcription_result == expected_transcription """ assert transcription_result == expected_transcription """
@pytest.fixture
def transcriber_instance():
return Transcriber.load_model('medium')
def test_transcriber_initialization(transcriber_instance): @pytest.fixture
assert isinstance(transcriber_instance, Transcriber) def whisper_instance():
return load_transcriber('medium', whisper_type='whisper')
@pytest.fixture
def whisperx_instance():
return load_transcriber('medium', whisper_type='whisperx')
def test_whisper_base_initialization(whisper_instance):
assert isinstance(whisper_instance, Transcriber)
def test_whisperx_base_initialization(whisperx_instance):
assert isinstance(whisperx_instance, Transcriber)
def test_whisper_transcriber_initialization(whisper_instance):
assert isinstance(whisper_instance, WhisperTranscriber)
def test_whisperx_transcriber_initialization(whisperx_instance):
assert isinstance(whisperx_instance, WhisperXTranscriber)
def test_wrong_transcriber_initialization():
with pytest.raises(ValueError):
load_transcriber('medium', whisper_type='wrong_whisper')
def test_get_whisper_kwargs(): def test_get_whisper_kwargs():
kwargs = {"arg1": 1, "arg3": 3} kwargs = {"arg1": 1, "arg3": 3}
@@ -42,11 +66,15 @@ def test_get_whisper_kwargs():
assert not valid_kwargs == {"arg1": 1, "arg3": 3} assert not valid_kwargs == {"arg1": 1, "arg3": 3}
def test_transcribe(transcriber_instance): def test_whisper_transcribe(whisper_instance):
model = transcriber_instance model = whisper_instance
#mocker.patch.object(transcriber_instance.model, 'transcribe', return_value={'Hello, World !'} ) # mocker.patch.object(transcriber_instance.model, 'transcribe', return_value={'Hello, World !'} )
transcript = model.transcribe('test/audio_test_2.mp4') transcript = model.transcribe('test/audio_test_2.mp4')
assert isinstance(transcript, str) assert isinstance(transcript, str)
def test_whisperx_transcribe(whisperx_instance):
model = whisperx_instance
# mocker.patch.object(transcriber_instance.model, 'transcribe', return_value={'Hello, World !'} )
transcript = model.transcribe('test/audio_test_2.mp4')
assert isinstance(transcript, str)