diff --git a/.github/workflows/mirror_to_gitlab.yml b/.github/workflows/mirror_to_gitlab.yml new file mode 100644 index 0000000..b100359 --- /dev/null +++ b/.github/workflows/mirror_to_gitlab.yml @@ -0,0 +1,23 @@ +name: Mirror and run GitLab CI + +on: [push, delete] + +jobs: + build: + runs-on: ubuntu-latest + steps: + - uses: actions/checkout@v3 + with: + fetch-depth: 0 + - name: Mirror + trigger CI + uses: SvanBoxel/gitlab-mirror-and-ci-action@master + with: + args: "https://git-dmz.thuenen.de/kida/i2-skills-beratungsstelle/scraibe" + env: + FOLLOW_TAGS: "true" + FORCE_PUSH: "true" + GITLAB_HOSTNAME: "git-dmz.thuenen.de" + GITLAB_USERNAME: ${{ secrets.GITLAB_USERNAME }} + GITLAB_PASSWORD: ${{ secrets.GITLAB_PASSWORD }} + GITLAB_PROJECT_ID: ${{ secrets.GITLAB_PROJECT_ID }} + GITHUB_TOKEN: ${{ secrets.GH_TOKEN }} diff --git a/requirements.txt b/requirements.txt index aed43e8..8cf1782 100644 --- a/requirements.txt +++ b/requirements.txt @@ -1,10 +1,12 @@ -openai-whisper==20230314 +torch~=2.2.0 + +openai-whisper~=20231117 numpy~=1.23.5 -pyannote.audio~=2.1.1 -pyannote.core~=4.5 -pyannote.database~=4.1.3 +pyannote.audio~=3.1.1 +pyannote.core~=5.0.0 +pyannote.database~=5.0.1 pyannote.metrics~=3.2.1 -pyannote.pipeline~=2.3 +pyannote.pipeline~=3.0.1 setuptools~=65.6.3 setuptools-rust~=1.5.2 diff --git a/scraibe/diarisation.py b/scraibe/diarisation.py index e598d30..0f0e14a 100644 --- a/scraibe/diarisation.py +++ b/scraibe/diarisation.py @@ -36,6 +36,8 @@ from typing import TypeVar, Union from pyannote.audio import Pipeline from pyannote.audio.pipelines.speaker_diarization import SpeakerDiarization from torch import Tensor +from torch import device as torch_device +from torch.cuda import is_available, current_device from .misc import PYANNOTE_DEFAULT_PATH, PYANNOTE_DEFAULT_CONFIG Annotation = TypeVar('Annotation') @@ -186,6 +188,7 @@ class Diariser: cache_token: bool = True, cache_dir: Union[Path, str] = PYANNOTE_DEFAULT_PATH, hparams_file: Union[str, Path] = None, + device: str = None, *args, **kwargs ) -> Pipeline: @@ -200,6 +203,7 @@ class Diariser: cache_token: Whether to cache the token locally for future use. cache_dir: Directory for caching models. hparams_file: Path to a YAML file containing hyperparameters. + device: Device to load the model on. args: Additional arguments only to avoid errors. kwargs: Additional keyword arguments only to avoid errors. @@ -207,6 +211,7 @@ class Diariser: Pipeline: A pyannote.audio Pipeline object, encapsulating the loaded model. """ + if cache_token and use_auth_token is not None: cls._save_token(use_auth_token) @@ -253,6 +258,12 @@ class Diariser: cache_dir = cache_dir, hparams_file = hparams_file,) + # try to move the model to the device + if device is None: + device = "cuda" if is_available() else "cpu" + + _model = _model.to(torch_device(device)) # torch_device is renamed from torch.device to avoid name conflict + if _model is None: raise ValueError('Unable to load model either from local cache' \ 'or from huggingface.co models. Please check your token' \