added kwargs parsing

This commit is contained in:
Jaikinator
2023-06-14 16:31:25 +02:00
parent 002c7b5189
commit 67e4e4585d
2 changed files with 49 additions and 2 deletions
+20
View File
@@ -1,4 +1,5 @@
from pyannote.audio import Pipeline
from pyannote.audio.pipelines.speaker_diarization import SpeakerDiarization
from torch import Tensor
import os
from typing import TypeVar, Union
@@ -23,6 +24,7 @@ class Diariser:
:param kwargs: kwargs for diarization model
:return: diarization
"""
kwargs = self._get_diarisation_kwargs(**kwargs)
diarization = self.model(audiofile,*args, **kwargs)
@@ -132,6 +134,24 @@ class Diariser:
return cls(diarization_model)
@staticmethod
def _get_diarisation_kwargs(**kwargs) -> dict:
"""
Get kwargs for pyannote diarization model
Ensure that kwargs are valid
:return: kwargs for pyannote diarization model
:rtype: dict
"""
_possible_kwargs = SpeakerDiarization.apply.__code__.co_varnames
diarisation_kwargs = dict()
for k in kwargs.keys():
if k in _possible_kwargs:
diarisation_kwargs[k] = kwargs[k]
return diarisation_kwargs
def __repr__(self):
return f"Diarisation(model={self.model})"
+29 -2
View File
@@ -1,5 +1,5 @@
import os
from whisper import Whisper
from typing import TypeVar , Union
from whisper import load_model
from glob import glob
@@ -43,8 +43,17 @@ class Transcriber:
:return: transcript as string
"""
result = self.model.transcribe(audio, *args, **kwargs)
kwargs = self._get_whisper_kwargs(**kwargs)
if kwargs or args:
result = self.model.transcribe(audio, *args, **kwargs)
else:
# if kwargs is empty but parsed anyway whisper
# will not use the default kwargs
print("No kwargs parsed. Using default kwargs.")
result = self.model.transcribe(audio)
return result["text"]
@staticmethod
@@ -117,3 +126,21 @@ class Transcriber:
_model = load_model(model, download_root=download_root)
return cls(_model)
@staticmethod
def _get_whisper_kwargs(**kwargs) -> dict:
"""
Get kwargs for whisper model.
Ensure that kwargs are valid.
:return: kwargs for whisper model
:rtype: dict
"""
_possible_kwargs = Whisper.transcribe.__code__.co_varnames
whisper_kwargs = dict()
for k in kwargs.keys():
if k in _possible_kwargs:
whisper_kwargs[k] = kwargs[k]
return whisper_kwargs