Auto fixes from PEP8, fixes from flake8.

This commit is contained in:
Marko Henning
2024-05-15 15:18:17 +02:00
parent 9f526a8f3b
commit 4bcd28d0ea
15 changed files with 391 additions and 417 deletions
+22 -20
View File
@@ -28,11 +28,11 @@ from whisper import Whisper
from whisper import load_model as whisper_load_model
from whisperx.asr import WhisperModel
from whisperx import load_model as whisperx_load_model
from typing import TypeVar , Union , Optional
from typing import TypeVar, Union, Optional
from torch import Tensor, device
from numpy import ndarray
from inspect import getfullargspec
from abc import ABC, abstractmethod
from abc import abstractmethod
from .misc import WHISPER_DEFAULT_PATH
whisper = TypeVar('whisper')
@@ -66,6 +66,7 @@ class Transcriber:
The class supports various sizes and versions of Whisper models. Please refer to
the load_model method for available options.
"""
def __init__(self, model: whisper, model_name: str) -> None:
"""
Initialize the Transcriber class with a Whisper model.
@@ -74,13 +75,13 @@ class Transcriber:
model (whisper): The Whisper model to use for transcription.
model_name (str): The name of the model.
"""
self.model = model
self.model_name = model_name
@abstractmethod
def transcribe(self, audio: Union[str, Tensor, ndarray] ,
def transcribe(self, audio: Union[str, Tensor, ndarray],
*args, **kwargs) -> str:
"""
Transcribe an audio file.
@@ -95,9 +96,9 @@ class Transcriber:
str: The transcript as a string.
"""
pass
@staticmethod
def save_transcript(transcript : str , save_path : str) -> None:
def save_transcript(transcript: str, save_path: str) -> None:
"""
Save a transcript to a file.
@@ -111,7 +112,7 @@ class Transcriber:
with open(save_path, 'w') as f:
f.write(transcript)
print(f'Transcript saved to {save_path}')
@classmethod
@@ -176,10 +177,10 @@ class Transcriber:
dict: Keyword arguments for whisper model.
"""
pass
def __repr__(self) -> str:
return f"Transcriber(model_name={self.model_name}, model={self.model})"
class WhisperTranscriber(Transcriber):
def __init__(self, model: whisper, model_name: str) -> None:
@@ -233,10 +234,10 @@ class WhisperTranscriber(Transcriber):
- 'large-v2'
- 'large-v3'
- 'large'
download_root (str, optional): Path to download the model.
Defaults to WHISPER_DEFAULT_PATH.
device (Optional[Union[str, torch.device]], optional):
Device to load model on. Defaults to None.
in_memory (bool, optional): Whether to load model in memory.
@@ -266,7 +267,8 @@ class WhisperTranscriber(Transcriber):
_kwargs = getfullargspec(Whisper.transcribe).kwonlyargs
_possible_kwargs = _args + _kwargs
whisper_kwargs = {k: v for k, v in kwargs.items() if k in _possible_kwargs}
whisper_kwargs = {k: v for k,
v in kwargs.items() if k in _possible_kwargs}
if (task := kwargs.get("task")):
whisper_kwargs["task"] = task
@@ -280,7 +282,7 @@ class WhisperTranscriber(Transcriber):
class WhisperXTranscriber(Transcriber):
def __init__(self, model: whisper, model_name: str) -> None:
super().__init__(model, model_name)
def transcribe(self, audio: Union[str, Tensor, ndarray],
*args, **kwargs) -> str:
"""
@@ -296,7 +298,7 @@ class WhisperXTranscriber(Transcriber):
str: The transcript as a string.
"""
kwargs = self._get_whisper_kwargs(**kwargs)
if isinstance(audio, Tensor):
audio = audio.cpu().numpy()
result = self.model.transcribe(audio, *args, **kwargs)
@@ -304,8 +306,7 @@ class WhisperXTranscriber(Transcriber):
for seg in result['segments']:
text += seg['text']
return text
@classmethod
def load_model(cls,
model: str = "medium",
@@ -330,10 +331,10 @@ class WhisperXTranscriber(Transcriber):
- 'large-v2'
- 'large-v3'
- 'large'
download_root (str, optional): Path to download the model.
Defaults to WHISPER_DEFAULT_PATH.
device (Optional[Union[str, torch.device]], optional):
Device to load model on. Defaults to None.
in_memory (bool, optional): Whether to load model in memory.
@@ -364,7 +365,8 @@ class WhisperXTranscriber(Transcriber):
_kwargs = getfullargspec(WhisperModel.transcribe).kwonlyargs
_possible_kwargs = _args + _kwargs
whisper_kwargs = {k: v for k, v in kwargs.items() if k in _possible_kwargs}
whisper_kwargs = {k: v for k,
v in kwargs.items() if k in _possible_kwargs}
if (task := kwargs.get("task")):
whisper_kwargs["task"] = task