unified docstrings

This commit is contained in:
Jaikinator
2023-08-23 15:32:54 +02:00
parent 9e00b13524
commit cab50cba70
+112 -58
View File
@@ -1,33 +1,91 @@
import os """
Transcriber Module
------------------
This module provides the Transcriber class, a comprehensive tool for working with Whisper models.
The Transcriber class offers functionalities such as loading different Whisper models, transcribing audio files,
and saving transcriptions to text files. It acts as an interface between various Whisper models and the user,
simplifying the process of audio transcription.
Main Features:
- Loading different sizes and versions of Whisper models.
- Transcribing audio in various formats including str, Tensor, and nparray.
- Saving the transcriptions to the specified paths.
- Adaptable to various language specifications.
- Options to control the verbosity of the transcription process.
Constants:
WHISPER_DEFAULT_PATH: Default path for downloading and loading Whisper models.
Usage:
>>> from your_package import Transcriber
>>> transcriber = Transcriber.load_model(model="medium")
>>> transcript = transcriber.transcribe(audio="path/to/audio.wav")
>>> transcriber.save_transcript(transcript, "path/to/save.txt")
"""
from whisper import Whisper, load_model from whisper import Whisper, load_model
from typing import TypeVar , Union , Optional from typing import TypeVar , Union , Optional
import torch from torch import Tensor, device
from glob import glob from numpy import ndarray
from .misc import WHISPER_DEFAULT_PATH from .misc import WHISPER_DEFAULT_PATH
whisper = TypeVar('whisper') whisper = TypeVar('whisper')
Tensor = TypeVar('Tensor')
nparray = TypeVar('nparray')
class Transcriber: class Transcriber:
"""
Transcriber Class
-----------------
The Transcriber class serves as a wrapper around Whisper models for efficient audio
transcription. By encapsulating the intricacies of loading models, processing audio,
and saving transcripts, it offers an easy-to-use interface
for users to transcribe audio files.
Attributes:
model (whisper): The Whisper model used for transcription.
Methods:
transcribe: Transcribes the given audio file.
save_transcript: Saves the transcript to a file.
load_model: Loads a specific Whisper model.
_get_whisper_kwargs: Private method to get valid keyword arguments for the whisper model.
Examples:
>>> transcriber = Transcriber.load_model(model="medium")
>>> transcript = transcriber.transcribe(audio="path/to/audio.wav")
>>> transcriber.save_transcript(transcript, "path/to/save.txt")
Note:
The class supports various sizes and versions of Whisper models. Please refer to
the load_model method for available options.
"""
def __init__(self, model: whisper ) -> None: def __init__(self, model: whisper ) -> None:
""" """
Initialize Transcriber class with a whisper model Initialize the Transcriber class with a Whisper model.
:param model: whisper model
Args:
model (whisper): The Whisper model to use for transcription.
""" """
self.model = model self.model = model
def transcribe(self, audio : Union[str, Tensor, nparray] , def transcribe(self, audio : Union[str, Tensor, ndarray] ,
*args, **kwargs) -> str: *args, **kwargs) -> str:
""" """
transcribe audio file Transcribe an audio file.
:param file: audio file to transcribe
:param args: additional arguments Args:
:param kwargs: additional keyword arguments audio (Union[str, Tensor, nparray]): The audio file to transcribe.
example: *args: Additional arguments.
- language: language of the audio file **kwargs: Additional keyword arguments,
:return: transcript as string such as the language of the audio file.
Returns:
str: The transcript as a string.
""" """
kwargs = self._get_whisper_kwargs(**kwargs) kwargs = self._get_whisper_kwargs(**kwargs)
@@ -41,15 +99,18 @@ class Transcriber:
@staticmethod @staticmethod
def save_transcript(transcript : str , save_path : str) -> None: def save_transcript(transcript : str , save_path : str) -> None:
""" """
Save transcript to file Save a transcript to a file.
:param transcript: transcript as string
:param savepath: path to save the transcript Args:
:return: None transcript (str): The transcript as a string.
save_path (str): The path to save the transcript.
Returns:
None
""" """
with open(save_path, 'w') as f: with open(save_path, 'w') as f:
f.write(transcript) f.write(transcript)
f.close()
print(f'Transcript saved to {save_path}') print(f'Transcript saved to {save_path}')
@@ -57,44 +118,38 @@ class Transcriber:
def load_model(cls, def load_model(cls,
model: str = "medium", model: str = "medium",
download_root: str = WHISPER_DEFAULT_PATH, download_root: str = WHISPER_DEFAULT_PATH,
device: Optional[Union[str, torch.device]] = None, device: Optional[Union[str, device]] = None,
in_memory: bool = False, in_memory: bool = False,
) -> 'Transcriber': ) -> 'Transcriber':
""" """
Load whisper module Load whisper model.
Parameters Args:
---------- model (str): Whisper model. Available models include:
whisper : str - 'tiny.en'
whisper model - 'tiny'
available models: - 'base.en'
- 'base'
- 'small.en'
- 'small'
- 'medium.en'
- 'medium'
- 'large-v1'
- 'large-v2'
- 'large'
- 'tiny.en' download_root (str, optional): Path to download the model.
- 'tiny' Defaults to WHISPER_DEFAULT_PATH.
- 'base.en'
- 'base'
- 'small.en'
- 'small'
- 'medium.en'
- 'medium'
- 'large-v1'
- 'large-v2'
- 'large'
local : bool device (Optional[Union[str, torch.device]], optional):
If true, load from local cache Device to load model on. Defaults to None.
in_memory (bool, optional): Whether to load model in memory.
Defaults to False.
download_root : str Returns:
Path to download the model Transcriber: A Transcriber object initialized with the specified model.
default: /models/whisper
Returns
-------
Whisper Object
""" """
_model = load_model(model, download_root=download_root, _model = load_model(model, download_root=download_root,
device=device, in_memory=in_memory) device=device, in_memory=in_memory)
@@ -103,17 +158,16 @@ class Transcriber:
@staticmethod @staticmethod
def _get_whisper_kwargs(**kwargs) -> dict: def _get_whisper_kwargs(**kwargs) -> dict:
""" """
Get kwargs for whisper model. Get kwargs for whisper model. Ensure that kwargs are valid.
Ensure that kwargs are valid.
:return: kwargs for whisper model Returns:
:rtype: dict dict: Keyword arguments for whisper model.
""" """
_possible_kwargs = Whisper.transcribe.__code__.co_varnames _possible_kwargs = Whisper.transcribe.__code__.co_varnames
whisper_kwargs = dict() whisper_kwargs = {k: v for k, v in kwargs.items() if k in _possible_kwargs}
for k in kwargs.keys():
if k in _possible_kwargs:
whisper_kwargs[k] = kwargs[k]
return whisper_kwargs return whisper_kwargs
def __repr__(self) -> str:
return f"Transcriber(model={self.model})"