added functionallity to select diarisation model using cli
This commit is contained in:
@@ -39,7 +39,6 @@ class AutoTranscribe:
|
|||||||
|
|
||||||
if whisper_model is None:
|
if whisper_model is None:
|
||||||
self.transcriber = Transcriber.load_model("medium", local=True)
|
self.transcriber = Transcriber.load_model("medium", local=True)
|
||||||
|
|
||||||
elif isinstance(whisper_model, str):
|
elif isinstance(whisper_model, str):
|
||||||
self.transcriber = Transcriber.load_model(whisper_model, **whisper_kwargs)
|
self.transcriber = Transcriber.load_model(whisper_model, **whisper_kwargs)
|
||||||
else:
|
else:
|
||||||
@@ -170,6 +169,7 @@ def cli():
|
|||||||
from whisper.utils import get_writer
|
from whisper.utils import get_writer
|
||||||
from whisper.tokenizer import LANGUAGES , TO_LANGUAGE_CODE
|
from whisper.tokenizer import LANGUAGES , TO_LANGUAGE_CODE
|
||||||
from .transcriber import WHISPER_DEFAULT_PATH
|
from .transcriber import WHISPER_DEFAULT_PATH
|
||||||
|
from .diarisation import PYANNOTE_DEFAULT_PATH
|
||||||
def str2bool(string):
|
def str2bool(string):
|
||||||
str2val = {"True": True, "False": False}
|
str2val = {"True": True, "False": False}
|
||||||
if string in str2val:
|
if string in str2val:
|
||||||
@@ -190,6 +190,10 @@ def cli():
|
|||||||
parser.add_argument("--wmodel_dir", type=str, default= WHISPER_DEFAULT_PATH,
|
parser.add_argument("--wmodel_dir", type=str, default= WHISPER_DEFAULT_PATH,
|
||||||
help="the path to save model files; uses ./models/whisper by default")
|
help="the path to save model files; uses ./models/whisper by default")
|
||||||
|
|
||||||
|
parser.add_argument("--dia_model", type=str, default = PYANNOTE_DEFAULT_PATH)
|
||||||
|
|
||||||
|
parser.add_argument("--allow_download", type= bool, default=True,
|
||||||
|
help="whether to allow model download if model is not found locally")
|
||||||
parser.add_argument("--device",
|
parser.add_argument("--device",
|
||||||
default="cuda" if torch.cuda.is_available() else "cpu",
|
default="cuda" if torch.cuda.is_available() else "cpu",
|
||||||
help="device to use for PyTorch inference")
|
help="device to use for PyTorch inference")
|
||||||
@@ -219,6 +223,7 @@ def cli():
|
|||||||
model_dir: str = args.pop("wmodel_dir")
|
model_dir: str = args.pop("wmodel_dir")
|
||||||
output_dir: str = args.pop("output_dir")
|
output_dir: str = args.pop("output_dir")
|
||||||
output_format: str = args.pop("output_format")
|
output_format: str = args.pop("output_format")
|
||||||
|
local :str = args.pop("allow_download")
|
||||||
task = args.pop("task")
|
task = args.pop("task")
|
||||||
device: str = args.pop("device")
|
device: str = args.pop("device")
|
||||||
os.makedirs(output_dir, exist_ok=True)
|
os.makedirs(output_dir, exist_ok=True)
|
||||||
@@ -227,14 +232,17 @@ def cli():
|
|||||||
torch.set_num_threads(threads)
|
torch.set_num_threads(threads)
|
||||||
|
|
||||||
wkwargs = {"download_root": model_dir,
|
wkwargs = {"download_root": model_dir,
|
||||||
"device": device,
|
"local": local,
|
||||||
"language" : args.pop("language")}
|
"device": device}
|
||||||
|
diarisation_kwargs = {"local": local}
|
||||||
model = AutoTranscribe(whisper_model= model_name, whisper_kwargs= wkwargs)
|
model = AutoTranscribe(whisper_model= model_name,
|
||||||
|
whisper_kwargs= wkwargs,
|
||||||
|
dia_model= args.pop("dia_model"),
|
||||||
|
dia_kwargs_kwargs= diarisation_kwargs,)
|
||||||
|
|
||||||
if task == "transcribe":
|
if task == "transcribe":
|
||||||
for audio in args.pop("audio"):
|
for audio in args.pop("audio"):
|
||||||
out = model.transcribe(audio)
|
out = model.transcribe(audio, language = args.pop("language"))
|
||||||
basename = audio.split("/")[-1].split(".")[0]
|
basename = audio.split("/")[-1].split(".")[0]
|
||||||
spath = f"{output_dir}/{basename}.{output_format}"
|
spath = f"{output_dir}/{basename}.{output_format}"
|
||||||
out.save(spath)
|
out.save(spath)
|
||||||
@@ -257,7 +265,7 @@ def cli():
|
|||||||
"It is recommendet to use the whisper cli directly",
|
"It is recommendet to use the whisper cli directly",
|
||||||
RuntimeWarning)
|
RuntimeWarning)
|
||||||
for audio in args.pop("audio"):
|
for audio in args.pop("audio"):
|
||||||
out = model.transcriber.transcribe(audio, diarisation=True)
|
out = model.transcriber.transcribe(audio, language = args.pop("language"))
|
||||||
basename = audio.split("/")[-1].split(".")[0]
|
basename = audio.split("/")[-1].split(".")[0]
|
||||||
writer(out, audio)
|
writer(out, audio)
|
||||||
|
|
||||||
|
|||||||
Reference in New Issue
Block a user