added kwargs parsing
This commit is contained in:
@@ -1,4 +1,5 @@
|
||||
from pyannote.audio import Pipeline
|
||||
from pyannote.audio.pipelines.speaker_diarization import SpeakerDiarization
|
||||
from torch import Tensor
|
||||
import os
|
||||
from typing import TypeVar, Union
|
||||
@@ -23,6 +24,7 @@ class Diariser:
|
||||
:param kwargs: kwargs for diarization model
|
||||
:return: diarization
|
||||
"""
|
||||
kwargs = self._get_diarisation_kwargs(**kwargs)
|
||||
|
||||
diarization = self.model(audiofile,*args, **kwargs)
|
||||
|
||||
@@ -132,6 +134,24 @@ class Diariser:
|
||||
|
||||
return cls(diarization_model)
|
||||
|
||||
@staticmethod
|
||||
def _get_diarisation_kwargs(**kwargs) -> dict:
|
||||
"""
|
||||
Get kwargs for pyannote diarization model
|
||||
Ensure that kwargs are valid
|
||||
:return: kwargs for pyannote diarization model
|
||||
:rtype: dict
|
||||
"""
|
||||
_possible_kwargs = SpeakerDiarization.apply.__code__.co_varnames
|
||||
|
||||
diarisation_kwargs = dict()
|
||||
|
||||
for k in kwargs.keys():
|
||||
if k in _possible_kwargs:
|
||||
diarisation_kwargs[k] = kwargs[k]
|
||||
|
||||
return diarisation_kwargs
|
||||
|
||||
def __repr__(self):
|
||||
return f"Diarisation(model={self.model})"
|
||||
|
||||
|
||||
Reference in New Issue
Block a user