added kwargs to load model functionts to avoid errors
This commit is contained in:
@@ -26,7 +26,6 @@ Usage:
|
||||
# Standard Library Imports
|
||||
import os
|
||||
from glob import iglob
|
||||
import re
|
||||
from subprocess import run
|
||||
from typing import TypeVar, Union
|
||||
from warnings import warn
|
||||
@@ -42,6 +41,7 @@ from .diarisation import Diariser
|
||||
from .transcriber import Transcriber, whisper
|
||||
from .transcript_exporter import Transcript
|
||||
|
||||
|
||||
DiarisationType = TypeVar('DiarisationType')
|
||||
|
||||
|
||||
@@ -77,15 +77,16 @@ class AutoTranscribe:
|
||||
and pyannote diarization models.
|
||||
"""
|
||||
|
||||
|
||||
if whisper_model is None:
|
||||
self.transcriber = Transcriber.load_model("medium")
|
||||
self.transcriber = Transcriber.load_model("medium", **kwargs)
|
||||
elif isinstance(whisper_model, str):
|
||||
self.transcriber = Transcriber.load_model(whisper_model, **kwargs)
|
||||
else:
|
||||
self.transcriber = whisper_model
|
||||
|
||||
if dia_model is None:
|
||||
self.diariser = Diariser.load_model()
|
||||
self.diariser = Diariser.load_model(**kwargs)
|
||||
elif isinstance(dia_model, str):
|
||||
self.diariser = Diariser.load_model(dia_model, **kwargs)
|
||||
else:
|
||||
@@ -125,17 +126,18 @@ class AutoTranscribe:
|
||||
|
||||
diarisation = self.diariser.diarization(dia_audio, **kwargs)
|
||||
|
||||
|
||||
if not diarisation["segments"]:
|
||||
warn("No segments found. Try to run transcription without diarisation.")
|
||||
print("No segments found. Try to run transcription without diarisation.")
|
||||
|
||||
transcript = self.transcriber.transcribe(audio_file.waveform, **kwargs)
|
||||
|
||||
final_transcript= {"speakers" : ["speaker01"],
|
||||
final_transcript= {0 : {"speakers" : 'SPEAKER_01',
|
||||
"segments" : [0, len(audio_file.waveform)],
|
||||
"text" : transcript}
|
||||
"text" : transcript}}
|
||||
|
||||
return Transcript(final_transcript)
|
||||
|
||||
|
||||
print("Diarisation finished. Starting transcription.")
|
||||
|
||||
audio_file.sr = torch.Tensor([audio_file.sr]).to(audio_file.waveform.device)
|
||||
@@ -143,6 +145,8 @@ class AutoTranscribe:
|
||||
# Transcribe each segment and store the results
|
||||
final_transcript = dict()
|
||||
|
||||
|
||||
|
||||
for i in trange(len(diarisation["segments"]), desc= "Transcribing"):
|
||||
|
||||
seg = diarisation["segments"][i]
|
||||
@@ -277,3 +281,6 @@ class AutoTranscribe:
|
||||
raise ValueError(f'Audiofile must be of type AudioProcessor,' \
|
||||
f'not {type(audio_file)}')
|
||||
return audio_file
|
||||
|
||||
def __repr__(self):
|
||||
return f"AutoTranscribe(transcriber={self.transcriber}, diariser={self.diariser})"
|
||||
|
||||
@@ -177,10 +177,11 @@ class Diariser:
|
||||
@classmethod
|
||||
def load_model(cls,
|
||||
model: str = PYANNOTE_DEFAULT_CONFIG,
|
||||
token: str = None,
|
||||
use_auth_token: str = None,
|
||||
cache_token: bool = False,
|
||||
cache_dir: Union[Path, str] = PYANNOTE_DEFAULT_PATH,
|
||||
hparams_file: Union[str, Path] = None
|
||||
hparams_file: Union[str, Path] = None,
|
||||
*args, **kwargs
|
||||
) -> Pipeline:
|
||||
|
||||
"""
|
||||
@@ -194,20 +195,22 @@ class Diariser:
|
||||
cache_token: Whether to cache the token locally for future use.
|
||||
cache_dir: Directory for caching models.
|
||||
hparams_file: Path to a YAML file containing hyperparameters.
|
||||
args: Additional arguments only to avoid errors.
|
||||
kwargs: Additional keyword arguments only to avoid errors.
|
||||
|
||||
Returns:
|
||||
Pipeline: A pyannote.audio Pipeline object, encapsulating the loaded model.
|
||||
"""
|
||||
|
||||
if cache_token and token is not None:
|
||||
cls._save_token(token)
|
||||
if cache_token and use_auth_token is not None:
|
||||
cls._save_token(use_auth_token)
|
||||
|
||||
if not os.path.exists(model) and token is None:
|
||||
token = cls._get_token()
|
||||
if not os.path.exists(model) and use_auth_token is None:
|
||||
use_auth_token = cls._get_token()
|
||||
model = 'pyannote/speaker-diarization'
|
||||
|
||||
_model = Pipeline.from_pretrained(model,
|
||||
use_auth_token = token,
|
||||
use_auth_token = use_auth_token,
|
||||
cache_dir = cache_dir,
|
||||
hparams_file = hparams_file,)
|
||||
|
||||
|
||||
@@ -120,6 +120,7 @@ class Transcriber:
|
||||
download_root: str = WHISPER_DEFAULT_PATH,
|
||||
device: Optional[Union[str, device]] = None,
|
||||
in_memory: bool = False,
|
||||
*args, **kwargs
|
||||
) -> 'Transcriber':
|
||||
"""
|
||||
Load whisper model.
|
||||
@@ -145,6 +146,8 @@ class Transcriber:
|
||||
Device to load model on. Defaults to None.
|
||||
in_memory (bool, optional): Whether to load model in memory.
|
||||
Defaults to False.
|
||||
args: Additional arguments only to avoid errors.
|
||||
kwargs: Additional keyword arguments only to avoid errors.
|
||||
|
||||
Returns:
|
||||
Transcriber: A Transcriber object initialized with the specified model.
|
||||
|
||||
Reference in New Issue
Block a user