updated transcriber

This commit is contained in:
Schmieder, Jacob
2024-04-24 13:30:00 +00:00
parent 7171b02d64
commit 050555556d
+6 -3
View File
@@ -64,15 +64,18 @@ class Transcriber:
The class supports various sizes and versions of Whisper models. Please refer to
the load_model method for available options.
"""
def __init__(self, model: whisper ) -> None:
def __init__(self, model: whisper , model_name: str ) -> None:
"""
Initialize the Transcriber class with a Whisper model.
Args:
model (whisper): The Whisper model to use for transcription.
"""
self.model = model
self.model_name = model_name
def transcribe(self, audio : Union[str, Tensor, ndarray] ,
*args, **kwargs) -> str:
"""
@@ -156,7 +159,7 @@ class Transcriber:
_model = load_model(model, download_root=download_root,
device=device, in_memory=in_memory)
return cls(_model)
return cls(_model, model_name=model)
@staticmethod
def _get_whisper_kwargs(**kwargs) -> dict:
@@ -179,4 +182,4 @@ class Transcriber:
return whisper_kwargs
def __repr__(self) -> str:
return f"Transcriber(model={self.model})"
return f"Transcriber(model_name={self.model_name}, model={self.model})"