diff --git a/autotranscript/transcript_exporter.py b/autotranscript/transcript_exporter.py index c6bfa5c..999383d 100644 --- a/autotranscript/transcript_exporter.py +++ b/autotranscript/transcript_exporter.py @@ -1,6 +1,9 @@ import json import time + +from typing import Union + ALPHABET = [*"abcdefghijklmnopqrstuvwxyz"] @@ -19,6 +22,7 @@ class Transcript: Keys should correspond to segment IDs, and values should contain speaker and segment information. """ + self.transcript = transcript self.speakers = self._extract_speakers() self.segments = self._extract_segments() @@ -33,7 +37,7 @@ class Transcript: kwargs (dict): Dictionary with speaker names as keys and list of segments as values. Returns: - dict: Dictionary with speaker names as keys and the corresponding annotation as values. + dict: Dictionary with speaker names as keys and list of segments as values. Raises: ValueError: If the number of speaker names does not match the number @@ -45,7 +49,7 @@ class Transcript: raise ValueError("Number of speaker names does not match number of speakers") if args: - for arg, speaker in zip(args, self.speakers): + for arg, speaker in zip(args, sorted(self.speakers)): annotations[speaker] = arg invalid_speakers = set(kwargs.keys()) - set(self.speakers) @@ -55,7 +59,8 @@ class Transcript: annotations.update({key: kwargs[key] for key in self.speakers if key in kwargs}) self.annotation = annotations - return annotations + + return self def _extract_speakers(self) -> list: """ @@ -100,6 +105,7 @@ class Transcript: eseg = time.strftime("%H:%M:%S",time.gmtime(segm[1])) fstring += f"{speaker} ({sseg} ; {eseg}):\t{seq['text']}\n" + return fstring def __repr__(self) -> str: @@ -121,7 +127,7 @@ class Transcript: return self.transcript - def get_json(self, *args, **kwargs) -> str: + def get_json(self, *args, use_annotation : bool = True, **kwargs) -> str: """ Get transcript as json string :return: transcript as json string @@ -129,6 +135,12 @@ class Transcript: """ if "indent" not in kwargs: kwargs["indent"] = 3 + + if use_annotation and self.annotation: + for _id in self.transcript: + seq = self.transcript[_id] + seq["speakers"] = self.annotation[seq["speakers"]] + return json.dumps(self.transcript, *args, **kwargs) def get_html(self) -> str: @@ -264,5 +276,26 @@ class Transcript: self.to_pdf(path, *args, **kwargs) else: raise ValueError("Unknown file format") + + @classmethod + def from_json(cls, json: Union[dict, str]) -> "Transcript": + """Load transcript from json file + + Args: + path (str): path to json file + + Returns: + Transcript: Transcript object + """ + if isinstance(json, dict): + return cls(json) + else: + try: + transcript = json.loads(json) + except: + with open(json, "r") as f: + transcript = json.load(f) + + return cls(transcript) \ No newline at end of file