kyutai/scripts/stt_from_file_pytorch.py

223 lines
7.2 KiB
Python
Raw Permalink Normal View History

# /// script
# requires-python = ">=3.12"
# dependencies = [
# "julius",
# "librosa",
# "soundfile",
# "moshi",
# ]
# ///
"""An example script that illustrates how one can get per-word timestamps from
Kyutai STT models.
"""
import argparse
import dataclasses
import itertools
import math
import julius
import moshi.models
import sphn
import torch
import tqdm
@dataclasses.dataclass
class TimestampedText:
text: str
timestamp: tuple[float, float]
def __str__(self):
return f"{self.text} ({self.timestamp[0]:.2f}:{self.timestamp[1]:.2f})"
def tokens_to_timestamped_text(
text_tokens,
tokenizer,
frame_rate,
end_of_padding_id,
padding_token_id,
offset_seconds,
) -> list[TimestampedText]:
text_tokens = text_tokens.cpu().view(-1)
# Normally `end_of_padding` tokens indicate word boundaries.
# Everything between them should be a single word;
# the time offset of the those tokens correspond to word start and
# end timestamps (minus silence prefix and audio delay).
#
# However, in rare cases some complexities could arise. Firstly,
# for words that are said quickly but are represented with
# multiple tokens, the boundary might be omitted. Secondly,
# for the very last word the end boundary might not happen.
# Below is a code snippet that handles those situations a bit
# more carefully.
sequence_timestamps = []
def _tstmp(start_position, end_position):
return (
max(0, start_position / frame_rate - offset_seconds),
max(0, end_position / frame_rate - offset_seconds),
)
def _decode(t):
t = t[t > padding_token_id]
return tokenizer.decode(t.numpy().tolist())
def _decode_segment(start, end):
nonlocal text_tokens
nonlocal sequence_timestamps
text = _decode(text_tokens[start:end])
words_inside_segment = text.split()
if len(words_inside_segment) == 0:
return
if len(words_inside_segment) == 1:
# Single word within the boundaries, the general case
sequence_timestamps.append(
TimestampedText(text=text, timestamp=_tstmp(start, end))
)
else:
# We're in a rare situation where multiple words are so close they are not separated by `end_of_padding`.
# We tokenize words one-by-one; each word is assigned with as many frames as much tokens it has.
for adjacent_word in words_inside_segment[:-1]:
n_tokens = len(tokenizer.encode(adjacent_word))
sequence_timestamps.append(
TimestampedText(
text=adjacent_word, timestamp=_tstmp(start, start + n_tokens)
)
)
start += n_tokens
# The last word takes everything until the boundary
adjacent_word = words_inside_segment[-1]
sequence_timestamps.append(
TimestampedText(text=adjacent_word, timestamp=_tstmp(start, end))
)
(segment_boundaries,) = torch.where(text_tokens == end_of_padding_id)
if not segment_boundaries.numel():
return []
for i in range(len(segment_boundaries) - 1):
segment_start = int(segment_boundaries[i]) + 1
segment_end = int(segment_boundaries[i + 1])
_decode_segment(segment_start, segment_end)
last_segment_start = segment_boundaries[-1] + 1
boundary_token = torch.tensor([tokenizer.eos_id()])
(end_of_last_segment,) = torch.where(
torch.isin(text_tokens[last_segment_start:], boundary_token)
)
if not end_of_last_segment.numel():
# upper-bound either end of the audio or 1 second duration, whicher is smaller
last_segment_end = min(text_tokens.shape[-1], last_segment_start + frame_rate)
else:
last_segment_end = last_segment_start + end_of_last_segment[0]
_decode_segment(last_segment_start, last_segment_end)
return sequence_timestamps
def main(args):
info = moshi.models.loaders.CheckpointInfo.from_hf_repo(
args.hf_repo,
moshi_weights=args.moshi_weight,
mimi_weights=args.mimi_weight,
tokenizer=args.tokenizer,
config_path=args.config_path,
)
mimi = info.get_mimi(device=args.device)
tokenizer = info.get_text_tokenizer()
lm = info.get_moshi(
device=args.device,
dtype=torch.bfloat16,
)
lm_gen = moshi.models.LMGen(lm, temp=0, temp_text=0.0)
audio_silence_prefix_seconds = info.stt_config.get(
"audio_silence_prefix_seconds", 1.0
)
audio_delay_seconds = info.stt_config.get("audio_delay_seconds", 5.0)
padding_token_id = info.raw_config.get("text_padding_token_id", 3)
audio, input_sample_rate = sphn.read(args.file)
audio = torch.from_numpy(audio).to(args.device)
audio = julius.resample_frac(audio, input_sample_rate, mimi.sample_rate)
if audio.shape[-1] % mimi.frame_size != 0:
to_pad = mimi.frame_size - audio.shape[-1] % mimi.frame_size
audio = torch.nn.functional.pad(audio, (0, to_pad))
text_tokens_accum = []
n_prefix_chunks = math.ceil(audio_silence_prefix_seconds * mimi.frame_rate)
n_suffix_chunks = math.ceil(audio_delay_seconds * mimi.frame_rate)
silence_chunk = torch.zeros(
(1, 1, mimi.frame_size), dtype=torch.float32, device=args.device
)
chunks = itertools.chain(
itertools.repeat(silence_chunk, n_prefix_chunks),
torch.split(audio[:, None], mimi.frame_size, dim=-1),
itertools.repeat(silence_chunk, n_suffix_chunks),
)
with mimi.streaming(1), lm_gen.streaming(1):
for audio_chunk in tqdm.tqdm(chunks):
audio_tokens = mimi.encode(audio_chunk)
text_tokens = lm_gen.step(audio_tokens)
if text_tokens is not None:
text_tokens_accum.append(text_tokens)
print(tokenizer.decode(text_tokens.numpy().tolist()))
utterance_tokens = torch.concat(text_tokens_accum, dim=-1)
timed_text = tokens_to_timestamped_text(
utterance_tokens,
tokenizer,
mimi.frame_rate,
end_of_padding_id=0,
padding_token_id=padding_token_id,
offset_seconds=int(n_prefix_chunks / mimi.frame_rate) + audio_delay_seconds,
)
decoded = " ".join([str(t) for t in timed_text])
print(decoded)
if __name__ == "__main__":
parser = argparse.ArgumentParser(description="Example streaming STT w/ timestamps.")
parser.add_argument("in_file", help="The file to transcribe.")
parser.add_argument(
"--hf-repo", type=str, help="HF repo to load the STT model from. "
)
parser.add_argument("--tokenizer", type=str, help="Path to a local tokenizer file.")
parser.add_argument(
"--moshi-weight", type=str, help="Path to a local checkpoint file."
)
parser.add_argument(
"--mimi-weight", type=str, help="Path to a local checkpoint file for Mimi."
)
parser.add_argument(
"--config-path", type=str, help="Path to a local config file.", default=None
)
parser.add_argument(
"--device",
type=str,
default="cuda",
help="Device on which to run, defaults to 'cuda'.",
)
args = parser.parse_args()
main(args)