better readbility
This commit is contained in:
@@ -1,10 +1,12 @@
|
|||||||
|
|
||||||
import os
|
import os
|
||||||
from typing import TypeVar
|
from typing import TypeVar , Union
|
||||||
from whisper import load_model
|
from whisper import load_model
|
||||||
from glob import glob
|
from glob import glob
|
||||||
|
|
||||||
whisper = TypeVar('whisper')
|
whisper = TypeVar('whisper')
|
||||||
|
Tensor = TypeVar('Tensor')
|
||||||
|
nparray = TypeVar('nparray')
|
||||||
Transcriber = TypeVar('Transcriber')
|
Transcriber = TypeVar('Transcriber')
|
||||||
|
|
||||||
def get_whisper_default_path() -> str:
|
def get_whisper_default_path() -> str:
|
||||||
@@ -29,20 +31,24 @@ class Transcriber:
|
|||||||
"""
|
"""
|
||||||
self.model = model
|
self.model = model
|
||||||
|
|
||||||
|
def transcribe(self, audio : Union[str, Tensor, nparray] ,
|
||||||
def transcribe(self, file : str, language:str = "German"):
|
*args, **kwargs) -> str:
|
||||||
"""
|
"""
|
||||||
transcribe audio file
|
transcribe audio file
|
||||||
:param file: audio file to transcribe
|
:param file: audio file to transcribe
|
||||||
:param language: language of the audio file
|
:param args: additional arguments
|
||||||
|
:param kwargs: additional keyword arguments
|
||||||
|
example:
|
||||||
|
- language: language of the audio file
|
||||||
:return: transcript as string
|
:return: transcript as string
|
||||||
"""
|
"""
|
||||||
result = self.model.transcribe(file, language = language)
|
|
||||||
|
result = self.model.transcribe(audio, *args, **kwargs)
|
||||||
|
|
||||||
return result["text"]
|
return result["text"]
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
def save_transcript(transcript:str , save_path : str) -> None:
|
def save_transcript(transcript : str , save_path : str) -> None:
|
||||||
"""
|
"""
|
||||||
Save transcript to file
|
Save transcript to file
|
||||||
:param transcript: transcript as string
|
:param transcript: transcript as string
|
||||||
@@ -57,10 +63,10 @@ class Transcriber:
|
|||||||
print(f'Transcript saved to {save_path}')
|
print(f'Transcript saved to {save_path}')
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
def load_whisper_model(cls,
|
def load_model(cls,
|
||||||
model: str = "medium",
|
model: str = "medium",
|
||||||
local : bool = True,
|
local : bool = True,
|
||||||
download_root: str = WHISPER_DEFAULT_PATH) -> Transcriber:
|
download_root: str = WHISPER_DEFAULT_PATH) -> Transcriber:
|
||||||
"""
|
"""
|
||||||
Load whisper module
|
Load whisper module
|
||||||
|
|
||||||
@@ -97,7 +103,8 @@ class Transcriber:
|
|||||||
|
|
||||||
if local:
|
if local:
|
||||||
|
|
||||||
available_models = [os.path.basename(x) for x in glob(os.path.join(download_root, "*"))]
|
available_models = [os.path.basename(x) for x in
|
||||||
|
glob(os.path.join(download_root, "*"))]
|
||||||
|
|
||||||
for i, module in enumerate(available_models):
|
for i, module in enumerate(available_models):
|
||||||
available_models[i] = module.split(".")[0]
|
available_models[i] = module.split(".")[0]
|
||||||
|
|||||||
Reference in New Issue
Block a user