From dc6f552a47a65ce5b16d03073a0fcaecc2aa14a8 Mon Sep 17 00:00:00 2001 From: Laurent Date: Tue, 8 Jul 2025 15:43:38 +0200 Subject: [PATCH] VAD support. --- scripts/stt_from_file_mlx.py | 21 ++++++++++++++++----- 1 file changed, 16 insertions(+), 5 deletions(-) diff --git a/scripts/stt_from_file_mlx.py b/scripts/stt_from_file_mlx.py index 1b1da29..1584bf0 100644 --- a/scripts/stt_from_file_mlx.py +++ b/scripts/stt_from_file_mlx.py @@ -25,6 +25,9 @@ if __name__ == "__main__": parser.add_argument("in_file", help="The file to transcribe.") parser.add_argument("--max-steps", default=4096) parser.add_argument("--hf-repo", default="kyutai/stt-1b-en_fr-mlx") + parser.add_argument( + "--vad", action="store_true", help="Enable VAD (Voice Activity Detection)." + ) args = parser.parse_args() audio, _ = sphn.read(args.in_file, sample_rate=24000) @@ -65,11 +68,19 @@ if __name__ == "__main__": print(f"starting inference {audio.shape}") audio = mx.concat([mx.array(audio), mx.zeros((1, 48000))], axis=-1) + last_print_was_vad = False for start_idx in range(0, audio.shape[-1] // 1920 * 1920, 1920): - block = audio[:, None, start_idx:start_idx + 1920] - other_audio_tokens = audio_tokenizer.encode_step(block) - other_audio_tokens = mx.array(other_audio_tokens).transpose(0, 2, 1) - text_token = gen.step(other_audio_tokens[0]) + block = audio[:, None, start_idx : start_idx + 1920] + other_audio_tokens = audio_tokenizer.encode_step(block).transpose(0, 2, 1) + if args.vad: + text_token, vad_heads = gen.step_with_extra_heads(other_audio_tokens[0]) + if vad_heads: + pr_vad = vad_heads[2][0, 0, 0].item() + if pr_vad > 0.5 and not last_print_was_vad: + print(" [end of turn detected]") + last_print_was_vad = True + else: + text_token = gen.step(other_audio_tokens[0]) text_token = text_token[0].item() audio_tokens = gen.last_audio_tokens() _text = None @@ -77,5 +88,5 @@ if __name__ == "__main__": _text = text_tokenizer.id_to_piece(text_token) # type: ignore _text = _text.replace("▁", " ") print(_text, end="", flush=True) + last_print_was_vad = False print() -