diff --git a/autotranscript/autotranscript.py b/autotranscript/autotranscript.py index 26563de..d27dba8 100644 --- a/autotranscript/autotranscript.py +++ b/autotranscript/autotranscript.py @@ -26,7 +26,6 @@ Usage: # Standard Library Imports import os from glob import iglob -import re from subprocess import run from typing import TypeVar, Union from warnings import warn @@ -42,6 +41,7 @@ from .diarisation import Diariser from .transcriber import Transcriber, whisper from .transcript_exporter import Transcript + DiarisationType = TypeVar('DiarisationType') @@ -77,15 +77,16 @@ class AutoTranscribe: and pyannote diarization models. """ + if whisper_model is None: - self.transcriber = Transcriber.load_model("medium") + self.transcriber = Transcriber.load_model("medium", **kwargs) elif isinstance(whisper_model, str): self.transcriber = Transcriber.load_model(whisper_model, **kwargs) else: self.transcriber = whisper_model if dia_model is None: - self.diariser = Diariser.load_model() + self.diariser = Diariser.load_model(**kwargs) elif isinstance(dia_model, str): self.diariser = Diariser.load_model(dia_model, **kwargs) else: @@ -125,16 +126,17 @@ class AutoTranscribe: diarisation = self.diariser.diarization(dia_audio, **kwargs) + if not diarisation["segments"]: - warn("No segments found. Try to run transcription without diarisation.") + print("No segments found. Try to run transcription without diarisation.") + transcript = self.transcriber.transcribe(audio_file.waveform, **kwargs) - final_transcript= {"speakers" : ["speaker01"], + final_transcript= {0 : {"speakers" : 'SPEAKER_01', "segments" : [0, len(audio_file.waveform)], - "text" : transcript} + "text" : transcript}} return Transcript(final_transcript) - print("Diarisation finished. Starting transcription.") @@ -143,6 +145,8 @@ class AutoTranscribe: # Transcribe each segment and store the results final_transcript = dict() + + for i in trange(len(diarisation["segments"]), desc= "Transcribing"): seg = diarisation["segments"][i] @@ -276,4 +280,7 @@ class AutoTranscribe: if not isinstance(audio_file, AudioProcessor): raise ValueError(f'Audiofile must be of type AudioProcessor,' \ f'not {type(audio_file)}') - return audio_file \ No newline at end of file + return audio_file + + def __repr__(self): + return f"AutoTranscribe(transcriber={self.transcriber}, diariser={self.diariser})" diff --git a/autotranscript/diarisation.py b/autotranscript/diarisation.py index 5cf60ce..44964e0 100644 --- a/autotranscript/diarisation.py +++ b/autotranscript/diarisation.py @@ -177,10 +177,11 @@ class Diariser: @classmethod def load_model(cls, model: str = PYANNOTE_DEFAULT_CONFIG, - token: str = None, + use_auth_token: str = None, cache_token: bool = False, cache_dir: Union[Path, str] = PYANNOTE_DEFAULT_PATH, - hparams_file: Union[str, Path] = None + hparams_file: Union[str, Path] = None, + *args, **kwargs ) -> Pipeline: """ @@ -194,20 +195,22 @@ class Diariser: cache_token: Whether to cache the token locally for future use. cache_dir: Directory for caching models. hparams_file: Path to a YAML file containing hyperparameters. + args: Additional arguments only to avoid errors. + kwargs: Additional keyword arguments only to avoid errors. Returns: Pipeline: A pyannote.audio Pipeline object, encapsulating the loaded model. """ - if cache_token and token is not None: - cls._save_token(token) + if cache_token and use_auth_token is not None: + cls._save_token(use_auth_token) - if not os.path.exists(model) and token is None: - token = cls._get_token() + if not os.path.exists(model) and use_auth_token is None: + use_auth_token = cls._get_token() model = 'pyannote/speaker-diarization' _model = Pipeline.from_pretrained(model, - use_auth_token = token, + use_auth_token = use_auth_token, cache_dir = cache_dir, hparams_file = hparams_file,) diff --git a/autotranscript/transcriber.py b/autotranscript/transcriber.py index e319372..63174a4 100644 --- a/autotranscript/transcriber.py +++ b/autotranscript/transcriber.py @@ -120,6 +120,7 @@ class Transcriber: download_root: str = WHISPER_DEFAULT_PATH, device: Optional[Union[str, device]] = None, in_memory: bool = False, + *args, **kwargs ) -> 'Transcriber': """ Load whisper model. @@ -145,6 +146,8 @@ class Transcriber: Device to load model on. Defaults to None. in_memory (bool, optional): Whether to load model in memory. Defaults to False. + args: Additional arguments only to avoid errors. + kwargs: Additional keyword arguments only to avoid errors. Returns: Transcriber: A Transcriber object initialized with the specified model.