Merge pull request #77 from JSchmie/enhance_capability_with_webui
Enhance capability with webui
This commit is contained in:
@@ -95,7 +95,7 @@ class Scraibe:
|
||||
elif isinstance(dia_model, str):
|
||||
self.diariser = Diariser.load_model(dia_model, **kwargs)
|
||||
else:
|
||||
self.diariser = dia_model
|
||||
self.diariser : Diariser = dia_model
|
||||
|
||||
if kwargs.get("verbose"):
|
||||
print("Scraibe initialized all models successfully loaded.")
|
||||
@@ -133,7 +133,7 @@ class Scraibe:
|
||||
if kwargs.get("verbose"):
|
||||
self.verbose = kwargs.get("verbose")
|
||||
# Get audio file as an AudioProcessor object
|
||||
audio_file = self.get_audio_file(audio_file)
|
||||
audio_file : AudioProcessor = self.get_audio_file(audio_file)
|
||||
|
||||
# Prepare waveform and sample rate for diarization
|
||||
dia_audio = {
|
||||
@@ -203,7 +203,7 @@ class Scraibe:
|
||||
"""
|
||||
|
||||
# Get audio file as an AudioProcessor object
|
||||
audio_file = self.get_audio_file(audio_file)
|
||||
audio_file : AudioProcessor = self.get_audio_file(audio_file)
|
||||
|
||||
# Prepare waveform and sample rate for diarization
|
||||
dia_audio = {
|
||||
@@ -232,9 +232,56 @@ class Scraibe:
|
||||
str:
|
||||
The transcribed text from the audio source.
|
||||
"""
|
||||
audio_file = self.get_audio_file(audio_file)
|
||||
audio_file : AudioProcessor = self.get_audio_file(audio_file)
|
||||
|
||||
return self.transcriber.transcribe(audio_file.waveform, **kwargs)
|
||||
|
||||
def update_transcriber(self, whisper_model : Union[str, whisper], **kwargs) -> None:
|
||||
"""
|
||||
Update the transcriber model.
|
||||
|
||||
Args:
|
||||
whisper_model (Union[str, whisper]):
|
||||
The new whisper model to use for transcription.
|
||||
**kwargs:
|
||||
Additional keyword arguments for the transcriber model.
|
||||
|
||||
Returns:
|
||||
None
|
||||
"""
|
||||
_old_model = self.transcriber.model_name
|
||||
|
||||
if isinstance(whisper_model, str):
|
||||
self.transcriber = Transcriber.load_model(whisper_model, **kwargs)
|
||||
elif isinstance(whisper_model, Transcriber):
|
||||
self.transcriber = whisper_model
|
||||
else:
|
||||
warn(f"Invalid model type. Please provide a valid model. Fallback to old {_old_model} Model.", RuntimeWarning)
|
||||
|
||||
return None
|
||||
|
||||
def update_diariser(self, dia_model : Union[str, DiarisationType], **kwargs) -> None:
|
||||
"""
|
||||
Update the diariser model.
|
||||
|
||||
Args:
|
||||
dia_model (Union[str, DiarisationType]):
|
||||
The new diariser model to use for diarization.
|
||||
**kwargs:
|
||||
Additional keyword arguments for the diariser model.
|
||||
|
||||
Returns:
|
||||
None
|
||||
"""
|
||||
if isinstance(dia_model, str):
|
||||
self.diariser = Diariser.load_model(dia_model, **kwargs)
|
||||
elif isinstance(dia_model, Diariser):
|
||||
self.diariser = dia_model
|
||||
else:
|
||||
warn(f"Invalid model type. Please provide a valid model. Fallback to old Model.", RuntimeWarning)
|
||||
|
||||
return None
|
||||
|
||||
@staticmethod
|
||||
def remove_audio_file(audio_file : str,
|
||||
shred : bool = False) -> None:
|
||||
@@ -269,7 +316,6 @@ class Scraibe:
|
||||
print(f"Audiofile {audio_file} removed.")
|
||||
|
||||
|
||||
|
||||
@staticmethod
|
||||
def get_audio_file(audio_file : Union[str, torch.Tensor, ndarray],
|
||||
*args, **kwargs) -> AudioProcessor:
|
||||
@@ -298,6 +344,7 @@ class Scraibe:
|
||||
if not isinstance(audio_file, AudioProcessor):
|
||||
raise ValueError(f'Audiofile must be of type AudioProcessor,' \
|
||||
f'not {type(audio_file)}')
|
||||
|
||||
return audio_file
|
||||
|
||||
def __repr__(self):
|
||||
|
||||
@@ -64,14 +64,18 @@ class Transcriber:
|
||||
The class supports various sizes and versions of Whisper models. Please refer to
|
||||
the load_model method for available options.
|
||||
"""
|
||||
def __init__(self, model: whisper ) -> None:
|
||||
def __init__(self, model: whisper , model_name: str ) -> None:
|
||||
"""
|
||||
Initialize the Transcriber class with a Whisper model.
|
||||
|
||||
Args:
|
||||
model (whisper): The Whisper model to use for transcription.
|
||||
model_name (str): The name of the model.
|
||||
"""
|
||||
|
||||
self.model = model
|
||||
|
||||
self.model_name = model_name
|
||||
|
||||
def transcribe(self, audio : Union[str, Tensor, ndarray] ,
|
||||
*args, **kwargs) -> str:
|
||||
@@ -137,6 +141,7 @@ class Transcriber:
|
||||
- 'medium'
|
||||
- 'large-v1'
|
||||
- 'large-v2'
|
||||
- 'large-v3'
|
||||
- 'large'
|
||||
|
||||
download_root (str, optional): Path to download the model.
|
||||
@@ -156,7 +161,7 @@ class Transcriber:
|
||||
_model = load_model(model, download_root=download_root,
|
||||
device=device, in_memory=in_memory)
|
||||
|
||||
return cls(_model)
|
||||
return cls(_model, model_name=model)
|
||||
|
||||
@staticmethod
|
||||
def _get_whisper_kwargs(**kwargs) -> dict:
|
||||
@@ -179,4 +184,4 @@ class Transcriber:
|
||||
return whisper_kwargs
|
||||
|
||||
def __repr__(self) -> str:
|
||||
return f"Transcriber(model={self.model})"
|
||||
return f"Transcriber(model_name={self.model_name}, model={self.model})"
|
||||
Reference in New Issue
Block a user