updated transcriber
This commit is contained in:
@@ -64,15 +64,18 @@ class Transcriber:
|
|||||||
The class supports various sizes and versions of Whisper models. Please refer to
|
The class supports various sizes and versions of Whisper models. Please refer to
|
||||||
the load_model method for available options.
|
the load_model method for available options.
|
||||||
"""
|
"""
|
||||||
def __init__(self, model: whisper ) -> None:
|
def __init__(self, model: whisper , model_name: str ) -> None:
|
||||||
"""
|
"""
|
||||||
Initialize the Transcriber class with a Whisper model.
|
Initialize the Transcriber class with a Whisper model.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
model (whisper): The Whisper model to use for transcription.
|
model (whisper): The Whisper model to use for transcription.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
self.model = model
|
self.model = model
|
||||||
|
|
||||||
|
self.model_name = model_name
|
||||||
|
|
||||||
def transcribe(self, audio : Union[str, Tensor, ndarray] ,
|
def transcribe(self, audio : Union[str, Tensor, ndarray] ,
|
||||||
*args, **kwargs) -> str:
|
*args, **kwargs) -> str:
|
||||||
"""
|
"""
|
||||||
@@ -156,7 +159,7 @@ class Transcriber:
|
|||||||
_model = load_model(model, download_root=download_root,
|
_model = load_model(model, download_root=download_root,
|
||||||
device=device, in_memory=in_memory)
|
device=device, in_memory=in_memory)
|
||||||
|
|
||||||
return cls(_model)
|
return cls(_model, model_name=model)
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
def _get_whisper_kwargs(**kwargs) -> dict:
|
def _get_whisper_kwargs(**kwargs) -> dict:
|
||||||
@@ -179,4 +182,4 @@ class Transcriber:
|
|||||||
return whisper_kwargs
|
return whisper_kwargs
|
||||||
|
|
||||||
def __repr__(self) -> str:
|
def __repr__(self) -> str:
|
||||||
return f"Transcriber(model={self.model})"
|
return f"Transcriber(model_name={self.model_name}, model={self.model})"
|
||||||
Reference in New Issue
Block a user