Auto fixes from PEP8, fixes from flake8.

This commit is contained in:
Marko Henning
2024-05-15 15:18:17 +02:00
parent 9f526a8f3b
commit 4bcd28d0ea
15 changed files with 391 additions and 417 deletions
+11 -6
View File
@@ -14,8 +14,9 @@ if CACHE_DIR != PYANNOTE_CACHE_DIR:
WHISPER_DEFAULT_PATH = os.path.join(CACHE_DIR, "whisper")
PYANNOTE_DEFAULT_PATH = os.path.join(CACHE_DIR, "pyannote")
PYANNOTE_DEFAULT_CONFIG = os.path.join(PYANNOTE_DEFAULT_PATH, "config.yaml") \
if os.path.exists(os.path.join(PYANNOTE_DEFAULT_PATH, "config.yaml")) \
else ('jaikinator/scraibe', 'pyannote/speaker-diarization-3.1')
if os.path.exists(os.path.join(PYANNOTE_DEFAULT_PATH, "config.yaml")) \
else ('jaikinator/scraibe', 'pyannote/speaker-diarization-3.1')
def config_diarization_yaml(file_path: str, path_to_segmentation: str = None) -> None:
"""Configure diarization pipeline from a YAML file.
@@ -33,25 +34,29 @@ def config_diarization_yaml(file_path: str, path_to_segmentation: str = None) ->
with open(file_path, "r") as stream:
yml = yaml.safe_load(stream)
segmentation_path = path_to_segmentation or os.path.join(PYANNOTE_DEFAULT_PATH, "pytorch_model.bin")
segmentation_path = path_to_segmentation or os.path.join(
PYANNOTE_DEFAULT_PATH, "pytorch_model.bin")
yml["pipeline"]["params"]["segmentation"] = segmentation_path
if not os.path.exists(segmentation_path):
raise FileNotFoundError(f"Segmentation model not found at {segmentation_path}")
raise FileNotFoundError(
f"Segmentation model not found at {segmentation_path}")
with open(file_path, "w") as stream:
yaml.dump(yml, stream)
class ParseKwargs(Action):
"""
Custom argparse action to parse keyword arguments.
"""
def __call__(self, parser, namespace, values, option_string=None):
setattr(namespace, self.dest, dict())
for value in values:
key, value = value.split('=')
try:
value = eval(value)
value = ast.literal_eval(value)
except:
pass
getattr(namespace, self.dest)[key] = value
getattr(namespace, self.dest)[key] = value