update annotation

This commit is contained in:
Jaikinator
2023-09-11 15:57:19 +02:00
parent 1750e551f6
commit db843f9e99
+37 -4
View File
@@ -1,6 +1,9 @@
import json import json
import time import time
from typing import Union
ALPHABET = [*"abcdefghijklmnopqrstuvwxyz"] ALPHABET = [*"abcdefghijklmnopqrstuvwxyz"]
@@ -19,6 +22,7 @@ class Transcript:
Keys should correspond to segment IDs, and values should Keys should correspond to segment IDs, and values should
contain speaker and segment information. contain speaker and segment information.
""" """
self.transcript = transcript self.transcript = transcript
self.speakers = self._extract_speakers() self.speakers = self._extract_speakers()
self.segments = self._extract_segments() self.segments = self._extract_segments()
@@ -33,7 +37,7 @@ class Transcript:
kwargs (dict): Dictionary with speaker names as keys and list of segments as values. kwargs (dict): Dictionary with speaker names as keys and list of segments as values.
Returns: Returns:
dict: Dictionary with speaker names as keys and the corresponding annotation as values. dict: Dictionary with speaker names as keys and list of segments as values.
Raises: Raises:
ValueError: If the number of speaker names does not match the number ValueError: If the number of speaker names does not match the number
@@ -45,7 +49,7 @@ class Transcript:
raise ValueError("Number of speaker names does not match number of speakers") raise ValueError("Number of speaker names does not match number of speakers")
if args: if args:
for arg, speaker in zip(args, self.speakers): for arg, speaker in zip(args, sorted(self.speakers)):
annotations[speaker] = arg annotations[speaker] = arg
invalid_speakers = set(kwargs.keys()) - set(self.speakers) invalid_speakers = set(kwargs.keys()) - set(self.speakers)
@@ -55,7 +59,8 @@ class Transcript:
annotations.update({key: kwargs[key] for key in self.speakers if key in kwargs}) annotations.update({key: kwargs[key] for key in self.speakers if key in kwargs})
self.annotation = annotations self.annotation = annotations
return annotations
return self
def _extract_speakers(self) -> list: def _extract_speakers(self) -> list:
""" """
@@ -100,6 +105,7 @@ class Transcript:
eseg = time.strftime("%H:%M:%S",time.gmtime(segm[1])) eseg = time.strftime("%H:%M:%S",time.gmtime(segm[1]))
fstring += f"{speaker} ({sseg} ; {eseg}):\t{seq['text']}\n" fstring += f"{speaker} ({sseg} ; {eseg}):\t{seq['text']}\n"
return fstring return fstring
def __repr__(self) -> str: def __repr__(self) -> str:
@@ -121,7 +127,7 @@ class Transcript:
return self.transcript return self.transcript
def get_json(self, *args, **kwargs) -> str: def get_json(self, *args, use_annotation : bool = True, **kwargs) -> str:
""" """
Get transcript as json string Get transcript as json string
:return: transcript as json string :return: transcript as json string
@@ -129,6 +135,12 @@ class Transcript:
""" """
if "indent" not in kwargs: if "indent" not in kwargs:
kwargs["indent"] = 3 kwargs["indent"] = 3
if use_annotation and self.annotation:
for _id in self.transcript:
seq = self.transcript[_id]
seq["speakers"] = self.annotation[seq["speakers"]]
return json.dumps(self.transcript, *args, **kwargs) return json.dumps(self.transcript, *args, **kwargs)
def get_html(self) -> str: def get_html(self) -> str:
@@ -265,4 +277,25 @@ class Transcript:
else: else:
raise ValueError("Unknown file format") raise ValueError("Unknown file format")
@classmethod
def from_json(cls, json: Union[dict, str]) -> "Transcript":
"""Load transcript from json file
Args:
path (str): path to json file
Returns:
Transcript: Transcript object
"""
if isinstance(json, dict):
return cls(json)
else:
try:
transcript = json.loads(json)
except:
with open(json, "r") as f:
transcript = json.load(f)
return cls(transcript)