# /// script # requires-python = ">=3.12" # dependencies = [ # "datasets", # "jiwer==3.1.0", # "julius", # "librosa", # "moshi", # "openai-whisper", # "soundfile", # ] # /// """ Example implementation of the streaming STT example. Here we group test utterances in batches (pre- and post-padded with silence) and and then feed these batches into the streaming STT model frame-by-frame. """ # The outputs I get on my H100 using this code with the 2.6B model, # bsz 32: # LibriVox === cer: 4.09% wer: 7.33% corpus_wer: 6.78% RTF = 52.72 # Ami === cer: 15.99% wer: 18.78% corpus_wer: 12.20% RTF = 28.37 # LibriSpeech other === cer: 2.31% wer: 5.24% corpus_wer: 4.33% RTF = 44.76 # LibriSpeech clean === cer: 0.67% wer: 1.95% corpus_wer: 1.69% RTF = 68.19 # Tedlium (short) === cer: 2.15% wer: 3.65% corpus_wer: 3.33% RTF = 67.44 # spgispeech === cer: 0.99% wer: 2.00% corpus_wer: 2.03% RTF = 78.64 # gigaspeech === cer: 6.80% wer: 11.31% corpus_wer: 9.81% RTF = 64.04 # earnings22 (short) === cer: 12.63% wer: 15.70% corpus_wer: 11.02% RTF = 50.13 # Meanwhile === cer: 2.02% wer: 5.50% corpus_wer: 5.60% RTF = 69.19 # Tedlium (long) == cer: 1.53% wer: 2.56% corpus_wer: 2.97% RTF = 33.92 # Rev16 === cer: 6.57% wer: 10.08% corpus_wer: 11.43% RTF = 40.34 # Earnings21 === cer: 5.73% wer: 9.84% corpus_wer: 10.38% RTF = 73.15 import argparse import dataclasses import time import jiwer import julius import moshi.models import torch import tqdm from datasets import Dataset, load_dataset from whisper.normalizers import EnglishTextNormalizer _NORMALIZER = EnglishTextNormalizer() def get_text(sample): possible_keys = [ "text", "sentence", "normalized_text", "transcript", "transcription", ] for key in possible_keys: if key in sample: return sample[key] raise ValueError( f"Expected transcript column of either {possible_keys}." f"Got sample with keys: {', '.join(sample.keys())}. Ensure a text column name is present in the dataset." ) # The two functions below are adapted from https://github.com/huggingface/open_asr_leaderboard/blob/main/normalizer/data_utils.py def normalize(batch): batch["original_text"] = get_text(batch) batch["norm_text"] = _NORMALIZER(batch["original_text"]) return batch def is_target_text_in_range(ref): if ref.strip() == "ignore time segment in scoring": return False else: return ref.strip() != "" # End of the adapted part class AsrMetrics: def __init__(self): self.cer_sum = 0.0 self.wer_sum = 0.0 self.errors_sum = 0.0 self.total_words_sum = 0.0 self.num_sequences = 0.0 def update(self, hyp: str, ref: str) -> None: normalized_ref = _NORMALIZER(ref) normalized_hyp = _NORMALIZER(hyp) this_wer = jiwer.wer(normalized_ref, normalized_hyp) this_cer = jiwer.cer(normalized_ref, normalized_hyp) measures = jiwer.compute_measures(normalized_ref, normalized_hyp) self.wer_sum += this_wer self.cer_sum += this_cer self.errors_sum += ( measures["substitutions"] + measures["deletions"] + measures["insertions"] ) self.total_words_sum += ( measures["substitutions"] + measures["deletions"] + measures["hits"] ) self.num_sequences += 1 def compute(self) -> dict: assert self.num_sequences > 0, ( "Unable to compute with total number of comparisons <= 0" ) # type: ignore return { "cer": (self.cer_sum / self.num_sequences), "wer": (self.wer_sum / self.num_sequences), "corpus_wer": (self.errors_sum / self.total_words_sum), } def __str__(self) -> str: result = self.compute() return " ".join(f"{k}: {100 * v:.2f}%" for k, v in result.items()) class Timer: def __init__(self): self.total = 0 self._start_time = None def __enter__(self): self._start_time = time.perf_counter() return self def __exit__(self, *_): self.total += time.perf_counter() - self._start_time self._start_time = None @dataclasses.dataclass class _DatasetInfo: alias: str name: str config: str split: str = "test" _DATASETS = [ # Long-form datasets from distil-whisper _DatasetInfo("rev16", "distil-whisper/rev16", "whisper_subset"), _DatasetInfo("earnings21", "distil-whisper/earnings21", "full"), _DatasetInfo("earnings22", "distil-whisper/earnings22", "full"), _DatasetInfo("tedlium", "distil-whisper/tedlium-long-form", None), _DatasetInfo("meanwhile", "distil-whisper/meanwhile", None), # Short-form datasets from OpenASR leaderboard _DatasetInfo("ami", "hf-audio/esb-datasets-test-only-sorted", "ami"), _DatasetInfo( "librispeech.clean", "hf-audio/esb-datasets-test-only-sorted", "librispeech", split="test.clean", ), _DatasetInfo( "librispeech.other", "hf-audio/esb-datasets-test-only-sorted", "librispeech", split="test.other", ), _DatasetInfo("voxpopuli", "hf-audio/esb-datasets-test-only-sorted", "voxpopuli"), _DatasetInfo("spgispeech", "hf-audio/esb-datasets-test-only-sorted", "spgispeech"), _DatasetInfo("gigaspeech", "hf-audio/esb-datasets-test-only-sorted", "gigaspeech"), _DatasetInfo("tedlium-short", "hf-audio/esb-datasets-test-only-sorted", "tedlium"), _DatasetInfo( "earnings22-short", "hf-audio/esb-datasets-test-only-sorted", "earnings22" ), ] DATASET_MAP = {dataset.alias: dataset for dataset in _DATASETS} def get_dataset(args) -> Dataset: if args.dataset not in DATASET_MAP: raise RuntimeError(f"Unknown dataset: {args.dataset}") info = DATASET_MAP[args.dataset] dataset = load_dataset( info.name, info.config, split=info.split, cache_dir=args.hf_cache_dir, streaming=False, token=True, ) dataset = dataset.map(normalize) dataset = dataset.filter(is_target_text_in_range, input_columns=["norm_text"]) return dataset @torch.no_grad def get_padded_batch( audios: list[tuple[torch.Tensor, int]], before_padding: float, after_padding: float, audio_encoder, ): sample_rate = audio_encoder.sample_rate max_len = 0 batch = [] durations = [] for audio, sr in audios: durations.append(audio.shape[-1] / sr) audio = julius.resample_frac(audio, int(sr), int(sample_rate)) audio = torch.nn.functional.pad( audio, (int(before_padding * sample_rate), int(after_padding * sample_rate)) ) max_len = max(max_len, audio.shape[-1]) batch.append(audio) target = max_len if target % audio_encoder.frame_size != 0: target = target + ( audio_encoder.frame_size - max_len % audio_encoder.frame_size ) padded_batch = torch.stack( [ torch.nn.functional.pad(audio, (0, target - audio.shape[-1])) for audio in batch ] ) return padded_batch @torch.no_grad def streaming_transcribe( padded_batch: torch.Tensor, mimi, lm_gen, ): bsz = padded_batch.shape[0] text_tokens_acc = [] with mimi.streaming(bsz), lm_gen.streaming(bsz): for offset in range(0, padded_batch.shape[-1], mimi.frame_size): audio_chunk = padded_batch[:, offset : offset + mimi.frame_size] audio_chunk = audio_chunk[:, None, :] audio_tokens = mimi.encode(audio_chunk) text_tokens = lm_gen.step(audio_tokens) if text_tokens is not None: text_tokens_acc.append(text_tokens) return torch.concat(text_tokens_acc, axis=-1) def run_inference( dataset, mimi, lm_gen, tokenizer, padding_token_id, before_padding_sec, after_padding_sec, ): metrics = AsrMetrics() audio_time = 0.0 inference_timer = Timer() for batch in tqdm.tqdm(dataset.iter(args.batch_size)): audio_data = list( zip( [torch.tensor(x["array"]).float() for x in batch["audio"]], [x["sampling_rate"] for x in batch["audio"]], ) ) audio_time += sum(audio.shape[-1] / sr for (audio, sr) in audio_data) gt_transcripts = batch["original_text"] padded_batch = get_padded_batch( audio_data, before_padding=before_padding_sec, after_padding=after_padding_sec, audio_encoder=mimi, ) padded_batch = padded_batch.cuda() with inference_timer: text_tokens = streaming_transcribe( padded_batch, mimi=mimi, lm_gen=lm_gen, ) for batch_index in range(text_tokens.shape[0]): utterance_tokens = text_tokens[batch_index, ...] utterance_tokens = utterance_tokens[utterance_tokens > padding_token_id] text = tokenizer.decode(utterance_tokens.cpu().numpy().tolist()) metrics.update(hyp=text, ref=gt_transcripts[batch_index]) return metrics, inference_timer.total, audio_time def main(args): torch.set_float32_matmul_precision("high") info = moshi.models.loaders.CheckpointInfo.from_hf_repo( args.hf_repo, moshi_weights=args.moshi_weight, mimi_weights=args.mimi_weight, tokenizer=args.tokenizer, config_path=args.config_path, ) mimi = info.get_mimi(device=args.device) tokenizer = info.get_text_tokenizer() lm = info.get_moshi( device=args.device, dtype=torch.bfloat16, ) lm_gen = moshi.models.LMGen(lm, temp=0, temp_text=0.0) dataset = get_dataset(args) padding_token_id = info.raw_config.get("text_padding_token_id", 3) # Putting in some conservative defaults audio_silence_prefix_seconds = info.stt_config.get( "audio_silence_prefix_seconds", 1.0 ) audio_delay_seconds = info.stt_config.get("audio_delay_seconds", 5.0) wer_metric, inference_time, audio_time = run_inference( dataset, mimi, lm_gen, tokenizer, padding_token_id, audio_silence_prefix_seconds, audio_delay_seconds + 0.5, ) print(wer_metric, f"RTF = {audio_time / inference_time:.2f}") if __name__ == "__main__": parser = argparse.ArgumentParser(description="Example streaming STT inference.") parser.add_argument( "--dataset", required=True, choices=DATASET_MAP.keys(), help="Dataset to run inference on.", ) parser.add_argument( "--hf-repo", type=str, help="HF repo to load the STT model from." ) parser.add_argument("--tokenizer", type=str, help="Path to a local tokenizer file.") parser.add_argument( "--moshi-weight", type=str, help="Path to a local checkpoint file." ) parser.add_argument( "--mimi-weight", type=str, help="Path to a local checkpoint file for Mimi." ) parser.add_argument( "--config-path", type=str, help="Path to a local config file.", default=None ) parser.add_argument( "--batch-size", type=int, help="Batch size.", default=32, ) parser.add_argument( "--device", type=str, default="cuda", help="Device on which to run, defaults to 'cuda'.", ) parser.add_argument("--hf-cache-dir", type=str, help="HuggingFace cache folder.") args = parser.parse_args() main(args)