Files
scribe/autotranscript/transcriber.py
T
2023-06-12 11:48:47 +02:00

113 lines
3.0 KiB
Python

import os
from typing import TypeVar
from whisper import load_model
from glob import glob
whisper = TypeVar('whisper')
Transcriber = TypeVar('Transcriber')
def get_whisper_default_path() -> str:
"""
Get default path for whisper models
Returns
-------
str
path
"""
_path = os.path.dirname(os.path.dirname(__file__))
return os.path.join(_path, "models", "whisper")
WHISPER_DEFAULT_PATH = get_whisper_default_path()
class Transcriber:
def __init__(self, model: whisper ) -> None:
"""
Initialize Transcriber class with a whisper model
:param model: whisper model
"""
self.model = model
def transcribe(self, file : str, language:str = "German"):
"""
transcribe audio file
:param file: audio file to transcribe
:param language: language of the audio file
:return: transcript as string
"""
result = self.model.transcribe(file, language = language)
return result["text"]
@staticmethod
def save_transcript(transcript:str , save_path : str) -> None:
"""
Save transcript to file
:param transcript: transcript as string
:param savepath: path to save the transcript
:return: None
"""
with open(save_path, 'w') as f:
f.write(transcript)
f.close()
print(f'Transcript saved to {save_path}')
@classmethod
def load_whisper_model(cls,
model: str = "medium",
local : bool = True,
download_root: str = WHISPER_DEFAULT_PATH) -> Transcriber:
"""
Load whisper module
Parameters
----------
whisper : str
whisper model
available models:
- 'tiny.en'
- 'tiny'
- 'base.en'
- 'base'
- 'small.en'
- 'small'
- 'medium.en'
- 'medium'
- 'large-v1'
- 'large-v2'
- 'large'
local : bool
If true, load from local cache
download_root : str
Path to download the model
default: /models/whisper
Returns
-------
Whisper Object
"""
if local:
available_models = [os.path.basename(x) for x in glob(os.path.join(download_root, "*"))]
for i, module in enumerate(available_models):
available_models[i] = module.split(".")[0]
if model not in available_models:
raise RuntimeError("Model not found. Consider downloading the "/
"model first. By deactivating the local flag, " /
"the model will be downloaded automatically.")
_model = load_model(model, download_root=download_root)
return cls(_model)