added kwargs parsing
This commit is contained in:
@@ -1,4 +1,5 @@
|
|||||||
from pyannote.audio import Pipeline
|
from pyannote.audio import Pipeline
|
||||||
|
from pyannote.audio.pipelines.speaker_diarization import SpeakerDiarization
|
||||||
from torch import Tensor
|
from torch import Tensor
|
||||||
import os
|
import os
|
||||||
from typing import TypeVar, Union
|
from typing import TypeVar, Union
|
||||||
@@ -23,6 +24,7 @@ class Diariser:
|
|||||||
:param kwargs: kwargs for diarization model
|
:param kwargs: kwargs for diarization model
|
||||||
:return: diarization
|
:return: diarization
|
||||||
"""
|
"""
|
||||||
|
kwargs = self._get_diarisation_kwargs(**kwargs)
|
||||||
|
|
||||||
diarization = self.model(audiofile,*args, **kwargs)
|
diarization = self.model(audiofile,*args, **kwargs)
|
||||||
|
|
||||||
@@ -132,6 +134,24 @@ class Diariser:
|
|||||||
|
|
||||||
return cls(diarization_model)
|
return cls(diarization_model)
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def _get_diarisation_kwargs(**kwargs) -> dict:
|
||||||
|
"""
|
||||||
|
Get kwargs for pyannote diarization model
|
||||||
|
Ensure that kwargs are valid
|
||||||
|
:return: kwargs for pyannote diarization model
|
||||||
|
:rtype: dict
|
||||||
|
"""
|
||||||
|
_possible_kwargs = SpeakerDiarization.apply.__code__.co_varnames
|
||||||
|
|
||||||
|
diarisation_kwargs = dict()
|
||||||
|
|
||||||
|
for k in kwargs.keys():
|
||||||
|
if k in _possible_kwargs:
|
||||||
|
diarisation_kwargs[k] = kwargs[k]
|
||||||
|
|
||||||
|
return diarisation_kwargs
|
||||||
|
|
||||||
def __repr__(self):
|
def __repr__(self):
|
||||||
return f"Diarisation(model={self.model})"
|
return f"Diarisation(model={self.model})"
|
||||||
|
|
||||||
|
|||||||
@@ -1,5 +1,5 @@
|
|||||||
|
|
||||||
import os
|
import os
|
||||||
|
from whisper import Whisper
|
||||||
from typing import TypeVar , Union
|
from typing import TypeVar , Union
|
||||||
from whisper import load_model
|
from whisper import load_model
|
||||||
from glob import glob
|
from glob import glob
|
||||||
@@ -43,8 +43,17 @@ class Transcriber:
|
|||||||
:return: transcript as string
|
:return: transcript as string
|
||||||
"""
|
"""
|
||||||
|
|
||||||
result = self.model.transcribe(audio, *args, **kwargs)
|
kwargs = self._get_whisper_kwargs(**kwargs)
|
||||||
|
|
||||||
|
if kwargs or args:
|
||||||
|
result = self.model.transcribe(audio, *args, **kwargs)
|
||||||
|
else:
|
||||||
|
# if kwargs is empty but parsed anyway whisper
|
||||||
|
# will not use the default kwargs
|
||||||
|
|
||||||
|
print("No kwargs parsed. Using default kwargs.")
|
||||||
|
result = self.model.transcribe(audio)
|
||||||
|
|
||||||
return result["text"]
|
return result["text"]
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
@@ -117,3 +126,21 @@ class Transcriber:
|
|||||||
_model = load_model(model, download_root=download_root)
|
_model = load_model(model, download_root=download_root)
|
||||||
|
|
||||||
return cls(_model)
|
return cls(_model)
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def _get_whisper_kwargs(**kwargs) -> dict:
|
||||||
|
"""
|
||||||
|
Get kwargs for whisper model.
|
||||||
|
Ensure that kwargs are valid.
|
||||||
|
:return: kwargs for whisper model
|
||||||
|
:rtype: dict
|
||||||
|
"""
|
||||||
|
_possible_kwargs = Whisper.transcribe.__code__.co_varnames
|
||||||
|
|
||||||
|
whisper_kwargs = dict()
|
||||||
|
|
||||||
|
for k in kwargs.keys():
|
||||||
|
if k in _possible_kwargs:
|
||||||
|
whisper_kwargs[k] = kwargs[k]
|
||||||
|
|
||||||
|
return whisper_kwargs
|
||||||
Reference in New Issue
Block a user