better readbility

This commit is contained in:
Jaikinator
2023-06-12 15:56:52 +02:00
parent b5dab23dd4
commit 6870d03f6b
+18 -11
View File
@@ -1,10 +1,12 @@
import os import os
from typing import TypeVar from typing import TypeVar , Union
from whisper import load_model from whisper import load_model
from glob import glob from glob import glob
whisper = TypeVar('whisper') whisper = TypeVar('whisper')
Tensor = TypeVar('Tensor')
nparray = TypeVar('nparray')
Transcriber = TypeVar('Transcriber') Transcriber = TypeVar('Transcriber')
def get_whisper_default_path() -> str: def get_whisper_default_path() -> str:
@@ -29,20 +31,24 @@ class Transcriber:
""" """
self.model = model self.model = model
def transcribe(self, audio : Union[str, Tensor, nparray] ,
def transcribe(self, file : str, language:str = "German"): *args, **kwargs) -> str:
""" """
transcribe audio file transcribe audio file
:param file: audio file to transcribe :param file: audio file to transcribe
:param language: language of the audio file :param args: additional arguments
:param kwargs: additional keyword arguments
example:
- language: language of the audio file
:return: transcript as string :return: transcript as string
""" """
result = self.model.transcribe(file, language = language)
result = self.model.transcribe(audio, *args, **kwargs)
return result["text"] return result["text"]
@staticmethod @staticmethod
def save_transcript(transcript:str , save_path : str) -> None: def save_transcript(transcript : str , save_path : str) -> None:
""" """
Save transcript to file Save transcript to file
:param transcript: transcript as string :param transcript: transcript as string
@@ -57,10 +63,10 @@ class Transcriber:
print(f'Transcript saved to {save_path}') print(f'Transcript saved to {save_path}')
@classmethod @classmethod
def load_whisper_model(cls, def load_model(cls,
model: str = "medium", model: str = "medium",
local : bool = True, local : bool = True,
download_root: str = WHISPER_DEFAULT_PATH) -> Transcriber: download_root: str = WHISPER_DEFAULT_PATH) -> Transcriber:
""" """
Load whisper module Load whisper module
@@ -97,7 +103,8 @@ class Transcriber:
if local: if local:
available_models = [os.path.basename(x) for x in glob(os.path.join(download_root, "*"))] available_models = [os.path.basename(x) for x in
glob(os.path.join(download_root, "*"))]
for i, module in enumerate(available_models): for i, module in enumerate(available_models):
available_models[i] = module.split(".")[0] available_models[i] = module.split(".")[0]