@@ -2,6 +2,7 @@ tqdm>=4.65.0
|
|||||||
numpy>=1.26.4
|
numpy>=1.26.4
|
||||||
|
|
||||||
openai-whisper==20231117
|
openai-whisper==20231117
|
||||||
|
whisperx~=3.1.3
|
||||||
|
|
||||||
pyannote.audio~=3.1.1
|
pyannote.audio~=3.1.1
|
||||||
pyannote.core~=5.0.0
|
pyannote.core~=5.0.0
|
||||||
|
|||||||
+1
-1
@@ -8,5 +8,5 @@ from .version import get_version as _get_version
|
|||||||
from .misc import *
|
from .misc import *
|
||||||
|
|
||||||
from .cli import *
|
from .cli import *
|
||||||
|
|
||||||
__version__ = _get_version()
|
__version__ = _get_version()
|
||||||
|
|||||||
+23
-21
@@ -28,6 +28,7 @@ import torch
|
|||||||
SAMPLE_RATE = 16000
|
SAMPLE_RATE = 16000
|
||||||
NORMALIZATION_FACTOR = 32768.0
|
NORMALIZATION_FACTOR = 32768.0
|
||||||
|
|
||||||
|
|
||||||
class AudioProcessor:
|
class AudioProcessor:
|
||||||
"""
|
"""
|
||||||
Audio Processor class that leverages PyTorchaudio to provide functionalities
|
Audio Processor class that leverages PyTorchaudio to provide functionalities
|
||||||
@@ -39,10 +40,9 @@ class AudioProcessor:
|
|||||||
sr: int
|
sr: int
|
||||||
The sample rate of the audio.
|
The sample rate of the audio.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
def __init__(self, waveform: torch.Tensor, sr : int = SAMPLE_RATE,
|
def __init__(self, waveform: torch.Tensor, sr: int = SAMPLE_RATE,
|
||||||
*args, **kwargs) -> None:
|
*args, **kwargs) -> None:
|
||||||
|
|
||||||
"""
|
"""
|
||||||
Initialize the AudioProcessor object.
|
Initialize the AudioProcessor object.
|
||||||
|
|
||||||
@@ -56,16 +56,17 @@ class AudioProcessor:
|
|||||||
Raises:
|
Raises:
|
||||||
ValueError: If the provided sample rate is not of type int.
|
ValueError: If the provided sample rate is not of type int.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
device = kwargs.get("device", "cuda" if torch.cuda.is_available() else "cpu")
|
device = kwargs.get(
|
||||||
|
"device", "cuda" if torch.cuda.is_available() else "cpu")
|
||||||
|
|
||||||
self.waveform = waveform.to(device)
|
self.waveform = waveform.to(device)
|
||||||
self.sr = sr
|
self.sr = sr
|
||||||
|
|
||||||
if not isinstance(self.sr, int):
|
if not isinstance(self.sr, int):
|
||||||
raise ValueError("Sample rate should be a single value of type int," \
|
raise ValueError("Sample rate should be a single value of type int,"
|
||||||
f"not {len(self.sr)} and type {type(self.sr)}")
|
f"not {len(self.sr)} and type {type(self.sr)}")
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
def from_file(cls, file: str, *args, **kwargs) -> 'AudioProcessor':
|
def from_file(cls, file: str, *args, **kwargs) -> 'AudioProcessor':
|
||||||
"""
|
"""
|
||||||
@@ -77,14 +78,13 @@ class AudioProcessor:
|
|||||||
Returns:
|
Returns:
|
||||||
AudioProcessor: An instance of the AudioProcessor class containing the loaded audio.
|
AudioProcessor: An instance of the AudioProcessor class containing the loaded audio.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
audio, sr = cls.load_audio(file , *args, **kwargs)
|
audio, sr = cls.load_audio(file, *args, **kwargs)
|
||||||
|
|
||||||
audio = torch.from_numpy(audio)
|
audio = torch.from_numpy(audio)
|
||||||
|
|
||||||
return cls(audio, sr)
|
return cls(audio, sr)
|
||||||
|
|
||||||
|
|
||||||
def cut(self, start: float, end: float) -> torch.Tensor:
|
def cut(self, start: float, end: float) -> torch.Tensor:
|
||||||
"""
|
"""
|
||||||
Cut a segment from the audio waveform between the specified start and end times.
|
Cut a segment from the audio waveform between the specified start and end times.
|
||||||
@@ -96,7 +96,7 @@ class AudioProcessor:
|
|||||||
Returns:
|
Returns:
|
||||||
torch.Tensor: The cut waveform segment.
|
torch.Tensor: The cut waveform segment.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
start = int(start * self.sr)
|
start = int(start * self.sr)
|
||||||
if (isinstance(end, float) or isinstance(end, int)) and isinstance(self.sr, int):
|
if (isinstance(end, float) or isinstance(end, int)) and isinstance(self.sr, int):
|
||||||
end = int(np.ceil(end * self.sr))
|
end = int(np.ceil(end * self.sr))
|
||||||
@@ -140,11 +140,13 @@ class AudioProcessor:
|
|||||||
try:
|
try:
|
||||||
out = run(cmd, capture_output=True, check=True).stdout
|
out = run(cmd, capture_output=True, check=True).stdout
|
||||||
except CalledProcessError as e:
|
except CalledProcessError as e:
|
||||||
raise RuntimeError(f"Failed to load audio: {e.stderr.decode()}") from e
|
raise RuntimeError(
|
||||||
|
f"Failed to load audio: {e.stderr.decode()}") from e
|
||||||
|
|
||||||
|
out = np.frombuffer(out, np.int16).flatten().astype(
|
||||||
|
np.float32) / NORMALIZATION_FACTOR
|
||||||
|
|
||||||
|
return out, sr
|
||||||
|
|
||||||
out = np.frombuffer(out, np.int16).flatten().astype(np.float32) / NORMALIZATION_FACTOR
|
|
||||||
|
|
||||||
return out , sr
|
|
||||||
|
|
||||||
def __repr__(self) -> str:
|
def __repr__(self) -> str:
|
||||||
return f'TorchAudioProcessor(waveform={len(self.waveform)}, sr={int(self.sr)})'
|
return f'TorchAudioProcessor(waveform={len(self.waveform)}, sr={int(self.sr)})'
|
||||||
|
|||||||
+101
-95
@@ -38,7 +38,7 @@ from tqdm import trange
|
|||||||
# Application-Specific Imports
|
# Application-Specific Imports
|
||||||
from .audio import AudioProcessor
|
from .audio import AudioProcessor
|
||||||
from .diarisation import Diariser
|
from .diarisation import Diariser
|
||||||
from .transcriber import Transcriber, whisper
|
from .transcriber import Transcriber, load_transcriber, whisper
|
||||||
from .transcript_exporter import Transcript
|
from .transcript_exporter import Transcript
|
||||||
|
|
||||||
|
|
||||||
@@ -55,22 +55,26 @@ class Scraibe:
|
|||||||
Attributes:
|
Attributes:
|
||||||
transcriber (Transcriber): The transcriber object to handle transcription.
|
transcriber (Transcriber): The transcriber object to handle transcription.
|
||||||
diariser (Diariser): The diariser object to handle diarization.
|
diariser (Diariser): The diariser object to handle diarization.
|
||||||
|
|
||||||
Methods:
|
Methods:
|
||||||
__init__: Initializes the Scraibe class with appropriate models.
|
__init__: Initializes the Scraibe class with appropriate models.
|
||||||
transcribe: Transcribes an audio file using the whisper model and pyannote diarization model.
|
transcribe: Transcribes an audio file using the whisper model and pyannote diarization model.
|
||||||
remove_audio_file: Removes the original audio file to avoid disk space issues or ensure data privacy.
|
remove_audio_file: Removes the original audio file to avoid disk space issues or ensure data privacy.
|
||||||
get_audio_file: Gets an audio file as an AudioProcessor object.
|
get_audio_file: Gets an audio file as an AudioProcessor object.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
def __init__(self,
|
def __init__(self,
|
||||||
whisper_model: Union[bool, str, whisper] = None,
|
whisper_model: Union[bool, str, whisper] = None,
|
||||||
dia_model : Union[bool, str, DiarisationType] = None,
|
whisper_type: str = "whisper",
|
||||||
**kwargs) -> None:
|
dia_model: Union[bool, str, DiarisationType] = None,
|
||||||
|
**kwargs) -> None:
|
||||||
"""Initializes the Scraibe class.
|
"""Initializes the Scraibe class.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
whisper_model (Union[bool, str, whisper], optional):
|
whisper_model (Union[bool, str, whisper], optional):
|
||||||
Path to whisper model or whisper model itself.
|
Path to whisper model or whisper model itself.
|
||||||
|
whisper_type (str):
|
||||||
|
Type of whisper model to load. "whisper" or "whisperx".
|
||||||
diarisation_model (Union[bool, str, DiarisationType], optional):
|
diarisation_model (Union[bool, str, DiarisationType], optional):
|
||||||
Path to pyannote diarization model or model itself.
|
Path to pyannote diarization model or model itself.
|
||||||
**kwargs: Additional keyword arguments for whisper
|
**kwargs: Additional keyword arguments for whisper
|
||||||
@@ -81,12 +85,13 @@ class Scraibe:
|
|||||||
- save_kwargs: If True, the keyword arguments will be saved
|
- save_kwargs: If True, the keyword arguments will be saved
|
||||||
for autotranscribe. So you can unload the class and reload it again.
|
for autotranscribe. So you can unload the class and reload it again.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
|
|
||||||
if whisper_model is None:
|
if whisper_model is None:
|
||||||
self.transcriber = Transcriber.load_model("medium", **kwargs)
|
self.transcriber = load_transcriber(
|
||||||
|
"medium", whisper_type, **kwargs)
|
||||||
elif isinstance(whisper_model, str):
|
elif isinstance(whisper_model, str):
|
||||||
self.transcriber = Transcriber.load_model(whisper_model, **kwargs)
|
self.transcriber = load_transcriber(
|
||||||
|
whisper_model, whisper_type, **kwargs)
|
||||||
else:
|
else:
|
||||||
self.transcriber = whisper_model
|
self.transcriber = whisper_model
|
||||||
|
|
||||||
@@ -95,26 +100,25 @@ class Scraibe:
|
|||||||
elif isinstance(dia_model, str):
|
elif isinstance(dia_model, str):
|
||||||
self.diariser = Diariser.load_model(dia_model, **kwargs)
|
self.diariser = Diariser.load_model(dia_model, **kwargs)
|
||||||
else:
|
else:
|
||||||
self.diariser : Diariser = dia_model
|
self.diariser: Diariser = dia_model
|
||||||
|
|
||||||
if kwargs.get("verbose"):
|
if kwargs.get("verbose"):
|
||||||
print("Scraibe initialized all models successfully loaded.")
|
print("Scraibe initialized all models successfully loaded.")
|
||||||
self.verbose = True
|
self.verbose = True
|
||||||
else:
|
else:
|
||||||
self.verbose = False
|
self.verbose = False
|
||||||
|
|
||||||
# Save kwargs for autotranscribe if you want to unload the class and load it again.
|
# Save kwargs for autotranscribe if you want to unload the class and load it again.
|
||||||
if kwargs.get('save_setup'):
|
if kwargs.get('save_setup'):
|
||||||
self.params = dict(whisper_model = whisper_model,
|
self.params = dict(whisper_model=whisper_model,
|
||||||
dia_model = dia_model,
|
dia_model=dia_model,
|
||||||
**kwargs)
|
**kwargs)
|
||||||
else:
|
else:
|
||||||
self.params = {}
|
self.params = {}
|
||||||
|
|
||||||
|
def autotranscribe(self, audio_file: Union[str, torch.Tensor, ndarray],
|
||||||
def autotranscribe(self, audio_file : Union[str, torch.Tensor, ndarray],
|
remove_original: bool = False,
|
||||||
remove_original : bool = False,
|
**kwargs) -> Transcript:
|
||||||
**kwargs) -> Transcript:
|
|
||||||
"""
|
"""
|
||||||
Transcribes an audio file using the whisper model and pyannote diarization model.
|
Transcribes an audio file using the whisper model and pyannote diarization model.
|
||||||
|
|
||||||
@@ -133,60 +137,62 @@ class Scraibe:
|
|||||||
if kwargs.get("verbose"):
|
if kwargs.get("verbose"):
|
||||||
self.verbose = kwargs.get("verbose")
|
self.verbose = kwargs.get("verbose")
|
||||||
# Get audio file as an AudioProcessor object
|
# Get audio file as an AudioProcessor object
|
||||||
audio_file : AudioProcessor = self.get_audio_file(audio_file)
|
audio_file: AudioProcessor = self.get_audio_file(audio_file)
|
||||||
|
|
||||||
# Prepare waveform and sample rate for diarization
|
# Prepare waveform and sample rate for diarization
|
||||||
dia_audio = {
|
dia_audio = {
|
||||||
"waveform" : audio_file.waveform.reshape(1,len(audio_file.waveform)),
|
"waveform": audio_file.waveform.reshape(1, len(audio_file.waveform)),
|
||||||
"sample_rate": audio_file.sr
|
"sample_rate": audio_file.sr
|
||||||
}
|
}
|
||||||
|
|
||||||
if self.verbose:
|
if self.verbose:
|
||||||
print("Starting diarisation.")
|
print("Starting diarisation.")
|
||||||
|
|
||||||
diarisation = self.diariser.diarization(dia_audio, **kwargs)
|
diarisation = self.diariser.diarization(dia_audio, **kwargs)
|
||||||
|
|
||||||
if not diarisation["segments"]:
|
if not diarisation["segments"]:
|
||||||
print("No segments found. Try to run transcription without diarisation.")
|
print("No segments found. Try to run transcription without diarisation.")
|
||||||
|
|
||||||
transcript = self.transcriber.transcribe(audio_file.waveform, **kwargs)
|
transcript = self.transcriber.transcribe(
|
||||||
|
audio_file.waveform, **kwargs)
|
||||||
final_transcript= {0 : {"speakers" : 'SPEAKER_01',
|
|
||||||
"segments" : [0, len(audio_file.waveform)],
|
final_transcript = {0: {"speakers": 'SPEAKER_01',
|
||||||
"text" : transcript}}
|
"segments": [0, len(audio_file.waveform)],
|
||||||
|
"text": transcript}}
|
||||||
|
|
||||||
return Transcript(final_transcript)
|
return Transcript(final_transcript)
|
||||||
|
|
||||||
if self.verbose:
|
if self.verbose:
|
||||||
print("Diarisation finished. Starting transcription.")
|
print("Diarisation finished. Starting transcription.")
|
||||||
|
|
||||||
audio_file.sr = torch.Tensor([audio_file.sr]).to(audio_file.waveform.device)
|
audio_file.sr = torch.Tensor([audio_file.sr]).to(
|
||||||
|
audio_file.waveform.device)
|
||||||
|
|
||||||
# Transcribe each segment and store the results
|
# Transcribe each segment and store the results
|
||||||
final_transcript = dict()
|
final_transcript = dict()
|
||||||
|
|
||||||
for i in trange(len(diarisation["segments"]), desc= "Transcribing", disable = not self.verbose):
|
for i in trange(len(diarisation["segments"]), desc="Transcribing", disable=not self.verbose):
|
||||||
|
|
||||||
seg = diarisation["segments"][i]
|
seg = diarisation["segments"][i]
|
||||||
|
|
||||||
audio = audio_file.cut(seg[0], seg[1])
|
audio = audio_file.cut(seg[0], seg[1])
|
||||||
|
|
||||||
transcript = self.transcriber.transcribe(audio, **kwargs)
|
transcript = self.transcriber.transcribe(audio, **kwargs)
|
||||||
|
|
||||||
final_transcript[i] = {"speakers" : diarisation["speakers"][i],
|
final_transcript[i] = {"speakers": diarisation["speakers"][i],
|
||||||
"segments" : seg,
|
"segments": seg,
|
||||||
"text" : transcript}
|
"text": transcript}
|
||||||
|
|
||||||
# Remove original file if needed
|
# Remove original file if needed
|
||||||
if remove_original:
|
if remove_original:
|
||||||
if kwargs.get("shred") is True:
|
if kwargs.get("shred") is True:
|
||||||
self.remove_audio_file(audio_file, shred=True)
|
self.remove_audio_file(audio_file, shred=True)
|
||||||
else:
|
else:
|
||||||
self.remove_audio_file(audio_file, shred=False)
|
self.remove_audio_file(audio_file, shred=False)
|
||||||
|
|
||||||
return Transcript(final_transcript)
|
return Transcript(final_transcript)
|
||||||
|
|
||||||
def diarization(self, audio_file : Union[str, torch.Tensor, ndarray],
|
def diarization(self, audio_file: Union[str, torch.Tensor, ndarray],
|
||||||
**kwargs) -> dict:
|
**kwargs) -> dict:
|
||||||
"""
|
"""
|
||||||
Perform diarization on an audio file using the pyannote diarization model.
|
Perform diarization on an audio file using the pyannote diarization model.
|
||||||
@@ -201,24 +207,24 @@ class Scraibe:
|
|||||||
dict:
|
dict:
|
||||||
A dictionary containing the results of the diarization process.
|
A dictionary containing the results of the diarization process.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
# Get audio file as an AudioProcessor object
|
# Get audio file as an AudioProcessor object
|
||||||
audio_file : AudioProcessor = self.get_audio_file(audio_file)
|
audio_file: AudioProcessor = self.get_audio_file(audio_file)
|
||||||
|
|
||||||
# Prepare waveform and sample rate for diarization
|
# Prepare waveform and sample rate for diarization
|
||||||
dia_audio = {
|
dia_audio = {
|
||||||
"waveform" : audio_file.waveform.reshape(1,len(audio_file.waveform)),
|
"waveform": audio_file.waveform.reshape(1, len(audio_file.waveform)),
|
||||||
"sample_rate": audio_file.sr
|
"sample_rate": audio_file.sr
|
||||||
}
|
}
|
||||||
|
|
||||||
print("Starting diarisation.")
|
print("Starting diarisation.")
|
||||||
|
|
||||||
diarisation = self.diariser.diarization(dia_audio, **kwargs)
|
diarisation = self.diariser.diarization(dia_audio, **kwargs)
|
||||||
|
|
||||||
return diarisation
|
return diarisation
|
||||||
|
|
||||||
def transcribe(self, audio_file : Union[str, torch.Tensor, ndarray],
|
def transcribe(self, audio_file: Union[str, torch.Tensor, ndarray],
|
||||||
**kwargs):
|
**kwargs):
|
||||||
"""
|
"""
|
||||||
Transcribe the provided audio file.
|
Transcribe the provided audio file.
|
||||||
|
|
||||||
@@ -232,11 +238,11 @@ class Scraibe:
|
|||||||
str:
|
str:
|
||||||
The transcribed text from the audio source.
|
The transcribed text from the audio source.
|
||||||
"""
|
"""
|
||||||
audio_file : AudioProcessor = self.get_audio_file(audio_file)
|
audio_file: AudioProcessor = self.get_audio_file(audio_file)
|
||||||
|
|
||||||
return self.transcriber.transcribe(audio_file.waveform, **kwargs)
|
return self.transcriber.transcribe(audio_file.waveform, **kwargs)
|
||||||
|
|
||||||
def update_transcriber(self, whisper_model : Union[str, whisper], **kwargs) -> None:
|
def update_transcriber(self, whisper_model: Union[str, whisper], **kwargs) -> None:
|
||||||
"""
|
"""
|
||||||
Update the transcriber model.
|
Update the transcriber model.
|
||||||
|
|
||||||
@@ -245,22 +251,23 @@ class Scraibe:
|
|||||||
The new whisper model to use for transcription.
|
The new whisper model to use for transcription.
|
||||||
**kwargs:
|
**kwargs:
|
||||||
Additional keyword arguments for the transcriber model.
|
Additional keyword arguments for the transcriber model.
|
||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
None
|
None
|
||||||
"""
|
"""
|
||||||
_old_model = self.transcriber.model_name
|
_old_model = self.transcriber.model_name
|
||||||
|
|
||||||
if isinstance(whisper_model, str):
|
if isinstance(whisper_model, str):
|
||||||
self.transcriber = Transcriber.load_model(whisper_model, **kwargs)
|
self.transcriber = load_transcriber(whisper_model, **kwargs)
|
||||||
elif isinstance(whisper_model, Transcriber):
|
elif isinstance(whisper_model, Transcriber):
|
||||||
self.transcriber = whisper_model
|
self.transcriber = whisper_model
|
||||||
else:
|
else:
|
||||||
warn(f"Invalid model type. Please provide a valid model. Fallback to old {_old_model} Model.", RuntimeWarning)
|
warn(
|
||||||
|
f"Invalid model type. Please provide a valid model. Fallback to old {_old_model} Model.", RuntimeWarning)
|
||||||
|
|
||||||
return None
|
return None
|
||||||
|
|
||||||
def update_diariser(self, dia_model : Union[str, DiarisationType], **kwargs) -> None:
|
def update_diariser(self, dia_model: Union[str, DiarisationType], **kwargs) -> None:
|
||||||
"""
|
"""
|
||||||
Update the diariser model.
|
Update the diariser model.
|
||||||
|
|
||||||
@@ -269,7 +276,7 @@ class Scraibe:
|
|||||||
The new diariser model to use for diarization.
|
The new diariser model to use for diarization.
|
||||||
**kwargs:
|
**kwargs:
|
||||||
Additional keyword arguments for the diariser model.
|
Additional keyword arguments for the diariser model.
|
||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
None
|
None
|
||||||
"""
|
"""
|
||||||
@@ -278,13 +285,13 @@ class Scraibe:
|
|||||||
elif isinstance(dia_model, Diariser):
|
elif isinstance(dia_model, Diariser):
|
||||||
self.diariser = dia_model
|
self.diariser = dia_model
|
||||||
else:
|
else:
|
||||||
warn(f"Invalid model type. Please provide a valid model. Fallback to old Model.", RuntimeWarning)
|
warn("Invalid model type. Please provide a valid model. Fallback to old Model.", RuntimeWarning)
|
||||||
|
|
||||||
return None
|
return None
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
def remove_audio_file(audio_file : str,
|
def remove_audio_file(audio_file: str,
|
||||||
shred : bool = False) -> None:
|
shred: bool = False) -> None:
|
||||||
"""
|
"""
|
||||||
Removes the original audio file to avoid disk space issues or ensure data privacy.
|
Removes the original audio file to avoid disk space issues or ensure data privacy.
|
||||||
|
|
||||||
@@ -295,30 +302,29 @@ class Scraibe:
|
|||||||
"""
|
"""
|
||||||
if not os.path.exists(audio_file):
|
if not os.path.exists(audio_file):
|
||||||
raise ValueError(f"Audiofile {audio_file} does not exist.")
|
raise ValueError(f"Audiofile {audio_file} does not exist.")
|
||||||
|
|
||||||
if shred:
|
if shred:
|
||||||
|
|
||||||
warn("Shredding audiofile can take a long time.", RuntimeWarning)
|
warn("Shredding audiofile can take a long time.", RuntimeWarning)
|
||||||
|
|
||||||
gen = iglob(f'{audio_file}', recursive=True)
|
gen = iglob(f'{audio_file}', recursive=True)
|
||||||
cmd = ['shred', '-zvu', '-n', '10', f'{audio_file}']
|
cmd = ['shred', '-zvu', '-n', '10', f'{audio_file}']
|
||||||
|
|
||||||
if os.path.isdir(audio_file):
|
if os.path.isdir(audio_file):
|
||||||
raise ValueError(f"Audiofile {audio_file} is a directory.")
|
raise ValueError(f"Audiofile {audio_file} is a directory.")
|
||||||
|
|
||||||
for file in gen:
|
for file in gen:
|
||||||
print(f'shredding {file} now\n')
|
print(f'shredding {file} now\n')
|
||||||
|
|
||||||
run(cmd , check=True)
|
run(cmd, check=True)
|
||||||
|
|
||||||
else:
|
else:
|
||||||
os.remove(audio_file)
|
os.remove(audio_file)
|
||||||
print(f"Audiofile {audio_file} removed.")
|
print(f"Audiofile {audio_file} removed.")
|
||||||
|
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
def get_audio_file(audio_file : Union[str, torch.Tensor, ndarray],
|
def get_audio_file(audio_file: Union[str, torch.Tensor, ndarray],
|
||||||
*args, **kwargs) -> AudioProcessor:
|
*args, **kwargs) -> AudioProcessor:
|
||||||
"""Gets an audio file as TorchAudioProcessor.
|
"""Gets an audio file as TorchAudioProcessor.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
@@ -331,20 +337,20 @@ class Scraibe:
|
|||||||
AudioProcessor: An object containing the waveform and sample rate in
|
AudioProcessor: An object containing the waveform and sample rate in
|
||||||
torch.Tensor format.
|
torch.Tensor format.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
if isinstance(audio_file, str):
|
if isinstance(audio_file, str):
|
||||||
audio_file = AudioProcessor.from_file(audio_file)
|
audio_file = AudioProcessor.from_file(audio_file)
|
||||||
|
|
||||||
elif isinstance(audio_file, torch.Tensor):
|
elif isinstance(audio_file, torch.Tensor):
|
||||||
audio_file = AudioProcessor(audio_file[0], audio_file[1])
|
audio_file = AudioProcessor(audio_file[0], audio_file[1])
|
||||||
elif isinstance(audio_file, ndarray):
|
elif isinstance(audio_file, ndarray):
|
||||||
audio_file = AudioProcessor(torch.Tensor(audio_file[0]),
|
audio_file = AudioProcessor(torch.Tensor(audio_file[0]),
|
||||||
audio_file[1])
|
audio_file[1])
|
||||||
|
|
||||||
if not isinstance(audio_file, AudioProcessor):
|
if not isinstance(audio_file, AudioProcessor):
|
||||||
raise ValueError(f'Audiofile must be of type AudioProcessor,' \
|
raise ValueError(f'Audiofile must be of type AudioProcessor,'
|
||||||
f'not {type(audio_file)}')
|
f'not {type(audio_file)}')
|
||||||
|
|
||||||
return audio_file
|
return audio_file
|
||||||
|
|
||||||
def __repr__(self):
|
def __repr__(self):
|
||||||
|
|||||||
+66
-58
@@ -4,7 +4,7 @@ allowing for user interaction to transcribe and diarize audio files.
|
|||||||
The function includes arguments for specifying the audio files, model paths,
|
The function includes arguments for specifying the audio files, model paths,
|
||||||
output formats, and other options necessary for transcription.
|
output formats, and other options necessary for transcription.
|
||||||
"""
|
"""
|
||||||
import os
|
import os
|
||||||
from argparse import ArgumentParser, ArgumentDefaultsHelpFormatter
|
from argparse import ArgumentParser, ArgumentDefaultsHelpFormatter
|
||||||
import json
|
import json
|
||||||
|
|
||||||
@@ -12,7 +12,7 @@ from .autotranscript import Scraibe
|
|||||||
from .misc import ParseKwargs
|
from .misc import ParseKwargs
|
||||||
|
|
||||||
|
|
||||||
from whisper.tokenizer import LANGUAGES , TO_LANGUAGE_CODE
|
from whisper.tokenizer import LANGUAGES, TO_LANGUAGE_CODE
|
||||||
from torch.cuda import is_available
|
from torch.cuda import is_available
|
||||||
from torch import set_num_threads
|
from torch import set_num_threads
|
||||||
|
|
||||||
@@ -26,42 +26,43 @@ def cli():
|
|||||||
This function can be executed from the command line to perform transcription tasks, providing a
|
This function can be executed from the command line to perform transcription tasks, providing a
|
||||||
user-friendly way to access the Scraibe class functionalities.
|
user-friendly way to access the Scraibe class functionalities.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
def str2bool(string):
|
def str2bool(string):
|
||||||
str2val = {"True": True, "False": False}
|
str2val = {"True": True, "False": False}
|
||||||
if string in str2val:
|
if string in str2val:
|
||||||
return str2val[string]
|
return str2val[string]
|
||||||
else:
|
else:
|
||||||
raise ValueError(f"Expected one of {set(str2val.keys())}, got {string}")
|
raise ValueError(
|
||||||
|
f"Expected one of {set(str2val.keys())}, got {string}")
|
||||||
|
|
||||||
parser = ArgumentParser(formatter_class = ArgumentDefaultsHelpFormatter)
|
parser = ArgumentParser(formatter_class=ArgumentDefaultsHelpFormatter)
|
||||||
|
|
||||||
group = parser.add_mutually_exclusive_group()
|
group = parser.add_mutually_exclusive_group()
|
||||||
|
|
||||||
parser.add_argument("-f","--audio-files", nargs="+", type=str, default=None,
|
parser.add_argument("-f", "--audio-files", nargs="+", type=str, default=None,
|
||||||
help="List of audio files to transcribe.")
|
help="List of audio files to transcribe.")
|
||||||
|
|
||||||
group.add_argument('--start-server', action='store_true',
|
group.add_argument('--start-server', action='store_true',
|
||||||
help='Start the Gradio app.' \
|
help='Start the Gradio app.'
|
||||||
'If set, all other arguments are ignored' \
|
'If set, all other arguments are ignored'
|
||||||
'besides --server-config or --server-kwargs.')
|
'besides --server-config or --server-kwargs.')
|
||||||
|
|
||||||
parser.add_argument("--server-config", type=str, default= None,
|
parser.add_argument("--server-config", type=str, default=None,
|
||||||
help="Path to the configy.yml file.")
|
help="Path to the configy.yml file.")
|
||||||
|
|
||||||
parser.add_argument('--server-kwargs', nargs='*', action=ParseKwargs, default={},
|
parser.add_argument('--server-kwargs', nargs='*', action=ParseKwargs, default={},
|
||||||
help='Keyword arguments for the Gradio app.')
|
help='Keyword arguments for the Gradio app.')
|
||||||
|
|
||||||
parser.add_argument("--whisper-model-name", default="medium",
|
parser.add_argument("--whisper-model-name", default="medium",
|
||||||
help="Name of the Whisper model to use.")
|
help="Name of the Whisper model to use.")
|
||||||
|
|
||||||
parser.add_argument("--whisper-model-directory", type=str, default= None,
|
parser.add_argument("--whisper-model-directory", type=str, default=None,
|
||||||
help="Path to save Whisper model files; defaults to ./models/whisper.")
|
help="Path to save Whisper model files; defaults to ./models/whisper.")
|
||||||
|
|
||||||
parser.add_argument("--diarization-directory", type=str, default= None,
|
parser.add_argument("--diarization-directory", type=str, default=None,
|
||||||
help="Path to the diarization model directory.")
|
help="Path to the diarization model directory.")
|
||||||
|
|
||||||
parser.add_argument("--hf-token", default= None, type=str,
|
parser.add_argument("--hf-token", default=None, type=str,
|
||||||
help="HuggingFace token for private model download.")
|
help="HuggingFace token for private model download.")
|
||||||
|
|
||||||
parser.add_argument("--inference-device",
|
parser.add_argument("--inference-device",
|
||||||
@@ -82,105 +83,112 @@ def cli():
|
|||||||
parser.add_argument("--verbose-output", type=str2bool, default=True,
|
parser.add_argument("--verbose-output", type=str2bool, default=True,
|
||||||
help="Enable or disable progress and debug messages.")
|
help="Enable or disable progress and debug messages.")
|
||||||
|
|
||||||
parser.add_argument("--task", type=str, default= 'autotranscribe', # unifinished code
|
parser.add_argument("--task", type=str, default='autotranscribe', # unifinished code
|
||||||
choices=["autotranscribe", "diarization",
|
choices=["autotranscribe", "diarization",
|
||||||
"autotranscribe+translate", "translate", 'transcribe'],
|
"autotranscribe+translate", "translate", 'transcribe'],
|
||||||
help="Choose to perform transcription, diarization, or translation. \
|
help="Choose to perform transcription, diarization, or translation. \
|
||||||
If set to translate, the output will be translated to English.")
|
If set to translate, the output will be translated to English.")
|
||||||
|
|
||||||
parser.add_argument("--language", type=str, default=None,
|
parser.add_argument("--language", type=str, default=None,
|
||||||
choices=sorted(LANGUAGES.keys()) + sorted([k.title() for k in TO_LANGUAGE_CODE.keys()]),
|
choices=sorted(
|
||||||
|
LANGUAGES.keys()) + sorted([k.title() for k in TO_LANGUAGE_CODE.keys()]),
|
||||||
help="Language spoken in the audio. Specify None to perform language detection.")
|
help="Language spoken in the audio. Specify None to perform language detection.")
|
||||||
|
|
||||||
args = parser.parse_args()
|
args = parser.parse_args()
|
||||||
|
|
||||||
arg_dict = vars(args)
|
arg_dict = vars(args)
|
||||||
|
|
||||||
# configure output
|
# configure output
|
||||||
out_folder = arg_dict.pop("output_directory")
|
out_folder = arg_dict.pop("output_directory")
|
||||||
os.makedirs(out_folder, exist_ok=True)
|
os.makedirs(out_folder, exist_ok=True)
|
||||||
|
|
||||||
out_format = arg_dict.pop("output_format")
|
out_format = arg_dict.pop("output_format")
|
||||||
|
|
||||||
# seup server arg:
|
# seup server arg:
|
||||||
start_server = arg_dict.pop("start_server")
|
start_server = arg_dict.pop("start_server")
|
||||||
|
|
||||||
task = arg_dict.pop("task")
|
task = arg_dict.pop("task")
|
||||||
|
|
||||||
if args.num_threads > 0:
|
if args.num_threads > 0:
|
||||||
set_num_threads(arg_dict.pop("num_threads"))
|
set_num_threads(arg_dict.pop("num_threads"))
|
||||||
|
|
||||||
class_kwargs = {'whisper_model' : arg_dict.pop("whisper_model_name"),
|
class_kwargs = {'whisper_model': arg_dict.pop("whisper_model_name"),
|
||||||
'dia_model': arg_dict.pop("diarization_directory"),
|
'dia_model': arg_dict.pop("diarization_directory"),
|
||||||
'use_auth_token' : arg_dict.pop("hf_token")}
|
'use_auth_token': arg_dict.pop("hf_token")}
|
||||||
|
|
||||||
if arg_dict["whisper_model_directory"]:
|
if arg_dict["whisper_model_directory"]:
|
||||||
class_kwargs["download_root"] = arg_dict.pop("whisper_model_directory")
|
class_kwargs["download_root"] = arg_dict.pop("whisper_model_directory")
|
||||||
|
|
||||||
if not start_server:
|
if not start_server:
|
||||||
|
|
||||||
model = Scraibe(**class_kwargs)
|
model = Scraibe(**class_kwargs)
|
||||||
|
|
||||||
if arg_dict["audio_files"]:
|
if arg_dict["audio_files"]:
|
||||||
audio_files = arg_dict.pop("audio_files")
|
audio_files = arg_dict.pop("audio_files")
|
||||||
|
|
||||||
if task == "autotranscribe" or task == "autotranscribe+translate":
|
if task == "autotranscribe" or task == "autotranscribe+translate":
|
||||||
for audio in audio_files:
|
for audio in audio_files:
|
||||||
if task == "autotranscribe+translate":
|
if task == "autotranscribe+translate":
|
||||||
task = "translate"
|
task = "translate"
|
||||||
else:
|
else:
|
||||||
task = "transcribe"
|
task = "transcribe"
|
||||||
|
|
||||||
out = model.autotranscribe(audio,task = task, language=arg_dict.pop("language"), verbose = arg_dict.pop("verbose_output"))
|
out = model.autotranscribe(audio, task=task, language=arg_dict.pop(
|
||||||
|
"language"), verbose=arg_dict.pop("verbose_output"))
|
||||||
basename = audio.split("/")[-1].split(".")[0]
|
basename = audio.split("/")[-1].split(".")[0]
|
||||||
print(f'Saving {basename}.{out_format} to {out_folder}')
|
print(f'Saving {basename}.{out_format} to {out_folder}')
|
||||||
out.save(os.path.join(out_folder, f"{basename}.{out_format}"))
|
out.save(os.path.join(
|
||||||
|
out_folder, f"{basename}.{out_format}"))
|
||||||
|
|
||||||
elif task == "diarization":
|
elif task == "diarization":
|
||||||
for audio in audio_files:
|
for audio in audio_files:
|
||||||
if arg_dict.pop("verbose_output"):
|
if arg_dict.pop("verbose_output"):
|
||||||
print(f"Verbose not implemented for diarization.")
|
print("Verbose not implemented for diarization.")
|
||||||
|
|
||||||
out = model.diarization(audio)
|
out = model.diarization(audio)
|
||||||
basename = audio.split("/")[-1].split(".")[0]
|
basename = audio.split("/")[-1].split(".")[0]
|
||||||
path = os.path.join(out_folder, f"{basename}.{out_format}")
|
path = os.path.join(out_folder, f"{basename}.{out_format}")
|
||||||
|
|
||||||
print(f'Saving {basename}.{out_format} to {out_folder}')
|
print(f'Saving {basename}.{out_format} to {out_folder}')
|
||||||
|
|
||||||
with open(path, "w") as f:
|
with open(path, "w") as f:
|
||||||
json.dump(json.dumps(out, indent= 1), f)
|
json.dump(json.dumps(out, indent=1), f)
|
||||||
|
|
||||||
elif task == "transcribe" or task == "translate":
|
elif task == "transcribe" or task == "translate":
|
||||||
|
|
||||||
for audio in audio_files:
|
for audio in audio_files:
|
||||||
|
|
||||||
out = model.transcribe(audio, task = task,
|
out = model.transcribe(audio, task=task,
|
||||||
language= arg_dict.pop("language"),
|
language=arg_dict.pop("language"),
|
||||||
verbose = arg_dict.pop("verbose_output"))
|
verbose=arg_dict.pop("verbose_output"))
|
||||||
basename = audio.split("/")[-1].split(".")[0]
|
basename = audio.split("/")[-1].split(".")[0]
|
||||||
path = os.path.join(out_folder, f"{basename}.{out_format}")
|
path = os.path.join(out_folder, f"{basename}.{out_format}")
|
||||||
with open(path, "w") as f:
|
with open(path, "w") as f:
|
||||||
f.write(out)
|
f.write(out)
|
||||||
|
|
||||||
|
else: # unfinished code
|
||||||
else: # unfinished code
|
|
||||||
raise NotImplementedError("Currently not Working")
|
raise NotImplementedError("Currently not Working")
|
||||||
import subprocess
|
import subprocess
|
||||||
import sys
|
import sys
|
||||||
|
|
||||||
execute_path = os.path.join(os.path.dirname(__file__), "app/app_starter.py")
|
execute_path = os.path.join(
|
||||||
|
os.path.dirname(__file__), "app/app_starter.py")
|
||||||
|
|
||||||
config = arg_dict.pop("server_config")
|
config = arg_dict.pop("server_config")
|
||||||
server_kwargs = arg_dict.pop("server_kwargs")
|
server_kwargs = arg_dict.pop("server_kwargs")
|
||||||
|
|
||||||
if not config:
|
if not config:
|
||||||
subprocess.run([sys.executable, execute_path, f"--server-kwargs={server_kwargs}"])
|
subprocess.run([sys.executable, execute_path,
|
||||||
|
f"--server-kwargs={server_kwargs}"])
|
||||||
elif not server_kwargs:
|
elif not server_kwargs:
|
||||||
subprocess.run([sys.executable, execute_path, f"--server-config={config}"])
|
subprocess.run([sys.executable, execute_path,
|
||||||
|
f"--server-config={config}"])
|
||||||
elif not config and not server_kwargs:
|
elif not config and not server_kwargs:
|
||||||
subprocess.run([sys.executable, execute_path])
|
subprocess.run([sys.executable, execute_path])
|
||||||
else:
|
else:
|
||||||
subprocess.run([sys.executable, execute_path, f"--server-config={config}", f"--server-kwargs={server_kwargs}"])
|
subprocess.run([sys.executable, execute_path,
|
||||||
|
f"--server-config={config}", f"--server-kwargs={server_kwargs}"])
|
||||||
|
|
||||||
|
|
||||||
if __name__ == "__main__":
|
if __name__ == "__main__":
|
||||||
cli()
|
cli()
|
||||||
|
|||||||
+56
-52
@@ -37,15 +37,16 @@ from pyannote.audio import Pipeline
|
|||||||
from pyannote.audio.pipelines.speaker_diarization import SpeakerDiarization
|
from pyannote.audio.pipelines.speaker_diarization import SpeakerDiarization
|
||||||
from torch import Tensor
|
from torch import Tensor
|
||||||
from torch import device as torch_device
|
from torch import device as torch_device
|
||||||
from torch.cuda import is_available, current_device
|
from torch.cuda import is_available
|
||||||
from huggingface_hub import HfApi
|
from huggingface_hub import HfApi
|
||||||
from huggingface_hub.utils import RepositoryNotFoundError
|
from huggingface_hub.utils import RepositoryNotFoundError
|
||||||
|
|
||||||
from .misc import PYANNOTE_DEFAULT_PATH, PYANNOTE_DEFAULT_CONFIG
|
from .misc import PYANNOTE_DEFAULT_PATH, PYANNOTE_DEFAULT_CONFIG
|
||||||
Annotation = TypeVar('Annotation')
|
Annotation = TypeVar('Annotation')
|
||||||
|
|
||||||
TOKEN_PATH = os.path.join(os.path.dirname(
|
TOKEN_PATH = os.path.join(os.path.dirname(
|
||||||
os.path.realpath(__file__)), '.pyannotetoken')
|
os.path.realpath(__file__)), '.pyannotetoken')
|
||||||
|
|
||||||
|
|
||||||
class Diariser:
|
class Diariser:
|
||||||
"""
|
"""
|
||||||
@@ -55,12 +56,12 @@ class Diariser:
|
|||||||
Args:
|
Args:
|
||||||
model: The pretrained model to use for diarization.
|
model: The pretrained model to use for diarization.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
def __init__(self, model) -> None:
|
def __init__(self, model) -> None:
|
||||||
|
|
||||||
self.model = model
|
self.model = model
|
||||||
|
|
||||||
def diarization(self, audiofile : Union[str, Tensor, dict] ,
|
def diarization(self, audiofile: Union[str, Tensor, dict],
|
||||||
*args, **kwargs) -> Annotation:
|
*args, **kwargs) -> Annotation:
|
||||||
"""
|
"""
|
||||||
Perform speaker diarization on the provided audio file,
|
Perform speaker diarization on the provided audio file,
|
||||||
@@ -79,15 +80,15 @@ class Diariser:
|
|||||||
to the diarization process.
|
to the diarization process.
|
||||||
"""
|
"""
|
||||||
kwargs = self._get_diarisation_kwargs(**kwargs)
|
kwargs = self._get_diarisation_kwargs(**kwargs)
|
||||||
|
|
||||||
diarization = self.model(audiofile,*args, **kwargs)
|
diarization = self.model(audiofile, *args, **kwargs)
|
||||||
|
|
||||||
out = self.format_diarization_output(diarization)
|
out = self.format_diarization_output(diarization)
|
||||||
|
|
||||||
return out
|
return out
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
def format_diarization_output(dia : Annotation) -> dict:
|
def format_diarization_output(dia: Annotation) -> dict:
|
||||||
"""
|
"""
|
||||||
Formats the raw diarization output into a more usable structure for this project.
|
Formats the raw diarization output into a more usable structure for this project.
|
||||||
|
|
||||||
@@ -99,14 +100,14 @@ class Diariser:
|
|||||||
as keys and a list of tuples representing segments as values.
|
as keys and a list of tuples representing segments as values.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
dia_list = list(dia.itertracks(yield_label=True))
|
dia_list = list(dia.itertracks(yield_label=True))
|
||||||
diarization_output = {"speakers": [], "segments": []}
|
diarization_output = {"speakers": [], "segments": []}
|
||||||
|
|
||||||
normalized_output = []
|
normalized_output = []
|
||||||
index_start_speaker = 0
|
index_start_speaker = 0
|
||||||
index_end_speaker = 0
|
index_end_speaker = 0
|
||||||
current_speaker = str()
|
current_speaker = str()
|
||||||
|
|
||||||
###
|
###
|
||||||
# Sometimes two consecutive speakers are the same
|
# Sometimes two consecutive speakers are the same
|
||||||
# This loop removes these duplicates
|
# This loop removes these duplicates
|
||||||
@@ -115,40 +116,39 @@ class Diariser:
|
|||||||
if len(dia_list) == 1:
|
if len(dia_list) == 1:
|
||||||
normalized_output.append([0, 0, dia_list[0][2]])
|
normalized_output.append([0, 0, dia_list[0][2]])
|
||||||
else:
|
else:
|
||||||
|
|
||||||
for i, (_, _, speaker) in enumerate(dia_list):
|
for i, (_, _, speaker) in enumerate(dia_list):
|
||||||
|
|
||||||
if i == 0:
|
if i == 0:
|
||||||
current_speaker = speaker
|
current_speaker = speaker
|
||||||
|
|
||||||
if speaker != current_speaker:
|
if speaker != current_speaker:
|
||||||
|
|
||||||
index_end_speaker = i - 1
|
index_end_speaker = i - 1
|
||||||
|
|
||||||
normalized_output.append([index_start_speaker,
|
normalized_output.append([index_start_speaker,
|
||||||
index_end_speaker,
|
index_end_speaker,
|
||||||
current_speaker])
|
current_speaker])
|
||||||
|
|
||||||
index_start_speaker = i
|
index_start_speaker = i
|
||||||
current_speaker = speaker
|
current_speaker = speaker
|
||||||
|
|
||||||
|
|
||||||
if i == len(dia_list) - 1:
|
if i == len(dia_list) - 1:
|
||||||
|
|
||||||
index_end_speaker = i
|
index_end_speaker = i
|
||||||
|
|
||||||
normalized_output.append([index_start_speaker,
|
normalized_output.append([index_start_speaker,
|
||||||
index_end_speaker,
|
index_end_speaker,
|
||||||
current_speaker])
|
current_speaker])
|
||||||
|
|
||||||
for outp in normalized_output:
|
for outp in normalized_output:
|
||||||
start = dia_list[outp[0]][0].start
|
start = dia_list[outp[0]][0].start
|
||||||
end = dia_list[outp[1]][0].end
|
end = dia_list[outp[1]][0].end
|
||||||
|
|
||||||
diarization_output["segments"].append([start, end])
|
diarization_output["segments"].append([start, end])
|
||||||
diarization_output["speakers"].append(outp[2])
|
diarization_output["speakers"].append(outp[2])
|
||||||
return diarization_output
|
return diarization_output
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
def _get_token():
|
def _get_token():
|
||||||
"""
|
"""
|
||||||
@@ -161,14 +161,14 @@ class Diariser:
|
|||||||
Returns:
|
Returns:
|
||||||
str: The Huggingface token.
|
str: The Huggingface token.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
if os.path.exists(TOKEN_PATH):
|
if os.path.exists(TOKEN_PATH):
|
||||||
with open(TOKEN_PATH, 'r', encoding="utf-8") as file:
|
with open(TOKEN_PATH, 'r', encoding="utf-8") as file:
|
||||||
token = file.read()
|
token = file.read()
|
||||||
else:
|
else:
|
||||||
raise ValueError('No token found.' \
|
raise ValueError('No token found.'
|
||||||
'Please create a token at https://huggingface.co/settings/token' \
|
'Please create a token at https://huggingface.co/settings/token'
|
||||||
f'and save it in a file called {TOKEN_PATH}')
|
f'and save it in a file called {TOKEN_PATH}')
|
||||||
return token
|
return token
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
@@ -182,18 +182,17 @@ class Diariser:
|
|||||||
"""
|
"""
|
||||||
with open(TOKEN_PATH, 'w', encoding="utf-8") as file:
|
with open(TOKEN_PATH, 'w', encoding="utf-8") as file:
|
||||||
file.write(token)
|
file.write(token)
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
def load_model(cls,
|
def load_model(cls,
|
||||||
model: str = PYANNOTE_DEFAULT_CONFIG,
|
model: str = PYANNOTE_DEFAULT_CONFIG,
|
||||||
use_auth_token: str = None,
|
use_auth_token: str = None,
|
||||||
cache_token: bool = False,
|
cache_token: bool = False,
|
||||||
cache_dir: Union[Path, str] = PYANNOTE_DEFAULT_PATH,
|
cache_dir: Union[Path, str] = PYANNOTE_DEFAULT_PATH,
|
||||||
hparams_file: Union[str, Path] = None,
|
hparams_file: Union[str, Path] = None,
|
||||||
device: str = None,
|
device: str = None,
|
||||||
*args, **kwargs
|
*args, **kwargs
|
||||||
) -> Pipeline:
|
) -> Pipeline:
|
||||||
|
|
||||||
"""
|
"""
|
||||||
Loads a pretrained model from pyannote.audio,
|
Loads a pretrained model from pyannote.audio,
|
||||||
either from a local cache or some online repository.
|
either from a local cache or some online repository.
|
||||||
@@ -237,16 +236,18 @@ class Diariser:
|
|||||||
'deprecated and will be removed in future versions.',
|
'deprecated and will be removed in future versions.',
|
||||||
category=DeprecationWarning)
|
category=DeprecationWarning)
|
||||||
# list elementes with the ending .bin
|
# list elementes with the ending .bin
|
||||||
bin_files = [f for f in os.listdir(pwd) if f.endswith(".bin")]
|
bin_files = [f for f in os.listdir(
|
||||||
|
pwd) if f.endswith(".bin")]
|
||||||
if len(bin_files) == 1:
|
if len(bin_files) == 1:
|
||||||
path_to_model = os.path.join(pwd, bin_files[0])
|
path_to_model = os.path.join(pwd, bin_files[0])
|
||||||
else:
|
else:
|
||||||
warnings.warn("Found more than one .bin file. "\
|
warnings.warn("Found more than one .bin file. "
|
||||||
"or none. Please specify the path to the model " \
|
"or none. Please specify the path to the model "
|
||||||
"or setup a huggingface token.")
|
"or setup a huggingface token.")
|
||||||
raise FileNotFoundError
|
raise FileNotFoundError
|
||||||
|
|
||||||
warnings.warn(f"Found model at {path_to_model} overwriting config file.")
|
warnings.warn(
|
||||||
|
f"Found model at {path_to_model} overwriting config file.")
|
||||||
|
|
||||||
config['pipeline']['params']['segmentation'] = path_to_model
|
config['pipeline']['params']['segmentation'] = path_to_model
|
||||||
|
|
||||||
@@ -270,22 +271,24 @@ class Diariser:
|
|||||||
if use_auth_token is None:
|
if use_auth_token is None:
|
||||||
use_auth_token = cls._get_token()
|
use_auth_token = cls._get_token()
|
||||||
else:
|
else:
|
||||||
raise FileNotFoundError(f'No local model or directory found at {model}.')
|
raise FileNotFoundError(
|
||||||
|
f'No local model or directory found at {model}.')
|
||||||
|
|
||||||
_model = Pipeline.from_pretrained(model,
|
_model = Pipeline.from_pretrained(model,
|
||||||
use_auth_token=use_auth_token,
|
use_auth_token=use_auth_token,
|
||||||
cache_dir=cache_dir,
|
cache_dir=cache_dir,
|
||||||
hparams_file=hparams_file,)
|
hparams_file=hparams_file,)
|
||||||
if _model is None:
|
if _model is None:
|
||||||
raise ValueError('Unable to load model either from local cache' \
|
raise ValueError('Unable to load model either from local cache'
|
||||||
'or from huggingface.co models. Please check your token' \
|
'or from huggingface.co models. Please check your token'
|
||||||
'or your local model path')
|
'or your local model path')
|
||||||
|
|
||||||
# try to move the model to the device
|
# try to move the model to the device
|
||||||
if device is None:
|
if device is None:
|
||||||
device = "cuda" if is_available() else "cpu"
|
device = "cuda" if is_available() else "cpu"
|
||||||
|
|
||||||
_model = _model.to(torch_device(device)) # torch_device is renamed from torch.device to avoid name conflict
|
# torch_device is renamed from torch.device to avoid name conflict
|
||||||
|
_model = _model.to(torch_device(device))
|
||||||
|
|
||||||
return cls(_model)
|
return cls(_model)
|
||||||
|
|
||||||
@@ -302,9 +305,10 @@ class Diariser:
|
|||||||
"""
|
"""
|
||||||
_possible_kwargs = SpeakerDiarization.apply.__code__.co_varnames
|
_possible_kwargs = SpeakerDiarization.apply.__code__.co_varnames
|
||||||
|
|
||||||
diarisation_kwargs = {k: v for k, v in kwargs.items() if k in _possible_kwargs}
|
diarisation_kwargs = {k: v for k,
|
||||||
|
v in kwargs.items() if k in _possible_kwargs}
|
||||||
|
|
||||||
return diarisation_kwargs
|
return diarisation_kwargs
|
||||||
|
|
||||||
def __repr__(self):
|
def __repr__(self):
|
||||||
return f"Diarisation(model={self.model})"
|
return f"Diarisation(model={self.model})"
|
||||||
|
|||||||
@@ -1,6 +1,6 @@
|
|||||||
# List of known hallucinations - adapted from:
|
# List of known hallucinations - adapted from:
|
||||||
# https://github.com/openai/whisper/discussions/928
|
# https://github.com/openai/whisper/discussions/928
|
||||||
KNOWN_HALLUCINATIONS=[
|
KNOWN_HALLUCINATIONS = [
|
||||||
# en
|
# en
|
||||||
" www.mooji.org"
|
" www.mooji.org"
|
||||||
# nl
|
# nl
|
||||||
@@ -73,7 +73,7 @@ KNOWN_HALLUCINATIONS=[
|
|||||||
" Sous-titres réalisés para la communauté d'Amara.org"
|
" Sous-titres réalisés para la communauté d'Amara.org"
|
||||||
# ln
|
# ln
|
||||||
" Sous-titres réalisés para la communauté d'Amara.org"
|
" Sous-titres réalisés para la communauté d'Amara.org"
|
||||||
# pl
|
# pl
|
||||||
" Napisy stworzone przez społeczność Amara.org",
|
" Napisy stworzone przez społeczność Amara.org",
|
||||||
" Napisy wykonane przez społeczność Amara.org",
|
" Napisy wykonane przez społeczność Amara.org",
|
||||||
" Zdjęcia i napisy stworzone przez społeczność Amara.org",
|
" Zdjęcia i napisy stworzone przez społeczność Amara.org",
|
||||||
@@ -92,4 +92,4 @@ KNOWN_HALLUCINATIONS=[
|
|||||||
# zh
|
# zh
|
||||||
"字幕由Amara.org社区提供",
|
"字幕由Amara.org社区提供",
|
||||||
"小編字幕由Amara.org社區提供"
|
"小編字幕由Amara.org社區提供"
|
||||||
]
|
]
|
||||||
|
|||||||
+12
-6
@@ -2,6 +2,7 @@ import os
|
|||||||
import yaml
|
import yaml
|
||||||
from pyannote.audio.core.model import CACHE_DIR as PYANNOTE_CACHE_DIR
|
from pyannote.audio.core.model import CACHE_DIR as PYANNOTE_CACHE_DIR
|
||||||
from argparse import Action
|
from argparse import Action
|
||||||
|
from ast import literal_eval
|
||||||
|
|
||||||
CACHE_DIR = os.getenv(
|
CACHE_DIR = os.getenv(
|
||||||
"AUTOT_CACHE",
|
"AUTOT_CACHE",
|
||||||
@@ -14,8 +15,9 @@ if CACHE_DIR != PYANNOTE_CACHE_DIR:
|
|||||||
WHISPER_DEFAULT_PATH = os.path.join(CACHE_DIR, "whisper")
|
WHISPER_DEFAULT_PATH = os.path.join(CACHE_DIR, "whisper")
|
||||||
PYANNOTE_DEFAULT_PATH = os.path.join(CACHE_DIR, "pyannote")
|
PYANNOTE_DEFAULT_PATH = os.path.join(CACHE_DIR, "pyannote")
|
||||||
PYANNOTE_DEFAULT_CONFIG = os.path.join(PYANNOTE_DEFAULT_PATH, "config.yaml") \
|
PYANNOTE_DEFAULT_CONFIG = os.path.join(PYANNOTE_DEFAULT_PATH, "config.yaml") \
|
||||||
if os.path.exists(os.path.join(PYANNOTE_DEFAULT_PATH, "config.yaml")) \
|
if os.path.exists(os.path.join(PYANNOTE_DEFAULT_PATH, "config.yaml")) \
|
||||||
else ('jaikinator/scraibe', 'pyannote/speaker-diarization-3.1')
|
else ('jaikinator/scraibe', 'pyannote/speaker-diarization-3.1')
|
||||||
|
|
||||||
|
|
||||||
def config_diarization_yaml(file_path: str, path_to_segmentation: str = None) -> None:
|
def config_diarization_yaml(file_path: str, path_to_segmentation: str = None) -> None:
|
||||||
"""Configure diarization pipeline from a YAML file.
|
"""Configure diarization pipeline from a YAML file.
|
||||||
@@ -33,25 +35,29 @@ def config_diarization_yaml(file_path: str, path_to_segmentation: str = None) ->
|
|||||||
with open(file_path, "r") as stream:
|
with open(file_path, "r") as stream:
|
||||||
yml = yaml.safe_load(stream)
|
yml = yaml.safe_load(stream)
|
||||||
|
|
||||||
segmentation_path = path_to_segmentation or os.path.join(PYANNOTE_DEFAULT_PATH, "pytorch_model.bin")
|
segmentation_path = path_to_segmentation or os.path.join(
|
||||||
|
PYANNOTE_DEFAULT_PATH, "pytorch_model.bin")
|
||||||
yml["pipeline"]["params"]["segmentation"] = segmentation_path
|
yml["pipeline"]["params"]["segmentation"] = segmentation_path
|
||||||
|
|
||||||
if not os.path.exists(segmentation_path):
|
if not os.path.exists(segmentation_path):
|
||||||
raise FileNotFoundError(f"Segmentation model not found at {segmentation_path}")
|
raise FileNotFoundError(
|
||||||
|
f"Segmentation model not found at {segmentation_path}")
|
||||||
|
|
||||||
with open(file_path, "w") as stream:
|
with open(file_path, "w") as stream:
|
||||||
yaml.dump(yml, stream)
|
yaml.dump(yml, stream)
|
||||||
|
|
||||||
|
|
||||||
class ParseKwargs(Action):
|
class ParseKwargs(Action):
|
||||||
"""
|
"""
|
||||||
Custom argparse action to parse keyword arguments.
|
Custom argparse action to parse keyword arguments.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
def __call__(self, parser, namespace, values, option_string=None):
|
def __call__(self, parser, namespace, values, option_string=None):
|
||||||
setattr(namespace, self.dest, dict())
|
setattr(namespace, self.dest, dict())
|
||||||
for value in values:
|
for value in values:
|
||||||
key, value = value.split('=')
|
key, value = value.split('=')
|
||||||
try:
|
try:
|
||||||
value = eval(value)
|
value = literal_eval(value)
|
||||||
except:
|
except:
|
||||||
pass
|
pass
|
||||||
getattr(namespace, self.dest)[key] = value
|
getattr(namespace, self.dest)[key] = value
|
||||||
|
|||||||
+282
-39
@@ -24,16 +24,20 @@ Usage:
|
|||||||
>>> transcriber.save_transcript(transcript, "path/to/save.txt")
|
>>> transcriber.save_transcript(transcript, "path/to/save.txt")
|
||||||
"""
|
"""
|
||||||
|
|
||||||
from whisper import Whisper, load_model
|
from whisper import Whisper
|
||||||
from typing import TypeVar , Union , Optional
|
from whisper import load_model as whisper_load_model
|
||||||
|
from whisperx.asr import WhisperModel
|
||||||
|
from whisperx import load_model as whisperx_load_model
|
||||||
|
from typing import TypeVar, Union, Optional
|
||||||
from torch import Tensor, device
|
from torch import Tensor, device
|
||||||
|
from torch.cuda import is_available as cuda_is_available
|
||||||
from numpy import ndarray
|
from numpy import ndarray
|
||||||
|
from inspect import signature
|
||||||
|
from abc import abstractmethod
|
||||||
|
import warnings
|
||||||
|
|
||||||
from .misc import WHISPER_DEFAULT_PATH
|
from .misc import WHISPER_DEFAULT_PATH
|
||||||
whisper = TypeVar('whisper')
|
whisper = TypeVar('whisper')
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
class Transcriber:
|
class Transcriber:
|
||||||
@@ -64,7 +68,8 @@ class Transcriber:
|
|||||||
The class supports various sizes and versions of Whisper models. Please refer to
|
The class supports various sizes and versions of Whisper models. Please refer to
|
||||||
the load_model method for available options.
|
the load_model method for available options.
|
||||||
"""
|
"""
|
||||||
def __init__(self, model: whisper , model_name: str ) -> None:
|
|
||||||
|
def __init__(self, model: whisper, model_name: str) -> None:
|
||||||
"""
|
"""
|
||||||
Initialize the Transcriber class with a Whisper model.
|
Initialize the Transcriber class with a Whisper model.
|
||||||
|
|
||||||
@@ -72,12 +77,13 @@ class Transcriber:
|
|||||||
model (whisper): The Whisper model to use for transcription.
|
model (whisper): The Whisper model to use for transcription.
|
||||||
model_name (str): The name of the model.
|
model_name (str): The name of the model.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
self.model = model
|
self.model = model
|
||||||
|
|
||||||
self.model_name = model_name
|
self.model_name = model_name
|
||||||
|
|
||||||
def transcribe(self, audio : Union[str, Tensor, ndarray] ,
|
@abstractmethod
|
||||||
|
def transcribe(self, audio: Union[str, Tensor, ndarray],
|
||||||
*args, **kwargs) -> str:
|
*args, **kwargs) -> str:
|
||||||
"""
|
"""
|
||||||
Transcribe an audio file.
|
Transcribe an audio file.
|
||||||
@@ -91,17 +97,10 @@ class Transcriber:
|
|||||||
Returns:
|
Returns:
|
||||||
str: The transcript as a string.
|
str: The transcript as a string.
|
||||||
"""
|
"""
|
||||||
|
pass
|
||||||
kwargs = self._get_whisper_kwargs(**kwargs)
|
|
||||||
|
|
||||||
if not kwargs.get("verbose"):
|
|
||||||
kwargs["verbose"] = None
|
|
||||||
|
|
||||||
result = self.model.transcribe(audio, *args, **kwargs)
|
|
||||||
return result["text"]
|
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
def save_transcript(transcript : str , save_path : str) -> None:
|
def save_transcript(transcript: str, save_path: str) -> None:
|
||||||
"""
|
"""
|
||||||
Save a transcript to a file.
|
Save a transcript to a file.
|
||||||
|
|
||||||
@@ -115,17 +114,19 @@ class Transcriber:
|
|||||||
|
|
||||||
with open(save_path, 'w') as f:
|
with open(save_path, 'w') as f:
|
||||||
f.write(transcript)
|
f.write(transcript)
|
||||||
|
|
||||||
print(f'Transcript saved to {save_path}')
|
print(f'Transcript saved to {save_path}')
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
|
@abstractmethod
|
||||||
def load_model(cls,
|
def load_model(cls,
|
||||||
model: str = "medium",
|
model: str = "medium",
|
||||||
download_root: str = WHISPER_DEFAULT_PATH,
|
whisper_type: str = 'whisper',
|
||||||
device: Optional[Union[str, device]] = None,
|
download_root: str = WHISPER_DEFAULT_PATH,
|
||||||
in_memory: bool = False,
|
device: Optional[Union[str, device]] = None,
|
||||||
*args, **kwargs
|
in_memory: bool = False,
|
||||||
) -> 'Transcriber':
|
*args, **kwargs
|
||||||
|
) -> None:
|
||||||
"""
|
"""
|
||||||
Load whisper model.
|
Load whisper model.
|
||||||
|
|
||||||
@@ -143,10 +144,92 @@ class Transcriber:
|
|||||||
- 'large-v2'
|
- 'large-v2'
|
||||||
- 'large-v3'
|
- 'large-v3'
|
||||||
- 'large'
|
- 'large'
|
||||||
|
whisper_type (str):
|
||||||
|
Type of whisper model to load. "whisper" or "whisperx".
|
||||||
download_root (str, optional): Path to download the model.
|
download_root (str, optional): Path to download the model.
|
||||||
Defaults to WHISPER_DEFAULT_PATH.
|
Defaults to WHISPER_DEFAULT_PATH.
|
||||||
|
device (Optional[Union[str, torch.device]], optional):
|
||||||
|
Device to load model on. Defaults to None.
|
||||||
|
in_memory (bool, optional): Whether to load model in memory.
|
||||||
|
Defaults to False.
|
||||||
|
args: Additional arguments only to avoid errors.
|
||||||
|
kwargs: Additional keyword arguments only to avoid errors.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
None: abscract method.
|
||||||
|
"""
|
||||||
|
pass
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def _get_whisper_kwargs(**kwargs) -> dict:
|
||||||
|
"""
|
||||||
|
Get kwargs for whisper model. Ensure that kwargs are valid.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
dict: Keyword arguments for whisper model.
|
||||||
|
"""
|
||||||
|
pass
|
||||||
|
|
||||||
|
def __repr__(self) -> str:
|
||||||
|
return f"Transcriber(model_name={self.model_name}, model={self.model})"
|
||||||
|
|
||||||
|
|
||||||
|
class WhisperTranscriber(Transcriber):
|
||||||
|
def __init__(self, model: whisper, model_name: str) -> None:
|
||||||
|
super().__init__(model, model_name)
|
||||||
|
|
||||||
|
def transcribe(self, audio: Union[str, Tensor, ndarray],
|
||||||
|
*args, **kwargs) -> str:
|
||||||
|
"""
|
||||||
|
Transcribe an audio file.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
audio (Union[str, Tensor, nparray]): The audio file to transcribe.
|
||||||
|
*args: Additional arguments.
|
||||||
|
**kwargs: Additional keyword arguments,
|
||||||
|
such as the language of the audio file.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
str: The transcript as a string.
|
||||||
|
"""
|
||||||
|
|
||||||
|
kwargs = self._get_whisper_kwargs(**kwargs)
|
||||||
|
|
||||||
|
if not kwargs.get("verbose"):
|
||||||
|
kwargs["verbose"] = None
|
||||||
|
|
||||||
|
result = self.model.transcribe(audio, *args, **kwargs)
|
||||||
|
return result["text"]
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def load_model(cls,
|
||||||
|
model: str = "medium",
|
||||||
|
download_root: str = WHISPER_DEFAULT_PATH,
|
||||||
|
device: Optional[Union[str, device]] = None,
|
||||||
|
in_memory: bool = False,
|
||||||
|
*args, **kwargs
|
||||||
|
) -> 'WhisperTranscriber':
|
||||||
|
"""
|
||||||
|
Load whisper model.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
model (str): Whisper model. Available models include:
|
||||||
|
- 'tiny.en'
|
||||||
|
- 'tiny'
|
||||||
|
- 'base.en'
|
||||||
|
- 'base'
|
||||||
|
- 'small.en'
|
||||||
|
- 'small'
|
||||||
|
- 'medium.en'
|
||||||
|
- 'medium'
|
||||||
|
- 'large-v1'
|
||||||
|
- 'large-v2'
|
||||||
|
- 'large-v3'
|
||||||
|
- 'large'
|
||||||
|
|
||||||
|
download_root (str, optional): Path to download the model.
|
||||||
|
Defaults to WHISPER_DEFAULT_PATH.
|
||||||
|
|
||||||
device (Optional[Union[str, torch.device]], optional):
|
device (Optional[Union[str, torch.device]], optional):
|
||||||
Device to load model on. Defaults to None.
|
Device to load model on. Defaults to None.
|
||||||
in_memory (bool, optional): Whether to load model in memory.
|
in_memory (bool, optional): Whether to load model in memory.
|
||||||
@@ -158,8 +241,8 @@ class Transcriber:
|
|||||||
Transcriber: A Transcriber object initialized with the specified model.
|
Transcriber: A Transcriber object initialized with the specified model.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
_model = load_model(model, download_root=download_root,
|
_model = whisper_load_model(model, download_root=download_root,
|
||||||
device=device, in_memory=in_memory)
|
device=device, in_memory=in_memory)
|
||||||
|
|
||||||
return cls(_model, model_name=model)
|
return cls(_model, model_name=model)
|
||||||
|
|
||||||
@@ -171,17 +254,177 @@ class Transcriber:
|
|||||||
Returns:
|
Returns:
|
||||||
dict: Keyword arguments for whisper model.
|
dict: Keyword arguments for whisper model.
|
||||||
"""
|
"""
|
||||||
_possible_kwargs = Whisper.transcribe.__code__.co_varnames
|
# _possible_kwargs = WhisperModel.transcribe.__code__.co_varnames
|
||||||
|
_possible_kwargs = signature(Whisper.transcribe).parameters.keys()
|
||||||
whisper_kwargs = {k: v for k, v in kwargs.items() if k in _possible_kwargs}
|
|
||||||
|
whisper_kwargs = {k: v for k,
|
||||||
|
v in kwargs.items() if k in _possible_kwargs}
|
||||||
|
|
||||||
if (task := kwargs.get("task")):
|
if (task := kwargs.get("task")):
|
||||||
whisper_kwargs["task"] = task
|
whisper_kwargs["task"] = task
|
||||||
|
|
||||||
if (language := kwargs.get("language")):
|
if (language := kwargs.get("language")):
|
||||||
whisper_kwargs["language"] = language
|
whisper_kwargs["language"] = language
|
||||||
|
|
||||||
return whisper_kwargs
|
return whisper_kwargs
|
||||||
|
|
||||||
def __repr__(self) -> str:
|
def __repr__(self) -> str:
|
||||||
return f"Transcriber(model_name={self.model_name}, model={self.model})"
|
return f"WhisperTranscriber(model_name={self.model_name}, model={self.model})"
|
||||||
|
|
||||||
|
|
||||||
|
class WhisperXTranscriber(Transcriber):
|
||||||
|
def __init__(self, model: whisper, model_name: str) -> None:
|
||||||
|
super().__init__(model, model_name)
|
||||||
|
|
||||||
|
def transcribe(self, audio: Union[str, Tensor, ndarray],
|
||||||
|
*args, **kwargs) -> str:
|
||||||
|
"""
|
||||||
|
Transcribe an audio file.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
audio (Union[str, Tensor, nparray]): The audio file to transcribe.
|
||||||
|
*args: Additional arguments.
|
||||||
|
**kwargs: Additional keyword arguments,
|
||||||
|
such as the language of the audio file.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
str: The transcript as a string.
|
||||||
|
"""
|
||||||
|
kwargs = self._get_whisper_kwargs(**kwargs)
|
||||||
|
|
||||||
|
if isinstance(audio, Tensor):
|
||||||
|
audio = audio.cpu().numpy()
|
||||||
|
result = self.model.transcribe(audio, *args, **kwargs)
|
||||||
|
text = ""
|
||||||
|
for seg in result['segments']:
|
||||||
|
text += seg['text']
|
||||||
|
return text
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def load_model(cls,
|
||||||
|
model: str = "medium",
|
||||||
|
download_root: str = WHISPER_DEFAULT_PATH,
|
||||||
|
device: Optional[Union[str, device]] = None,
|
||||||
|
*args, **kwargs
|
||||||
|
) -> 'WhisperXTranscriber':
|
||||||
|
"""
|
||||||
|
Load whisper model.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
model (str): Whisper model. Available models include:
|
||||||
|
- 'tiny.en'
|
||||||
|
- 'tiny'
|
||||||
|
- 'base.en'
|
||||||
|
- 'base'
|
||||||
|
- 'small.en'
|
||||||
|
- 'small'
|
||||||
|
- 'medium.en'
|
||||||
|
- 'medium'
|
||||||
|
- 'large-v1'
|
||||||
|
- 'large-v2'
|
||||||
|
- 'large-v3'
|
||||||
|
- 'large'
|
||||||
|
|
||||||
|
download_root (str, optional): Path to download the model.
|
||||||
|
Defaults to WHISPER_DEFAULT_PATH.
|
||||||
|
|
||||||
|
device (Optional[Union[str, torch.device]], optional):
|
||||||
|
Device to load model on. Defaults to None.
|
||||||
|
in_memory (bool, optional): Whether to load model in memory.
|
||||||
|
Defaults to False.
|
||||||
|
args: Additional arguments only to avoid errors.
|
||||||
|
kwargs: Additional keyword arguments only to avoid errors.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Transcriber: A Transcriber object initialized with the specified model.
|
||||||
|
"""
|
||||||
|
if device is None:
|
||||||
|
device = "cuda" if cuda_is_available() else "cpu"
|
||||||
|
if not isinstance(device, str):
|
||||||
|
device = str(device)
|
||||||
|
compute_type = kwargs.get('compute_type', 'float16')
|
||||||
|
if device == 'cpu' and compute_type == 'float16':
|
||||||
|
warnings.warn(f'Compute type {compute_type} not compatible with '
|
||||||
|
f'device {device}! Changing compute type to int8.')
|
||||||
|
compute_type = 'int8'
|
||||||
|
_model = whisperx_load_model(model, download_root=download_root,
|
||||||
|
device=device, compute_type=compute_type)
|
||||||
|
|
||||||
|
return cls(_model, model_name=model)
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def _get_whisper_kwargs(**kwargs) -> dict:
|
||||||
|
"""
|
||||||
|
Get kwargs for whisper model. Ensure that kwargs are valid.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
dict: Keyword arguments for whisper model.
|
||||||
|
"""
|
||||||
|
# _possible_kwargs = WhisperModel.transcribe.__code__.co_varnames
|
||||||
|
_possible_kwargs = signature(WhisperModel.transcribe).parameters.keys()
|
||||||
|
|
||||||
|
whisper_kwargs = {k: v for k,
|
||||||
|
v in kwargs.items() if k in _possible_kwargs}
|
||||||
|
|
||||||
|
if (task := kwargs.get("task")):
|
||||||
|
whisper_kwargs["task"] = task
|
||||||
|
|
||||||
|
if (language := kwargs.get("language")):
|
||||||
|
whisper_kwargs["language"] = language
|
||||||
|
|
||||||
|
return whisper_kwargs
|
||||||
|
|
||||||
|
def __repr__(self) -> str:
|
||||||
|
return f"WhisperXTranscriber(model_name={self.model_name}, model={self.model})"
|
||||||
|
|
||||||
|
|
||||||
|
def load_transcriber(model: str = "medium",
|
||||||
|
whisper_type: str = 'whisper',
|
||||||
|
download_root: str = WHISPER_DEFAULT_PATH,
|
||||||
|
device: Optional[Union[str, device]] = None,
|
||||||
|
in_memory: bool = False,
|
||||||
|
*args, **kwargs
|
||||||
|
) -> Union[WhisperTranscriber, WhisperXTranscriber]:
|
||||||
|
"""
|
||||||
|
Load whisper model.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
model (str): Whisper model. Available models include:
|
||||||
|
- 'tiny.en'
|
||||||
|
- 'tiny'
|
||||||
|
- 'base.en'
|
||||||
|
- 'base'
|
||||||
|
- 'small.en'
|
||||||
|
- 'small'
|
||||||
|
- 'medium.en'
|
||||||
|
- 'medium'
|
||||||
|
- 'large-v1'
|
||||||
|
- 'large-v2'
|
||||||
|
- 'large-v3'
|
||||||
|
- 'large'
|
||||||
|
whisper_type (str):
|
||||||
|
Type of whisper model to load. "whisper" or "whisperx".
|
||||||
|
download_root (str, optional): Path to download the model.
|
||||||
|
Defaults to WHISPER_DEFAULT_PATH.
|
||||||
|
device (Optional[Union[str, torch.device]], optional):
|
||||||
|
Device to load model on. Defaults to None.
|
||||||
|
in_memory (bool, optional): Whether to load model in memory.
|
||||||
|
Defaults to False.
|
||||||
|
args: Additional arguments only to avoid errors.
|
||||||
|
kwargs: Additional keyword arguments only to avoid errors.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Union[WhisperTranscriber, WhisperXTranscriber]:
|
||||||
|
One of the Whisper variants as Transcrbier object initialized with the specified model.
|
||||||
|
"""
|
||||||
|
if whisper_type.lower() == 'whisper':
|
||||||
|
_model = WhisperTranscriber.load_model(
|
||||||
|
model, download_root, device, in_memory, *args, **kwargs)
|
||||||
|
return _model
|
||||||
|
elif whisper_type.lower() == 'whisperx':
|
||||||
|
_model = WhisperXTranscriber.load_model(
|
||||||
|
model, download_root, device, *args, **kwargs)
|
||||||
|
return _model
|
||||||
|
else:
|
||||||
|
raise ValueError(f'Model type not recognized, exptected "whisper" '
|
||||||
|
f'or "whisperx", got {whisper_type}.')
|
||||||
|
|||||||
@@ -1,5 +1,6 @@
|
|||||||
import json
|
import json
|
||||||
import time
|
import time
|
||||||
|
from json.decoder import JSONDecodeError
|
||||||
|
|
||||||
from typing import Union
|
from typing import Union
|
||||||
|
|
||||||
@@ -8,13 +9,12 @@ from .hallucinations import KNOWN_HALLUCINATIONS
|
|||||||
ALPHABET = [*"abcdefghijklmnopqrstuvwxyz"]
|
ALPHABET = [*"abcdefghijklmnopqrstuvwxyz"]
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
class Transcript:
|
class Transcript:
|
||||||
"""
|
"""
|
||||||
Class for storing transcript data, including speaker information and text segments,
|
Class for storing transcript data, including speaker information and text segments,
|
||||||
and exporting it to various file formats such as JSON, HTML, and LaTeX.
|
and exporting it to various file formats such as JSON, HTML, and LaTeX.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
def __init__(self, transcript: dict) -> None:
|
def __init__(self, transcript: dict) -> None:
|
||||||
"""
|
"""
|
||||||
Initializes the Transcript object with the given transcript data.
|
Initializes the Transcript object with the given transcript data.
|
||||||
@@ -30,7 +30,7 @@ class Transcript:
|
|||||||
self.speakers = self._extract_speakers()
|
self.speakers = self._extract_speakers()
|
||||||
self.segments = self._extract_segments()
|
self.segments = self._extract_segments()
|
||||||
self.annotation = {}
|
self.annotation = {}
|
||||||
|
|
||||||
def annotate(self, *args, **kwargs) -> dict:
|
def annotate(self, *args, **kwargs) -> dict:
|
||||||
"""
|
"""
|
||||||
Annotates the transcript to associate specific names with speakers.
|
Annotates the transcript to associate specific names with speakers.
|
||||||
@@ -46,36 +46,41 @@ class Transcript:
|
|||||||
ValueError: If the number of speaker names does not match the number
|
ValueError: If the number of speaker names does not match the number
|
||||||
of speakers, or if an unknown speaker is found.
|
of speakers, or if an unknown speaker is found.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
annotations = {}
|
annotations = {}
|
||||||
if args and len(args) != len(self.speakers):
|
if args and len(args) != len(self.speakers):
|
||||||
raise ValueError("Number of speaker names does not match number of speakers")
|
raise ValueError(
|
||||||
|
"Number of speaker names does not match number of speakers")
|
||||||
|
|
||||||
if args:
|
if args:
|
||||||
for arg, speaker in zip(args, sorted(self.speakers)):
|
for arg, speaker in zip(args, sorted(self.speakers)):
|
||||||
|
|
||||||
annotations[speaker] = arg
|
annotations[speaker] = arg
|
||||||
|
|
||||||
invalid_speakers = set(kwargs.keys()) - set(self.speakers)
|
invalid_speakers = set(kwargs.keys()) - set(self.speakers)
|
||||||
if invalid_speakers:
|
if invalid_speakers:
|
||||||
raise ValueError(f"These keys are not speakers: {', '.join(invalid_speakers)}")
|
raise ValueError(
|
||||||
|
f"These keys are not speakers: {', '.join(invalid_speakers)}")
|
||||||
|
|
||||||
annotations.update({key: kwargs[key] for key in self.speakers if key in kwargs})
|
annotations.update({key: kwargs[key]
|
||||||
|
for key in self.speakers if key in kwargs})
|
||||||
|
|
||||||
self.annotation = annotations
|
self.annotation = annotations
|
||||||
|
|
||||||
return self
|
return self
|
||||||
|
|
||||||
def _remove_hallucinations(self) -> None:
|
def _remove_hallucinations(self) -> None:
|
||||||
"""
|
"""
|
||||||
Removes all occurances of known hallucinations from all segments of the transcript.
|
Removes all occurances of known hallucinations from all segments of the transcript.
|
||||||
Segments that are identical to empty strings afterwards are removed from the transcript.
|
Segments that are identical to empty strings afterwards are removed from the transcript.
|
||||||
"""
|
"""
|
||||||
segments_to_drop=[]
|
segments_to_drop = []
|
||||||
for id in self.transcript:
|
for id in self.transcript:
|
||||||
for snippet in KNOWN_HALLUCINATIONS:
|
for snippet in KNOWN_HALLUCINATIONS:
|
||||||
self.transcript[id]['text']=self.transcript[id]['text'].replace(snippet,'')
|
self.transcript[id]['text'] = self.transcript[id]['text'].replace(
|
||||||
if self.transcript[id]['text'] == '': segments_to_drop.append(id)
|
snippet, '')
|
||||||
|
if self.transcript[id]['text'] == '':
|
||||||
|
segments_to_drop.append(id)
|
||||||
|
|
||||||
for id in segments_to_drop:
|
for id in segments_to_drop:
|
||||||
del self.transcript[id]
|
del self.transcript[id]
|
||||||
@@ -87,9 +92,9 @@ class Transcript:
|
|||||||
Returns:
|
Returns:
|
||||||
list: List of unique speaker names in the transcript.
|
list: List of unique speaker names in the transcript.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
return list(set([self.transcript[id]["speakers"] for id in self.transcript]))
|
return list(set([self.transcript[id]["speakers"] for id in self.transcript]))
|
||||||
|
|
||||||
def _extract_segments(self) -> list:
|
def _extract_segments(self) -> list:
|
||||||
"""
|
"""
|
||||||
Extracts all the text segments from the transcript.
|
Extracts all the text segments from the transcript.
|
||||||
@@ -109,23 +114,23 @@ class Transcript:
|
|||||||
time stamps for each segment.
|
time stamps for each segment.
|
||||||
"""
|
"""
|
||||||
fstring = ""
|
fstring = ""
|
||||||
|
|
||||||
for _id in self.transcript:
|
for _id in self.transcript:
|
||||||
seq = self.transcript[_id]
|
seq = self.transcript[_id]
|
||||||
|
|
||||||
if self.annotation:
|
if self.annotation:
|
||||||
speaker = self.annotation[seq["speakers"]]
|
speaker = self.annotation[seq["speakers"]]
|
||||||
else:
|
else:
|
||||||
speaker = seq["speakers"]
|
speaker = seq["speakers"]
|
||||||
|
|
||||||
segm = seq["segments"]
|
segm = seq["segments"]
|
||||||
sseg = time.strftime("%H:%M:%S",time.gmtime(segm[0]))
|
sseg = time.strftime("%H:%M:%S", time.gmtime(segm[0]))
|
||||||
eseg = time.strftime("%H:%M:%S",time.gmtime(segm[1]))
|
eseg = time.strftime("%H:%M:%S", time.gmtime(segm[1]))
|
||||||
|
|
||||||
fstring += f"{speaker} ({sseg} ; {eseg}):\t{seq['text']}\n"
|
fstring += f"{speaker} ({sseg} ; {eseg}):\t{seq['text']}\n"
|
||||||
|
|
||||||
return fstring
|
return fstring
|
||||||
|
|
||||||
def __repr__(self) -> str:
|
def __repr__(self) -> str:
|
||||||
"""Return a string representation of the Transcript object.
|
"""Return a string representation of the Transcript object.
|
||||||
|
|
||||||
@@ -133,8 +138,8 @@ class Transcript:
|
|||||||
str: A string that provides an informative description of the object.
|
str: A string that provides an informative description of the object.
|
||||||
"""
|
"""
|
||||||
return f"Transcript(speakers = {self.speakers},"\
|
return f"Transcript(speakers = {self.speakers},"\
|
||||||
f"segments = {self.segments}, annotation = {self.annotation})"
|
f"segments = {self.segments}, annotation = {self.annotation})"
|
||||||
|
|
||||||
def get_dict(self) -> dict:
|
def get_dict(self) -> dict:
|
||||||
"""
|
"""
|
||||||
Get transcript as dict
|
Get transcript as dict
|
||||||
@@ -142,10 +147,10 @@ class Transcript:
|
|||||||
:return: transcript as dict
|
:return: transcript as dict
|
||||||
:rtype: dict
|
:rtype: dict
|
||||||
"""
|
"""
|
||||||
|
|
||||||
return self.transcript
|
return self.transcript
|
||||||
|
|
||||||
def get_json(self, *args, use_annotation : bool = True, **kwargs) -> str:
|
def get_json(self, *args, use_annotation: bool = True, **kwargs) -> str:
|
||||||
"""
|
"""
|
||||||
Get transcript as json string
|
Get transcript as json string
|
||||||
:return: transcript as json string
|
:return: transcript as json string
|
||||||
@@ -153,14 +158,14 @@ class Transcript:
|
|||||||
"""
|
"""
|
||||||
if "indent" not in kwargs:
|
if "indent" not in kwargs:
|
||||||
kwargs["indent"] = 3
|
kwargs["indent"] = 3
|
||||||
|
|
||||||
if use_annotation and self.annotation:
|
if use_annotation and self.annotation:
|
||||||
for _id in self.transcript:
|
for _id in self.transcript:
|
||||||
seq = self.transcript[_id]
|
seq = self.transcript[_id]
|
||||||
seq["speakers"] = self.annotation[seq["speakers"]]
|
seq["speakers"] = self.annotation[seq["speakers"]]
|
||||||
|
|
||||||
return json.dumps(self.transcript, *args, **kwargs)
|
return json.dumps(self.transcript, *args, **kwargs)
|
||||||
|
|
||||||
def get_html(self) -> str:
|
def get_html(self) -> str:
|
||||||
"""
|
"""
|
||||||
Get transcript as html string
|
Get transcript as html string
|
||||||
@@ -171,9 +176,9 @@ class Transcript:
|
|||||||
html = "<p>" + self.__str__().replace("\n", "<br>") + "</p>"
|
html = "<p>" + self.__str__().replace("\n", "<br>") + "</p>"
|
||||||
html = "<html><body>" + html + "</body></html>"
|
html = "<html><body>" + html + "</body></html>"
|
||||||
html = html.replace("\t", " ")
|
html = html.replace("\t", " ")
|
||||||
|
|
||||||
return html
|
return html
|
||||||
|
|
||||||
def get_md(self) -> str:
|
def get_md(self) -> str:
|
||||||
"""Get transcript as Markdown string, using HTML formatting.
|
"""Get transcript as Markdown string, using HTML formatting.
|
||||||
|
|
||||||
@@ -181,7 +186,7 @@ class Transcript:
|
|||||||
str: Transcript as a Markdown string.
|
str: Transcript as a Markdown string.
|
||||||
"""
|
"""
|
||||||
return self.get_html()
|
return self.get_html()
|
||||||
|
|
||||||
def get_tex(self) -> str:
|
def get_tex(self) -> str:
|
||||||
"""Get transcript as LaTeX string. If no annotations are present, the speakers will
|
"""Get transcript as LaTeX string. If no annotations are present, the speakers will
|
||||||
be annotated with the first letters of the alphabet.
|
be annotated with the first letters of the alphabet.
|
||||||
@@ -192,43 +197,42 @@ class Transcript:
|
|||||||
if not self.annotation:
|
if not self.annotation:
|
||||||
|
|
||||||
self.annotate(*ALPHABET[:len(self.speakers)])
|
self.annotate(*ALPHABET[:len(self.speakers)])
|
||||||
|
|
||||||
fstring ="\\begin{drama}"
|
fstring = "\\begin{drama}"
|
||||||
|
|
||||||
for speaker in self.speakers:
|
for speaker in self.speakers:
|
||||||
|
|
||||||
fstring += "\n\t\\Character{"+ str(self.annotation[speaker]) + "}" \
|
fstring += "\n\t\\Character{" + str(self.annotation[speaker]) + "}" \
|
||||||
"{"+ str(self.annotation[speaker]) + "}"
|
"{" + str(self.annotation[speaker]) + "}"
|
||||||
|
|
||||||
for id in self.transcript:
|
for id in self.transcript:
|
||||||
seq = self.transcript[id]
|
seq = self.transcript[id]
|
||||||
speaker = self.annotation[seq["speakers"]]
|
speaker = self.annotation[seq["speakers"]]
|
||||||
fstring += f"\n\\{speaker}speaks:\n{seq['text']}"
|
fstring += f"\n\\{speaker}speaks:\n{seq['text']}"
|
||||||
|
|
||||||
fstring += "\n\\end{drama}"
|
fstring += "\n\\end{drama}"
|
||||||
|
|
||||||
return fstring
|
return fstring
|
||||||
|
|
||||||
|
def to_json(self, path, *args, **kwargs) -> None:
|
||||||
def to_json(self,path, *args, **kwargs) -> None:
|
|
||||||
"""Save transcript as json file
|
"""Save transcript as json file
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
path (str): path to save file
|
path (str): path to save file
|
||||||
"""
|
"""
|
||||||
with open(path, "w") as f:
|
with open(path, "w") as f:
|
||||||
json.dump(self.transcript, f, *args, **kwargs)
|
json.dump(self.transcript, f, *args, **kwargs)
|
||||||
|
|
||||||
def to_txt(self, path: str) -> None:
|
def to_txt(self, path: str) -> None:
|
||||||
"""Save transcript as a LaTeX file (placeholder function, implementation needed).
|
"""Save transcript as a LaTeX file (placeholder function, implementation needed).
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
path (str): Path to save the LaTeX file.
|
path (str): Path to save the LaTeX file.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
with open(path, "w") as f:
|
with open(path, "w") as f:
|
||||||
f.write(self.__str__())
|
f.write(self.__str__())
|
||||||
|
|
||||||
def to_md(self, path: str) -> None:
|
def to_md(self, path: str) -> None:
|
||||||
"""Get transcript as Markdown string, using HTML formatting.
|
"""Get transcript as Markdown string, using HTML formatting.
|
||||||
|
|
||||||
@@ -236,7 +240,7 @@ class Transcript:
|
|||||||
str: Transcript as a Markdown string.
|
str: Transcript as a Markdown string.
|
||||||
"""
|
"""
|
||||||
return self.to_html(path)
|
return self.to_html(path)
|
||||||
|
|
||||||
def to_html(self, path: str) -> None:
|
def to_html(self, path: str) -> None:
|
||||||
"""
|
"""
|
||||||
Save transcript as html file
|
Save transcript as html file
|
||||||
@@ -244,10 +248,10 @@ class Transcript:
|
|||||||
:param path: path to save file
|
:param path: path to save file
|
||||||
:type path: str
|
:type path: str
|
||||||
"""
|
"""
|
||||||
|
|
||||||
with open(path, "w") as file:
|
with open(path, "w") as file:
|
||||||
file.write(self.get_html())
|
file.write(self.get_html())
|
||||||
|
|
||||||
def to_tex(self, path: str) -> None:
|
def to_tex(self, path: str) -> None:
|
||||||
"""Save transcript as a LaTeX file (placeholder function, implementation needed).
|
"""Save transcript as a LaTeX file (placeholder function, implementation needed).
|
||||||
|
|
||||||
@@ -255,7 +259,7 @@ class Transcript:
|
|||||||
path (str): Path to save the LaTeX file.
|
path (str): Path to save the LaTeX file.
|
||||||
"""
|
"""
|
||||||
pass
|
pass
|
||||||
|
|
||||||
def to_pdf(self, path: str) -> None:
|
def to_pdf(self, path: str) -> None:
|
||||||
"""Save transcript as a PDF file (placeholder function, implementation needed).
|
"""Save transcript as a PDF file (placeholder function, implementation needed).
|
||||||
|
|
||||||
@@ -263,7 +267,7 @@ class Transcript:
|
|||||||
path (str): Path to save the PDF file.
|
path (str): Path to save the PDF file.
|
||||||
"""
|
"""
|
||||||
pass
|
pass
|
||||||
|
|
||||||
def save(self, path: str, *args, **kwargs) -> None:
|
def save(self, path: str, *args, **kwargs) -> None:
|
||||||
"""Save transcript to file with the given path and file format.
|
"""Save transcript to file with the given path and file format.
|
||||||
|
|
||||||
@@ -279,7 +283,7 @@ class Transcript:
|
|||||||
Raises:
|
Raises:
|
||||||
ValueError: If the file format specified in the path is unknown.
|
ValueError: If the file format specified in the path is unknown.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
if path.endswith(".json"):
|
if path.endswith(".json"):
|
||||||
self.to_json(path, *args, **kwargs)
|
self.to_json(path, *args, **kwargs)
|
||||||
elif path.endswith(".txt"):
|
elif path.endswith(".txt"):
|
||||||
@@ -294,7 +298,7 @@ class Transcript:
|
|||||||
self.to_pdf(path, *args, **kwargs)
|
self.to_pdf(path, *args, **kwargs)
|
||||||
else:
|
else:
|
||||||
raise ValueError("Unknown file format")
|
raise ValueError("Unknown file format")
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
def from_json(cls, json: Union[dict, str]) -> "Transcript":
|
def from_json(cls, json: Union[dict, str]) -> "Transcript":
|
||||||
"""Load transcript from json file
|
"""Load transcript from json file
|
||||||
@@ -310,10 +314,8 @@ class Transcript:
|
|||||||
else:
|
else:
|
||||||
try:
|
try:
|
||||||
transcript = json.loads(json)
|
transcript = json.loads(json)
|
||||||
except:
|
except (TypeError, JSONDecodeError):
|
||||||
with open(json, "r") as f:
|
with open(json, "r") as f:
|
||||||
transcript = json.load(f)
|
transcript = json.load(f)
|
||||||
|
|
||||||
return cls(transcript)
|
|
||||||
|
|
||||||
|
return cls(transcript)
|
||||||
|
|||||||
+6
-4
@@ -10,6 +10,8 @@ VERSION = '%d.%d.%d.%d' % (MAJOR, MINOR, MICRO, NANO)
|
|||||||
|
|
||||||
# Return the git revision as a string
|
# Return the git revision as a string
|
||||||
# taken from numpy/numpy
|
# taken from numpy/numpy
|
||||||
|
|
||||||
|
|
||||||
def git_version():
|
def git_version():
|
||||||
def _minimal_ext_cmd(cmd):
|
def _minimal_ext_cmd(cmd):
|
||||||
# construct minimal environment
|
# construct minimal environment
|
||||||
@@ -24,7 +26,8 @@ def git_version():
|
|||||||
env['LANG'] = 'C'
|
env['LANG'] = 'C'
|
||||||
env['LC_ALL'] = 'C'
|
env['LC_ALL'] = 'C'
|
||||||
|
|
||||||
out = sp.Popen(cmd, stdout=sp.PIPE, stderr=sp.PIPE, env=env).communicate()[0]
|
out = sp.Popen(cmd, stdout=sp.PIPE, stderr=sp.PIPE,
|
||||||
|
env=env).communicate()[0]
|
||||||
return out
|
return out
|
||||||
|
|
||||||
try:
|
try:
|
||||||
@@ -35,6 +38,7 @@ def git_version():
|
|||||||
|
|
||||||
return GIT_REVISION
|
return GIT_REVISION
|
||||||
|
|
||||||
|
|
||||||
def _get_git_version():
|
def _get_git_version():
|
||||||
cwd = os.getcwd()
|
cwd = os.getcwd()
|
||||||
|
|
||||||
@@ -51,6 +55,7 @@ def _get_git_version():
|
|||||||
os.chdir(cwd)
|
os.chdir(cwd)
|
||||||
return res
|
return res
|
||||||
|
|
||||||
|
|
||||||
def get_version(build_version=False):
|
def get_version(build_version=False):
|
||||||
if ISRELEASED:
|
if ISRELEASED:
|
||||||
return VERSION
|
return VERSION
|
||||||
@@ -64,6 +69,3 @@ def get_version(build_version=False):
|
|||||||
return VERSION + ".dev" + date
|
return VERSION + ".dev" + date
|
||||||
else:
|
else:
|
||||||
return VERSION + ".dev0+" + GIT_REVISION[:7]
|
return VERSION + ".dev0+" + GIT_REVISION[:7]
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
|||||||
+10
-10
@@ -31,16 +31,16 @@ release = '0.1.1'
|
|||||||
# extensions coming with Sphinx (named 'sphinx.ext.*') or your custom
|
# extensions coming with Sphinx (named 'sphinx.ext.*') or your custom
|
||||||
# ones.
|
# ones.
|
||||||
extensions = ['sphinx.ext.autodoc',
|
extensions = ['sphinx.ext.autodoc',
|
||||||
'sphinx.ext.doctest',
|
'sphinx.ext.doctest',
|
||||||
'sphinx.ext.intersphinx',
|
'sphinx.ext.intersphinx',
|
||||||
'sphinx.ext.todo',
|
'sphinx.ext.todo',
|
||||||
'sphinx.ext.coverage',
|
'sphinx.ext.coverage',
|
||||||
'sphinx.ext.mathjax',
|
'sphinx.ext.mathjax',
|
||||||
'sphinx.ext.ifconfig',
|
'sphinx.ext.ifconfig',
|
||||||
'sphinx.ext.viewcode',
|
'sphinx.ext.viewcode',
|
||||||
'sphinx.ext.githubpages',
|
'sphinx.ext.githubpages',
|
||||||
'sphinx.ext.napoleon',
|
'sphinx.ext.napoleon',
|
||||||
'myst_parser']
|
'myst_parser']
|
||||||
|
|
||||||
# Napoleon settings
|
# Napoleon settings
|
||||||
napoleon_google_docstring = True
|
napoleon_google_docstring = True
|
||||||
|
|||||||
+15
-46
@@ -3,7 +3,6 @@ from scraibe.audio import AudioProcessor
|
|||||||
import torch
|
import torch
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu")
|
DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu")
|
||||||
TEST_WAVEFORM = torch.sin(torch.randn(160000)).to(DEVICE)
|
TEST_WAVEFORM = torch.sin(torch.randn(160000)).to(DEVICE)
|
||||||
TEST_SR = 16000
|
TEST_SR = 16000
|
||||||
@@ -14,21 +13,17 @@ NORMALIZATION_FACTOR = 32768
|
|||||||
@pytest.fixture
|
@pytest.fixture
|
||||||
def probe_audio_processor():
|
def probe_audio_processor():
|
||||||
"""Fixture for creating an instance of the AudioProcessor class with test waveform and sample rate.
|
"""Fixture for creating an instance of the AudioProcessor class with test waveform and sample rate.
|
||||||
|
|
||||||
This fixture is used to create an instance of the AudioProcessor class with a predfined test waveform and sample rate (TEST_SR). It returns the instantiated AudioProcessor , which can bes used as a
|
This fixture is used to create an instance of the AudioProcessor class with a predfined test waveform and sample rate (TEST_SR). It returns the instantiated AudioProcessor , which can bes used as a
|
||||||
dependency in other test functions.
|
dependency in other test functions.
|
||||||
|
|
||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
AudioProcessor (obj): An instance of the AudioProcessor class with the test waveform and sample rate.
|
AudioProcessor (obj): An instance of the AudioProcessor class with the test waveform and sample rate.
|
||||||
"""
|
"""
|
||||||
return AudioProcessor(TEST_WAVEFORM, TEST_SR)
|
return AudioProcessor(TEST_WAVEFORM, TEST_SR)
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
def test_AudioProcessor_init(probe_audio_processor):
|
def test_AudioProcessor_init(probe_audio_processor):
|
||||||
"""
|
"""
|
||||||
Test the initialization of the AudioProcessor class.
|
Test the initialization of the AudioProcessor class.
|
||||||
@@ -43,20 +38,19 @@ def test_AudioProcessor_init(probe_audio_processor):
|
|||||||
Returns:
|
Returns:
|
||||||
None
|
None
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
"""
|
|
||||||
|
"""
|
||||||
assert isinstance(probe_audio_processor, AudioProcessor)
|
assert isinstance(probe_audio_processor, AudioProcessor)
|
||||||
assert probe_audio_processor.waveform.device == TEST_WAVEFORM.device
|
assert probe_audio_processor.waveform.device == TEST_WAVEFORM.device
|
||||||
assert torch.equal(probe_audio_processor.waveform, TEST_WAVEFORM)
|
assert torch.equal(probe_audio_processor.waveform, TEST_WAVEFORM)
|
||||||
assert probe_audio_processor.sr == TEST_SR
|
assert probe_audio_processor.sr == TEST_SR
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
def test_cut(probe_audio_processor):
|
def test_cut(probe_audio_processor):
|
||||||
"""Test the cut function of the AudioProcessor class.
|
"""Test the cut function of the AudioProcessor class.
|
||||||
|
|
||||||
This test verifies that the cut function correctly extracts a segment of audio data from
|
This test verifies that the cut function correctly extracts a segment of audio data from
|
||||||
the waveform, given start and end indices. It checks whether the size of the extracted segment matches
|
the waveform, given start and end indices. It checks whether the size of the extracted segment matches
|
||||||
the expected size based on the provided start and end indices and the sample rate.
|
the expected size based on the provided start and end indices and the sample rate.
|
||||||
@@ -65,63 +59,38 @@ def test_cut(probe_audio_processor):
|
|||||||
None
|
None
|
||||||
|
|
||||||
|
|
||||||
"""
|
"""
|
||||||
|
|
||||||
start = 4
|
start = 4
|
||||||
end = 7
|
end = 7
|
||||||
trimmed_waveform = probe_audio_processor.cut(start, end)
|
trimmed_waveform = probe_audio_processor.cut(start, end)
|
||||||
expected_size = int((end - start) * TEST_SR)
|
expected_size = int((end - start) * TEST_SR)
|
||||||
real_size = trimmed_waveform.size(0)
|
real_size = trimmed_waveform.size(0)
|
||||||
assert real_size == expected_size
|
assert real_size == expected_size
|
||||||
#assert AudioProcessor(TEST_WAVEFORM, TEST_SR).cut(start, end).size() == int((end - start) * TEST_SR)
|
# assert AudioProcessor(TEST_WAVEFORM, TEST_SR).cut(start, end).size() == int((end - start) * TEST_SR)
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
def test_audio_processor_invalid_sr():
|
def test_audio_processor_invalid_sr():
|
||||||
"""Test the behavior of AudioProcessor when an invalid smaple rate is provided.
|
"""Test the behavior of AudioProcessor when an invalid smaple rate is provided.
|
||||||
|
|
||||||
This test verifies that the AudioProcessor constructor raises a ValueError when an invalid sample rate is provided. It uses the pytest.raises context manager to check if the ValueError is raised when initializing an
|
This test verifies that the AudioProcessor constructor raises a ValueError when an invalid sample rate is provided. It uses the pytest.raises context manager to check if the ValueError is raised when initializing an
|
||||||
AudioProcessor object with an invalid sample rate.
|
AudioProcessor object with an invalid sample rate.
|
||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
None
|
None
|
||||||
"""
|
"""
|
||||||
with pytest.raises(ValueError):
|
with pytest.raises(ValueError):
|
||||||
AudioProcessor(TEST_WAVEFORM, [44100,48000])
|
AudioProcessor(TEST_WAVEFORM, [44100, 48000])
|
||||||
|
|
||||||
|
|
||||||
def test_audio_processor_SAMPLE_RATE():
|
def test_audio_processor_SAMPLE_RATE():
|
||||||
"""Test the default sample rate of the AudioProcessor class.
|
"""Test the default sample rate of the AudioProcessor class.
|
||||||
|
|
||||||
This test verifies that the default sample rate of the AudioProcessor class matches the expected value defined by the constant SAMPLE_RATE. It instantiates an AudioProcessor object with a test waveform
|
This test verifies that the default sample rate of the AudioProcessor class matches the expected value defined by the constant SAMPLE_RATE. It instantiates an AudioProcessor object with a test waveform
|
||||||
and checks whether the sample rate attribute (sr) of the AudioProcessor object equals the predefined constant SAMPLE_RATE.
|
and checks whether the sample rate attribute (sr) of the AudioProcessor object equals the predefined constant SAMPLE_RATE.
|
||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
None
|
None
|
||||||
"""
|
"""
|
||||||
probe_audio_processor = AudioProcessor(TEST_WAVEFORM)
|
probe_audio_processor = AudioProcessor(TEST_WAVEFORM)
|
||||||
assert probe_audio_processor.sr == SAMPLE_RATE
|
assert probe_audio_processor.sr == SAMPLE_RATE
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
|||||||
@@ -1,20 +1,14 @@
|
|||||||
import pytest
|
import pytest
|
||||||
from scraibe import Scraibe, Diariser, Transcriber, Transcript
|
from scraibe import Scraibe, Diariser, Transcriber, Transcript
|
||||||
from unittest.mock import MagicMock, patch
|
|
||||||
import os
|
import os
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
@pytest.fixture
|
@pytest.fixture
|
||||||
def create_scraibe_instance():
|
def create_scraibe_instance():
|
||||||
if "HF_TOKEN" in os.environ:
|
if "HF_TOKEN" in os.environ:
|
||||||
return Scraibe(use_auth_token=os.environ["HF_TOKEN"] )
|
return Scraibe(use_auth_token=os.environ["HF_TOKEN"])
|
||||||
else:
|
else:
|
||||||
return Scraibe()
|
return Scraibe()
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
def test_scraibe_init(create_scraibe_instance):
|
def test_scraibe_init(create_scraibe_instance):
|
||||||
@@ -47,7 +41,7 @@ def test_scraibe_transcribe(create_scraibe_instance):
|
|||||||
model.remove_audio_file("non_existing_audio_file")
|
model.remove_audio_file("non_existing_audio_file")
|
||||||
|
|
||||||
model.remove_audio_file("audio_test_2.mp4")
|
model.remove_audio_file("audio_test_2.mp4")
|
||||||
assert not os.path.exists("audio_test_2.mp4") """
|
assert not os.path.exists("audio_test_2.mp4") """
|
||||||
|
|
||||||
|
|
||||||
""" def test_get_audio_file(create_scraibe_instance):
|
""" def test_get_audio_file(create_scraibe_instance):
|
||||||
|
|||||||
@@ -1,8 +1,5 @@
|
|||||||
import pytest
|
import pytest
|
||||||
import os
|
from scraibe import Diariser
|
||||||
from unittest import mock
|
|
||||||
from scraibe import diarisation, Diariser
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
@pytest.fixture
|
@pytest.fixture
|
||||||
@@ -15,11 +12,10 @@ def diariser_instance():
|
|||||||
Returns:
|
Returns:
|
||||||
Diariser(Obj): An instance of the Diariser class with a mocked token.
|
Diariser(Obj): An instance of the Diariser class with a mocked token.
|
||||||
"""
|
"""
|
||||||
#with mock.patch.object(Diariser, '_get_token', return_value = 'HF_TOKEN' ):
|
# with mock.patch.object(Diariser, '_get_token', return_value = 'HF_TOKEN' ):
|
||||||
return Diariser('pyannote')
|
return Diariser('pyannote')
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
def test_Diariser_init(diariser_instance):
|
def test_Diariser_init(diariser_instance):
|
||||||
"""Test the initialization of the Diariser class.
|
"""Test the initialization of the Diariser class.
|
||||||
|
|
||||||
@@ -30,18 +26,7 @@ def test_Diariser_init(diariser_instance):
|
|||||||
Args:
|
Args:
|
||||||
diariser_instance (obj): instance of the Diariser class
|
diariser_instance (obj): instance of the Diariser class
|
||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
None
|
None
|
||||||
"""
|
"""
|
||||||
assert diariser_instance.model == 'pyannote'
|
assert diariser_instance.model == 'pyannote'
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
|||||||
+46
-18
@@ -1,27 +1,26 @@
|
|||||||
import pytest
|
import pytest
|
||||||
from unittest.mock import patch
|
from scraibe import (Transcriber, WhisperTranscriber,
|
||||||
from scraibe import Transcriber
|
WhisperXTranscriber, load_transcriber)
|
||||||
import torch
|
import torch
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu")
|
DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu")
|
||||||
TEST_WAVEFORM = "Hello World"
|
TEST_WAVEFORM = "Hello World"
|
||||||
|
|
||||||
"""
|
"""
|
||||||
@pytest.mark.parametrize("audio_file, expected_transcription",[("path_to_test_audiofile", "test_transcription")] )
|
@pytest.mark.parametrize("audio_file, expected_transcription",[("path_to_test_audiofile", "test_transcription")] )
|
||||||
@patch("scraibe.Transcriber.load_model")
|
@patch("scraibe.Transcriber.load_model")
|
||||||
|
|
||||||
def test_transcriber(mock_load_model, audio_file, expected_transcription):
|
def test_transcriber(mock_load_model, audio_file, expected_transcription):
|
||||||
|
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
mock_load_model (_type_): _description_
|
mock_load_model (_type_): _description_
|
||||||
audio_file (_type_): _description_
|
audio_file (_type_): _description_
|
||||||
expected_transcription (_type_): _description_
|
expected_transcription (_type_): _description_
|
||||||
|
|
||||||
mock_model = mock_load_model.return_value
|
mock_model = mock_load_model.return_value
|
||||||
mock_model.transcribe.return_value ={"text": expected_transcription}
|
mock_model.transcribe.return_value ={"text": expected_transcription}
|
||||||
|
|
||||||
transcriber = Transcriber.load_model(model="medium")
|
transcriber = Transcriber.load_model(model="medium")
|
||||||
|
|
||||||
@@ -29,24 +28,53 @@ def test_transcriber(mock_load_model, audio_file, expected_transcription):
|
|||||||
|
|
||||||
assert transcription_result == expected_transcription """
|
assert transcription_result == expected_transcription """
|
||||||
|
|
||||||
@pytest.fixture
|
|
||||||
def transcriber_instance():
|
|
||||||
return Transcriber.load_model('medium')
|
|
||||||
|
|
||||||
def test_transcriber_initialization(transcriber_instance):
|
@pytest.fixture
|
||||||
assert isinstance(transcriber_instance, Transcriber)
|
def whisper_instance():
|
||||||
|
return load_transcriber('medium', whisper_type='whisper')
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.fixture
|
||||||
|
def whisperx_instance():
|
||||||
|
return load_transcriber('medium', whisper_type='whisperx')
|
||||||
|
|
||||||
|
|
||||||
|
def test_whisper_base_initialization(whisper_instance):
|
||||||
|
assert isinstance(whisper_instance, Transcriber)
|
||||||
|
|
||||||
|
|
||||||
|
def test_whisperx_base_initialization(whisperx_instance):
|
||||||
|
assert isinstance(whisperx_instance, Transcriber)
|
||||||
|
|
||||||
|
|
||||||
|
def test_whisper_transcriber_initialization(whisper_instance):
|
||||||
|
assert isinstance(whisper_instance, WhisperTranscriber)
|
||||||
|
|
||||||
|
|
||||||
|
def test_whisperx_transcriber_initialization(whisperx_instance):
|
||||||
|
assert isinstance(whisperx_instance, WhisperXTranscriber)
|
||||||
|
|
||||||
|
|
||||||
|
def test_wrong_transcriber_initialization():
|
||||||
|
with pytest.raises(ValueError):
|
||||||
|
load_transcriber('medium', whisper_type='wrong_whisper')
|
||||||
|
|
||||||
|
|
||||||
def test_get_whisper_kwargs():
|
def test_get_whisper_kwargs():
|
||||||
kwargs = {"arg1": 1, "arg3": 3}
|
kwargs = {"arg1": 1, "arg3": 3}
|
||||||
valid_kwargs = Transcriber._get_whisper_kwargs(**kwargs)
|
valid_kwargs = Transcriber._get_whisper_kwargs(**kwargs)
|
||||||
assert not valid_kwargs == {"arg1": 1, "arg3": 3}
|
assert not valid_kwargs == {"arg1": 1, "arg3": 3}
|
||||||
|
|
||||||
|
|
||||||
def test_transcribe(transcriber_instance):
|
def test_whisper_transcribe(whisper_instance):
|
||||||
model = transcriber_instance
|
model = whisper_instance
|
||||||
#mocker.patch.object(transcriber_instance.model, 'transcribe', return_value={'Hello, World !'} )
|
# mocker.patch.object(transcriber_instance.model, 'transcribe', return_value={'Hello, World !'} )
|
||||||
transcript = model.transcribe('test/audio_test_2.mp4')
|
transcript = model.transcribe('test/audio_test_2.mp4')
|
||||||
assert isinstance(transcript, str)
|
assert isinstance(transcript, str)
|
||||||
|
|
||||||
|
|
||||||
|
def test_whisperx_transcribe(whisperx_instance):
|
||||||
|
model = whisperx_instance
|
||||||
|
# mocker.patch.object(transcriber_instance.model, 'transcribe', return_value={'Hello, World !'} )
|
||||||
|
transcript = model.transcribe('test/audio_test_2.mp4')
|
||||||
|
assert isinstance(transcript, str)
|
||||||
|
|||||||
Reference in New Issue
Block a user