From ef864a6f386c6dc9f95aae7d1d86184207e00a73 Mon Sep 17 00:00:00 2001 From: eugene-kharitonov Date: Fri, 20 Jun 2025 15:53:45 +0200 Subject: [PATCH] Moving over STT inference scripts (#7) * Adding links to STT example scripts One script for HF dataset inference; another for retrieving timestamps. * Moving inference scripts to the delayed-streams-repo --------- Co-authored-by: Eugene --- README.md | 16 ++ scripts/streaming_stt.py | 396 ++++++++++++++++++++++++++++ scripts/streaming_stt_timestamps.py | 231 ++++++++++++++++ 3 files changed, 643 insertions(+) create mode 100644 scripts/streaming_stt.py create mode 100644 scripts/streaming_stt_timestamps.py diff --git a/README.md b/README.md index 02f750c..99ed1bd 100644 --- a/README.md +++ b/README.md @@ -60,6 +60,22 @@ uvx --with moshi python -m moshi.run_inference --hf-repo kyutai/stt-2.6b-en bria ``` It will install the moshi package in a temporary environment and run the speech-to-text. +Additionally, we provide two scripts that highlight different usage scenarios. The first script illustrates how to extract word-level timestamps from the model's outputs: + +```bash +uv run \ + scripts/streaming_stt_timestamps.py \ + --hf-repo kyutai/stt-2.6b-en \ + --file bria.mp3 +``` + +The second script can be used to run a model on an existing Hugging Face dataset and calculate its performance metrics: +```bash +uv run scripts/streaming_stt.py \ + --dataset meanwhile \ + --hf-repo kyutai/stt-2.6b-en +``` + ### Rust server Hugging Face diff --git a/scripts/streaming_stt.py b/scripts/streaming_stt.py new file mode 100644 index 0000000..bffe294 --- /dev/null +++ b/scripts/streaming_stt.py @@ -0,0 +1,396 @@ +# /// script +# requires-python = ">=3.12" +# dependencies = [ +# "datasets", +# "jiwer==3.1.0", +# "julius", +# "librosa", +# "moshi", +# "openai-whisper", +# "soundfile", +# ] +# /// +""" +Example implementation of the streaming STT example. Here we group +test utterances in batches (pre- and post-padded with silence) and +and then feed these batches into the streaming STT model frame-by-frame. + +Example command: +``` +uv run scripts/streaming_stt.py \ + --dataset meanwhile \ + --hf-repo kyutai/stt-2.6b-en +``` + +""" + +# The outputs I get on my H100 using this code with the 2.6B model, +# bsz 32: + +# LibriVox === cer: 4.09% wer: 7.33% corpus_wer: 6.78% RTF = 52.72 +# Ami === cer: 15.99% wer: 18.78% corpus_wer: 12.20% RTF = 28.37 +# LibriSpeech other === cer: 2.31% wer: 5.24% corpus_wer: 4.33% RTF = 44.76 +# LibriSpeech clean === cer: 0.67% wer: 1.95% corpus_wer: 1.69% RTF = 68.19 +# Tedlium (short) === cer: 2.15% wer: 3.65% corpus_wer: 3.33% RTF = 67.44 +# spgispeech === cer: 0.99% wer: 2.00% corpus_wer: 2.03% RTF = 78.64 +# gigaspeech === cer: 6.80% wer: 11.31% corpus_wer: 9.81% RTF = 64.04 +# earnings22 (short) === cer: 12.63% wer: 15.70% corpus_wer: 11.02% RTF = 50.13 + +# Meanwhile === cer: 2.02% wer: 5.50% corpus_wer: 5.60% RTF = 69.19 +# Tedlium (long) == cer: 1.53% wer: 2.56% corpus_wer: 2.97% RTF = 33.92 +# Rev16 === cer: 6.57% wer: 10.08% corpus_wer: 11.43% RTF = 40.34 +# Earnings21 === cer: 5.73% wer: 9.84% corpus_wer: 10.38% RTF = 73.15 + +import dataclasses +import julius +import jiwer +from datasets import load_dataset, Dataset +from whisper.normalizers import EnglishTextNormalizer +import argparse + +import torch +import moshi.models +import tqdm +import time + + +_NORMALIZER = EnglishTextNormalizer() + + +def get_text(sample): + possible_keys = [ + "text", + "sentence", + "normalized_text", + "transcript", + "transcription", + ] + for key in possible_keys: + if key in sample: + return sample[key] + raise ValueError( + f"Expected transcript column of either {possible_keys}." + f"Got sample with keys: {', '.join(sample.keys())}. Ensure a text column name is present in the dataset." + ) + + +# The two functions below are adapted from https://github.com/huggingface/open_asr_leaderboard/blob/main/normalizer/data_utils.py + + +def normalize(batch): + batch["original_text"] = get_text(batch) + batch["norm_text"] = _NORMALIZER(batch["original_text"]) + return batch + + +def is_target_text_in_range(ref): + if ref.strip() == "ignore time segment in scoring": + return False + else: + return ref.strip() != "" + + +# End of the adapted part + + +class AsrMetrics: + def __init__(self): + self.cer_sum = 0.0 + self.wer_sum = 0.0 + self.errors_sum = 0.0 + self.total_words_sum = 0.0 + self.num_sequences = 0.0 + + def update(self, hyp: str, ref: str) -> None: + normalized_ref = _NORMALIZER(ref) + normalized_hyp = _NORMALIZER(hyp) + + this_wer = jiwer.wer(normalized_ref, normalized_hyp) + this_cer = jiwer.cer(normalized_ref, normalized_hyp) + measures = jiwer.compute_measures(normalized_ref, normalized_hyp) + + self.wer_sum += this_wer + self.cer_sum += this_cer + self.errors_sum += ( + measures["substitutions"] + measures["deletions"] + measures["insertions"] + ) + self.total_words_sum += ( + measures["substitutions"] + measures["deletions"] + measures["hits"] + ) + self.num_sequences += 1 + + def compute(self) -> dict: + assert ( + self.num_sequences > 0 + ), "Unable to compute with total number of comparisons <= 0" # type: ignore + return { + "cer": (self.cer_sum / self.num_sequences), + "wer": (self.wer_sum / self.num_sequences), + "corpus_wer": (self.errors_sum / self.total_words_sum), + } + + def __str__(self) -> str: + result = self.compute() + return " ".join(f"{k}: {100 * v:.2f}%" for k, v in result.items()) + + +class Timer: + def __init__(self): + self.total = 0 + self._start_time = None + + def __enter__(self): + self._start_time = time.perf_counter() + return self + + def __exit__(self, *_): + self.total += time.perf_counter() - self._start_time + self._start_time = None + + +@dataclasses.dataclass +class _DatasetInfo: + alias: str + + name: str + config: str + split: str = "test" + + +_DATASETS = [ + # Long-form datasets from distil-whisper + _DatasetInfo("rev16", "distil-whisper/rev16", "whisper_subset"), + _DatasetInfo("earnings21", "distil-whisper/earnings21", "full"), + _DatasetInfo("earnings22", "distil-whisper/earnings22", "full"), + _DatasetInfo("tedlium", "distil-whisper/tedlium-long-form", None), + _DatasetInfo("meanwhile", "distil-whisper/meanwhile", None), + # Short-form datasets from OpenASR leaderboard + _DatasetInfo("ami", "hf-audio/esb-datasets-test-only-sorted", "ami"), + _DatasetInfo( + "librispeech.clean", + "hf-audio/esb-datasets-test-only-sorted", + "librispeech", + split="test.clean", + ), + _DatasetInfo( + "librispeech.other", + "hf-audio/esb-datasets-test-only-sorted", + "librispeech", + split="test.other", + ), + _DatasetInfo("voxpopuli", "hf-audio/esb-datasets-test-only-sorted", "voxpopuli"), + _DatasetInfo("spgispeech", "hf-audio/esb-datasets-test-only-sorted", "spgispeech"), + _DatasetInfo("gigaspeech", "hf-audio/esb-datasets-test-only-sorted", "gigaspeech"), + _DatasetInfo("tedlium-short", "hf-audio/esb-datasets-test-only-sorted", "tedlium"), + _DatasetInfo( + "earnings22-short", "hf-audio/esb-datasets-test-only-sorted", "earnings22" + ), +] +DATASET_MAP = {dataset.alias: dataset for dataset in _DATASETS} + + +def get_dataset(args) -> Dataset: + if args.dataset not in DATASET_MAP: + raise RuntimeError(f"Unknown dataset: {args.dataset}") + + info = DATASET_MAP[args.dataset] + + dataset = load_dataset( + info.name, + info.config, + split=info.split, + cache_dir=args.hf_cache_dir, + streaming=False, + token=True, + ) + dataset = dataset.map(normalize) + dataset = dataset.filter(is_target_text_in_range, input_columns=["norm_text"]) + + return dataset + + +@torch.no_grad +def get_padded_batch( + audios: list[tuple[torch.Tensor, int]], + before_padding: float, + after_padding: float, + audio_encoder, +): + sample_rate = audio_encoder.sample_rate + + max_len = 0 + batch = [] + durations = [] + for audio, sr in audios: + durations.append(audio.shape[-1] / sr) + audio = julius.resample_frac(audio, int(sr), int(sample_rate)) + audio = torch.nn.functional.pad( + audio, (int(before_padding * sample_rate), int(after_padding * sample_rate)) + ) + max_len = max(max_len, audio.shape[-1]) + batch.append(audio) + + target = max_len + if target % audio_encoder.frame_size != 0: + target = target + ( + audio_encoder.frame_size - max_len % audio_encoder.frame_size + ) + padded_batch = torch.stack( + [ + torch.nn.functional.pad(audio, (0, target - audio.shape[-1])) + for audio in batch + ] + ) + return padded_batch + + +@torch.no_grad +def streaming_transcribe( + padded_batch: torch.Tensor, + mimi, + lm_gen, +): + bsz = padded_batch.shape[0] + + text_tokens_acc = [] + + with mimi.streaming(bsz), lm_gen.streaming(bsz): + for offset in range(0, padded_batch.shape[-1], mimi.frame_size): + audio_chunk = padded_batch[:, offset : offset + mimi.frame_size] + audio_chunk = audio_chunk[:, None, :] + + audio_tokens = mimi.encode(audio_chunk) + text_tokens = lm_gen.step(audio_tokens) + if text_tokens is not None: + text_tokens_acc.append(text_tokens) + + return torch.concat(text_tokens_acc, axis=-1) + + +def run_inference( + dataset, + mimi, + lm_gen, + tokenizer, + padding_token_id, + before_padding_sec, + after_padding_sec, +): + metrics = AsrMetrics() + audio_time = 0.0 + inference_timer = Timer() + + for batch in tqdm.tqdm(dataset.iter(args.batch_size)): + audio_data = list( + zip( + [torch.tensor(x["array"]).float() for x in batch["audio"]], + [x["sampling_rate"] for x in batch["audio"]], + ) + ) + + audio_time += sum(audio.shape[-1] / sr for (audio, sr) in audio_data) + + gt_transcripts = batch["original_text"] + + padded_batch = get_padded_batch( + audio_data, + before_padding=before_padding_sec, + after_padding=after_padding_sec, + audio_encoder=mimi, + ) + padded_batch = padded_batch.cuda() + + with inference_timer: + text_tokens = streaming_transcribe( + padded_batch, + mimi=mimi, + lm_gen=lm_gen, + ) + + for batch_index in range(text_tokens.shape[0]): + utterance_tokens = text_tokens[batch_index, ...] + utterance_tokens = utterance_tokens[utterance_tokens > padding_token_id] + text = tokenizer.decode(utterance_tokens.cpu().numpy().tolist()) + metrics.update(hyp=text, ref=gt_transcripts[batch_index]) + + return metrics, inference_timer.total, audio_time + + +def main(args): + torch.set_float32_matmul_precision("high") + + info = moshi.models.loaders.CheckpointInfo.from_hf_repo( + args.hf_repo, + moshi_weights=args.moshi_weight, + mimi_weights=args.mimi_weight, + tokenizer=args.tokenizer, + config_path=args.config_path, + ) + + mimi = info.get_mimi(device=args.device) + tokenizer = info.get_text_tokenizer() + lm = info.get_moshi( + device=args.device, + dtype=torch.bfloat16, + ) + lm_gen = moshi.models.LMGen(lm, temp=0, temp_text=0.0) + dataset = get_dataset(args) + + padding_token_id = info.raw_config.get("text_padding_token_id", 3) + # Putting in some conservative defaults + audio_silence_prefix_seconds = info.stt_config.get( + "audio_silence_prefix_seconds", 1.0 + ) + audio_delay_seconds = info.stt_config.get("audio_delay_seconds", 5.0) + + wer_metric, inference_time, audio_time = run_inference( + dataset, + mimi, + lm_gen, + tokenizer, + padding_token_id, + audio_silence_prefix_seconds, + audio_delay_seconds + 0.5, + ) + + print(wer_metric, f"RTF = {audio_time / inference_time:.2f}") + + +if __name__ == "__main__": + parser = argparse.ArgumentParser(description="Example streaming STT inference.") + parser.add_argument( + "--dataset", + required=True, + choices=DATASET_MAP.keys(), + help="Dataset to run inference on.", + ) + + parser.add_argument( + "--hf-repo", type=str, help="HF repo to load the STT model from. " + ) + parser.add_argument("--tokenizer", type=str, help="Path to a local tokenizer file.") + parser.add_argument( + "--moshi-weight", type=str, help="Path to a local checkpoint file." + ) + parser.add_argument( + "--mimi-weight", type=str, help="Path to a local checkpoint file for Mimi." + ) + parser.add_argument( + "--config-path", type=str, help="Path to a local config file.", default=None + ) + parser.add_argument( + "--batch-size", + type=int, + help="Batch size.", + default=32, + ) + parser.add_argument( + "--device", + type=str, + default="cuda", + help="Device on which to run, defaults to 'cuda'.", + ) + parser.add_argument("--hf-cache-dir", type=str, help="HuggingFace cache folder.") + args = parser.parse_args() + + main(args) diff --git a/scripts/streaming_stt_timestamps.py b/scripts/streaming_stt_timestamps.py new file mode 100644 index 0000000..7e0970b --- /dev/null +++ b/scripts/streaming_stt_timestamps.py @@ -0,0 +1,231 @@ +# /// script +# requires-python = ">=3.12" +# dependencies = [ +# "julius", +# "librosa", +# "soundfile", +# "moshi", +# ] +# /// + +"""An example script that illustrates how one can get per-word timestamps from +Kyutai STT models. + +Usage: +``` +uv run scripts/streaming_stt_timestamps.py \ + --hf-repo kyutai/stt-2.6b-en \ + --file bria.mp3 +``` +""" + +import itertools +import dataclasses +import julius +import sphn +import argparse +import math + +import torch +import moshi.models +import tqdm + + +@dataclasses.dataclass +class TimestampedText: + text: str + timestamp: tuple[float, float] + + def __str__(self): + return f"{self.text} ({self.timestamp[0]:.2f}:{self.timestamp[1]:.2f})" + + +def tokens_to_timestamped_text( + text_tokens, + tokenizer, + frame_rate, + end_of_padding_id, + padding_token_id, + offset_seconds, +) -> list[TimestampedText]: + text_tokens = text_tokens.cpu().view(-1) + + # Normally `end_of_padding` tokens indicate word boundaries. + # Everything between them should be a single word; + # the time offset of the those tokens correspond to word start and + # end timestamps (minus silence prefix and audio delay). + # + # However, in rare cases some complexities could arise. Firstly, + # for words that are said quickly but are represented with + # multiple tokens, the boundary might be omitted. Secondly, + # for the very last word the end boundary might not happen. + # Below is a code snippet that handles those situations a bit + # more carefully. + + sequence_timestamps = [] + + def _tstmp(start_position, end_position): + return ( + max(0, start_position / frame_rate - offset_seconds), + max(0, end_position / frame_rate - offset_seconds), + ) + + def _decode(t): + t = t[t > padding_token_id] + return tokenizer.decode(t.numpy().tolist()) + + def _decode_segment(start, end): + nonlocal text_tokens + nonlocal sequence_timestamps + + text = _decode(text_tokens[start:end]) + words_inside_segment = text.split() + + if len(words_inside_segment) == 0: + return + if len(words_inside_segment) == 1: + # Single word within the boundaries, the general case + sequence_timestamps.append( + TimestampedText(text=text, timestamp=_tstmp(start, end)) + ) + else: + # We're in a rare situation where multiple words are so close they are not separated by `end_of_padding`. + # We tokenize words one-by-one; each word is assigned with as many frames as much tokens it has. + for adjacent_word in words_inside_segment[:-1]: + n_tokens = len(tokenizer.encode(adjacent_word)) + sequence_timestamps.append( + TimestampedText( + text=adjacent_word, timestamp=_tstmp(start, start + n_tokens) + ) + ) + start += n_tokens + + # The last word takes everything until the boundary + adjacent_word = words_inside_segment[-1] + sequence_timestamps.append( + TimestampedText(text=adjacent_word, timestamp=_tstmp(start, end)) + ) + + (segment_boundaries,) = torch.where(text_tokens == end_of_padding_id) + + if not segment_boundaries.numel(): + return [] + + for i in range(len(segment_boundaries) - 1): + segment_start = int(segment_boundaries[i]) + 1 + segment_end = int(segment_boundaries[i + 1]) + + _decode_segment(segment_start, segment_end) + + last_segment_start = segment_boundaries[-1] + 1 + + boundary_token = torch.tensor([tokenizer.eos_id()]) + (end_of_last_segment,) = torch.where( + torch.isin(text_tokens[last_segment_start:], boundary_token) + ) + + if not end_of_last_segment.numel(): + # upper-bound either end of the audio or 1 second duration, whicher is smaller + last_segment_end = min(text_tokens.shape[-1], last_segment_start + frame_rate) + else: + last_segment_end = last_segment_start + end_of_last_segment[0] + _decode_segment(last_segment_start, last_segment_end) + + return sequence_timestamps + + +def main(args): + info = moshi.models.loaders.CheckpointInfo.from_hf_repo( + args.hf_repo, + moshi_weights=args.moshi_weight, + mimi_weights=args.mimi_weight, + tokenizer=args.tokenizer, + config_path=args.config_path, + ) + + mimi = info.get_mimi(device=args.device) + tokenizer = info.get_text_tokenizer() + lm = info.get_moshi( + device=args.device, + dtype=torch.bfloat16, + ) + lm_gen = moshi.models.LMGen(lm, temp=0, temp_text=0.0) + + audio_silence_prefix_seconds = info.stt_config.get( + "audio_silence_prefix_seconds", 1.0 + ) + audio_delay_seconds = info.stt_config.get("audio_delay_seconds", 5.0) + padding_token_id = info.raw_config.get("text_padding_token_id", 3) + + audio, input_sample_rate = sphn.read(args.file) + audio = torch.from_numpy(audio).to(args.device) + audio = julius.resample_frac(audio, input_sample_rate, mimi.sample_rate) + if audio.shape[-1] % mimi.frame_size != 0: + to_pad = mimi.frame_size - audio.shape[-1] % mimi.frame_size + audio = torch.nn.functional.pad(audio, (0, to_pad)) + + text_tokens_accum = [] + + n_prefix_chunks = math.ceil(audio_silence_prefix_seconds * mimi.frame_rate) + n_suffix_chunks = math.ceil(audio_delay_seconds * mimi.frame_rate) + silence_chunk = torch.zeros( + (1, 1, mimi.frame_size), dtype=torch.float32, device=args.device + ) + + chunks = itertools.chain( + itertools.repeat(silence_chunk, n_prefix_chunks), + torch.split(audio[:, None], mimi.frame_size, dim=-1), + itertools.repeat(silence_chunk, n_suffix_chunks), + ) + + with mimi.streaming(1), lm_gen.streaming(1): + for audio_chunk in tqdm.tqdm(chunks): + audio_tokens = mimi.encode(audio_chunk) + text_tokens = lm_gen.step(audio_tokens) + if text_tokens is not None: + text_tokens_accum.append(text_tokens) + + utterance_tokens = torch.concat(text_tokens_accum, dim=-1) + timed_text = tokens_to_timestamped_text( + utterance_tokens, + tokenizer, + mimi.frame_rate, + end_of_padding_id=0, + padding_token_id=padding_token_id, + offset_seconds=int(n_prefix_chunks / mimi.frame_rate) + audio_delay_seconds, + ) + + decoded = " ".join([str(t) for t in timed_text]) + print(decoded) + + +if __name__ == "__main__": + parser = argparse.ArgumentParser(description="Example streaming STT w/ timestamps.") + parser.add_argument( + "--file", + required=True, + help="File to transcribe.", + ) + + parser.add_argument( + "--hf-repo", type=str, help="HF repo to load the STT model from. " + ) + parser.add_argument("--tokenizer", type=str, help="Path to a local tokenizer file.") + parser.add_argument( + "--moshi-weight", type=str, help="Path to a local checkpoint file." + ) + parser.add_argument( + "--mimi-weight", type=str, help="Path to a local checkpoint file for Mimi." + ) + parser.add_argument( + "--config-path", type=str, help="Path to a local config file.", default=None + ) + parser.add_argument( + "--device", + type=str, + default="cuda", + help="Device on which to run, defaults to 'cuda'.", + ) + args = parser.parse_args() + + main(args)