diff --git a/autotranscript/transcriber.py b/autotranscript/transcriber.py index 0cd42bf..81787da 100644 --- a/autotranscript/transcriber.py +++ b/autotranscript/transcriber.py @@ -1,33 +1,91 @@ -import os +""" +Transcriber Module +------------------ + +This module provides the Transcriber class, a comprehensive tool for working with Whisper models. +The Transcriber class offers functionalities such as loading different Whisper models, transcribing audio files, +and saving transcriptions to text files. It acts as an interface between various Whisper models and the user, +simplifying the process of audio transcription. + +Main Features: + - Loading different sizes and versions of Whisper models. + - Transcribing audio in various formats including str, Tensor, and nparray. + - Saving the transcriptions to the specified paths. + - Adaptable to various language specifications. + - Options to control the verbosity of the transcription process. + +Constants: + WHISPER_DEFAULT_PATH: Default path for downloading and loading Whisper models. + +Usage: + >>> from your_package import Transcriber + >>> transcriber = Transcriber.load_model(model="medium") + >>> transcript = transcriber.transcribe(audio="path/to/audio.wav") + >>> transcriber.save_transcript(transcript, "path/to/save.txt") +""" + from whisper import Whisper, load_model from typing import TypeVar , Union , Optional -import torch -from glob import glob +from torch import Tensor, device +from numpy import ndarray + + from .misc import WHISPER_DEFAULT_PATH whisper = TypeVar('whisper') -Tensor = TypeVar('Tensor') -nparray = TypeVar('nparray') + class Transcriber: + """ + Transcriber Class + ----------------- + + The Transcriber class serves as a wrapper around Whisper models for efficient audio + transcription. By encapsulating the intricacies of loading models, processing audio, + and saving transcripts, it offers an easy-to-use interface + for users to transcribe audio files. + + Attributes: + model (whisper): The Whisper model used for transcription. + + Methods: + transcribe: Transcribes the given audio file. + save_transcript: Saves the transcript to a file. + load_model: Loads a specific Whisper model. + _get_whisper_kwargs: Private method to get valid keyword arguments for the whisper model. + + Examples: + >>> transcriber = Transcriber.load_model(model="medium") + >>> transcript = transcriber.transcribe(audio="path/to/audio.wav") + >>> transcriber.save_transcript(transcript, "path/to/save.txt") + + Note: + The class supports various sizes and versions of Whisper models. Please refer to + the load_model method for available options. + """ def __init__(self, model: whisper ) -> None: """ - Initialize Transcriber class with a whisper model - :param model: whisper model + Initialize the Transcriber class with a Whisper model. + + Args: + model (whisper): The Whisper model to use for transcription. """ self.model = model - def transcribe(self, audio : Union[str, Tensor, nparray] , + def transcribe(self, audio : Union[str, Tensor, ndarray] , *args, **kwargs) -> str: """ - transcribe audio file - :param file: audio file to transcribe - :param args: additional arguments - :param kwargs: additional keyword arguments - example: - - language: language of the audio file - :return: transcript as string + Transcribe an audio file. + + Args: + audio (Union[str, Tensor, nparray]): The audio file to transcribe. + *args: Additional arguments. + **kwargs: Additional keyword arguments, + such as the language of the audio file. + + Returns: + str: The transcript as a string. """ kwargs = self._get_whisper_kwargs(**kwargs) @@ -41,15 +99,18 @@ class Transcriber: @staticmethod def save_transcript(transcript : str , save_path : str) -> None: """ - Save transcript to file - :param transcript: transcript as string - :param savepath: path to save the transcript - :return: None + Save a transcript to a file. + + Args: + transcript (str): The transcript as a string. + save_path (str): The path to save the transcript. + + Returns: + None """ with open(save_path, 'w') as f: f.write(transcript) - f.close() print(f'Transcript saved to {save_path}') @@ -57,44 +118,38 @@ class Transcriber: def load_model(cls, model: str = "medium", download_root: str = WHISPER_DEFAULT_PATH, - device: Optional[Union[str, torch.device]] = None, + device: Optional[Union[str, device]] = None, in_memory: bool = False, ) -> 'Transcriber': """ - Load whisper module + Load whisper model. - Parameters - ---------- - whisper : str - whisper model - available models: + Args: + model (str): Whisper model. Available models include: + - 'tiny.en' + - 'tiny' + - 'base.en' + - 'base' + - 'small.en' + - 'small' + - 'medium.en' + - 'medium' + - 'large-v1' + - 'large-v2' + - 'large' + + download_root (str, optional): Path to download the model. + Defaults to WHISPER_DEFAULT_PATH. + + device (Optional[Union[str, torch.device]], optional): + Device to load model on. Defaults to None. + in_memory (bool, optional): Whether to load model in memory. + Defaults to False. - - 'tiny.en' - - 'tiny' - - 'base.en' - - 'base' - - 'small.en' - - 'small' - - 'medium.en' - - 'medium' - - 'large-v1' - - 'large-v2' - - 'large' - - local : bool - If true, load from local cache - - download_root : str - Path to download the model - - default: /models/whisper - - Returns - ------- - Whisper Object + Returns: + Transcriber: A Transcriber object initialized with the specified model. """ - _model = load_model(model, download_root=download_root, device=device, in_memory=in_memory) @@ -103,17 +158,16 @@ class Transcriber: @staticmethod def _get_whisper_kwargs(**kwargs) -> dict: """ - Get kwargs for whisper model. - Ensure that kwargs are valid. - :return: kwargs for whisper model - :rtype: dict + Get kwargs for whisper model. Ensure that kwargs are valid. + + Returns: + dict: Keyword arguments for whisper model. """ _possible_kwargs = Whisper.transcribe.__code__.co_varnames - whisper_kwargs = dict() - - for k in kwargs.keys(): - if k in _possible_kwargs: - whisper_kwargs[k] = kwargs[k] + whisper_kwargs = {k: v for k, v in kwargs.items() if k in _possible_kwargs} - return whisper_kwargs \ No newline at end of file + return whisper_kwargs + + def __repr__(self) -> str: + return f"Transcriber(model={self.model})" \ No newline at end of file