From c9c4e0a573b6698edfd4b6bd126ca564341e4e41 Mon Sep 17 00:00:00 2001 From: laurent Date: Tue, 8 Jul 2025 11:28:54 +0200 Subject: [PATCH] Add some VAD to the pytorch speech-to-text example. --- scripts/stt_from_file_pytorch.py | 40 +++++++++++++++++++++++++++----- 1 file changed, 34 insertions(+), 6 deletions(-) diff --git a/scripts/stt_from_file_pytorch.py b/scripts/stt_from_file_pytorch.py index fd10c67..46479ea 100644 --- a/scripts/stt_from_file_pytorch.py +++ b/scripts/stt_from_file_pytorch.py @@ -4,7 +4,7 @@ # "julius", # "librosa", # "soundfile", -# "moshi", +# "moshi==0.2.9", # ] # /// @@ -20,8 +20,8 @@ import math import julius import moshi.models import sphn +import time import torch -import tqdm @dataclasses.dataclass @@ -138,9 +138,13 @@ def main(args): mimi = info.get_mimi(device=args.device) tokenizer = info.get_text_tokenizer() + lm_kwargs_overrides = {} + if args.vad: + lm_kwargs_overrides = {"extra_heads_num_heads": 4} lm = info.get_moshi( device=args.device, dtype=torch.bfloat16, + lm_kwargs_overrides=lm_kwargs_overrides, ) lm_gen = moshi.models.LMGen(lm, temp=0, temp_text=0.0) @@ -171,14 +175,35 @@ def main(args): itertools.repeat(silence_chunk, n_suffix_chunks), ) + start_time = time.time() + nchunks = 0 + last_print_was_vad = False with mimi.streaming(1), lm_gen.streaming(1): - for audio_chunk in tqdm.tqdm(chunks): + for audio_chunk in chunks: + nchunks += 1 audio_tokens = mimi.encode(audio_chunk) - text_tokens = lm_gen.step(audio_tokens) - if text_tokens is not None: - text_tokens_accum.append(text_tokens) + if args.vad: + text_tokens, vad_heads = lm_gen.step_with_extra_heads(audio_tokens) + if vad_heads: + pr_vad = vad_heads[2][0, 0, 0].cpu().item() + if pr_vad > 0.5 and not last_print_was_vad: + print(" [end of turn detected]") + last_print_was_vad = True + else: + text_tokens = lm_gen.step(audio_tokens) + text_token = text_tokens[0, 0, 0].cpu().item() + if text_token not in (0, 3): + _text = tokenizer.id_to_piece(text_tokens[0, 0, 0].cpu().item()) # type: ignore + _text = _text.replace("▁", " ") + print(_text, end="", flush=True) + last_print_was_vad = False + text_tokens_accum.append(text_tokens) utterance_tokens = torch.concat(text_tokens_accum, dim=-1) + dt = time.time() - start_time + print( + f"\nprocessed {nchunks} chunks in {dt:.2f} seconds, steps per second: {nchunks / dt:.2f}" + ) timed_text = tokens_to_timestamped_text( utterance_tokens, tokenizer, @@ -209,6 +234,9 @@ if __name__ == "__main__": parser.add_argument( "--config-path", type=str, help="Path to a local config file.", default=None ) + parser.add_argument( + "--vad", action="store_true", help="Enable VAD (Voice Activity Detection)." + ) parser.add_argument( "--device", type=str,