From de9c81b3136652012cf92b345fe6b9621a670798 Mon Sep 17 00:00:00 2001 From: "Schmieder, Jacob" Date: Tue, 10 Sep 2024 09:01:59 +0000 Subject: [PATCH] added language to code support for faster whisper --- scraibe/transcriber.py | 32 ++++++++++++++++++++++++++++++++ 1 file changed, 32 insertions(+) diff --git a/scraibe/transcriber.py b/scraibe/transcriber.py index cea7274..abf1ace 100644 --- a/scraibe/transcriber.py +++ b/scraibe/transcriber.py @@ -26,7 +26,9 @@ Usage: from whisper import Whisper from whisper import load_model as whisper_load_model +from whisper.tokenizer import TO_LANGUAGE_CODE from faster_whisper import WhisperModel as FasterWhisperModel +from faster_whisper.tokenizer import _LANGUAGE_CODES as FASTER_WHISPER_LANGUAGE_CODES from typing import TypeVar, Union, Optional from torch import Tensor, device from torch.cuda import is_available as cuda_is_available @@ -369,14 +371,44 @@ class FasterWhisperTranscriber(Transcriber): whisper_kwargs["task"] = task if (language := kwargs.get("language")): + language = FasterWhisperTranscriber.convert_to_language_code(language) whisper_kwargs["language"] = language return whisper_kwargs + @staticmethod + def convert_to_language_code(lang : str) -> str: + """ + Load whisper model. + + Args: + lang (str): language as code or language name + + Returns: + language (str) code of language + """ + + # If the input is already in FASTER_WHISPER_LANGUAGE_CODES, return it directly + if lang in FASTER_WHISPER_LANGUAGE_CODES: + return lang + + # Normalize the input to lowercase + lang = lang.lower() + + # Check if the language name is in the TO_LANGUAGE_CODE mapping + if lang in TO_LANGUAGE_CODE: + return TO_LANGUAGE_CODE[lang] + + # If the language is not recognized, raise a ValueError with the available options + available_codes = ', '.join(FASTER_WHISPER_LANGUAGE_CODES) + raise ValueError(f"Language '{lang}' is not a valid language code or name. " + f"Available language codes are: {available_codes}.") + def __repr__(self) -> str: return f"FasterWhisperTranscriber(model_name={self.model_name}, model={self.model})" + def load_transcriber(model: str = "medium", whisper_type: str = 'whisper', download_root: str = WHISPER_DEFAULT_PATH,