updated diarisation file to better handle tokens
This commit is contained in:
@@ -1,13 +1,21 @@
|
|||||||
from pyannote.audio import Pipeline
|
"""
|
||||||
from pyannote.audio.pipelines.speaker_diarization import SpeakerDiarization
|
Diarisation class.
|
||||||
from torch import Tensor
|
This class is used to diarize an audio file using a pretrained model
|
||||||
|
"""
|
||||||
import os
|
import os
|
||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
from typing import TypeVar, Union
|
from typing import TypeVar, Union
|
||||||
import json
|
|
||||||
|
from pyannote.audio import Pipeline
|
||||||
|
from pyannote.audio.pipelines.speaker_diarization import SpeakerDiarization
|
||||||
|
from torch import Tensor
|
||||||
|
|
||||||
from .misc import PYANNOTE_DEFAULT_CONFIG, PYANNOTE_DEFAULT_PATH
|
from .misc import PYANNOTE_DEFAULT_CONFIG, PYANNOTE_DEFAULT_PATH
|
||||||
Annotation = TypeVar('Annotation')
|
Annotation = TypeVar('Annotation')
|
||||||
|
|
||||||
|
TOKEN_PATH = os.path.join(os.path.dirname(
|
||||||
|
os.path.realpath(__file__)), '.pyannotetoken')
|
||||||
|
|
||||||
class Diariser:
|
class Diariser:
|
||||||
"""
|
"""
|
||||||
Diarisation class
|
Diarisation class
|
||||||
@@ -15,7 +23,7 @@ class Diariser:
|
|||||||
from pyannote.audio.
|
from pyannote.audio.
|
||||||
:param model: model to use for diarization
|
:param model: model to use for diarization
|
||||||
"""
|
"""
|
||||||
def __init__(self, model,*args,**kwargs) -> None:
|
def __init__(self, model) -> None:
|
||||||
|
|
||||||
self.model = model
|
self.model = model
|
||||||
|
|
||||||
@@ -92,36 +100,40 @@ class Diariser:
|
|||||||
diarization_output["speakers"].append(outp[2])
|
diarization_output["speakers"].append(outp[2])
|
||||||
return diarization_output
|
return diarization_output
|
||||||
|
|
||||||
def save(self, path : str, *args, **kwargs) -> None:
|
|
||||||
"""
|
|
||||||
Save diarization output to a file
|
|
||||||
|
|
||||||
:param path: path to save file
|
|
||||||
:type path: str
|
|
||||||
"""
|
|
||||||
with open(path, "w") as f:
|
|
||||||
json.dump(self.transcript, f, *args, **kwargs)
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
def _get_token():
|
def _get_token():
|
||||||
# check ig .pyannotetoken.txt exists
|
"""
|
||||||
path = os.path.join(os.path.dirname(
|
Get token from .pyannotetoken.txt
|
||||||
os.path.realpath(__file__)), '.pyannotetoken')
|
:raises ValueError: No token found
|
||||||
if os.path.exists(path):
|
:return: Huggingface token
|
||||||
with open(path, 'r') as f:
|
:rtype: str
|
||||||
token = f.read()
|
"""
|
||||||
|
|
||||||
|
if os.path.exists(TOKEN_PATH):
|
||||||
|
with open(TOKEN_PATH, 'r', encoding="utf-8") as file:
|
||||||
|
token = file.read()
|
||||||
else:
|
else:
|
||||||
raise ValueError('No token found.' \
|
raise ValueError('No token found.' \
|
||||||
'Please create a token at https://huggingface.co/settings/token' \
|
'Please create a token at https://huggingface.co/settings/token' \
|
||||||
'and save it in a file called .pyannotetoken.txt')
|
f'and save it in a file called {TOKEN_PATH}')
|
||||||
return token
|
return token
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def _save_token(token):
|
||||||
|
"""
|
||||||
|
Save token to .pyannotetoken.txt
|
||||||
|
|
||||||
|
:param token: Huggingface token
|
||||||
|
:type token: str
|
||||||
|
"""
|
||||||
|
with open(TOKEN_PATH, 'r', encoding="utf-8") as file:
|
||||||
|
file.write(token)
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
def load_model(cls,
|
def load_model(cls,
|
||||||
model: str = PYANNOTE_DEFAULT_CONFIG,
|
model: str = PYANNOTE_DEFAULT_CONFIG,
|
||||||
token: str = None,
|
token: str = None,
|
||||||
|
cache_token: bool = False,
|
||||||
cache_dir: Union[Path, str] = PYANNOTE_DEFAULT_PATH,
|
cache_dir: Union[Path, str] = PYANNOTE_DEFAULT_PATH,
|
||||||
hparams_file: Union[str, Path] = None
|
hparams_file: Union[str, Path] = None
|
||||||
) -> Pipeline:
|
) -> Pipeline:
|
||||||
@@ -142,14 +154,23 @@ class Diariser:
|
|||||||
-------
|
-------
|
||||||
Pipeline Object
|
Pipeline Object
|
||||||
"""
|
"""
|
||||||
|
|
||||||
|
if cache_token and token is not None:
|
||||||
|
cls._save_token(token)
|
||||||
|
|
||||||
if not os.path.exists(model) and token is None:
|
if not os.path.exists(model) and token is None:
|
||||||
token = cls._get_token()
|
token = cls._get_token()
|
||||||
|
model = 'pyannote/speaker-diarization'
|
||||||
|
|
||||||
_model = Pipeline.from_pretrained(model,
|
_model = Pipeline.from_pretrained(model,
|
||||||
use_auth_token = token,
|
use_auth_token = token,
|
||||||
cache_dir = cache_dir,
|
cache_dir = cache_dir,
|
||||||
hparams_file = hparams_file,)
|
hparams_file = hparams_file,)
|
||||||
|
|
||||||
|
if model is None:
|
||||||
|
raise ValueError('Unable to load model either from local cache' \
|
||||||
|
'or from huggingface.co models. Please check your token' \
|
||||||
|
'or your local model path')
|
||||||
return cls(_model)
|
return cls(_model)
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
|
|||||||
Reference in New Issue
Block a user