Support for outputting .srt format

This commit is contained in:
Christopher Oezbek 2025-07-04 14:24:02 +02:00
parent 6966635499
commit fb5395a523

View File

@ -32,6 +32,29 @@ class TimestampedText:
def __str__(self):
return f"{self.text} ({self.timestamp[0]:.2f}:{self.timestamp[1]:.2f})"
def _seconds_to_srt_time(seconds: float) -> str:
"""Converts seconds to SRT time format HH:MM:SS,mmm or HH:MM:SS.mmm."""
hours = int(seconds // 3600)
minutes = int((seconds % 3600) // 60)
seconds = seconds % 60
milliseconds = int((seconds - int(seconds)) * 1000)
return f"{hours:02}:{minutes:02}:{int(seconds):02},{milliseconds:03}"
@dataclasses.dataclass
class TimestampedTranscript:
segments: list[TimestampedText]
def __str__(self) -> str:
"""Formats the transcript into an SRT (SubRip) file content."""
srt_chunks = []
for i, segment in enumerate(self.segments, 1):
start_time = _seconds_to_srt_time(segment.timestamp[0])
end_time = _seconds_to_srt_time(segment.timestamp[1])
text = segment.text.strip()
srt_chunks.append(f"{i}\n{start_time} --> {end_time}\n{text}\n")
return "\n".join(srt_chunks)
def tokens_to_timestamped_text(
text_tokens,
@ -189,6 +212,10 @@ def main(args):
offset_seconds=int(n_prefix_chunks / mimi.frame_rate) + audio_delay_seconds,
)
if args.srt:
transcript = TimestampedTranscript(segments=timed_text)
print(transcript)
else:
decoded = " ".join([str(t) for t in timed_text])
print(decoded)
@ -216,6 +243,11 @@ if __name__ == "__main__":
default="cuda",
help="Device on which to run, defaults to 'cuda'.",
)
parser.add_argument(
"--srt",
action="store_true",
help="Prints the transcript in SRT format.",
)
args = parser.parse_args()
main(args)