update annotation
This commit is contained in:
@@ -1,6 +1,9 @@
|
||||
import json
|
||||
import time
|
||||
|
||||
|
||||
from typing import Union
|
||||
|
||||
ALPHABET = [*"abcdefghijklmnopqrstuvwxyz"]
|
||||
|
||||
|
||||
@@ -19,6 +22,7 @@ class Transcript:
|
||||
Keys should correspond to segment IDs, and values should
|
||||
contain speaker and segment information.
|
||||
"""
|
||||
|
||||
self.transcript = transcript
|
||||
self.speakers = self._extract_speakers()
|
||||
self.segments = self._extract_segments()
|
||||
@@ -33,7 +37,7 @@ class Transcript:
|
||||
kwargs (dict): Dictionary with speaker names as keys and list of segments as values.
|
||||
|
||||
Returns:
|
||||
dict: Dictionary with speaker names as keys and the corresponding annotation as values.
|
||||
dict: Dictionary with speaker names as keys and list of segments as values.
|
||||
|
||||
Raises:
|
||||
ValueError: If the number of speaker names does not match the number
|
||||
@@ -45,7 +49,7 @@ class Transcript:
|
||||
raise ValueError("Number of speaker names does not match number of speakers")
|
||||
|
||||
if args:
|
||||
for arg, speaker in zip(args, self.speakers):
|
||||
for arg, speaker in zip(args, sorted(self.speakers)):
|
||||
annotations[speaker] = arg
|
||||
|
||||
invalid_speakers = set(kwargs.keys()) - set(self.speakers)
|
||||
@@ -55,7 +59,8 @@ class Transcript:
|
||||
annotations.update({key: kwargs[key] for key in self.speakers if key in kwargs})
|
||||
|
||||
self.annotation = annotations
|
||||
return annotations
|
||||
|
||||
return self
|
||||
|
||||
def _extract_speakers(self) -> list:
|
||||
"""
|
||||
@@ -100,6 +105,7 @@ class Transcript:
|
||||
eseg = time.strftime("%H:%M:%S",time.gmtime(segm[1]))
|
||||
|
||||
fstring += f"{speaker} ({sseg} ; {eseg}):\t{seq['text']}\n"
|
||||
|
||||
return fstring
|
||||
|
||||
def __repr__(self) -> str:
|
||||
@@ -121,7 +127,7 @@ class Transcript:
|
||||
|
||||
return self.transcript
|
||||
|
||||
def get_json(self, *args, **kwargs) -> str:
|
||||
def get_json(self, *args, use_annotation : bool = True, **kwargs) -> str:
|
||||
"""
|
||||
Get transcript as json string
|
||||
:return: transcript as json string
|
||||
@@ -129,6 +135,12 @@ class Transcript:
|
||||
"""
|
||||
if "indent" not in kwargs:
|
||||
kwargs["indent"] = 3
|
||||
|
||||
if use_annotation and self.annotation:
|
||||
for _id in self.transcript:
|
||||
seq = self.transcript[_id]
|
||||
seq["speakers"] = self.annotation[seq["speakers"]]
|
||||
|
||||
return json.dumps(self.transcript, *args, **kwargs)
|
||||
|
||||
def get_html(self) -> str:
|
||||
@@ -264,5 +276,26 @@ class Transcript:
|
||||
self.to_pdf(path, *args, **kwargs)
|
||||
else:
|
||||
raise ValueError("Unknown file format")
|
||||
|
||||
@classmethod
|
||||
def from_json(cls, json: Union[dict, str]) -> "Transcript":
|
||||
"""Load transcript from json file
|
||||
|
||||
Args:
|
||||
path (str): path to json file
|
||||
|
||||
Returns:
|
||||
Transcript: Transcript object
|
||||
"""
|
||||
if isinstance(json, dict):
|
||||
return cls(json)
|
||||
else:
|
||||
try:
|
||||
transcript = json.loads(json)
|
||||
except:
|
||||
with open(json, "r") as f:
|
||||
transcript = json.load(f)
|
||||
|
||||
return cls(transcript)
|
||||
|
||||
|
||||
Reference in New Issue
Block a user