From 67e4e4585da3be40190a265bcf7b12e446f2ee69 Mon Sep 17 00:00:00 2001 From: Jaikinator Date: Wed, 14 Jun 2023 16:31:25 +0200 Subject: [PATCH] added kwargs parsing --- autotranscript/diarisation.py | 20 ++++++++++++++++++++ autotranscript/transcriber.py | 31 +++++++++++++++++++++++++++++-- 2 files changed, 49 insertions(+), 2 deletions(-) diff --git a/autotranscript/diarisation.py b/autotranscript/diarisation.py index 55fd0cb..3b64fac 100644 --- a/autotranscript/diarisation.py +++ b/autotranscript/diarisation.py @@ -1,4 +1,5 @@ from pyannote.audio import Pipeline +from pyannote.audio.pipelines.speaker_diarization import SpeakerDiarization from torch import Tensor import os from typing import TypeVar, Union @@ -23,6 +24,7 @@ class Diariser: :param kwargs: kwargs for diarization model :return: diarization """ + kwargs = self._get_diarisation_kwargs(**kwargs) diarization = self.model(audiofile,*args, **kwargs) @@ -132,6 +134,24 @@ class Diariser: return cls(diarization_model) + @staticmethod + def _get_diarisation_kwargs(**kwargs) -> dict: + """ + Get kwargs for pyannote diarization model + Ensure that kwargs are valid + :return: kwargs for pyannote diarization model + :rtype: dict + """ + _possible_kwargs = SpeakerDiarization.apply.__code__.co_varnames + + diarisation_kwargs = dict() + + for k in kwargs.keys(): + if k in _possible_kwargs: + diarisation_kwargs[k] = kwargs[k] + + return diarisation_kwargs + def __repr__(self): return f"Diarisation(model={self.model})" diff --git a/autotranscript/transcriber.py b/autotranscript/transcriber.py index 069866a..57a3423 100644 --- a/autotranscript/transcriber.py +++ b/autotranscript/transcriber.py @@ -1,5 +1,5 @@ - import os +from whisper import Whisper from typing import TypeVar , Union from whisper import load_model from glob import glob @@ -43,8 +43,17 @@ class Transcriber: :return: transcript as string """ - result = self.model.transcribe(audio, *args, **kwargs) + kwargs = self._get_whisper_kwargs(**kwargs) + if kwargs or args: + result = self.model.transcribe(audio, *args, **kwargs) + else: + # if kwargs is empty but parsed anyway whisper + # will not use the default kwargs + + print("No kwargs parsed. Using default kwargs.") + result = self.model.transcribe(audio) + return result["text"] @staticmethod @@ -117,3 +126,21 @@ class Transcriber: _model = load_model(model, download_root=download_root) return cls(_model) + + @staticmethod + def _get_whisper_kwargs(**kwargs) -> dict: + """ + Get kwargs for whisper model. + Ensure that kwargs are valid. + :return: kwargs for whisper model + :rtype: dict + """ + _possible_kwargs = Whisper.transcribe.__code__.co_varnames + + whisper_kwargs = dict() + + for k in kwargs.keys(): + if k in _possible_kwargs: + whisper_kwargs[k] = kwargs[k] + + return whisper_kwargs \ No newline at end of file