From baf0c75bba89608e921cb26e03c959981df2ad5f Mon Sep 17 00:00:00 2001 From: Laurent Date: Tue, 8 Jul 2025 16:08:32 +0200 Subject: [PATCH] VAD support in the mlx-stt example that uses the microphone. --- scripts/stt_from_mic_mlx.py | 22 +++++++++++++++++++--- 1 file changed, 19 insertions(+), 3 deletions(-) diff --git a/scripts/stt_from_mic_mlx.py b/scripts/stt_from_mic_mlx.py index 8f82af6..003ea7c 100644 --- a/scripts/stt_from_mic_mlx.py +++ b/scripts/stt_from_mic_mlx.py @@ -2,7 +2,7 @@ # requires-python = ">=3.12" # dependencies = [ # "huggingface_hub", -# "moshi_mlx", +# "moshi_mlx==0.2.10", # "numpy", # "rustymimi", # "sentencepiece", @@ -26,6 +26,9 @@ if __name__ == "__main__": parser = argparse.ArgumentParser() parser.add_argument("--max-steps", default=4096) parser.add_argument("--hf-repo", default="kyutai/stt-1b-en_fr-mlx") + parser.add_argument( + "--vad", action="store_true", help="Enable VAD (Voice Activity Detection)." + ) args = parser.parse_args() lm_config = hf_hub_download(args.hf_repo, "config.json") @@ -45,7 +48,10 @@ if __name__ == "__main__": nn.quantize(model, bits=8, group_size=64) print(f"loading model weights from {moshi_weights}") - model.load_weights(moshi_weights, strict=True) + if args.hf_repo.endswith("-candle"): + model.load_pytorch_weights(moshi_weights, lm_config, strict=True) + else: + model.load_weights(moshi_weights, strict=True) print(f"loading the text tokenizer from {tokenizer}") text_tokenizer = sentencepiece.SentencePieceProcessor(tokenizer) # type: ignore @@ -71,6 +77,7 @@ if __name__ == "__main__": block_queue.put(indata.copy()) print("recording audio from microphone, speak to get your words transcribed") + last_print_was_vad = False with sd.InputStream( channels=1, dtype="float32", @@ -85,7 +92,15 @@ if __name__ == "__main__": other_audio_tokens = mx.array(other_audio_tokens).transpose(0, 2, 1)[ :, :, :other_codebooks ] - text_token = gen.step(other_audio_tokens[0]) + if args.vad: + text_token, vad_heads = gen.step_with_extra_heads(other_audio_tokens[0]) + if vad_heads: + pr_vad = vad_heads[2][0, 0, 0].item() + if pr_vad > 0.5 and not last_print_was_vad: + print(" [end of turn detected]") + last_print_was_vad = True + else: + text_token = gen.step(other_audio_tokens[0]) text_token = text_token[0].item() audio_tokens = gen.last_audio_tokens() _text = None @@ -93,3 +108,4 @@ if __name__ == "__main__": _text = text_tokenizer.id_to_piece(text_token) # type: ignore _text = _text.replace("▁", " ") print(_text, end="", flush=True) + last_print_was_vad = False