Files
scribe/autotranscript/transcriber.py
T
2023-06-30 18:44:10 +02:00

119 lines
3.2 KiB
Python

import os
from whisper import Whisper, load_model
from typing import TypeVar , Union , Optional
import torch
from glob import glob
from .misc import WHISPER_DEFAULT_PATH
whisper = TypeVar('whisper')
Tensor = TypeVar('Tensor')
nparray = TypeVar('nparray')
class Transcriber:
def __init__(self, model: whisper ) -> None:
"""
Initialize Transcriber class with a whisper model
:param model: whisper model
"""
self.model = model
def transcribe(self, audio : Union[str, Tensor, nparray] ,
*args, **kwargs) -> str:
"""
transcribe audio file
:param file: audio file to transcribe
:param args: additional arguments
:param kwargs: additional keyword arguments
example:
- language: language of the audio file
:return: transcript as string
"""
kwargs = self._get_whisper_kwargs(**kwargs)
if "verbose" not in kwargs:
kwargs["verbose"] = False
result = self.model.transcribe(audio, *args, **kwargs)
return result["text"]
@staticmethod
def save_transcript(transcript : str , save_path : str) -> None:
"""
Save transcript to file
:param transcript: transcript as string
:param savepath: path to save the transcript
:return: None
"""
with open(save_path, 'w') as f:
f.write(transcript)
f.close()
print(f'Transcript saved to {save_path}')
@classmethod
def load_model(cls,
model: str = "medium",
download_root: str = WHISPER_DEFAULT_PATH,
device: Optional[Union[str, torch.device]] = None,
in_memory: bool = False,
) -> 'Transcriber':
"""
Load whisper module
Parameters
----------
whisper : str
whisper model
available models:
- 'tiny.en'
- 'tiny'
- 'base.en'
- 'base'
- 'small.en'
- 'small'
- 'medium.en'
- 'medium'
- 'large-v1'
- 'large-v2'
- 'large'
local : bool
If true, load from local cache
download_root : str
Path to download the model
default: /models/whisper
Returns
-------
Whisper Object
"""
_model = load_model(model, download_root=download_root,
device=device, in_memory=in_memory)
return cls(_model)
@staticmethod
def _get_whisper_kwargs(**kwargs) -> dict:
"""
Get kwargs for whisper model.
Ensure that kwargs are valid.
:return: kwargs for whisper model
:rtype: dict
"""
_possible_kwargs = Whisper.transcribe.__code__.co_varnames
whisper_kwargs = dict()
for k in kwargs.keys():
if k in _possible_kwargs:
whisper_kwargs[k] = kwargs[k]
return whisper_kwargs