added kwargs to load model functionts to avoid errors
This commit is contained in:
@@ -26,7 +26,6 @@ Usage:
|
|||||||
# Standard Library Imports
|
# Standard Library Imports
|
||||||
import os
|
import os
|
||||||
from glob import iglob
|
from glob import iglob
|
||||||
import re
|
|
||||||
from subprocess import run
|
from subprocess import run
|
||||||
from typing import TypeVar, Union
|
from typing import TypeVar, Union
|
||||||
from warnings import warn
|
from warnings import warn
|
||||||
@@ -42,6 +41,7 @@ from .diarisation import Diariser
|
|||||||
from .transcriber import Transcriber, whisper
|
from .transcriber import Transcriber, whisper
|
||||||
from .transcript_exporter import Transcript
|
from .transcript_exporter import Transcript
|
||||||
|
|
||||||
|
|
||||||
DiarisationType = TypeVar('DiarisationType')
|
DiarisationType = TypeVar('DiarisationType')
|
||||||
|
|
||||||
|
|
||||||
@@ -77,15 +77,16 @@ class AutoTranscribe:
|
|||||||
and pyannote diarization models.
|
and pyannote diarization models.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
|
|
||||||
if whisper_model is None:
|
if whisper_model is None:
|
||||||
self.transcriber = Transcriber.load_model("medium")
|
self.transcriber = Transcriber.load_model("medium", **kwargs)
|
||||||
elif isinstance(whisper_model, str):
|
elif isinstance(whisper_model, str):
|
||||||
self.transcriber = Transcriber.load_model(whisper_model, **kwargs)
|
self.transcriber = Transcriber.load_model(whisper_model, **kwargs)
|
||||||
else:
|
else:
|
||||||
self.transcriber = whisper_model
|
self.transcriber = whisper_model
|
||||||
|
|
||||||
if dia_model is None:
|
if dia_model is None:
|
||||||
self.diariser = Diariser.load_model()
|
self.diariser = Diariser.load_model(**kwargs)
|
||||||
elif isinstance(dia_model, str):
|
elif isinstance(dia_model, str):
|
||||||
self.diariser = Diariser.load_model(dia_model, **kwargs)
|
self.diariser = Diariser.load_model(dia_model, **kwargs)
|
||||||
else:
|
else:
|
||||||
@@ -125,17 +126,18 @@ class AutoTranscribe:
|
|||||||
|
|
||||||
diarisation = self.diariser.diarization(dia_audio, **kwargs)
|
diarisation = self.diariser.diarization(dia_audio, **kwargs)
|
||||||
|
|
||||||
|
|
||||||
if not diarisation["segments"]:
|
if not diarisation["segments"]:
|
||||||
warn("No segments found. Try to run transcription without diarisation.")
|
print("No segments found. Try to run transcription without diarisation.")
|
||||||
|
|
||||||
transcript = self.transcriber.transcribe(audio_file.waveform, **kwargs)
|
transcript = self.transcriber.transcribe(audio_file.waveform, **kwargs)
|
||||||
|
|
||||||
final_transcript= {"speakers" : ["speaker01"],
|
final_transcript= {0 : {"speakers" : 'SPEAKER_01',
|
||||||
"segments" : [0, len(audio_file.waveform)],
|
"segments" : [0, len(audio_file.waveform)],
|
||||||
"text" : transcript}
|
"text" : transcript}}
|
||||||
|
|
||||||
return Transcript(final_transcript)
|
return Transcript(final_transcript)
|
||||||
|
|
||||||
|
|
||||||
print("Diarisation finished. Starting transcription.")
|
print("Diarisation finished. Starting transcription.")
|
||||||
|
|
||||||
audio_file.sr = torch.Tensor([audio_file.sr]).to(audio_file.waveform.device)
|
audio_file.sr = torch.Tensor([audio_file.sr]).to(audio_file.waveform.device)
|
||||||
@@ -143,6 +145,8 @@ class AutoTranscribe:
|
|||||||
# Transcribe each segment and store the results
|
# Transcribe each segment and store the results
|
||||||
final_transcript = dict()
|
final_transcript = dict()
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
for i in trange(len(diarisation["segments"]), desc= "Transcribing"):
|
for i in trange(len(diarisation["segments"]), desc= "Transcribing"):
|
||||||
|
|
||||||
seg = diarisation["segments"][i]
|
seg = diarisation["segments"][i]
|
||||||
@@ -277,3 +281,6 @@ class AutoTranscribe:
|
|||||||
raise ValueError(f'Audiofile must be of type AudioProcessor,' \
|
raise ValueError(f'Audiofile must be of type AudioProcessor,' \
|
||||||
f'not {type(audio_file)}')
|
f'not {type(audio_file)}')
|
||||||
return audio_file
|
return audio_file
|
||||||
|
|
||||||
|
def __repr__(self):
|
||||||
|
return f"AutoTranscribe(transcriber={self.transcriber}, diariser={self.diariser})"
|
||||||
|
|||||||
@@ -177,10 +177,11 @@ class Diariser:
|
|||||||
@classmethod
|
@classmethod
|
||||||
def load_model(cls,
|
def load_model(cls,
|
||||||
model: str = PYANNOTE_DEFAULT_CONFIG,
|
model: str = PYANNOTE_DEFAULT_CONFIG,
|
||||||
token: str = None,
|
use_auth_token: str = None,
|
||||||
cache_token: bool = False,
|
cache_token: bool = False,
|
||||||
cache_dir: Union[Path, str] = PYANNOTE_DEFAULT_PATH,
|
cache_dir: Union[Path, str] = PYANNOTE_DEFAULT_PATH,
|
||||||
hparams_file: Union[str, Path] = None
|
hparams_file: Union[str, Path] = None,
|
||||||
|
*args, **kwargs
|
||||||
) -> Pipeline:
|
) -> Pipeline:
|
||||||
|
|
||||||
"""
|
"""
|
||||||
@@ -194,20 +195,22 @@ class Diariser:
|
|||||||
cache_token: Whether to cache the token locally for future use.
|
cache_token: Whether to cache the token locally for future use.
|
||||||
cache_dir: Directory for caching models.
|
cache_dir: Directory for caching models.
|
||||||
hparams_file: Path to a YAML file containing hyperparameters.
|
hparams_file: Path to a YAML file containing hyperparameters.
|
||||||
|
args: Additional arguments only to avoid errors.
|
||||||
|
kwargs: Additional keyword arguments only to avoid errors.
|
||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
Pipeline: A pyannote.audio Pipeline object, encapsulating the loaded model.
|
Pipeline: A pyannote.audio Pipeline object, encapsulating the loaded model.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
if cache_token and token is not None:
|
if cache_token and use_auth_token is not None:
|
||||||
cls._save_token(token)
|
cls._save_token(use_auth_token)
|
||||||
|
|
||||||
if not os.path.exists(model) and token is None:
|
if not os.path.exists(model) and use_auth_token is None:
|
||||||
token = cls._get_token()
|
use_auth_token = cls._get_token()
|
||||||
model = 'pyannote/speaker-diarization'
|
model = 'pyannote/speaker-diarization'
|
||||||
|
|
||||||
_model = Pipeline.from_pretrained(model,
|
_model = Pipeline.from_pretrained(model,
|
||||||
use_auth_token = token,
|
use_auth_token = use_auth_token,
|
||||||
cache_dir = cache_dir,
|
cache_dir = cache_dir,
|
||||||
hparams_file = hparams_file,)
|
hparams_file = hparams_file,)
|
||||||
|
|
||||||
|
|||||||
@@ -120,6 +120,7 @@ class Transcriber:
|
|||||||
download_root: str = WHISPER_DEFAULT_PATH,
|
download_root: str = WHISPER_DEFAULT_PATH,
|
||||||
device: Optional[Union[str, device]] = None,
|
device: Optional[Union[str, device]] = None,
|
||||||
in_memory: bool = False,
|
in_memory: bool = False,
|
||||||
|
*args, **kwargs
|
||||||
) -> 'Transcriber':
|
) -> 'Transcriber':
|
||||||
"""
|
"""
|
||||||
Load whisper model.
|
Load whisper model.
|
||||||
@@ -145,6 +146,8 @@ class Transcriber:
|
|||||||
Device to load model on. Defaults to None.
|
Device to load model on. Defaults to None.
|
||||||
in_memory (bool, optional): Whether to load model in memory.
|
in_memory (bool, optional): Whether to load model in memory.
|
||||||
Defaults to False.
|
Defaults to False.
|
||||||
|
args: Additional arguments only to avoid errors.
|
||||||
|
kwargs: Additional keyword arguments only to avoid errors.
|
||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
Transcriber: A Transcriber object initialized with the specified model.
|
Transcriber: A Transcriber object initialized with the specified model.
|
||||||
|
|||||||
Reference in New Issue
Block a user