updated transcriber
This commit is contained in:
@@ -64,15 +64,18 @@ class Transcriber:
|
||||
The class supports various sizes and versions of Whisper models. Please refer to
|
||||
the load_model method for available options.
|
||||
"""
|
||||
def __init__(self, model: whisper ) -> None:
|
||||
def __init__(self, model: whisper , model_name: str ) -> None:
|
||||
"""
|
||||
Initialize the Transcriber class with a Whisper model.
|
||||
|
||||
Args:
|
||||
model (whisper): The Whisper model to use for transcription.
|
||||
"""
|
||||
|
||||
self.model = model
|
||||
|
||||
self.model_name = model_name
|
||||
|
||||
def transcribe(self, audio : Union[str, Tensor, ndarray] ,
|
||||
*args, **kwargs) -> str:
|
||||
"""
|
||||
@@ -156,7 +159,7 @@ class Transcriber:
|
||||
_model = load_model(model, download_root=download_root,
|
||||
device=device, in_memory=in_memory)
|
||||
|
||||
return cls(_model)
|
||||
return cls(_model, model_name=model)
|
||||
|
||||
@staticmethod
|
||||
def _get_whisper_kwargs(**kwargs) -> dict:
|
||||
@@ -179,4 +182,4 @@ class Transcriber:
|
||||
return whisper_kwargs
|
||||
|
||||
def __repr__(self) -> str:
|
||||
return f"Transcriber(model={self.model})"
|
||||
return f"Transcriber(model_name={self.model_name}, model={self.model})"
|
||||
Reference in New Issue
Block a user