update annotation

This commit is contained in:
Jaikinator
2023-09-11 15:57:19 +02:00
parent 1750e551f6
commit db843f9e99
+37 -4
View File
@@ -1,6 +1,9 @@
import json
import time
from typing import Union
ALPHABET = [*"abcdefghijklmnopqrstuvwxyz"]
@@ -19,6 +22,7 @@ class Transcript:
Keys should correspond to segment IDs, and values should
contain speaker and segment information.
"""
self.transcript = transcript
self.speakers = self._extract_speakers()
self.segments = self._extract_segments()
@@ -33,7 +37,7 @@ class Transcript:
kwargs (dict): Dictionary with speaker names as keys and list of segments as values.
Returns:
dict: Dictionary with speaker names as keys and the corresponding annotation as values.
dict: Dictionary with speaker names as keys and list of segments as values.
Raises:
ValueError: If the number of speaker names does not match the number
@@ -45,7 +49,7 @@ class Transcript:
raise ValueError("Number of speaker names does not match number of speakers")
if args:
for arg, speaker in zip(args, self.speakers):
for arg, speaker in zip(args, sorted(self.speakers)):
annotations[speaker] = arg
invalid_speakers = set(kwargs.keys()) - set(self.speakers)
@@ -55,7 +59,8 @@ class Transcript:
annotations.update({key: kwargs[key] for key in self.speakers if key in kwargs})
self.annotation = annotations
return annotations
return self
def _extract_speakers(self) -> list:
"""
@@ -100,6 +105,7 @@ class Transcript:
eseg = time.strftime("%H:%M:%S",time.gmtime(segm[1]))
fstring += f"{speaker} ({sseg} ; {eseg}):\t{seq['text']}\n"
return fstring
def __repr__(self) -> str:
@@ -121,7 +127,7 @@ class Transcript:
return self.transcript
def get_json(self, *args, **kwargs) -> str:
def get_json(self, *args, use_annotation : bool = True, **kwargs) -> str:
"""
Get transcript as json string
:return: transcript as json string
@@ -129,6 +135,12 @@ class Transcript:
"""
if "indent" not in kwargs:
kwargs["indent"] = 3
if use_annotation and self.annotation:
for _id in self.transcript:
seq = self.transcript[_id]
seq["speakers"] = self.annotation[seq["speakers"]]
return json.dumps(self.transcript, *args, **kwargs)
def get_html(self) -> str:
@@ -264,5 +276,26 @@ class Transcript:
self.to_pdf(path, *args, **kwargs)
else:
raise ValueError("Unknown file format")
@classmethod
def from_json(cls, json: Union[dict, str]) -> "Transcript":
"""Load transcript from json file
Args:
path (str): path to json file
Returns:
Transcript: Transcript object
"""
if isinstance(json, dict):
return cls(json)
else:
try:
transcript = json.loads(json)
except:
with open(json, "r") as f:
transcript = json.load(f)
return cls(transcript)