Auto fixes from PEP8, fixes from flake8.
This commit is contained in:
+8
-6
@@ -28,6 +28,7 @@ import torch
|
|||||||
SAMPLE_RATE = 16000
|
SAMPLE_RATE = 16000
|
||||||
NORMALIZATION_FACTOR = 32768.0
|
NORMALIZATION_FACTOR = 32768.0
|
||||||
|
|
||||||
|
|
||||||
class AudioProcessor:
|
class AudioProcessor:
|
||||||
"""
|
"""
|
||||||
Audio Processor class that leverages PyTorchaudio to provide functionalities
|
Audio Processor class that leverages PyTorchaudio to provide functionalities
|
||||||
@@ -42,7 +43,6 @@ class AudioProcessor:
|
|||||||
|
|
||||||
def __init__(self, waveform: torch.Tensor, sr: int = SAMPLE_RATE,
|
def __init__(self, waveform: torch.Tensor, sr: int = SAMPLE_RATE,
|
||||||
*args, **kwargs) -> None:
|
*args, **kwargs) -> None:
|
||||||
|
|
||||||
"""
|
"""
|
||||||
Initialize the AudioProcessor object.
|
Initialize the AudioProcessor object.
|
||||||
|
|
||||||
@@ -57,13 +57,14 @@ class AudioProcessor:
|
|||||||
ValueError: If the provided sample rate is not of type int.
|
ValueError: If the provided sample rate is not of type int.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
device = kwargs.get("device", "cuda" if torch.cuda.is_available() else "cpu")
|
device = kwargs.get(
|
||||||
|
"device", "cuda" if torch.cuda.is_available() else "cpu")
|
||||||
|
|
||||||
self.waveform = waveform.to(device)
|
self.waveform = waveform.to(device)
|
||||||
self.sr = sr
|
self.sr = sr
|
||||||
|
|
||||||
if not isinstance(self.sr, int):
|
if not isinstance(self.sr, int):
|
||||||
raise ValueError("Sample rate should be a single value of type int," \
|
raise ValueError("Sample rate should be a single value of type int,"
|
||||||
f"not {len(self.sr)} and type {type(self.sr)}")
|
f"not {len(self.sr)} and type {type(self.sr)}")
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
@@ -84,7 +85,6 @@ class AudioProcessor:
|
|||||||
|
|
||||||
return cls(audio, sr)
|
return cls(audio, sr)
|
||||||
|
|
||||||
|
|
||||||
def cut(self, start: float, end: float) -> torch.Tensor:
|
def cut(self, start: float, end: float) -> torch.Tensor:
|
||||||
"""
|
"""
|
||||||
Cut a segment from the audio waveform between the specified start and end times.
|
Cut a segment from the audio waveform between the specified start and end times.
|
||||||
@@ -140,9 +140,11 @@ class AudioProcessor:
|
|||||||
try:
|
try:
|
||||||
out = run(cmd, capture_output=True, check=True).stdout
|
out = run(cmd, capture_output=True, check=True).stdout
|
||||||
except CalledProcessError as e:
|
except CalledProcessError as e:
|
||||||
raise RuntimeError(f"Failed to load audio: {e.stderr.decode()}") from e
|
raise RuntimeError(
|
||||||
|
f"Failed to load audio: {e.stderr.decode()}") from e
|
||||||
|
|
||||||
out = np.frombuffer(out, np.int16).flatten().astype(np.float32) / NORMALIZATION_FACTOR
|
out = np.frombuffer(out, np.int16).flatten().astype(
|
||||||
|
np.float32) / NORMALIZATION_FACTOR
|
||||||
|
|
||||||
return out, sr
|
return out, sr
|
||||||
|
|
||||||
|
|||||||
+13
-10
@@ -62,6 +62,7 @@ class Scraibe:
|
|||||||
remove_audio_file: Removes the original audio file to avoid disk space issues or ensure data privacy.
|
remove_audio_file: Removes the original audio file to avoid disk space issues or ensure data privacy.
|
||||||
get_audio_file: Gets an audio file as an AudioProcessor object.
|
get_audio_file: Gets an audio file as an AudioProcessor object.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
def __init__(self,
|
def __init__(self,
|
||||||
whisper_model: Union[bool, str, whisper] = None,
|
whisper_model: Union[bool, str, whisper] = None,
|
||||||
whisper_type: str = "whisper",
|
whisper_type: str = "whisper",
|
||||||
@@ -85,11 +86,12 @@ class Scraibe:
|
|||||||
for autotranscribe. So you can unload the class and reload it again.
|
for autotranscribe. So you can unload the class and reload it again.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
|
|
||||||
if whisper_model is None:
|
if whisper_model is None:
|
||||||
self.transcriber = Transcriber.load_model("medium", whisper_type, **kwargs)
|
self.transcriber = Transcriber.load_model(
|
||||||
|
"medium", whisper_type, **kwargs)
|
||||||
elif isinstance(whisper_model, str):
|
elif isinstance(whisper_model, str):
|
||||||
self.transcriber = Transcriber.load_model(whisper_model, whisper_type, **kwargs)
|
self.transcriber = Transcriber.load_model(
|
||||||
|
whisper_model, whisper_type, **kwargs)
|
||||||
else:
|
else:
|
||||||
self.transcriber = whisper_model
|
self.transcriber = whisper_model
|
||||||
|
|
||||||
@@ -114,7 +116,6 @@ class Scraibe:
|
|||||||
else:
|
else:
|
||||||
self.params = {}
|
self.params = {}
|
||||||
|
|
||||||
|
|
||||||
def autotranscribe(self, audio_file: Union[str, torch.Tensor, ndarray],
|
def autotranscribe(self, audio_file: Union[str, torch.Tensor, ndarray],
|
||||||
remove_original: bool = False,
|
remove_original: bool = False,
|
||||||
**kwargs) -> Transcript:
|
**kwargs) -> Transcript:
|
||||||
@@ -152,7 +153,8 @@ class Scraibe:
|
|||||||
if not diarisation["segments"]:
|
if not diarisation["segments"]:
|
||||||
print("No segments found. Try to run transcription without diarisation.")
|
print("No segments found. Try to run transcription without diarisation.")
|
||||||
|
|
||||||
transcript = self.transcriber.transcribe(audio_file.waveform, **kwargs)
|
transcript = self.transcriber.transcribe(
|
||||||
|
audio_file.waveform, **kwargs)
|
||||||
|
|
||||||
final_transcript = {0: {"speakers": 'SPEAKER_01',
|
final_transcript = {0: {"speakers": 'SPEAKER_01',
|
||||||
"segments": [0, len(audio_file.waveform)],
|
"segments": [0, len(audio_file.waveform)],
|
||||||
@@ -163,7 +165,8 @@ class Scraibe:
|
|||||||
if self.verbose:
|
if self.verbose:
|
||||||
print("Diarisation finished. Starting transcription.")
|
print("Diarisation finished. Starting transcription.")
|
||||||
|
|
||||||
audio_file.sr = torch.Tensor([audio_file.sr]).to(audio_file.waveform.device)
|
audio_file.sr = torch.Tensor([audio_file.sr]).to(
|
||||||
|
audio_file.waveform.device)
|
||||||
|
|
||||||
# Transcribe each segment and store the results
|
# Transcribe each segment and store the results
|
||||||
final_transcript = dict()
|
final_transcript = dict()
|
||||||
@@ -259,7 +262,8 @@ class Scraibe:
|
|||||||
elif isinstance(whisper_model, Transcriber):
|
elif isinstance(whisper_model, Transcriber):
|
||||||
self.transcriber = whisper_model
|
self.transcriber = whisper_model
|
||||||
else:
|
else:
|
||||||
warn(f"Invalid model type. Please provide a valid model. Fallback to old {_old_model} Model.", RuntimeWarning)
|
warn(
|
||||||
|
f"Invalid model type. Please provide a valid model. Fallback to old {_old_model} Model.", RuntimeWarning)
|
||||||
|
|
||||||
return None
|
return None
|
||||||
|
|
||||||
@@ -281,7 +285,7 @@ class Scraibe:
|
|||||||
elif isinstance(dia_model, Diariser):
|
elif isinstance(dia_model, Diariser):
|
||||||
self.diariser = dia_model
|
self.diariser = dia_model
|
||||||
else:
|
else:
|
||||||
warn(f"Invalid model type. Please provide a valid model. Fallback to old Model.", RuntimeWarning)
|
warn("Invalid model type. Please provide a valid model. Fallback to old Model.", RuntimeWarning)
|
||||||
|
|
||||||
return None
|
return None
|
||||||
|
|
||||||
@@ -318,7 +322,6 @@ class Scraibe:
|
|||||||
os.remove(audio_file)
|
os.remove(audio_file)
|
||||||
print(f"Audiofile {audio_file} removed.")
|
print(f"Audiofile {audio_file} removed.")
|
||||||
|
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
def get_audio_file(audio_file: Union[str, torch.Tensor, ndarray],
|
def get_audio_file(audio_file: Union[str, torch.Tensor, ndarray],
|
||||||
*args, **kwargs) -> AudioProcessor:
|
*args, **kwargs) -> AudioProcessor:
|
||||||
@@ -345,7 +348,7 @@ class Scraibe:
|
|||||||
audio_file[1])
|
audio_file[1])
|
||||||
|
|
||||||
if not isinstance(audio_file, AudioProcessor):
|
if not isinstance(audio_file, AudioProcessor):
|
||||||
raise ValueError(f'Audiofile must be of type AudioProcessor,' \
|
raise ValueError(f'Audiofile must be of type AudioProcessor,'
|
||||||
f'not {type(audio_file)}')
|
f'not {type(audio_file)}')
|
||||||
|
|
||||||
return audio_file
|
return audio_file
|
||||||
|
|||||||
+20
-12
@@ -32,7 +32,8 @@ def cli():
|
|||||||
if string in str2val:
|
if string in str2val:
|
||||||
return str2val[string]
|
return str2val[string]
|
||||||
else:
|
else:
|
||||||
raise ValueError(f"Expected one of {set(str2val.keys())}, got {string}")
|
raise ValueError(
|
||||||
|
f"Expected one of {set(str2val.keys())}, got {string}")
|
||||||
|
|
||||||
parser = ArgumentParser(formatter_class=ArgumentDefaultsHelpFormatter)
|
parser = ArgumentParser(formatter_class=ArgumentDefaultsHelpFormatter)
|
||||||
|
|
||||||
@@ -42,8 +43,8 @@ def cli():
|
|||||||
help="List of audio files to transcribe.")
|
help="List of audio files to transcribe.")
|
||||||
|
|
||||||
group.add_argument('--start-server', action='store_true',
|
group.add_argument('--start-server', action='store_true',
|
||||||
help='Start the Gradio app.' \
|
help='Start the Gradio app.'
|
||||||
'If set, all other arguments are ignored' \
|
'If set, all other arguments are ignored'
|
||||||
'besides --server-config or --server-kwargs.')
|
'besides --server-config or --server-kwargs.')
|
||||||
|
|
||||||
parser.add_argument("--server-config", type=str, default=None,
|
parser.add_argument("--server-config", type=str, default=None,
|
||||||
@@ -89,7 +90,8 @@ def cli():
|
|||||||
If set to translate, the output will be translated to English.")
|
If set to translate, the output will be translated to English.")
|
||||||
|
|
||||||
parser.add_argument("--language", type=str, default=None,
|
parser.add_argument("--language", type=str, default=None,
|
||||||
choices=sorted(LANGUAGES.keys()) + sorted([k.title() for k in TO_LANGUAGE_CODE.keys()]),
|
choices=sorted(
|
||||||
|
LANGUAGES.keys()) + sorted([k.title() for k in TO_LANGUAGE_CODE.keys()]),
|
||||||
help="Language spoken in the audio. Specify None to perform language detection.")
|
help="Language spoken in the audio. Specify None to perform language detection.")
|
||||||
|
|
||||||
args = parser.parse_args()
|
args = parser.parse_args()
|
||||||
@@ -131,15 +133,17 @@ def cli():
|
|||||||
else:
|
else:
|
||||||
task = "transcribe"
|
task = "transcribe"
|
||||||
|
|
||||||
out = model.autotranscribe(audio,task = task, language=arg_dict.pop("language"), verbose = arg_dict.pop("verbose_output"))
|
out = model.autotranscribe(audio, task=task, language=arg_dict.pop(
|
||||||
|
"language"), verbose=arg_dict.pop("verbose_output"))
|
||||||
basename = audio.split("/")[-1].split(".")[0]
|
basename = audio.split("/")[-1].split(".")[0]
|
||||||
print(f'Saving {basename}.{out_format} to {out_folder}')
|
print(f'Saving {basename}.{out_format} to {out_folder}')
|
||||||
out.save(os.path.join(out_folder, f"{basename}.{out_format}"))
|
out.save(os.path.join(
|
||||||
|
out_folder, f"{basename}.{out_format}"))
|
||||||
|
|
||||||
elif task == "diarization":
|
elif task == "diarization":
|
||||||
for audio in audio_files:
|
for audio in audio_files:
|
||||||
if arg_dict.pop("verbose_output"):
|
if arg_dict.pop("verbose_output"):
|
||||||
print(f"Verbose not implemented for diarization.")
|
print("Verbose not implemented for diarization.")
|
||||||
|
|
||||||
out = model.diarization(audio)
|
out = model.diarization(audio)
|
||||||
basename = audio.split("/")[-1].split(".")[0]
|
basename = audio.split("/")[-1].split(".")[0]
|
||||||
@@ -162,25 +166,29 @@ def cli():
|
|||||||
with open(path, "w") as f:
|
with open(path, "w") as f:
|
||||||
f.write(out)
|
f.write(out)
|
||||||
|
|
||||||
|
|
||||||
else: # unfinished code
|
else: # unfinished code
|
||||||
raise NotImplementedError("Currently not Working")
|
raise NotImplementedError("Currently not Working")
|
||||||
import subprocess
|
import subprocess
|
||||||
import sys
|
import sys
|
||||||
|
|
||||||
execute_path = os.path.join(os.path.dirname(__file__), "app/app_starter.py")
|
execute_path = os.path.join(
|
||||||
|
os.path.dirname(__file__), "app/app_starter.py")
|
||||||
|
|
||||||
config = arg_dict.pop("server_config")
|
config = arg_dict.pop("server_config")
|
||||||
server_kwargs = arg_dict.pop("server_kwargs")
|
server_kwargs = arg_dict.pop("server_kwargs")
|
||||||
|
|
||||||
if not config:
|
if not config:
|
||||||
subprocess.run([sys.executable, execute_path, f"--server-kwargs={server_kwargs}"])
|
subprocess.run([sys.executable, execute_path,
|
||||||
|
f"--server-kwargs={server_kwargs}"])
|
||||||
elif not server_kwargs:
|
elif not server_kwargs:
|
||||||
subprocess.run([sys.executable, execute_path, f"--server-config={config}"])
|
subprocess.run([sys.executable, execute_path,
|
||||||
|
f"--server-config={config}"])
|
||||||
elif not config and not server_kwargs:
|
elif not config and not server_kwargs:
|
||||||
subprocess.run([sys.executable, execute_path])
|
subprocess.run([sys.executable, execute_path])
|
||||||
else:
|
else:
|
||||||
subprocess.run([sys.executable, execute_path, f"--server-config={config}", f"--server-kwargs={server_kwargs}"])
|
subprocess.run([sys.executable, execute_path,
|
||||||
|
f"--server-config={config}", f"--server-kwargs={server_kwargs}"])
|
||||||
|
|
||||||
|
|
||||||
if __name__ == "__main__":
|
if __name__ == "__main__":
|
||||||
cli()
|
cli()
|
||||||
+18
-14
@@ -37,7 +37,7 @@ from pyannote.audio import Pipeline
|
|||||||
from pyannote.audio.pipelines.speaker_diarization import SpeakerDiarization
|
from pyannote.audio.pipelines.speaker_diarization import SpeakerDiarization
|
||||||
from torch import Tensor
|
from torch import Tensor
|
||||||
from torch import device as torch_device
|
from torch import device as torch_device
|
||||||
from torch.cuda import is_available, current_device
|
from torch.cuda import is_available
|
||||||
from huggingface_hub import HfApi
|
from huggingface_hub import HfApi
|
||||||
from huggingface_hub.utils import RepositoryNotFoundError
|
from huggingface_hub.utils import RepositoryNotFoundError
|
||||||
|
|
||||||
@@ -47,6 +47,7 @@ Annotation = TypeVar('Annotation')
|
|||||||
TOKEN_PATH = os.path.join(os.path.dirname(
|
TOKEN_PATH = os.path.join(os.path.dirname(
|
||||||
os.path.realpath(__file__)), '.pyannotetoken')
|
os.path.realpath(__file__)), '.pyannotetoken')
|
||||||
|
|
||||||
|
|
||||||
class Diariser:
|
class Diariser:
|
||||||
"""
|
"""
|
||||||
Handles the diarization process of an audio file using a pretrained model
|
Handles the diarization process of an audio file using a pretrained model
|
||||||
@@ -132,7 +133,6 @@ class Diariser:
|
|||||||
index_start_speaker = i
|
index_start_speaker = i
|
||||||
current_speaker = speaker
|
current_speaker = speaker
|
||||||
|
|
||||||
|
|
||||||
if i == len(dia_list) - 1:
|
if i == len(dia_list) - 1:
|
||||||
|
|
||||||
index_end_speaker = i
|
index_end_speaker = i
|
||||||
@@ -166,8 +166,8 @@ class Diariser:
|
|||||||
with open(TOKEN_PATH, 'r', encoding="utf-8") as file:
|
with open(TOKEN_PATH, 'r', encoding="utf-8") as file:
|
||||||
token = file.read()
|
token = file.read()
|
||||||
else:
|
else:
|
||||||
raise ValueError('No token found.' \
|
raise ValueError('No token found.'
|
||||||
'Please create a token at https://huggingface.co/settings/token' \
|
'Please create a token at https://huggingface.co/settings/token'
|
||||||
f'and save it in a file called {TOKEN_PATH}')
|
f'and save it in a file called {TOKEN_PATH}')
|
||||||
return token
|
return token
|
||||||
|
|
||||||
@@ -193,7 +193,6 @@ class Diariser:
|
|||||||
device: str = None,
|
device: str = None,
|
||||||
*args, **kwargs
|
*args, **kwargs
|
||||||
) -> Pipeline:
|
) -> Pipeline:
|
||||||
|
|
||||||
"""
|
"""
|
||||||
Loads a pretrained model from pyannote.audio,
|
Loads a pretrained model from pyannote.audio,
|
||||||
either from a local cache or some online repository.
|
either from a local cache or some online repository.
|
||||||
@@ -237,16 +236,18 @@ class Diariser:
|
|||||||
'deprecated and will be removed in future versions.',
|
'deprecated and will be removed in future versions.',
|
||||||
category=DeprecationWarning)
|
category=DeprecationWarning)
|
||||||
# list elementes with the ending .bin
|
# list elementes with the ending .bin
|
||||||
bin_files = [f for f in os.listdir(pwd) if f.endswith(".bin")]
|
bin_files = [f for f in os.listdir(
|
||||||
|
pwd) if f.endswith(".bin")]
|
||||||
if len(bin_files) == 1:
|
if len(bin_files) == 1:
|
||||||
path_to_model = os.path.join(pwd, bin_files[0])
|
path_to_model = os.path.join(pwd, bin_files[0])
|
||||||
else:
|
else:
|
||||||
warnings.warn("Found more than one .bin file. "\
|
warnings.warn("Found more than one .bin file. "
|
||||||
"or none. Please specify the path to the model " \
|
"or none. Please specify the path to the model "
|
||||||
"or setup a huggingface token.")
|
"or setup a huggingface token.")
|
||||||
raise FileNotFoundError
|
raise FileNotFoundError
|
||||||
|
|
||||||
warnings.warn(f"Found model at {path_to_model} overwriting config file.")
|
warnings.warn(
|
||||||
|
f"Found model at {path_to_model} overwriting config file.")
|
||||||
|
|
||||||
config['pipeline']['params']['segmentation'] = path_to_model
|
config['pipeline']['params']['segmentation'] = path_to_model
|
||||||
|
|
||||||
@@ -270,22 +271,24 @@ class Diariser:
|
|||||||
if use_auth_token is None:
|
if use_auth_token is None:
|
||||||
use_auth_token = cls._get_token()
|
use_auth_token = cls._get_token()
|
||||||
else:
|
else:
|
||||||
raise FileNotFoundError(f'No local model or directory found at {model}.')
|
raise FileNotFoundError(
|
||||||
|
f'No local model or directory found at {model}.')
|
||||||
|
|
||||||
_model = Pipeline.from_pretrained(model,
|
_model = Pipeline.from_pretrained(model,
|
||||||
use_auth_token=use_auth_token,
|
use_auth_token=use_auth_token,
|
||||||
cache_dir=cache_dir,
|
cache_dir=cache_dir,
|
||||||
hparams_file=hparams_file,)
|
hparams_file=hparams_file,)
|
||||||
if _model is None:
|
if _model is None:
|
||||||
raise ValueError('Unable to load model either from local cache' \
|
raise ValueError('Unable to load model either from local cache'
|
||||||
'or from huggingface.co models. Please check your token' \
|
'or from huggingface.co models. Please check your token'
|
||||||
'or your local model path')
|
'or your local model path')
|
||||||
|
|
||||||
# try to move the model to the device
|
# try to move the model to the device
|
||||||
if device is None:
|
if device is None:
|
||||||
device = "cuda" if is_available() else "cpu"
|
device = "cuda" if is_available() else "cpu"
|
||||||
|
|
||||||
_model = _model.to(torch_device(device)) # torch_device is renamed from torch.device to avoid name conflict
|
# torch_device is renamed from torch.device to avoid name conflict
|
||||||
|
_model = _model.to(torch_device(device))
|
||||||
|
|
||||||
return cls(_model)
|
return cls(_model)
|
||||||
|
|
||||||
@@ -302,7 +305,8 @@ class Diariser:
|
|||||||
"""
|
"""
|
||||||
_possible_kwargs = SpeakerDiarization.apply.__code__.co_varnames
|
_possible_kwargs = SpeakerDiarization.apply.__code__.co_varnames
|
||||||
|
|
||||||
diarisation_kwargs = {k: v for k, v in kwargs.items() if k in _possible_kwargs}
|
diarisation_kwargs = {k: v for k,
|
||||||
|
v in kwargs.items() if k in _possible_kwargs}
|
||||||
|
|
||||||
return diarisation_kwargs
|
return diarisation_kwargs
|
||||||
|
|
||||||
|
|||||||
+8
-3
@@ -17,6 +17,7 @@ PYANNOTE_DEFAULT_CONFIG = os.path.join(PYANNOTE_DEFAULT_PATH, "config.yaml") \
|
|||||||
if os.path.exists(os.path.join(PYANNOTE_DEFAULT_PATH, "config.yaml")) \
|
if os.path.exists(os.path.join(PYANNOTE_DEFAULT_PATH, "config.yaml")) \
|
||||||
else ('jaikinator/scraibe', 'pyannote/speaker-diarization-3.1')
|
else ('jaikinator/scraibe', 'pyannote/speaker-diarization-3.1')
|
||||||
|
|
||||||
|
|
||||||
def config_diarization_yaml(file_path: str, path_to_segmentation: str = None) -> None:
|
def config_diarization_yaml(file_path: str, path_to_segmentation: str = None) -> None:
|
||||||
"""Configure diarization pipeline from a YAML file.
|
"""Configure diarization pipeline from a YAML file.
|
||||||
|
|
||||||
@@ -33,25 +34,29 @@ def config_diarization_yaml(file_path: str, path_to_segmentation: str = None) ->
|
|||||||
with open(file_path, "r") as stream:
|
with open(file_path, "r") as stream:
|
||||||
yml = yaml.safe_load(stream)
|
yml = yaml.safe_load(stream)
|
||||||
|
|
||||||
segmentation_path = path_to_segmentation or os.path.join(PYANNOTE_DEFAULT_PATH, "pytorch_model.bin")
|
segmentation_path = path_to_segmentation or os.path.join(
|
||||||
|
PYANNOTE_DEFAULT_PATH, "pytorch_model.bin")
|
||||||
yml["pipeline"]["params"]["segmentation"] = segmentation_path
|
yml["pipeline"]["params"]["segmentation"] = segmentation_path
|
||||||
|
|
||||||
if not os.path.exists(segmentation_path):
|
if not os.path.exists(segmentation_path):
|
||||||
raise FileNotFoundError(f"Segmentation model not found at {segmentation_path}")
|
raise FileNotFoundError(
|
||||||
|
f"Segmentation model not found at {segmentation_path}")
|
||||||
|
|
||||||
with open(file_path, "w") as stream:
|
with open(file_path, "w") as stream:
|
||||||
yaml.dump(yml, stream)
|
yaml.dump(yml, stream)
|
||||||
|
|
||||||
|
|
||||||
class ParseKwargs(Action):
|
class ParseKwargs(Action):
|
||||||
"""
|
"""
|
||||||
Custom argparse action to parse keyword arguments.
|
Custom argparse action to parse keyword arguments.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
def __call__(self, parser, namespace, values, option_string=None):
|
def __call__(self, parser, namespace, values, option_string=None):
|
||||||
setattr(namespace, self.dest, dict())
|
setattr(namespace, self.dest, dict())
|
||||||
for value in values:
|
for value in values:
|
||||||
key, value = value.split('=')
|
key, value = value.split('=')
|
||||||
try:
|
try:
|
||||||
value = eval(value)
|
value = ast.literal_eval(value)
|
||||||
except:
|
except:
|
||||||
pass
|
pass
|
||||||
getattr(namespace, self.dest)[key] = value
|
getattr(namespace, self.dest)[key] = value
|
||||||
@@ -32,7 +32,7 @@ from typing import TypeVar , Union , Optional
|
|||||||
from torch import Tensor, device
|
from torch import Tensor, device
|
||||||
from numpy import ndarray
|
from numpy import ndarray
|
||||||
from inspect import getfullargspec
|
from inspect import getfullargspec
|
||||||
from abc import ABC, abstractmethod
|
from abc import abstractmethod
|
||||||
|
|
||||||
from .misc import WHISPER_DEFAULT_PATH
|
from .misc import WHISPER_DEFAULT_PATH
|
||||||
whisper = TypeVar('whisper')
|
whisper = TypeVar('whisper')
|
||||||
@@ -66,6 +66,7 @@ class Transcriber:
|
|||||||
The class supports various sizes and versions of Whisper models. Please refer to
|
The class supports various sizes and versions of Whisper models. Please refer to
|
||||||
the load_model method for available options.
|
the load_model method for available options.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
def __init__(self, model: whisper, model_name: str) -> None:
|
def __init__(self, model: whisper, model_name: str) -> None:
|
||||||
"""
|
"""
|
||||||
Initialize the Transcriber class with a Whisper model.
|
Initialize the Transcriber class with a Whisper model.
|
||||||
@@ -266,7 +267,8 @@ class WhisperTranscriber(Transcriber):
|
|||||||
_kwargs = getfullargspec(Whisper.transcribe).kwonlyargs
|
_kwargs = getfullargspec(Whisper.transcribe).kwonlyargs
|
||||||
_possible_kwargs = _args + _kwargs
|
_possible_kwargs = _args + _kwargs
|
||||||
|
|
||||||
whisper_kwargs = {k: v for k, v in kwargs.items() if k in _possible_kwargs}
|
whisper_kwargs = {k: v for k,
|
||||||
|
v in kwargs.items() if k in _possible_kwargs}
|
||||||
|
|
||||||
if (task := kwargs.get("task")):
|
if (task := kwargs.get("task")):
|
||||||
whisper_kwargs["task"] = task
|
whisper_kwargs["task"] = task
|
||||||
@@ -305,7 +307,6 @@ class WhisperXTranscriber(Transcriber):
|
|||||||
text += seg['text']
|
text += seg['text']
|
||||||
return text
|
return text
|
||||||
|
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
def load_model(cls,
|
def load_model(cls,
|
||||||
model: str = "medium",
|
model: str = "medium",
|
||||||
@@ -364,7 +365,8 @@ class WhisperXTranscriber(Transcriber):
|
|||||||
_kwargs = getfullargspec(WhisperModel.transcribe).kwonlyargs
|
_kwargs = getfullargspec(WhisperModel.transcribe).kwonlyargs
|
||||||
_possible_kwargs = _args + _kwargs
|
_possible_kwargs = _args + _kwargs
|
||||||
|
|
||||||
whisper_kwargs = {k: v for k, v in kwargs.items() if k in _possible_kwargs}
|
whisper_kwargs = {k: v for k,
|
||||||
|
v in kwargs.items() if k in _possible_kwargs}
|
||||||
|
|
||||||
if (task := kwargs.get("task")):
|
if (task := kwargs.get("task")):
|
||||||
whisper_kwargs["task"] = task
|
whisper_kwargs["task"] = task
|
||||||
|
|||||||
@@ -1,5 +1,6 @@
|
|||||||
import json
|
import json
|
||||||
import time
|
import time
|
||||||
|
from json.decoder import JSONDecodeError
|
||||||
|
|
||||||
from typing import Union
|
from typing import Union
|
||||||
|
|
||||||
@@ -8,7 +9,6 @@ from .hallucinations import KNOWN_HALLUCINATIONS
|
|||||||
ALPHABET = [*"abcdefghijklmnopqrstuvwxyz"]
|
ALPHABET = [*"abcdefghijklmnopqrstuvwxyz"]
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
class Transcript:
|
class Transcript:
|
||||||
"""
|
"""
|
||||||
Class for storing transcript data, including speaker information and text segments,
|
Class for storing transcript data, including speaker information and text segments,
|
||||||
@@ -49,7 +49,8 @@ class Transcript:
|
|||||||
|
|
||||||
annotations = {}
|
annotations = {}
|
||||||
if args and len(args) != len(self.speakers):
|
if args and len(args) != len(self.speakers):
|
||||||
raise ValueError("Number of speaker names does not match number of speakers")
|
raise ValueError(
|
||||||
|
"Number of speaker names does not match number of speakers")
|
||||||
|
|
||||||
if args:
|
if args:
|
||||||
for arg, speaker in zip(args, sorted(self.speakers)):
|
for arg, speaker in zip(args, sorted(self.speakers)):
|
||||||
@@ -58,9 +59,11 @@ class Transcript:
|
|||||||
|
|
||||||
invalid_speakers = set(kwargs.keys()) - set(self.speakers)
|
invalid_speakers = set(kwargs.keys()) - set(self.speakers)
|
||||||
if invalid_speakers:
|
if invalid_speakers:
|
||||||
raise ValueError(f"These keys are not speakers: {', '.join(invalid_speakers)}")
|
raise ValueError(
|
||||||
|
f"These keys are not speakers: {', '.join(invalid_speakers)}")
|
||||||
|
|
||||||
annotations.update({key: kwargs[key] for key in self.speakers if key in kwargs})
|
annotations.update({key: kwargs[key]
|
||||||
|
for key in self.speakers if key in kwargs})
|
||||||
|
|
||||||
self.annotation = annotations
|
self.annotation = annotations
|
||||||
|
|
||||||
@@ -74,8 +77,10 @@ class Transcript:
|
|||||||
segments_to_drop = []
|
segments_to_drop = []
|
||||||
for id in self.transcript:
|
for id in self.transcript:
|
||||||
for snippet in KNOWN_HALLUCINATIONS:
|
for snippet in KNOWN_HALLUCINATIONS:
|
||||||
self.transcript[id]['text']=self.transcript[id]['text'].replace(snippet,'')
|
self.transcript[id]['text'] = self.transcript[id]['text'].replace(
|
||||||
if self.transcript[id]['text'] == '': segments_to_drop.append(id)
|
snippet, '')
|
||||||
|
if self.transcript[id]['text'] == '':
|
||||||
|
segments_to_drop.append(id)
|
||||||
|
|
||||||
for id in segments_to_drop:
|
for id in segments_to_drop:
|
||||||
del self.transcript[id]
|
del self.transcript[id]
|
||||||
@@ -209,7 +214,6 @@ class Transcript:
|
|||||||
|
|
||||||
return fstring
|
return fstring
|
||||||
|
|
||||||
|
|
||||||
def to_json(self, path, *args, **kwargs) -> None:
|
def to_json(self, path, *args, **kwargs) -> None:
|
||||||
"""Save transcript as json file
|
"""Save transcript as json file
|
||||||
|
|
||||||
@@ -310,10 +314,8 @@ class Transcript:
|
|||||||
else:
|
else:
|
||||||
try:
|
try:
|
||||||
transcript = json.loads(json)
|
transcript = json.loads(json)
|
||||||
except:
|
except (TypeError, JSONDecodeError):
|
||||||
with open(json, "r") as f:
|
with open(json, "r") as f:
|
||||||
transcript = json.load(f)
|
transcript = json.load(f)
|
||||||
|
|
||||||
return cls(transcript)
|
return cls(transcript)
|
||||||
|
|
||||||
|
|
||||||
+6
-4
@@ -10,6 +10,8 @@ VERSION = '%d.%d.%d.%d' % (MAJOR, MINOR, MICRO, NANO)
|
|||||||
|
|
||||||
# Return the git revision as a string
|
# Return the git revision as a string
|
||||||
# taken from numpy/numpy
|
# taken from numpy/numpy
|
||||||
|
|
||||||
|
|
||||||
def git_version():
|
def git_version():
|
||||||
def _minimal_ext_cmd(cmd):
|
def _minimal_ext_cmd(cmd):
|
||||||
# construct minimal environment
|
# construct minimal environment
|
||||||
@@ -24,7 +26,8 @@ def git_version():
|
|||||||
env['LANG'] = 'C'
|
env['LANG'] = 'C'
|
||||||
env['LC_ALL'] = 'C'
|
env['LC_ALL'] = 'C'
|
||||||
|
|
||||||
out = sp.Popen(cmd, stdout=sp.PIPE, stderr=sp.PIPE, env=env).communicate()[0]
|
out = sp.Popen(cmd, stdout=sp.PIPE, stderr=sp.PIPE,
|
||||||
|
env=env).communicate()[0]
|
||||||
return out
|
return out
|
||||||
|
|
||||||
try:
|
try:
|
||||||
@@ -35,6 +38,7 @@ def git_version():
|
|||||||
|
|
||||||
return GIT_REVISION
|
return GIT_REVISION
|
||||||
|
|
||||||
|
|
||||||
def _get_git_version():
|
def _get_git_version():
|
||||||
cwd = os.getcwd()
|
cwd = os.getcwd()
|
||||||
|
|
||||||
@@ -51,6 +55,7 @@ def _get_git_version():
|
|||||||
os.chdir(cwd)
|
os.chdir(cwd)
|
||||||
return res
|
return res
|
||||||
|
|
||||||
|
|
||||||
def get_version(build_version=False):
|
def get_version(build_version=False):
|
||||||
if ISRELEASED:
|
if ISRELEASED:
|
||||||
return VERSION
|
return VERSION
|
||||||
@@ -64,6 +69,3 @@ def get_version(build_version=False):
|
|||||||
return VERSION + ".dev" + date
|
return VERSION + ".dev" + date
|
||||||
else:
|
else:
|
||||||
return VERSION + ".dev0+" + GIT_REVISION[:7]
|
return VERSION + ".dev0+" + GIT_REVISION[:7]
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
|||||||
@@ -3,7 +3,6 @@ from scraibe.audio import AudioProcessor
|
|||||||
import torch
|
import torch
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu")
|
DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu")
|
||||||
TEST_WAVEFORM = torch.sin(torch.randn(160000)).to(DEVICE)
|
TEST_WAVEFORM = torch.sin(torch.randn(160000)).to(DEVICE)
|
||||||
TEST_SR = 16000
|
TEST_SR = 16000
|
||||||
@@ -25,10 +24,6 @@ def probe_audio_processor():
|
|||||||
return AudioProcessor(TEST_WAVEFORM, TEST_SR)
|
return AudioProcessor(TEST_WAVEFORM, TEST_SR)
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
def test_AudioProcessor_init(probe_audio_processor):
|
def test_AudioProcessor_init(probe_audio_processor):
|
||||||
"""
|
"""
|
||||||
Test the initialization of the AudioProcessor class.
|
Test the initialization of the AudioProcessor class.
|
||||||
@@ -53,7 +48,6 @@ def test_AudioProcessor_init(probe_audio_processor):
|
|||||||
assert probe_audio_processor.sr == TEST_SR
|
assert probe_audio_processor.sr == TEST_SR
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
def test_cut(probe_audio_processor):
|
def test_cut(probe_audio_processor):
|
||||||
"""Test the cut function of the AudioProcessor class.
|
"""Test the cut function of the AudioProcessor class.
|
||||||
|
|
||||||
@@ -76,14 +70,6 @@ def test_cut(probe_audio_processor):
|
|||||||
# assert AudioProcessor(TEST_WAVEFORM, TEST_SR).cut(start, end).size() == int((end - start) * TEST_SR)
|
# assert AudioProcessor(TEST_WAVEFORM, TEST_SR).cut(start, end).size() == int((end - start) * TEST_SR)
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
def test_audio_processor_invalid_sr():
|
def test_audio_processor_invalid_sr():
|
||||||
"""Test the behavior of AudioProcessor when an invalid smaple rate is provided.
|
"""Test the behavior of AudioProcessor when an invalid smaple rate is provided.
|
||||||
|
|
||||||
@@ -108,20 +94,3 @@ def test_audio_processor_SAMPLE_RATE():
|
|||||||
"""
|
"""
|
||||||
probe_audio_processor = AudioProcessor(TEST_WAVEFORM)
|
probe_audio_processor = AudioProcessor(TEST_WAVEFORM)
|
||||||
assert probe_audio_processor.sr == SAMPLE_RATE
|
assert probe_audio_processor.sr == SAMPLE_RATE
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
|||||||
@@ -1,12 +1,8 @@
|
|||||||
import pytest
|
import pytest
|
||||||
from scraibe import Scraibe, Diariser, Transcriber, Transcript
|
from scraibe import Scraibe, Diariser, Transcriber, Transcript
|
||||||
from unittest.mock import MagicMock, patch
|
|
||||||
import os
|
import os
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
@pytest.fixture
|
@pytest.fixture
|
||||||
def create_scraibe_instance():
|
def create_scraibe_instance():
|
||||||
if "HF_TOKEN" in os.environ:
|
if "HF_TOKEN" in os.environ:
|
||||||
@@ -15,8 +11,6 @@ def create_scraibe_instance():
|
|||||||
return Scraibe()
|
return Scraibe()
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
def test_scraibe_init(create_scraibe_instance):
|
def test_scraibe_init(create_scraibe_instance):
|
||||||
model = create_scraibe_instance
|
model = create_scraibe_instance
|
||||||
assert isinstance(model.transcriber, Transcriber)
|
assert isinstance(model.transcriber, Transcriber)
|
||||||
|
|||||||
@@ -1,8 +1,5 @@
|
|||||||
import pytest
|
import pytest
|
||||||
import os
|
from scraibe import Diariser
|
||||||
from unittest import mock
|
|
||||||
from scraibe import diarisation, Diariser
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
@pytest.fixture
|
@pytest.fixture
|
||||||
@@ -19,7 +16,6 @@ def diariser_instance():
|
|||||||
return Diariser('pyannote')
|
return Diariser('pyannote')
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
def test_Diariser_init(diariser_instance):
|
def test_Diariser_init(diariser_instance):
|
||||||
"""Test the initialization of the Diariser class.
|
"""Test the initialization of the Diariser class.
|
||||||
|
|
||||||
@@ -34,14 +30,3 @@ def test_Diariser_init(diariser_instance):
|
|||||||
None
|
None
|
||||||
"""
|
"""
|
||||||
assert diariser_instance.model == 'pyannote'
|
assert diariser_instance.model == 'pyannote'
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
|||||||
@@ -1,10 +1,8 @@
|
|||||||
import pytest
|
import pytest
|
||||||
from unittest.mock import patch
|
|
||||||
from scraibe import Transcriber
|
from scraibe import Transcriber
|
||||||
import torch
|
import torch
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu")
|
DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu")
|
||||||
TEST_WAVEFORM = "Hello World"
|
TEST_WAVEFORM = "Hello World"
|
||||||
|
|
||||||
@@ -29,13 +27,16 @@ def test_transcriber(mock_load_model, audio_file, expected_transcription):
|
|||||||
|
|
||||||
assert transcription_result == expected_transcription """
|
assert transcription_result == expected_transcription """
|
||||||
|
|
||||||
|
|
||||||
@pytest.fixture
|
@pytest.fixture
|
||||||
def transcriber_instance():
|
def transcriber_instance():
|
||||||
return Transcriber.load_model('medium')
|
return Transcriber.load_model('medium')
|
||||||
|
|
||||||
|
|
||||||
def test_transcriber_initialization(transcriber_instance):
|
def test_transcriber_initialization(transcriber_instance):
|
||||||
assert isinstance(transcriber_instance, Transcriber)
|
assert isinstance(transcriber_instance, Transcriber)
|
||||||
|
|
||||||
|
|
||||||
def test_get_whisper_kwargs():
|
def test_get_whisper_kwargs():
|
||||||
kwargs = {"arg1": 1, "arg3": 3}
|
kwargs = {"arg1": 1, "arg3": 3}
|
||||||
valid_kwargs = Transcriber._get_whisper_kwargs(**kwargs)
|
valid_kwargs = Transcriber._get_whisper_kwargs(**kwargs)
|
||||||
@@ -47,6 +48,3 @@ def test_transcribe(transcriber_instance):
|
|||||||
# mocker.patch.object(transcriber_instance.model, 'transcribe', return_value={'Hello, World !'} )
|
# mocker.patch.object(transcriber_instance.model, 'transcribe', return_value={'Hello, World !'} )
|
||||||
transcript = model.transcribe('test/audio_test_2.mp4')
|
transcript = model.transcribe('test/audio_test_2.mp4')
|
||||||
assert isinstance(transcript, str)
|
assert isinstance(transcript, str)
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
|||||||
Reference in New Issue
Block a user