added kwargs to load model functionts to avoid errors

This commit is contained in:
Jaikinator
2023-09-01 12:57:09 +02:00
parent 66b5c8fd3e
commit edb666a98c
3 changed files with 28 additions and 15 deletions
+14 -7
View File
@@ -26,7 +26,6 @@ Usage:
# Standard Library Imports # Standard Library Imports
import os import os
from glob import iglob from glob import iglob
import re
from subprocess import run from subprocess import run
from typing import TypeVar, Union from typing import TypeVar, Union
from warnings import warn from warnings import warn
@@ -42,6 +41,7 @@ from .diarisation import Diariser
from .transcriber import Transcriber, whisper from .transcriber import Transcriber, whisper
from .transcript_exporter import Transcript from .transcript_exporter import Transcript
DiarisationType = TypeVar('DiarisationType') DiarisationType = TypeVar('DiarisationType')
@@ -77,15 +77,16 @@ class AutoTranscribe:
and pyannote diarization models. and pyannote diarization models.
""" """
if whisper_model is None: if whisper_model is None:
self.transcriber = Transcriber.load_model("medium") self.transcriber = Transcriber.load_model("medium", **kwargs)
elif isinstance(whisper_model, str): elif isinstance(whisper_model, str):
self.transcriber = Transcriber.load_model(whisper_model, **kwargs) self.transcriber = Transcriber.load_model(whisper_model, **kwargs)
else: else:
self.transcriber = whisper_model self.transcriber = whisper_model
if dia_model is None: if dia_model is None:
self.diariser = Diariser.load_model() self.diariser = Diariser.load_model(**kwargs)
elif isinstance(dia_model, str): elif isinstance(dia_model, str):
self.diariser = Diariser.load_model(dia_model, **kwargs) self.diariser = Diariser.load_model(dia_model, **kwargs)
else: else:
@@ -125,17 +126,18 @@ class AutoTranscribe:
diarisation = self.diariser.diarization(dia_audio, **kwargs) diarisation = self.diariser.diarization(dia_audio, **kwargs)
if not diarisation["segments"]: if not diarisation["segments"]:
warn("No segments found. Try to run transcription without diarisation.") print("No segments found. Try to run transcription without diarisation.")
transcript = self.transcriber.transcribe(audio_file.waveform, **kwargs) transcript = self.transcriber.transcribe(audio_file.waveform, **kwargs)
final_transcript= {"speakers" : ["speaker01"], final_transcript= {0 : {"speakers" : 'SPEAKER_01',
"segments" : [0, len(audio_file.waveform)], "segments" : [0, len(audio_file.waveform)],
"text" : transcript} "text" : transcript}}
return Transcript(final_transcript) return Transcript(final_transcript)
print("Diarisation finished. Starting transcription.") print("Diarisation finished. Starting transcription.")
audio_file.sr = torch.Tensor([audio_file.sr]).to(audio_file.waveform.device) audio_file.sr = torch.Tensor([audio_file.sr]).to(audio_file.waveform.device)
@@ -143,6 +145,8 @@ class AutoTranscribe:
# Transcribe each segment and store the results # Transcribe each segment and store the results
final_transcript = dict() final_transcript = dict()
for i in trange(len(diarisation["segments"]), desc= "Transcribing"): for i in trange(len(diarisation["segments"]), desc= "Transcribing"):
seg = diarisation["segments"][i] seg = diarisation["segments"][i]
@@ -277,3 +281,6 @@ class AutoTranscribe:
raise ValueError(f'Audiofile must be of type AudioProcessor,' \ raise ValueError(f'Audiofile must be of type AudioProcessor,' \
f'not {type(audio_file)}') f'not {type(audio_file)}')
return audio_file return audio_file
def __repr__(self):
return f"AutoTranscribe(transcriber={self.transcriber}, diariser={self.diariser})"
+10 -7
View File
@@ -177,10 +177,11 @@ class Diariser:
@classmethod @classmethod
def load_model(cls, def load_model(cls,
model: str = PYANNOTE_DEFAULT_CONFIG, model: str = PYANNOTE_DEFAULT_CONFIG,
token: str = None, use_auth_token: str = None,
cache_token: bool = False, cache_token: bool = False,
cache_dir: Union[Path, str] = PYANNOTE_DEFAULT_PATH, cache_dir: Union[Path, str] = PYANNOTE_DEFAULT_PATH,
hparams_file: Union[str, Path] = None hparams_file: Union[str, Path] = None,
*args, **kwargs
) -> Pipeline: ) -> Pipeline:
""" """
@@ -194,20 +195,22 @@ class Diariser:
cache_token: Whether to cache the token locally for future use. cache_token: Whether to cache the token locally for future use.
cache_dir: Directory for caching models. cache_dir: Directory for caching models.
hparams_file: Path to a YAML file containing hyperparameters. hparams_file: Path to a YAML file containing hyperparameters.
args: Additional arguments only to avoid errors.
kwargs: Additional keyword arguments only to avoid errors.
Returns: Returns:
Pipeline: A pyannote.audio Pipeline object, encapsulating the loaded model. Pipeline: A pyannote.audio Pipeline object, encapsulating the loaded model.
""" """
if cache_token and token is not None: if cache_token and use_auth_token is not None:
cls._save_token(token) cls._save_token(use_auth_token)
if not os.path.exists(model) and token is None: if not os.path.exists(model) and use_auth_token is None:
token = cls._get_token() use_auth_token = cls._get_token()
model = 'pyannote/speaker-diarization' model = 'pyannote/speaker-diarization'
_model = Pipeline.from_pretrained(model, _model = Pipeline.from_pretrained(model,
use_auth_token = token, use_auth_token = use_auth_token,
cache_dir = cache_dir, cache_dir = cache_dir,
hparams_file = hparams_file,) hparams_file = hparams_file,)
+3
View File
@@ -120,6 +120,7 @@ class Transcriber:
download_root: str = WHISPER_DEFAULT_PATH, download_root: str = WHISPER_DEFAULT_PATH,
device: Optional[Union[str, device]] = None, device: Optional[Union[str, device]] = None,
in_memory: bool = False, in_memory: bool = False,
*args, **kwargs
) -> 'Transcriber': ) -> 'Transcriber':
""" """
Load whisper model. Load whisper model.
@@ -145,6 +146,8 @@ class Transcriber:
Device to load model on. Defaults to None. Device to load model on. Defaults to None.
in_memory (bool, optional): Whether to load model in memory. in_memory (bool, optional): Whether to load model in memory.
Defaults to False. Defaults to False.
args: Additional arguments only to avoid errors.
kwargs: Additional keyword arguments only to avoid errors.
Returns: Returns:
Transcriber: A Transcriber object initialized with the specified model. Transcriber: A Transcriber object initialized with the specified model.